1package webhook
2
3import (
4 "context"
5 "errors"
6 "net/http"
7 "testing"
8 "time"
9
10 "github.com/charmbracelet/soft-serve/pkg/db/models"
11 "github.com/charmbracelet/soft-serve/pkg/ssrf"
12)
13
14// TestSSRFProtection is an integration test verifying the webhook send path
15// blocks private IPs end-to-end (models.Webhook -> secureHTTPClient -> ssrf).
16func TestSSRFProtection(t *testing.T) {
17 tests := []struct {
18 name string
19 webhookURL string
20 shouldBlock bool
21 }{
22 {"block loopback", "http://127.0.0.1:8080/webhook", true},
23 {"block metadata", "http://169.254.169.254/latest/meta-data/", true},
24 {"allow public IP", "http://8.8.8.8/webhook", false},
25 }
26
27 for _, tt := range tests {
28 t.Run(tt.name, func(t *testing.T) {
29 w := models.Webhook{
30 URL: tt.webhookURL,
31 ContentType: int(ContentTypeJSON),
32 }
33
34 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
35 defer cancel()
36
37 req, err := http.NewRequestWithContext(ctx, "POST", w.URL, nil)
38 if err != nil {
39 t.Fatalf("failed to create request: %v", err)
40 }
41
42 resp, err := secureHTTPClient.Do(req)
43 if resp != nil {
44 resp.Body.Close()
45 }
46
47 if tt.shouldBlock {
48 if err == nil {
49 t.Errorf("%s: expected error but got none", tt.name)
50 }
51 } else {
52 if err != nil && errors.Is(err, ssrf.ErrPrivateIP) {
53 t.Errorf("%s: should not block public IPs, got: %v", tt.name, err)
54 }
55 }
56 })
57 }
58}