ssrf_test.go

  1package webhook
  2
  3import (
  4	"context"
  5	"net/http"
  6	"net/http/httptest"
  7	"testing"
  8	"time"
  9
 10	"github.com/charmbracelet/soft-serve/pkg/db/models"
 11)
 12
 13// TestSSRFProtection tests that the webhook system blocks SSRF attempts.
 14func TestSSRFProtection(t *testing.T) {
 15	tests := []struct {
 16		name        string
 17		webhookURL  string
 18		shouldBlock bool
 19		description string
 20	}{
 21		{
 22			name:        "block localhost",
 23			webhookURL:  "http://localhost:8080/webhook",
 24			shouldBlock: true,
 25			description: "should block localhost addresses",
 26		},
 27		{
 28			name:        "block 127.0.0.1",
 29			webhookURL:  "http://127.0.0.1:8080/webhook",
 30			shouldBlock: true,
 31			description: "should block loopback addresses",
 32		},
 33		{
 34			name:        "block 169.254.169.254",
 35			webhookURL:  "http://169.254.169.254/latest/meta-data/",
 36			shouldBlock: true,
 37			description: "should block cloud metadata service",
 38		},
 39		{
 40			name:        "block private network",
 41			webhookURL:  "http://192.168.1.1/webhook",
 42			shouldBlock: true,
 43			description: "should block private networks",
 44		},
 45		{
 46			name:        "allow public IP",
 47			webhookURL:  "http://8.8.8.8/webhook",
 48			shouldBlock: false,
 49			description: "should allow public IP addresses",
 50		},
 51	}
 52
 53	for _, tt := range tests {
 54		t.Run(tt.name, func(t *testing.T) {
 55			// Create a test webhook
 56			webhook := models.Webhook{
 57				URL:         tt.webhookURL,
 58				ContentType: int(ContentTypeJSON),
 59				Secret:      "",
 60			}
 61
 62			// Try to send a webhook
 63			ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
 64			defer cancel()
 65
 66			// Create a simple payload
 67			payload := map[string]string{"test": "data"}
 68
 69			err := sendWebhookWithContext(ctx, webhook, EventPush, payload)
 70
 71			if tt.shouldBlock {
 72				if err == nil {
 73					t.Errorf("%s: expected error but got none", tt.description)
 74				}
 75			} else {
 76				// For public IPs, we expect a connection error (since 8.8.8.8 won't be listening)
 77				// but NOT an SSRF blocking error
 78				if err != nil && isSSRFError(err) {
 79					t.Errorf("%s: should not block public IPs, got: %v", tt.description, err)
 80				}
 81			}
 82		})
 83	}
 84}
 85
 86// TestSecureHTTPClientBlocksRedirects tests that redirects are not followed.
 87func TestSecureHTTPClientBlocksRedirects(t *testing.T) {
 88	// Create a test server on a public-looking address that redirects
 89	redirectServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 90		http.Redirect(w, r, "http://8.8.8.8:8080/safe", http.StatusFound)
 91	}))
 92	defer redirectServer.Close()
 93
 94	// Try to make a request that would redirect
 95	req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, redirectServer.URL, nil)
 96	if err != nil {
 97		t.Fatalf("Failed to create request: %v", err)
 98	}
 99
100	resp, err := secureHTTPClient.Do(req)
101	if err != nil {
102		// httptest.NewServer uses 127.0.0.1, which will be blocked by our SSRF protection
103		// This is actually correct behavior - we're blocking the initial connection
104		if !isSSRFError(err) {
105			t.Fatalf("Request failed with non-SSRF error: %v", err)
106		}
107		// Test passed - we blocked the loopback connection
108		return
109	}
110	defer resp.Body.Close()
111
112	// If we got here, check that we got the redirect response (not followed)
113	if resp.StatusCode != http.StatusFound {
114		t.Errorf("Expected redirect response (302), got %d", resp.StatusCode)
115	}
116}
117
118// TestDialContextBlocksPrivateIPs tests the DialContext function directly.
119func TestDialContextBlocksPrivateIPs(t *testing.T) {
120	transport := secureHTTPClient.Transport.(*http.Transport)
121
122	tests := []struct {
123		name    string
124		addr    string
125		wantErr bool
126	}{
127		{"block loopback", "127.0.0.1:80", true},
128		{"block private 10.x", "10.0.0.1:80", true},
129		{"block private 192.168.x", "192.168.1.1:80", true},
130		{"block link-local", "169.254.169.254:80", true},
131		{"allow public IP", "8.8.8.8:80", false},
132	}
133
134	for _, tt := range tests {
135		t.Run(tt.name, func(t *testing.T) {
136			ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
137			defer cancel()
138
139			conn, err := transport.DialContext(ctx, "tcp", tt.addr)
140			if conn != nil {
141				conn.Close()
142			}
143
144			if tt.wantErr {
145				if err == nil {
146					t.Errorf("Expected error for %s, got none", tt.addr)
147				}
148			} else {
149				// For public IPs, we expect a connection timeout/refused (not an SSRF block)
150				if err != nil && isSSRFError(err) {
151					t.Errorf("Should not block %s with SSRF error, got: %v", tt.addr, err)
152				}
153			}
154		})
155	}
156}
157
158// sendWebhookWithContext is a test helper that doesn't require database.
159func sendWebhookWithContext(ctx context.Context, w models.Webhook, _ Event, _ any) error {
160	// This is a simplified version for testing that just attempts the HTTP connection
161	req, err := http.NewRequest("POST", w.URL, nil)
162	if err != nil {
163		return err //nolint:wrapcheck
164	}
165	req = req.WithContext(ctx)
166
167	resp, err := secureHTTPClient.Do(req)
168	if resp != nil {
169		resp.Body.Close()
170	}
171	return err //nolint:wrapcheck
172}
173
174// isSSRFError checks if an error is related to SSRF blocking.
175func isSSRFError(err error) bool {
176	if err == nil {
177		return false
178	}
179	errMsg := err.Error()
180	return contains(errMsg, "private IP") ||
181		contains(errMsg, "blocked connection") ||
182		err == ErrPrivateIP
183}
184
185func contains(s, substr string) bool {
186	return len(s) >= len(substr) && (s == substr || len(substr) == 0 || indexOfSubstring(s, substr) >= 0)
187}
188
189func indexOfSubstring(s, substr string) int {
190	for i := 0; i <= len(s)-len(substr); i++ {
191		if s[i:i+len(substr)] == substr {
192			return i
193		}
194	}
195	return -1
196}
197
198// TestPrivateIPResolution tests that hostnames resolving to private IPs are blocked.
199func TestPrivateIPResolution(t *testing.T) {
200	// This test verifies that even if a hostname looks public, if it resolves to a private IP, it's blocked
201	webhook := models.Webhook{
202		URL:         "http://127.0.0.1:9999/webhook",
203		ContentType: int(ContentTypeJSON),
204	}
205
206	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
207	defer cancel()
208
209	err := sendWebhookWithContext(ctx, webhook, EventPush, map[string]string{"test": "data"})
210	if err == nil {
211		t.Error("Expected error when connecting to loopback address")
212		return
213	}
214
215	if !isSSRFError(err) {
216		t.Errorf("Expected SSRF blocking error, got: %v", err)
217	}
218}