1package ssrf
2
3import (
4 "context"
5 "errors"
6 "net"
7 "net/http"
8 "net/http/httptest"
9 "testing"
10 "time"
11)
12
13func TestNewSecureClientBlocksPrivateIPs(t *testing.T) {
14 client := NewSecureClient()
15 transport := client.Transport.(*http.Transport)
16
17 tests := []struct {
18 name string
19 addr string
20 wantErr bool
21 }{
22 {"block loopback", "127.0.0.1:80", true},
23 {"block private 10.x", "10.0.0.1:80", true},
24 {"block link-local", "169.254.169.254:80", true},
25 {"block CGNAT", "100.64.0.1:80", true},
26 {"allow public IP", "8.8.8.8:80", false},
27 }
28
29 for _, tt := range tests {
30 t.Run(tt.name, func(t *testing.T) {
31 ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
32 defer cancel()
33
34 conn, err := transport.DialContext(ctx, "tcp", tt.addr)
35 if conn != nil {
36 conn.Close()
37 }
38
39 if tt.wantErr {
40 if err == nil {
41 t.Errorf("expected error for %s, got none", tt.addr)
42 }
43 } else {
44 if err != nil && errors.Is(err, ErrPrivateIP) {
45 t.Errorf("should not block %s with SSRF error, got: %v", tt.addr, err)
46 }
47 }
48 })
49 }
50}
51
52func TestNewSecureClientBlocksPrivateHostnames(t *testing.T) {
53 client := NewSecureClient()
54 transport := client.Transport.(*http.Transport)
55
56 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
57 defer cancel()
58
59 // "localhost" resolves to 127.0.0.1 (loopback) -- must be blocked.
60 // This exercises the hostname resolution path in DialContext:
61 // net.LookupIP("localhost") -> 127.0.0.1 -> isPrivateOrInternal -> blocked.
62 conn, err := transport.DialContext(ctx, "tcp", "localhost:80")
63 if conn != nil {
64 conn.Close()
65 }
66 if !errors.Is(err, ErrPrivateIP) {
67 t.Errorf("expected ErrPrivateIP for hostname resolving to loopback, got: %v", err)
68 }
69}
70
71func TestNewSecureClientNilIPNotErrPrivateIP(t *testing.T) {
72 client := NewSecureClient()
73 transport := client.Transport.(*http.Transport)
74
75 ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
76 defer cancel()
77
78 conn, err := transport.DialContext(ctx, "tcp", "not-an-ip:80")
79 if conn != nil {
80 conn.Close()
81 }
82 if err == nil {
83 t.Fatal("expected error for non-IP address, got none")
84 }
85 if errors.Is(err, ErrPrivateIP) {
86 t.Errorf("nil-IP path should not wrap ErrPrivateIP, got: %v", err)
87 }
88}
89
90func TestNewSecureClientBlocksRedirects(t *testing.T) {
91 redirectServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
92 http.Redirect(w, r, "http://8.8.8.8:8080/safe", http.StatusFound)
93 }))
94 defer redirectServer.Close()
95
96 client := NewSecureClient()
97 req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, redirectServer.URL, nil)
98 if err != nil {
99 t.Fatalf("Failed to create request: %v", err)
100 }
101
102 resp, err := client.Do(req)
103 if err != nil {
104 // httptest uses 127.0.0.1, blocked by SSRF protection
105 if !errors.Is(err, ErrPrivateIP) {
106 t.Fatalf("Request failed with non-SSRF error: %v", err)
107 }
108 return
109 }
110 defer resp.Body.Close()
111
112 if resp.StatusCode != http.StatusFound {
113 t.Errorf("Expected redirect response (302), got %d", resp.StatusCode)
114 }
115}
116
117func TestIsPrivateOrInternal(t *testing.T) {
118 tests := []struct {
119 ip string
120 want bool
121 }{
122 // Public
123 {"8.8.8.8", false},
124 {"2001:4860:4860::8888", false},
125
126 // Loopback
127 {"127.0.0.1", true},
128 {"::1", true},
129
130 // Private ranges
131 {"10.0.0.1", true},
132 {"192.168.1.1", true},
133 {"172.16.0.1", true},
134
135 // Link-local (cloud metadata)
136 {"169.254.169.254", true},
137
138 // CGNAT boundaries
139 {"100.64.0.1", true},
140 {"100.127.255.255", true},
141
142 // IPv6-mapped IPv4 (bypass vector the old webhook code missed)
143 {"::ffff:127.0.0.1", true},
144 {"::ffff:169.254.169.254", true},
145 {"::ffff:8.8.8.8", false},
146
147 // Reserved
148 {"0.0.0.0", true},
149 {"240.0.0.1", true},
150 }
151
152 for _, tt := range tests {
153 t.Run(tt.ip, func(t *testing.T) {
154 ip := net.ParseIP(tt.ip)
155 if ip == nil {
156 t.Fatalf("failed to parse IP: %s", tt.ip)
157 }
158 if got := isPrivateOrInternal(ip); got != tt.want {
159 t.Errorf("isPrivateOrInternal(%s) = %v, want %v", tt.ip, got, tt.want)
160 }
161 })
162 }
163}
164
165func TestValidateURL(t *testing.T) {
166 tests := []struct {
167 name string
168 url string
169 wantErr bool
170 errType error
171 }{
172 // Valid
173 {"valid https", "https://1.1.1.1/webhook", false, nil},
174
175 // Scheme validation
176 {"ftp scheme", "ftp://example.com/webhook", true, ErrInvalidScheme},
177 {"no scheme", "example.com/webhook", true, ErrInvalidScheme},
178
179 // Localhost
180 {"localhost", "http://localhost/webhook", true, ErrPrivateIP},
181 {"subdomain.localhost", "http://test.localhost/webhook", true, ErrPrivateIP},
182
183 // IP-based blocking (one per category -- range coverage is in TestIsPrivateOrInternal)
184 {"loopback IP", "http://127.0.0.1/webhook", true, ErrPrivateIP},
185 {"metadata IP", "http://169.254.169.254/latest/meta-data/", true, ErrPrivateIP},
186
187 // Invalid URLs
188 {"empty", "", true, ErrInvalidURL},
189 {"missing hostname", "http:///webhook", true, ErrInvalidURL},
190 }
191
192 for _, tt := range tests {
193 t.Run(tt.name, func(t *testing.T) {
194 err := ValidateURL(tt.url)
195 if (err != nil) != tt.wantErr {
196 t.Errorf("ValidateURL(%q) error = %v, wantErr %v", tt.url, err, tt.wantErr)
197 return
198 }
199 if tt.wantErr && tt.errType != nil {
200 if !errors.Is(err, tt.errType) {
201 t.Errorf("ValidateURL(%q) error = %v, want error type %v", tt.url, err, tt.errType)
202 }
203 }
204 })
205 }
206}
207
208func TestIsLocalhost(t *testing.T) {
209 tests := []struct {
210 hostname string
211 want bool
212 }{
213 {"localhost", true},
214 {"LOCALHOST", true},
215 {"test.localhost", true},
216 {"example.com", false},
217 {"localhost.com", false},
218 }
219
220 for _, tt := range tests {
221 t.Run(tt.hostname, func(t *testing.T) {
222 if got := isLocalhost(tt.hostname); got != tt.want {
223 t.Errorf("isLocalhost(%s) = %v, want %v", tt.hostname, got, tt.want)
224 }
225 })
226 }
227}