1package ssrf
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "net"
8 "net/http"
9 "net/url"
10 "slices"
11 "strings"
12 "time"
13)
14
15var (
16 // ErrPrivateIP is returned when a connection to a private or internal IP is blocked.
17 ErrPrivateIP = errors.New("connection to private or internal IP address is not allowed")
18 // ErrInvalidScheme is returned when a URL scheme is not http or https.
19 ErrInvalidScheme = errors.New("URL must use http or https scheme")
20 // ErrInvalidURL is returned when a URL is invalid.
21 ErrInvalidURL = errors.New("invalid URL")
22)
23
24// NewSecureClient returns an HTTP client with SSRF protection.
25// It validates resolved IPs at dial time to block connections to private
26// and internal networks. Since validation uses the already-resolved IP
27// from the Transport's DNS lookup, there is no TOCTOU gap between
28// resolution and connection. Redirects are disabled to match the
29// webhook client convention and prevent redirect-based SSRF.
30func NewSecureClient() *http.Client {
31 return &http.Client{
32 Timeout: 30 * time.Second,
33 Transport: &http.Transport{
34 DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
35 host, _, err := net.SplitHostPort(addr)
36 if err != nil {
37 return nil, err //nolint:wrapcheck
38 }
39
40 ip := net.ParseIP(host)
41 if ip == nil {
42 ips, err := net.LookupIP(host) //nolint
43 if err != nil {
44 return nil, fmt.Errorf("DNS resolution failed for host %s: %v", host, err)
45 }
46 if len(ips) == 0 {
47 return nil, fmt.Errorf("no IP addresses found for host: %s", host)
48 }
49 ip = ips[0] // Use the first resolved IP address
50 }
51 if isPrivateOrInternal(ip) {
52 return nil, fmt.Errorf("%w", ErrPrivateIP)
53 }
54
55 dialer := &net.Dialer{
56 Timeout: 10 * time.Second,
57 KeepAlive: 30 * time.Second,
58 }
59 return dialer.DialContext(ctx, network, addr)
60 },
61 MaxIdleConns: 100,
62 IdleConnTimeout: 90 * time.Second,
63 TLSHandshakeTimeout: 10 * time.Second,
64 ExpectContinueTimeout: 1 * time.Second,
65 },
66 CheckRedirect: func(*http.Request, []*http.Request) error {
67 return http.ErrUseLastResponse
68 },
69 }
70}
71
72// isPrivateOrInternal checks if an IP address is private, internal, or reserved.
73func isPrivateOrInternal(ip net.IP) bool {
74 // Normalize IPv6-mapped IPv4 (e.g. ::ffff:127.0.0.1) to IPv4 form
75 // so all checks apply consistently.
76 if ip4 := ip.To4(); ip4 != nil {
77 ip = ip4
78 }
79
80 if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() ||
81 ip.IsPrivate() || ip.IsUnspecified() || ip.IsMulticast() {
82 return true
83 }
84
85 if ip4 := ip.To4(); ip4 != nil {
86 // 0.0.0.0/8
87 if ip4[0] == 0 {
88 return true
89 }
90 // 100.64.0.0/10 (Shared Address Space / CGNAT)
91 if ip4[0] == 100 && ip4[1] >= 64 && ip4[1] <= 127 {
92 return true
93 }
94 // 192.0.0.0/24 (IETF Protocol Assignments)
95 if ip4[0] == 192 && ip4[1] == 0 && ip4[2] == 0 {
96 return true
97 }
98 // 192.0.2.0/24 (TEST-NET-1)
99 if ip4[0] == 192 && ip4[1] == 0 && ip4[2] == 2 {
100 return true
101 }
102 // 198.18.0.0/15 (benchmarking)
103 if ip4[0] == 198 && (ip4[1] == 18 || ip4[1] == 19) {
104 return true
105 }
106 // 198.51.100.0/24 (TEST-NET-2)
107 if ip4[0] == 198 && ip4[1] == 51 && ip4[2] == 100 {
108 return true
109 }
110 // 203.0.113.0/24 (TEST-NET-3)
111 if ip4[0] == 203 && ip4[1] == 0 && ip4[2] == 113 {
112 return true
113 }
114 // 240.0.0.0/4 (Reserved, includes 255.255.255.255 broadcast)
115 if ip4[0] >= 240 {
116 return true
117 }
118 }
119
120 return false
121}
122
123// ValidateURL validates that a URL is safe to make requests to.
124// It checks that the scheme is http/https, the hostname is not localhost,
125// and all resolved IPs are public.
126func ValidateURL(rawURL string) error {
127 if rawURL == "" {
128 return ErrInvalidURL
129 }
130
131 u, err := url.Parse(rawURL)
132 if err != nil {
133 return fmt.Errorf("%w: %v", ErrInvalidURL, err)
134 }
135
136 if u.Scheme != "http" && u.Scheme != "https" {
137 return ErrInvalidScheme
138 }
139
140 hostname := u.Hostname()
141 if hostname == "" {
142 return fmt.Errorf("%w: missing hostname", ErrInvalidURL)
143 }
144
145 if isLocalhost(hostname) {
146 return ErrPrivateIP
147 }
148
149 if ip := net.ParseIP(hostname); ip != nil {
150 if isPrivateOrInternal(ip) {
151 return ErrPrivateIP
152 }
153 return nil
154 }
155
156 ips, err := net.DefaultResolver.LookupIPAddr(context.Background(), hostname)
157 if err != nil {
158 return fmt.Errorf("%w: cannot resolve hostname: %v", ErrInvalidURL, err)
159 }
160
161 if slices.ContainsFunc(ips, func(addr net.IPAddr) bool {
162 return isPrivateOrInternal(addr.IP)
163 }) {
164 return ErrPrivateIP
165 }
166
167 return nil
168}
169
170// ValidateIPBeforeDial validates an IP address before establishing a connection.
171// This prevents DNS rebinding attacks by checking the resolved IP at dial time.
172func ValidateIPBeforeDial(ip net.IP) error {
173 if isPrivateOrInternal(ip) {
174 return ErrPrivateIP
175 }
176 return nil
177}
178
179// isLocalhost checks if the hostname is localhost or similar.
180func isLocalhost(hostname string) bool {
181 hostname = strings.ToLower(hostname)
182 return hostname == "localhost" ||
183 hostname == "localhost.localdomain" ||
184 strings.HasSuffix(hostname, ".localhost")
185}