Merge commit from fork

Vinayak Mishra created

* fix: apply SSRF protection to LFS HTTP client

The LFS HTTP client uses http.DefaultClient which has no SSRF
protection. This allows server-side requests from LFS operations
to reach private/internal networks. The webhook subsystem already
has SSRF protection via secureHTTPClient with IP validation and
redirect blocking, but the LFS code path was missed.

Add a shared pkg/ssrf package with a secure HTTP client constructor
that validates resolved IPs before dialing (blocking private, link-
local, loopback, CGNAT, and reserved ranges) and blocks redirects.
Replace http.DefaultClient in newHTTPClient() with ssrf.NewSecureClient()
at both locations (batch API client and BasicTransferAdapter).

* refactor: consolidate webhook SSRF protection into pkg/ssrf

Pull shared IP validation into pkg/ssrf so both the LFS client and
webhook client use the same SSRF protection. The webhook validator
becomes a thin wrapper and the inline secureHTTPClient is replaced
with ssrf.NewSecureClient().

Two latent issues in the webhook path fixed in the process:
- nil ParseIP result was silently allowed through (now fail-closed)
- IPv6-mapped IPv4 bypassed manual range checks (now normalized)

Error aliases kept in pkg/webhook for backward compatibility.

Change summary

pkg/lfs/http_client.go        |   6 
pkg/ssrf/ssrf.go              | 178 +++++++++++++++++++++
pkg/ssrf/ssrf_test.go         | 208 +++++++++++++++++++++++++
pkg/webhook/ssrf_test.go      | 196 ++---------------------
pkg/webhook/validator.go      | 163 -------------------
pkg/webhook/validator_test.go | 303 ++----------------------------------
pkg/webhook/webhook.go        |  41 ----
7 files changed, 438 insertions(+), 657 deletions(-)

Detailed changes

pkg/lfs/http_client.go 🔗

@@ -9,6 +9,7 @@ import (
 	"net/http"
 
 	"charm.land/log/v2"
+	"github.com/charmbracelet/soft-serve/pkg/ssrf"
 )
 
 // httpClient is a Git LFS client to communicate with a LFS source API.
@@ -22,11 +23,12 @@ var _ Client = (*httpClient)(nil)
 
 // newHTTPClient returns a new Git LFS client.
 func newHTTPClient(endpoint Endpoint) *httpClient {
+	client := ssrf.NewSecureClient()
 	return &httpClient{
-		client:   http.DefaultClient,
+		client:   client,
 		endpoint: endpoint,
 		transfers: map[string]TransferAdapter{
-			TransferBasic: &BasicTransferAdapter{http.DefaultClient},
+			TransferBasic: &BasicTransferAdapter{client},
 		},
 	}
 }

pkg/ssrf/ssrf.go 🔗

@@ -0,0 +1,178 @@
+package ssrf
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"net"
+	"net/http"
+	"net/url"
+	"slices"
+	"strings"
+	"time"
+)
+
+var (
+	// ErrPrivateIP is returned when a connection to a private or internal IP is blocked.
+	ErrPrivateIP = errors.New("connection to private or internal IP address is not allowed")
+	// ErrInvalidScheme is returned when a URL scheme is not http or https.
+	ErrInvalidScheme = errors.New("URL must use http or https scheme")
+	// ErrInvalidURL is returned when a URL is invalid.
+	ErrInvalidURL = errors.New("invalid URL")
+)
+
+// NewSecureClient returns an HTTP client with SSRF protection.
+// It validates resolved IPs at dial time to block connections to private
+// and internal networks. Since validation uses the already-resolved IP
+// from the Transport's DNS lookup, there is no TOCTOU gap between
+// resolution and connection. Redirects are disabled to match the
+// webhook client convention and prevent redirect-based SSRF.
+func NewSecureClient() *http.Client {
+	return &http.Client{
+		Timeout: 30 * time.Second,
+		Transport: &http.Transport{
+			DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+				host, _, err := net.SplitHostPort(addr)
+				if err != nil {
+					return nil, err //nolint:wrapcheck
+				}
+
+				ip := net.ParseIP(host)
+				if ip == nil {
+					return nil, fmt.Errorf("unexpected non-IP address in dial: %s", host)
+				}
+				if isPrivateOrInternal(ip) {
+					return nil, fmt.Errorf("%w", ErrPrivateIP)
+				}
+
+				dialer := &net.Dialer{
+					Timeout:   10 * time.Second,
+					KeepAlive: 30 * time.Second,
+				}
+				return dialer.DialContext(ctx, network, addr)
+			},
+			MaxIdleConns:          100,
+			IdleConnTimeout:       90 * time.Second,
+			TLSHandshakeTimeout:   10 * time.Second,
+			ExpectContinueTimeout: 1 * time.Second,
+		},
+		CheckRedirect: func(*http.Request, []*http.Request) error {
+			return http.ErrUseLastResponse
+		},
+	}
+}
+
+// isPrivateOrInternal checks if an IP address is private, internal, or reserved.
+func isPrivateOrInternal(ip net.IP) bool {
+	// Normalize IPv6-mapped IPv4 (e.g. ::ffff:127.0.0.1) to IPv4 form
+	// so all checks apply consistently.
+	if ip4 := ip.To4(); ip4 != nil {
+		ip = ip4
+	}
+
+	if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() ||
+		ip.IsPrivate() || ip.IsUnspecified() || ip.IsMulticast() {
+		return true
+	}
+
+	if ip4 := ip.To4(); ip4 != nil {
+		// 0.0.0.0/8
+		if ip4[0] == 0 {
+			return true
+		}
+		// 100.64.0.0/10 (Shared Address Space / CGNAT)
+		if ip4[0] == 100 && ip4[1] >= 64 && ip4[1] <= 127 {
+			return true
+		}
+		// 192.0.0.0/24 (IETF Protocol Assignments)
+		if ip4[0] == 192 && ip4[1] == 0 && ip4[2] == 0 {
+			return true
+		}
+		// 192.0.2.0/24 (TEST-NET-1)
+		if ip4[0] == 192 && ip4[1] == 0 && ip4[2] == 2 {
+			return true
+		}
+		// 198.18.0.0/15 (benchmarking)
+		if ip4[0] == 198 && (ip4[1] == 18 || ip4[1] == 19) {
+			return true
+		}
+		// 198.51.100.0/24 (TEST-NET-2)
+		if ip4[0] == 198 && ip4[1] == 51 && ip4[2] == 100 {
+			return true
+		}
+		// 203.0.113.0/24 (TEST-NET-3)
+		if ip4[0] == 203 && ip4[1] == 0 && ip4[2] == 113 {
+			return true
+		}
+		// 240.0.0.0/4 (Reserved, includes 255.255.255.255 broadcast)
+		if ip4[0] >= 240 {
+			return true
+		}
+	}
+
+	return false
+}
+
+// ValidateURL validates that a URL is safe to make requests to.
+// It checks that the scheme is http/https, the hostname is not localhost,
+// and all resolved IPs are public.
+func ValidateURL(rawURL string) error {
+	if rawURL == "" {
+		return ErrInvalidURL
+	}
+
+	u, err := url.Parse(rawURL)
+	if err != nil {
+		return fmt.Errorf("%w: %v", ErrInvalidURL, err)
+	}
+
+	if u.Scheme != "http" && u.Scheme != "https" {
+		return ErrInvalidScheme
+	}
+
+	hostname := u.Hostname()
+	if hostname == "" {
+		return fmt.Errorf("%w: missing hostname", ErrInvalidURL)
+	}
+
+	if isLocalhost(hostname) {
+		return ErrPrivateIP
+	}
+
+	if ip := net.ParseIP(hostname); ip != nil {
+		if isPrivateOrInternal(ip) {
+			return ErrPrivateIP
+		}
+		return nil
+	}
+
+	ips, err := net.DefaultResolver.LookupIPAddr(context.Background(), hostname)
+	if err != nil {
+		return fmt.Errorf("%w: cannot resolve hostname: %v", ErrInvalidURL, err)
+	}
+
+	if slices.ContainsFunc(ips, func(addr net.IPAddr) bool {
+		return isPrivateOrInternal(addr.IP)
+	}) {
+		return ErrPrivateIP
+	}
+
+	return nil
+}
+
+// ValidateIPBeforeDial validates an IP address before establishing a connection.
+// This prevents DNS rebinding attacks by checking the resolved IP at dial time.
+func ValidateIPBeforeDial(ip net.IP) error {
+	if isPrivateOrInternal(ip) {
+		return ErrPrivateIP
+	}
+	return nil
+}
+
+// isLocalhost checks if the hostname is localhost or similar.
+func isLocalhost(hostname string) bool {
+	hostname = strings.ToLower(hostname)
+	return hostname == "localhost" ||
+		hostname == "localhost.localdomain" ||
+		strings.HasSuffix(hostname, ".localhost")
+}

pkg/ssrf/ssrf_test.go 🔗

@@ -0,0 +1,208 @@
+package ssrf
+
+import (
+	"context"
+	"errors"
+	"net"
+	"net/http"
+	"net/http/httptest"
+	"testing"
+	"time"
+)
+
+func TestNewSecureClientBlocksPrivateIPs(t *testing.T) {
+	client := NewSecureClient()
+	transport := client.Transport.(*http.Transport)
+
+	tests := []struct {
+		name    string
+		addr    string
+		wantErr bool
+	}{
+		{"block loopback", "127.0.0.1:80", true},
+		{"block private 10.x", "10.0.0.1:80", true},
+		{"block link-local", "169.254.169.254:80", true},
+		{"block CGNAT", "100.64.0.1:80", true},
+		{"allow public IP", "8.8.8.8:80", false},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
+			defer cancel()
+
+			conn, err := transport.DialContext(ctx, "tcp", tt.addr)
+			if conn != nil {
+				conn.Close()
+			}
+
+			if tt.wantErr {
+				if err == nil {
+					t.Errorf("expected error for %s, got none", tt.addr)
+				}
+			} else {
+				if err != nil && errors.Is(err, ErrPrivateIP) {
+					t.Errorf("should not block %s with SSRF error, got: %v", tt.addr, err)
+				}
+			}
+		})
+	}
+}
+
+func TestNewSecureClientNilIPNotErrPrivateIP(t *testing.T) {
+	client := NewSecureClient()
+	transport := client.Transport.(*http.Transport)
+
+	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
+	defer cancel()
+
+	conn, err := transport.DialContext(ctx, "tcp", "not-an-ip:80")
+	if conn != nil {
+		conn.Close()
+	}
+	if err == nil {
+		t.Fatal("expected error for non-IP address, got none")
+	}
+	if errors.Is(err, ErrPrivateIP) {
+		t.Errorf("nil-IP path should not wrap ErrPrivateIP, got: %v", err)
+	}
+}
+
+func TestNewSecureClientBlocksRedirects(t *testing.T) {
+	redirectServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		http.Redirect(w, r, "http://8.8.8.8:8080/safe", http.StatusFound)
+	}))
+	defer redirectServer.Close()
+
+	client := NewSecureClient()
+	req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, redirectServer.URL, nil)
+	if err != nil {
+		t.Fatalf("Failed to create request: %v", err)
+	}
+
+	resp, err := client.Do(req)
+	if err != nil {
+		// httptest uses 127.0.0.1, blocked by SSRF protection
+		if !errors.Is(err, ErrPrivateIP) {
+			t.Fatalf("Request failed with non-SSRF error: %v", err)
+		}
+		return
+	}
+	defer resp.Body.Close()
+
+	if resp.StatusCode != http.StatusFound {
+		t.Errorf("Expected redirect response (302), got %d", resp.StatusCode)
+	}
+}
+
+func TestIsPrivateOrInternal(t *testing.T) {
+	tests := []struct {
+		ip   string
+		want bool
+	}{
+		// Public
+		{"8.8.8.8", false},
+		{"2001:4860:4860::8888", false},
+
+		// Loopback
+		{"127.0.0.1", true},
+		{"::1", true},
+
+		// Private ranges
+		{"10.0.0.1", true},
+		{"192.168.1.1", true},
+		{"172.16.0.1", true},
+
+		// Link-local (cloud metadata)
+		{"169.254.169.254", true},
+
+		// CGNAT boundaries
+		{"100.64.0.1", true},
+		{"100.127.255.255", true},
+
+		// IPv6-mapped IPv4 (bypass vector the old webhook code missed)
+		{"::ffff:127.0.0.1", true},
+		{"::ffff:169.254.169.254", true},
+		{"::ffff:8.8.8.8", false},
+
+		// Reserved
+		{"0.0.0.0", true},
+		{"240.0.0.1", true},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.ip, func(t *testing.T) {
+			ip := net.ParseIP(tt.ip)
+			if ip == nil {
+				t.Fatalf("failed to parse IP: %s", tt.ip)
+			}
+			if got := isPrivateOrInternal(ip); got != tt.want {
+				t.Errorf("isPrivateOrInternal(%s) = %v, want %v", tt.ip, got, tt.want)
+			}
+		})
+	}
+}
+
+func TestValidateURL(t *testing.T) {
+	tests := []struct {
+		name    string
+		url     string
+		wantErr bool
+		errType error
+	}{
+		// Valid
+		{"valid https", "https://1.1.1.1/webhook", false, nil},
+
+		// Scheme validation
+		{"ftp scheme", "ftp://example.com/webhook", true, ErrInvalidScheme},
+		{"no scheme", "example.com/webhook", true, ErrInvalidScheme},
+
+		// Localhost
+		{"localhost", "http://localhost/webhook", true, ErrPrivateIP},
+		{"subdomain.localhost", "http://test.localhost/webhook", true, ErrPrivateIP},
+
+		// IP-based blocking (one per category -- range coverage is in TestIsPrivateOrInternal)
+		{"loopback IP", "http://127.0.0.1/webhook", true, ErrPrivateIP},
+		{"metadata IP", "http://169.254.169.254/latest/meta-data/", true, ErrPrivateIP},
+
+		// Invalid URLs
+		{"empty", "", true, ErrInvalidURL},
+		{"missing hostname", "http:///webhook", true, ErrInvalidURL},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			err := ValidateURL(tt.url)
+			if (err != nil) != tt.wantErr {
+				t.Errorf("ValidateURL(%q) error = %v, wantErr %v", tt.url, err, tt.wantErr)
+				return
+			}
+			if tt.wantErr && tt.errType != nil {
+				if !errors.Is(err, tt.errType) {
+					t.Errorf("ValidateURL(%q) error = %v, want error type %v", tt.url, err, tt.errType)
+				}
+			}
+		})
+	}
+}
+
+func TestIsLocalhost(t *testing.T) {
+	tests := []struct {
+		hostname string
+		want     bool
+	}{
+		{"localhost", true},
+		{"LOCALHOST", true},
+		{"test.localhost", true},
+		{"example.com", false},
+		{"localhost.com", false},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.hostname, func(t *testing.T) {
+			if got := isLocalhost(tt.hostname); got != tt.want {
+				t.Errorf("isLocalhost(%s) = %v, want %v", tt.hostname, got, tt.want)
+			}
+		})
+	}
+}

pkg/webhook/ssrf_test.go 🔗

@@ -2,217 +2,57 @@ package webhook
 
 import (
 	"context"
+	"errors"
 	"net/http"
-	"net/http/httptest"
 	"testing"
 	"time"
 
 	"github.com/charmbracelet/soft-serve/pkg/db/models"
+	"github.com/charmbracelet/soft-serve/pkg/ssrf"
 )
 
-// TestSSRFProtection tests that the webhook system blocks SSRF attempts.
+// TestSSRFProtection is an integration test verifying the webhook send path
+// blocks private IPs end-to-end (models.Webhook -> secureHTTPClient -> ssrf).
 func TestSSRFProtection(t *testing.T) {
 	tests := []struct {
 		name        string
 		webhookURL  string
 		shouldBlock bool
-		description string
 	}{
-		{
-			name:        "block localhost",
-			webhookURL:  "http://localhost:8080/webhook",
-			shouldBlock: true,
-			description: "should block localhost addresses",
-		},
-		{
-			name:        "block 127.0.0.1",
-			webhookURL:  "http://127.0.0.1:8080/webhook",
-			shouldBlock: true,
-			description: "should block loopback addresses",
-		},
-		{
-			name:        "block 169.254.169.254",
-			webhookURL:  "http://169.254.169.254/latest/meta-data/",
-			shouldBlock: true,
-			description: "should block cloud metadata service",
-		},
-		{
-			name:        "block private network",
-			webhookURL:  "http://192.168.1.1/webhook",
-			shouldBlock: true,
-			description: "should block private networks",
-		},
-		{
-			name:        "allow public IP",
-			webhookURL:  "http://8.8.8.8/webhook",
-			shouldBlock: false,
-			description: "should allow public IP addresses",
-		},
+		{"block loopback", "http://127.0.0.1:8080/webhook", true},
+		{"block metadata", "http://169.254.169.254/latest/meta-data/", true},
+		{"allow public IP", "http://8.8.8.8/webhook", false},
 	}
 
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
-			// Create a test webhook
-			webhook := models.Webhook{
+			w := models.Webhook{
 				URL:         tt.webhookURL,
 				ContentType: int(ContentTypeJSON),
-				Secret:      "",
 			}
 
-			// Try to send a webhook
 			ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
 			defer cancel()
 
-			// Create a simple payload
-			payload := map[string]string{"test": "data"}
-
-			err := sendWebhookWithContext(ctx, webhook, EventPush, payload)
-
-			if tt.shouldBlock {
-				if err == nil {
-					t.Errorf("%s: expected error but got none", tt.description)
-				}
-			} else {
-				// For public IPs, we expect a connection error (since 8.8.8.8 won't be listening)
-				// but NOT an SSRF blocking error
-				if err != nil && isSSRFError(err) {
-					t.Errorf("%s: should not block public IPs, got: %v", tt.description, err)
-				}
+			req, err := http.NewRequestWithContext(ctx, "POST", w.URL, nil)
+			if err != nil {
+				t.Fatalf("failed to create request: %v", err)
 			}
-		})
-	}
-}
-
-// TestSecureHTTPClientBlocksRedirects tests that redirects are not followed.
-func TestSecureHTTPClientBlocksRedirects(t *testing.T) {
-	// Create a test server on a public-looking address that redirects
-	redirectServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		http.Redirect(w, r, "http://8.8.8.8:8080/safe", http.StatusFound)
-	}))
-	defer redirectServer.Close()
-
-	// Try to make a request that would redirect
-	req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, redirectServer.URL, nil)
-	if err != nil {
-		t.Fatalf("Failed to create request: %v", err)
-	}
 
-	resp, err := secureHTTPClient.Do(req)
-	if err != nil {
-		// httptest.NewServer uses 127.0.0.1, which will be blocked by our SSRF protection
-		// This is actually correct behavior - we're blocking the initial connection
-		if !isSSRFError(err) {
-			t.Fatalf("Request failed with non-SSRF error: %v", err)
-		}
-		// Test passed - we blocked the loopback connection
-		return
-	}
-	defer resp.Body.Close()
-
-	// If we got here, check that we got the redirect response (not followed)
-	if resp.StatusCode != http.StatusFound {
-		t.Errorf("Expected redirect response (302), got %d", resp.StatusCode)
-	}
-}
-
-// TestDialContextBlocksPrivateIPs tests the DialContext function directly.
-func TestDialContextBlocksPrivateIPs(t *testing.T) {
-	transport := secureHTTPClient.Transport.(*http.Transport)
-
-	tests := []struct {
-		name    string
-		addr    string
-		wantErr bool
-	}{
-		{"block loopback", "127.0.0.1:80", true},
-		{"block private 10.x", "10.0.0.1:80", true},
-		{"block private 192.168.x", "192.168.1.1:80", true},
-		{"block link-local", "169.254.169.254:80", true},
-		{"allow public IP", "8.8.8.8:80", false},
-	}
-
-	for _, tt := range tests {
-		t.Run(tt.name, func(t *testing.T) {
-			ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
-			defer cancel()
-
-			conn, err := transport.DialContext(ctx, "tcp", tt.addr)
-			if conn != nil {
-				conn.Close()
+			resp, err := secureHTTPClient.Do(req)
+			if resp != nil {
+				resp.Body.Close()
 			}
 
-			if tt.wantErr {
+			if tt.shouldBlock {
 				if err == nil {
-					t.Errorf("Expected error for %s, got none", tt.addr)
+					t.Errorf("%s: expected error but got none", tt.name)
 				}
 			} else {
-				// For public IPs, we expect a connection timeout/refused (not an SSRF block)
-				if err != nil && isSSRFError(err) {
-					t.Errorf("Should not block %s with SSRF error, got: %v", tt.addr, err)
+				if err != nil && errors.Is(err, ssrf.ErrPrivateIP) {
+					t.Errorf("%s: should not block public IPs, got: %v", tt.name, err)
 				}
 			}
 		})
 	}
 }
-
-// sendWebhookWithContext is a test helper that doesn't require database.
-func sendWebhookWithContext(ctx context.Context, w models.Webhook, _ Event, _ any) error {
-	// This is a simplified version for testing that just attempts the HTTP connection
-	req, err := http.NewRequestWithContext(ctx, "POST", w.URL, nil)
-	if err != nil {
-		return err //nolint:wrapcheck
-	}
-	req = req.WithContext(ctx)
-
-	resp, err := secureHTTPClient.Do(req)
-	if resp != nil {
-		resp.Body.Close()
-	}
-	return err //nolint:wrapcheck
-}
-
-// isSSRFError checks if an error is related to SSRF blocking.
-func isSSRFError(err error) bool {
-	if err == nil {
-		return false
-	}
-	errMsg := err.Error()
-	return contains(errMsg, "private IP") ||
-		contains(errMsg, "blocked connection") ||
-		err == ErrPrivateIP
-}
-
-func contains(s, substr string) bool {
-	return len(s) >= len(substr) && (s == substr || len(substr) == 0 || indexOfSubstring(s, substr) >= 0)
-}
-
-func indexOfSubstring(s, substr string) int {
-	for i := 0; i <= len(s)-len(substr); i++ {
-		if s[i:i+len(substr)] == substr {
-			return i
-		}
-	}
-	return -1
-}
-
-// TestPrivateIPResolution tests that hostnames resolving to private IPs are blocked.
-func TestPrivateIPResolution(t *testing.T) {
-	// This test verifies that even if a hostname looks public, if it resolves to a private IP, it's blocked
-	webhook := models.Webhook{
-		URL:         "http://127.0.0.1:9999/webhook",
-		ContentType: int(ContentTypeJSON),
-	}
-
-	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
-	defer cancel()
-
-	err := sendWebhookWithContext(ctx, webhook, EventPush, map[string]string{"test": "data"})
-	if err == nil {
-		t.Error("Expected error when connecting to loopback address")
-		return
-	}
-
-	if !isSSRFError(err) {
-		t.Errorf("Expected SSRF blocking error, got: %v", err)
-	}
-}

pkg/webhook/validator.go 🔗

@@ -1,169 +1,20 @@
 package webhook
 
 import (
-	"context"
-	"errors"
-	"fmt"
-	"net"
-	"net/url"
-	"slices"
-	"strings"
+	"github.com/charmbracelet/soft-serve/pkg/ssrf"
 )
 
+// Error aliases for backward compatibility.
 var (
-	// ErrInvalidScheme is returned when the webhook URL scheme is not http or https.
-	ErrInvalidScheme = errors.New("webhook URL must use http or https scheme")
-	// ErrPrivateIP is returned when the webhook URL resolves to a private IP address.
-	ErrPrivateIP = errors.New("webhook URL cannot resolve to private or internal IP addresses")
-	// ErrInvalidURL is returned when the webhook URL is invalid.
-	ErrInvalidURL = errors.New("invalid webhook URL")
+	ErrInvalidScheme = ssrf.ErrInvalidScheme
+	ErrPrivateIP     = ssrf.ErrPrivateIP
+	ErrInvalidURL    = ssrf.ErrInvalidURL
 )
 
 // ValidateWebhookURL validates that a webhook URL is safe to use.
-// It checks:
-// - URL is properly formatted
-// - Scheme is http or https
-// - Hostname does not resolve to private/internal IP addresses
-// - Hostname is not localhost or similar.
 func ValidateWebhookURL(rawURL string) error {
-	if rawURL == "" {
-		return ErrInvalidURL
-	}
-
-	// Parse the URL
-	u, err := url.Parse(rawURL)
-	if err != nil {
-		return fmt.Errorf("%w: %v", ErrInvalidURL, err)
-	}
-
-	// Check scheme
-	if u.Scheme != "http" && u.Scheme != "https" {
-		return ErrInvalidScheme
-	}
-
-	// Extract hostname (without port)
-	hostname := u.Hostname()
-	if hostname == "" {
-		return fmt.Errorf("%w: missing hostname", ErrInvalidURL)
-	}
-
-	// Check for localhost variations
-	if isLocalhost(hostname) {
-		return ErrPrivateIP
-	}
-
-	// If it's an IP address, validate it directly
-	if ip := net.ParseIP(hostname); ip != nil {
-		if isPrivateOrInternalIP(ip) {
-			return ErrPrivateIP
-		}
-		return nil
-	}
-
-	// Resolve hostname to IP addresses
-	ips, err := net.DefaultResolver.LookupIPAddr(context.Background(), hostname)
-	if err != nil {
-		return fmt.Errorf("%w: cannot resolve hostname: %v", ErrInvalidURL, err)
-	}
-
-	// Check all resolved IPs
-	if slices.ContainsFunc(ips, isPrivateOrInternalIPAddr) {
-		return ErrPrivateIP
-	}
-
-	return nil
-}
-
-// isLocalhost checks if the hostname is localhost or similar.
-func isLocalhost(hostname string) bool {
-	hostname = strings.ToLower(hostname)
-	return hostname == "localhost" ||
-		hostname == "localhost.localdomain" ||
-		strings.HasSuffix(hostname, ".localhost")
-}
-
-// isPrivateOrInternalIPAddr is a helper function that users net.IPAddr instead of net.IP.
-func isPrivateOrInternalIPAddr(ipAddr net.IPAddr) bool {
-	return isPrivateOrInternalIP(ipAddr.IP)
-}
-
-// isPrivateOrInternalIP checks if an IP address is private, internal, or reserved.
-func isPrivateOrInternalIP(ip net.IP) bool {
-	// Loopback addresses (127.0.0.0/8, ::1)
-	if ip.IsLoopback() {
-		return true
-	}
-
-	// Link-local addresses (169.254.0.0/16, fe80::/10)
-	// This blocks AWS/GCP/Azure metadata services
-	if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
-		return true
-	}
-
-	// Private addresses (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, fc00::/7)
-	if ip.IsPrivate() {
-		return true
-	}
-
-	// Unspecified addresses (0.0.0.0, ::)
-	if ip.IsUnspecified() {
-		return true
-	}
-
-	// Multicast addresses
-	if ip.IsMulticast() {
-		return true
-	}
-
-	// Additional checks for IPv4
-	if ip4 := ip.To4(); ip4 != nil {
-		// 0.0.0.0/8 (current network)
-		if ip4[0] == 0 {
-			return true
-		}
-		// 100.64.0.0/10 (Shared Address Space)
-		if ip4[0] == 100 && ip4[1] >= 64 && ip4[1] <= 127 {
-			return true
-		}
-		// 192.0.0.0/24 (IETF Protocol Assignments)
-		if ip4[0] == 192 && ip4[1] == 0 && ip4[2] == 0 {
-			return true
-		}
-		// 192.0.2.0/24 (TEST-NET-1)
-		if ip4[0] == 192 && ip4[1] == 0 && ip4[2] == 2 {
-			return true
-		}
-		// 198.18.0.0/15 (benchmarking)
-		if ip4[0] == 198 && (ip4[1] == 18 || ip4[1] == 19) {
-			return true
-		}
-		// 198.51.100.0/24 (TEST-NET-2)
-		if ip4[0] == 198 && ip4[1] == 51 && ip4[2] == 100 {
-			return true
-		}
-		// 203.0.113.0/24 (TEST-NET-3)
-		if ip4[0] == 203 && ip4[1] == 0 && ip4[2] == 113 {
-			return true
-		}
-		// 224.0.0.0/4 (Multicast - already handled by IsMulticast)
-		// 240.0.0.0/4 (Reserved for future use)
-		if ip4[0] >= 240 {
-			return true
-		}
-		// 255.255.255.255/32 (Broadcast)
-		if ip4[0] == 255 && ip4[1] == 255 && ip4[2] == 255 && ip4[3] == 255 {
-			return true
-		}
-	}
-
-	return false
+	return ssrf.ValidateURL(rawURL) //nolint:wrapcheck
 }
 
 // ValidateIPBeforeDial validates an IP address before establishing a connection.
-// This is used to prevent DNS rebinding attacks.
-func ValidateIPBeforeDial(ip net.IP) error {
-	if isPrivateOrInternalIP(ip) {
-		return ErrPrivateIP
-	}
-	return nil
-}
+var ValidateIPBeforeDial = ssrf.ValidateIPBeforeDial

pkg/webhook/validator_test.go 🔗

@@ -1,315 +1,52 @@
 package webhook
 
 import (
-	"net"
+	"errors"
 	"testing"
+
+	"github.com/charmbracelet/soft-serve/pkg/ssrf"
 )
 
+// TestValidateWebhookURL verifies the wrapper delegates correctly and
+// error aliases work across the package boundary. IP range coverage
+// is in pkg/ssrf/ssrf_test.go -- here we just confirm the plumbing.
 func TestValidateWebhookURL(t *testing.T) {
 	tests := []struct {
 		name    string
 		url     string
 		wantErr bool
 		errType error
-		skip    string
 	}{
-		// Valid URLs (these will perform DNS lookups, so may fail in some environments)
-		{
-			name:    "valid https URL",
-			url:     "https://1.1.1.1/webhook",
-			wantErr: false,
-		},
-		{
-			name:    "valid http URL",
-			url:     "http://8.8.8.8/webhook",
-			wantErr: false,
-		},
-		{
-			name:    "valid URL with port",
-			url:     "https://1.1.1.1:8080/webhook",
-			wantErr: false,
-		},
-		{
-			name:    "valid URL with path and query",
-			url:     "https://8.8.8.8/webhook?token=abc123",
-			wantErr: false,
-		},
-
-		// Invalid schemes
-		{
-			name:    "ftp scheme",
-			url:     "ftp://example.com/webhook",
-			wantErr: true,
-			errType: ErrInvalidScheme,
-		},
-		{
-			name:    "file scheme",
-			url:     "file:///etc/passwd",
-			wantErr: true,
-			errType: ErrInvalidScheme,
-		},
-		{
-			name:    "gopher scheme",
-			url:     "gopher://example.com",
-			wantErr: true,
-			errType: ErrInvalidScheme,
-		},
-		{
-			name:    "no scheme",
-			url:     "example.com/webhook",
-			wantErr: true,
-			errType: ErrInvalidScheme,
-		},
-
-		// Localhost variations
-		{
-			name:    "localhost",
-			url:     "http://localhost/webhook",
-			wantErr: true,
-			errType: ErrPrivateIP,
-		},
-		{
-			name:    "localhost with port",
-			url:     "http://localhost:8080/webhook",
-			wantErr: true,
-			errType: ErrPrivateIP,
-		},
-		{
-			name:    "localhost.localdomain",
-			url:     "http://localhost.localdomain/webhook",
-			wantErr: true,
-			errType: ErrPrivateIP,
-		},
-
-		// Loopback IPs
-		{
-			name:    "127.0.0.1",
-			url:     "http://127.0.0.1/webhook",
-			wantErr: true,
-			errType: ErrPrivateIP,
-		},
-		{
-			name:    "127.0.0.1 with port",
-			url:     "http://127.0.0.1:8080/webhook",
-			wantErr: true,
-			errType: ErrPrivateIP,
-		},
-		{
-			name:    "127.1.2.3",
-			url:     "http://127.1.2.3/webhook",
-			wantErr: true,
-			errType: ErrPrivateIP,
-		},
-		{
-			name:    "IPv6 loopback",
-			url:     "http://[::1]/webhook",
-			wantErr: true,
-			errType: ErrPrivateIP,
-		},
-
-		// Private IPv4 ranges
-		{
-			name:    "10.0.0.0",
-			url:     "http://10.0.0.1/webhook",
-			wantErr: true,
-			errType: ErrPrivateIP,
-		},
-		{
-			name:    "192.168.0.0",
-			url:     "http://192.168.1.1/webhook",
-			wantErr: true,
-			errType: ErrPrivateIP,
-		},
-		{
-			name:    "172.16.0.0",
-			url:     "http://172.16.0.1/webhook",
-			wantErr: true,
-			errType: ErrPrivateIP,
-		},
-		{
-			name:    "172.31.255.255",
-			url:     "http://172.31.255.255/webhook",
-			wantErr: true,
-			errType: ErrPrivateIP,
-		},
-
-		// Link-local (AWS/GCP/Azure metadata)
-		{
-			name:    "AWS metadata service",
-			url:     "http://169.254.169.254/latest/meta-data/",
-			wantErr: true,
-			errType: ErrPrivateIP,
-		},
-		{
-			name:    "link-local",
-			url:     "http://169.254.1.1/webhook",
-			wantErr: true,
-			errType: ErrPrivateIP,
-		},
-
-		// Other reserved ranges
-		{
-			name:    "0.0.0.0",
-			url:     "http://0.0.0.0/webhook",
-			wantErr: true,
-			errType: ErrPrivateIP,
-		},
-		{
-			name:    "broadcast",
-			url:     "http://255.255.255.255/webhook",
-			wantErr: true,
-			errType: ErrPrivateIP,
-		},
-
-		// Invalid URLs
-		{
-			name:    "empty URL",
-			url:     "",
-			wantErr: true,
-			errType: ErrInvalidURL,
-		},
-		{
-			name:    "missing hostname",
-			url:     "http:///webhook",
-			wantErr: true,
-			errType: ErrInvalidURL,
-		},
+		{"valid", "https://1.1.1.1/webhook", false, nil},
+		{"bad scheme", "ftp://example.com", true, ErrInvalidScheme},
+		{"private IP", "http://127.0.0.1/webhook", true, ErrPrivateIP},
+		{"empty", "", true, ErrInvalidURL},
 	}
 
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
-			if tt.skip != "" {
-				t.Skip(tt.skip)
-			}
 			err := ValidateWebhookURL(tt.url)
 			if (err != nil) != tt.wantErr {
-				t.Errorf("ValidateWebhookURL() error = %v, wantErr %v", err, tt.wantErr)
+				t.Errorf("ValidateWebhookURL(%q) error = %v, wantErr %v", tt.url, err, tt.wantErr)
 				return
 			}
 			if tt.wantErr && tt.errType != nil {
-				if !isErrorType(err, tt.errType) {
-					t.Errorf("ValidateWebhookURL() error = %v, want error type %v", err, tt.errType)
+				if !errors.Is(err, tt.errType) {
+					t.Errorf("ValidateWebhookURL(%q) error = %v, want %v", tt.url, err, tt.errType)
 				}
 			}
 		})
 	}
 }
 
-func TestIsPrivateOrInternalIP(t *testing.T) {
-	tests := []struct {
-		name   string
-		ip     string
-		isPriv bool
-	}{
-		// Public IPs
-		{"Google DNS", "8.8.8.8", false},
-		{"Cloudflare DNS", "1.1.1.1", false},
-		{"Public IPv6", "2001:4860:4860::8888", false},
-
-		// Loopback
-		{"127.0.0.1", "127.0.0.1", true},
-		{"127.1.2.3", "127.1.2.3", true},
-		{"::1", "::1", true},
-
-		// Private ranges
-		{"10.0.0.1", "10.0.0.1", true},
-		{"192.168.1.1", "192.168.1.1", true},
-		{"172.16.0.1", "172.16.0.1", true},
-		{"172.31.255.255", "172.31.255.255", true},
-
-		// Link-local
-		{"169.254.169.254", "169.254.169.254", true},
-		{"169.254.1.1", "169.254.1.1", true},
-		{"fe80::1", "fe80::1", true},
-
-		// Other reserved
-		{"0.0.0.0", "0.0.0.0", true},
-		{"255.255.255.255", "255.255.255.255", true},
-		{"240.0.0.1", "240.0.0.1", true},
-
-		// Shared address space
-		{"100.64.0.1", "100.64.0.1", true},
-		{"100.127.255.255", "100.127.255.255", true},
+func TestErrorAliases(t *testing.T) {
+	if ErrPrivateIP != ssrf.ErrPrivateIP {
+		t.Error("ErrPrivateIP should alias ssrf.ErrPrivateIP")
 	}
-
-	for _, tt := range tests {
-		t.Run(tt.name, func(t *testing.T) {
-			ip := net.ParseIP(tt.ip)
-			if ip == nil {
-				t.Fatalf("Failed to parse IP: %s", tt.ip)
-			}
-			if got := isPrivateOrInternalIP(ip); got != tt.isPriv {
-				t.Errorf("isPrivateOrInternalIP(%s) = %v, want %v", tt.ip, got, tt.isPriv)
-			}
-		})
-	}
-}
-
-func TestIsLocalhost(t *testing.T) {
-	tests := []struct {
-		name     string
-		hostname string
-		want     bool
-	}{
-		{"localhost", "localhost", true},
-		{"LOCALHOST", "LOCALHOST", true},
-		{"localhost.localdomain", "localhost.localdomain", true},
-		{"test.localhost", "test.localhost", true},
-		{"example.com", "example.com", false},
-		{"localhos", "localhos", false},
-		{"localhost.com", "localhost.com", false},
-	}
-
-	for _, tt := range tests {
-		t.Run(tt.name, func(t *testing.T) {
-			if got := isLocalhost(tt.hostname); got != tt.want {
-				t.Errorf("isLocalhost(%s) = %v, want %v", tt.hostname, got, tt.want)
-			}
-		})
-	}
-}
-
-func TestValidateIPBeforeDial(t *testing.T) {
-	tests := []struct {
-		name    string
-		ip      string
-		wantErr bool
-	}{
-		{"public IP", "8.8.8.8", false},
-		{"private IP", "192.168.1.1", true},
-		{"loopback", "127.0.0.1", true},
-		{"link-local", "169.254.169.254", true},
-	}
-
-	for _, tt := range tests {
-		t.Run(tt.name, func(t *testing.T) {
-			ip := net.ParseIP(tt.ip)
-			if ip == nil {
-				t.Fatalf("Failed to parse IP: %s", tt.ip)
-			}
-			err := ValidateIPBeforeDial(ip)
-			if (err != nil) != tt.wantErr {
-				t.Errorf("ValidateIPBeforeDial(%s) error = %v, wantErr %v", tt.ip, err, tt.wantErr)
-			}
-		})
-	}
-}
-
-// isErrorType checks if err is or wraps errType.
-func isErrorType(err, errType error) bool {
-	if err == errType {
-		return true
+	if ErrInvalidScheme != ssrf.ErrInvalidScheme {
+		t.Error("ErrInvalidScheme should alias ssrf.ErrInvalidScheme")
 	}
-	// Check if err wraps errType
-	for err != nil {
-		if err == errType {
-			return true
-		}
-		unwrapped, ok := err.(interface{ Unwrap() error })
-		if !ok {
-			break
-		}
-		err = unwrapped.Unwrap()
+	if ErrInvalidURL != ssrf.ErrInvalidURL {
+		t.Error("ErrInvalidURL should alias ssrf.ErrInvalidURL")
 	}
-	return false
 }

pkg/webhook/webhook.go 🔗

@@ -10,14 +10,13 @@ import (
 	"errors"
 	"fmt"
 	"io"
-	"net"
 	"net/http"
-	"time"
 
 	"github.com/charmbracelet/soft-serve/git"
 	"github.com/charmbracelet/soft-serve/pkg/db"
 	"github.com/charmbracelet/soft-serve/pkg/db/models"
 	"github.com/charmbracelet/soft-serve/pkg/proto"
+	"github.com/charmbracelet/soft-serve/pkg/ssrf"
 	"github.com/charmbracelet/soft-serve/pkg/store"
 	"github.com/charmbracelet/soft-serve/pkg/utils"
 	"github.com/charmbracelet/soft-serve/pkg/version"
@@ -38,42 +37,8 @@ type Delivery struct {
 	Event Event
 }
 
-// secureHTTPClient creates an HTTP client with SSRF protection.
-var secureHTTPClient = &http.Client{
-	Timeout: 30 * time.Second,
-	Transport: &http.Transport{
-		DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
-			// Parse the address to get the IP
-			host, _, err := net.SplitHostPort(addr)
-			if err != nil {
-				return nil, err //nolint:wrapcheck
-			}
-
-			// Validate the resolved IP before connecting
-			ip := net.ParseIP(host)
-			if ip != nil {
-				if err := ValidateIPBeforeDial(ip); err != nil {
-					return nil, fmt.Errorf("blocked connection to private IP: %w", err)
-				}
-			}
-
-			// Use standard dialer with timeout
-			dialer := &net.Dialer{
-				Timeout:   10 * time.Second,
-				KeepAlive: 30 * time.Second,
-			}
-			return dialer.DialContext(ctx, network, addr)
-		},
-		MaxIdleConns:          100,
-		IdleConnTimeout:       90 * time.Second,
-		TLSHandshakeTimeout:   10 * time.Second,
-		ExpectContinueTimeout: 1 * time.Second,
-	},
-	// Don't follow redirects to prevent bypassing IP validation
-	CheckRedirect: func(*http.Request, []*http.Request) error {
-		return http.ErrUseLastResponse
-	},
-}
+// secureHTTPClient is an HTTP client with SSRF protection.
+var secureHTTPClient = ssrf.NewSecureClient()
 
 // do sends a webhook.
 // Caller must close the returned body.