1package ssrf
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "net"
8 "net/http"
9 "net/url"
10 "slices"
11 "strings"
12 "time"
13)
14
15var (
16 // ErrPrivateIP is returned when a connection to a private or internal IP is blocked.
17 ErrPrivateIP = errors.New("connection to private or internal IP address is not allowed")
18 // ErrInvalidScheme is returned when a URL scheme is not http or https.
19 ErrInvalidScheme = errors.New("URL must use http or https scheme")
20 // ErrInvalidURL is returned when a URL is invalid.
21 ErrInvalidURL = errors.New("invalid URL")
22)
23
24// NewSecureClient returns an HTTP client with SSRF protection.
25// It validates resolved IPs at dial time to block connections to private
26// and internal networks. Hostnames are resolved and the validated IP is
27// used directly in the dial call to prevent DNS rebinding (TOCTOU between
28// validation and connection). Redirects are disabled to match the webhook
29// client convention and prevent redirect-based SSRF.
30func NewSecureClient() *http.Client {
31 return &http.Client{
32 Timeout: 30 * time.Second,
33 Transport: &http.Transport{
34 DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
35 host, port, err := net.SplitHostPort(addr)
36 if err != nil {
37 return nil, err //nolint:wrapcheck
38 }
39
40 ip := net.ParseIP(host)
41 if ip == nil {
42 ips, err := net.LookupIP(host) //nolint
43 if err != nil {
44 return nil, fmt.Errorf("DNS resolution failed for host %s: %v", host, err)
45 }
46 if len(ips) == 0 {
47 return nil, fmt.Errorf("no IP addresses found for host: %s", host)
48 }
49 ip = ips[0] // Use the first resolved IP address
50 }
51 if isPrivateOrInternal(ip) {
52 return nil, fmt.Errorf("%w", ErrPrivateIP)
53 }
54
55 dialer := &net.Dialer{
56 Timeout: 10 * time.Second,
57 KeepAlive: 30 * time.Second,
58 }
59 // Dial using the validated IP to prevent DNS rebinding.
60 // Without this, the dialer resolves the hostname again
61 // independently, and the second resolution could return
62 // a different (private) IP.
63 return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port))
64 },
65 MaxIdleConns: 100,
66 IdleConnTimeout: 90 * time.Second,
67 TLSHandshakeTimeout: 10 * time.Second,
68 ExpectContinueTimeout: 1 * time.Second,
69 },
70 CheckRedirect: func(*http.Request, []*http.Request) error {
71 return http.ErrUseLastResponse
72 },
73 }
74}
75
76// isPrivateOrInternal checks if an IP address is private, internal, or reserved.
77func isPrivateOrInternal(ip net.IP) bool {
78 // Normalize IPv6-mapped IPv4 (e.g. ::ffff:127.0.0.1) to IPv4 form
79 // so all checks apply consistently.
80 if ip4 := ip.To4(); ip4 != nil {
81 ip = ip4
82 }
83
84 if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() ||
85 ip.IsPrivate() || ip.IsUnspecified() || ip.IsMulticast() {
86 return true
87 }
88
89 if ip4 := ip.To4(); ip4 != nil {
90 // 0.0.0.0/8
91 if ip4[0] == 0 {
92 return true
93 }
94 // 100.64.0.0/10 (Shared Address Space / CGNAT)
95 if ip4[0] == 100 && ip4[1] >= 64 && ip4[1] <= 127 {
96 return true
97 }
98 // 192.0.0.0/24 (IETF Protocol Assignments)
99 if ip4[0] == 192 && ip4[1] == 0 && ip4[2] == 0 {
100 return true
101 }
102 // 192.0.2.0/24 (TEST-NET-1)
103 if ip4[0] == 192 && ip4[1] == 0 && ip4[2] == 2 {
104 return true
105 }
106 // 198.18.0.0/15 (benchmarking)
107 if ip4[0] == 198 && (ip4[1] == 18 || ip4[1] == 19) {
108 return true
109 }
110 // 198.51.100.0/24 (TEST-NET-2)
111 if ip4[0] == 198 && ip4[1] == 51 && ip4[2] == 100 {
112 return true
113 }
114 // 203.0.113.0/24 (TEST-NET-3)
115 if ip4[0] == 203 && ip4[1] == 0 && ip4[2] == 113 {
116 return true
117 }
118 // 240.0.0.0/4 (Reserved, includes 255.255.255.255 broadcast)
119 if ip4[0] >= 240 {
120 return true
121 }
122 }
123
124 return false
125}
126
127// ValidateURL validates that a URL is safe to make requests to.
128// It checks that the scheme is http/https, the hostname is not localhost,
129// and all resolved IPs are public.
130func ValidateURL(rawURL string) error {
131 if rawURL == "" {
132 return ErrInvalidURL
133 }
134
135 u, err := url.Parse(rawURL)
136 if err != nil {
137 return fmt.Errorf("%w: %v", ErrInvalidURL, err)
138 }
139
140 if u.Scheme != "http" && u.Scheme != "https" {
141 return ErrInvalidScheme
142 }
143
144 hostname := u.Hostname()
145 if hostname == "" {
146 return fmt.Errorf("%w: missing hostname", ErrInvalidURL)
147 }
148
149 if isLocalhost(hostname) {
150 return ErrPrivateIP
151 }
152
153 if ip := net.ParseIP(hostname); ip != nil {
154 if isPrivateOrInternal(ip) {
155 return ErrPrivateIP
156 }
157 return nil
158 }
159
160 ips, err := net.DefaultResolver.LookupIPAddr(context.Background(), hostname)
161 if err != nil {
162 return fmt.Errorf("%w: cannot resolve hostname: %v", ErrInvalidURL, err)
163 }
164
165 if slices.ContainsFunc(ips, func(addr net.IPAddr) bool {
166 return isPrivateOrInternal(addr.IP)
167 }) {
168 return ErrPrivateIP
169 }
170
171 return nil
172}
173
174// ValidateIPBeforeDial validates an IP address before establishing a connection.
175// This prevents DNS rebinding attacks by checking the resolved IP at dial time.
176func ValidateIPBeforeDial(ip net.IP) error {
177 if isPrivateOrInternal(ip) {
178 return ErrPrivateIP
179 }
180 return nil
181}
182
183// isLocalhost checks if the hostname is localhost or similar.
184func isLocalhost(hostname string) bool {
185 hostname = strings.ToLower(hostname)
186 return hostname == "localhost" ||
187 hostname == "localhost.localdomain" ||
188 strings.HasSuffix(hostname, ".localhost")
189}