fix(ssrf): pin resolved IP in dial to prevent DNS rebinding (#791)

Vinayak Mishra created

Change summary

pkg/ssrf/ssrf.go      | 16 ++++++++++------
pkg/ssrf/ssrf_test.go | 19 +++++++++++++++++++
2 files changed, 29 insertions(+), 6 deletions(-)

Detailed changes

pkg/ssrf/ssrf.go 🔗

@@ -23,16 +23,16 @@ var (
 
 // NewSecureClient returns an HTTP client with SSRF protection.
 // It validates resolved IPs at dial time to block connections to private
-// and internal networks. Since validation uses the already-resolved IP
-// from the Transport's DNS lookup, there is no TOCTOU gap between
-// resolution and connection. Redirects are disabled to match the
-// webhook client convention and prevent redirect-based SSRF.
+// and internal networks. Hostnames are resolved and the validated IP is
+// used directly in the dial call to prevent DNS rebinding (TOCTOU between
+// validation and connection). Redirects are disabled to match the webhook
+// client convention and prevent redirect-based SSRF.
 func NewSecureClient() *http.Client {
 	return &http.Client{
 		Timeout: 30 * time.Second,
 		Transport: &http.Transport{
 			DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
-				host, _, err := net.SplitHostPort(addr)
+				host, port, err := net.SplitHostPort(addr)
 				if err != nil {
 					return nil, err //nolint:wrapcheck
 				}
@@ -56,7 +56,11 @@ func NewSecureClient() *http.Client {
 					Timeout:   10 * time.Second,
 					KeepAlive: 30 * time.Second,
 				}
-				return dialer.DialContext(ctx, network, addr)
+				// Dial using the validated IP to prevent DNS rebinding.
+				// Without this, the dialer resolves the hostname again
+				// independently, and the second resolution could return
+				// a different (private) IP.
+				return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port))
 			},
 			MaxIdleConns:          100,
 			IdleConnTimeout:       90 * time.Second,

pkg/ssrf/ssrf_test.go 🔗

@@ -49,6 +49,25 @@ func TestNewSecureClientBlocksPrivateIPs(t *testing.T) {
 	}
 }
 
+func TestNewSecureClientBlocksPrivateHostnames(t *testing.T) {
+	client := NewSecureClient()
+	transport := client.Transport.(*http.Transport)
+
+	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+	defer cancel()
+
+	// "localhost" resolves to 127.0.0.1 (loopback) -- must be blocked.
+	// This exercises the hostname resolution path in DialContext:
+	// net.LookupIP("localhost") -> 127.0.0.1 -> isPrivateOrInternal -> blocked.
+	conn, err := transport.DialContext(ctx, "tcp", "localhost:80")
+	if conn != nil {
+		conn.Close()
+	}
+	if !errors.Is(err, ErrPrivateIP) {
+		t.Errorf("expected ErrPrivateIP for hostname resolving to loopback, got: %v", err)
+	}
+}
+
 func TestNewSecureClientNilIPNotErrPrivateIP(t *testing.T) {
 	client := NewSecureClient()
 	transport := client.Transport.(*http.Transport)