ssrf.go

  1package ssrf
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"net"
  8	"net/http"
  9	"net/url"
 10	"slices"
 11	"strings"
 12	"time"
 13)
 14
 15var (
 16	// ErrPrivateIP is returned when a connection to a private or internal IP is blocked.
 17	ErrPrivateIP = errors.New("connection to private or internal IP address is not allowed")
 18	// ErrInvalidScheme is returned when a URL scheme is not http or https.
 19	ErrInvalidScheme = errors.New("URL must use http or https scheme")
 20	// ErrInvalidURL is returned when a URL is invalid.
 21	ErrInvalidURL = errors.New("invalid URL")
 22)
 23
 24// NewSecureClient returns an HTTP client with SSRF protection.
 25// It validates resolved IPs at dial time to block connections to private
 26// and internal networks. Since validation uses the already-resolved IP
 27// from the Transport's DNS lookup, there is no TOCTOU gap between
 28// resolution and connection. Redirects are disabled to match the
 29// webhook client convention and prevent redirect-based SSRF.
 30func NewSecureClient() *http.Client {
 31	return &http.Client{
 32		Timeout: 30 * time.Second,
 33		Transport: &http.Transport{
 34			DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
 35				host, _, err := net.SplitHostPort(addr)
 36				if err != nil {
 37					return nil, err //nolint:wrapcheck
 38				}
 39
 40				ip := net.ParseIP(host)
 41				if ip == nil {
 42					ips, err := net.LookupIP(host) //nolint
 43					if err != nil {
 44						return nil, fmt.Errorf("DNS resolution failed for host %s: %v", host, err)
 45					}
 46					if len(ips) == 0 {
 47						return nil, fmt.Errorf("no IP addresses found for host: %s", host)
 48					}
 49					ip = ips[0] // Use the first resolved IP address
 50				}
 51				if isPrivateOrInternal(ip) {
 52					return nil, fmt.Errorf("%w", ErrPrivateIP)
 53				}
 54
 55				dialer := &net.Dialer{
 56					Timeout:   10 * time.Second,
 57					KeepAlive: 30 * time.Second,
 58				}
 59				return dialer.DialContext(ctx, network, addr)
 60			},
 61			MaxIdleConns:          100,
 62			IdleConnTimeout:       90 * time.Second,
 63			TLSHandshakeTimeout:   10 * time.Second,
 64			ExpectContinueTimeout: 1 * time.Second,
 65		},
 66		CheckRedirect: func(*http.Request, []*http.Request) error {
 67			return http.ErrUseLastResponse
 68		},
 69	}
 70}
 71
 72// isPrivateOrInternal checks if an IP address is private, internal, or reserved.
 73func isPrivateOrInternal(ip net.IP) bool {
 74	// Normalize IPv6-mapped IPv4 (e.g. ::ffff:127.0.0.1) to IPv4 form
 75	// so all checks apply consistently.
 76	if ip4 := ip.To4(); ip4 != nil {
 77		ip = ip4
 78	}
 79
 80	if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() ||
 81		ip.IsPrivate() || ip.IsUnspecified() || ip.IsMulticast() {
 82		return true
 83	}
 84
 85	if ip4 := ip.To4(); ip4 != nil {
 86		// 0.0.0.0/8
 87		if ip4[0] == 0 {
 88			return true
 89		}
 90		// 100.64.0.0/10 (Shared Address Space / CGNAT)
 91		if ip4[0] == 100 && ip4[1] >= 64 && ip4[1] <= 127 {
 92			return true
 93		}
 94		// 192.0.0.0/24 (IETF Protocol Assignments)
 95		if ip4[0] == 192 && ip4[1] == 0 && ip4[2] == 0 {
 96			return true
 97		}
 98		// 192.0.2.0/24 (TEST-NET-1)
 99		if ip4[0] == 192 && ip4[1] == 0 && ip4[2] == 2 {
100			return true
101		}
102		// 198.18.0.0/15 (benchmarking)
103		if ip4[0] == 198 && (ip4[1] == 18 || ip4[1] == 19) {
104			return true
105		}
106		// 198.51.100.0/24 (TEST-NET-2)
107		if ip4[0] == 198 && ip4[1] == 51 && ip4[2] == 100 {
108			return true
109		}
110		// 203.0.113.0/24 (TEST-NET-3)
111		if ip4[0] == 203 && ip4[1] == 0 && ip4[2] == 113 {
112			return true
113		}
114		// 240.0.0.0/4 (Reserved, includes 255.255.255.255 broadcast)
115		if ip4[0] >= 240 {
116			return true
117		}
118	}
119
120	return false
121}
122
123// ValidateURL validates that a URL is safe to make requests to.
124// It checks that the scheme is http/https, the hostname is not localhost,
125// and all resolved IPs are public.
126func ValidateURL(rawURL string) error {
127	if rawURL == "" {
128		return ErrInvalidURL
129	}
130
131	u, err := url.Parse(rawURL)
132	if err != nil {
133		return fmt.Errorf("%w: %v", ErrInvalidURL, err)
134	}
135
136	if u.Scheme != "http" && u.Scheme != "https" {
137		return ErrInvalidScheme
138	}
139
140	hostname := u.Hostname()
141	if hostname == "" {
142		return fmt.Errorf("%w: missing hostname", ErrInvalidURL)
143	}
144
145	if isLocalhost(hostname) {
146		return ErrPrivateIP
147	}
148
149	if ip := net.ParseIP(hostname); ip != nil {
150		if isPrivateOrInternal(ip) {
151			return ErrPrivateIP
152		}
153		return nil
154	}
155
156	ips, err := net.DefaultResolver.LookupIPAddr(context.Background(), hostname)
157	if err != nil {
158		return fmt.Errorf("%w: cannot resolve hostname: %v", ErrInvalidURL, err)
159	}
160
161	if slices.ContainsFunc(ips, func(addr net.IPAddr) bool {
162		return isPrivateOrInternal(addr.IP)
163	}) {
164		return ErrPrivateIP
165	}
166
167	return nil
168}
169
170// ValidateIPBeforeDial validates an IP address before establishing a connection.
171// This prevents DNS rebinding attacks by checking the resolved IP at dial time.
172func ValidateIPBeforeDial(ip net.IP) error {
173	if isPrivateOrInternal(ip) {
174		return ErrPrivateIP
175	}
176	return nil
177}
178
179// isLocalhost checks if the hostname is localhost or similar.
180func isLocalhost(hostname string) bool {
181	hostname = strings.ToLower(hostname)
182	return hostname == "localhost" ||
183		hostname == "localhost.localdomain" ||
184		strings.HasSuffix(hostname, ".localhost")
185}