Detailed changes
@@ -9,6 +9,7 @@ import (
"net/http"
"charm.land/log/v2"
+ "github.com/charmbracelet/soft-serve/pkg/ssrf"
)
// httpClient is a Git LFS client to communicate with a LFS source API.
@@ -22,11 +23,12 @@ var _ Client = (*httpClient)(nil)
// newHTTPClient returns a new Git LFS client.
func newHTTPClient(endpoint Endpoint) *httpClient {
+ client := ssrf.NewSecureClient()
return &httpClient{
- client: http.DefaultClient,
+ client: client,
endpoint: endpoint,
transfers: map[string]TransferAdapter{
- TransferBasic: &BasicTransferAdapter{http.DefaultClient},
+ TransferBasic: &BasicTransferAdapter{client},
},
}
}
@@ -0,0 +1,178 @@
+package ssrf
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net"
+ "net/http"
+ "net/url"
+ "slices"
+ "strings"
+ "time"
+)
+
+var (
+ // ErrPrivateIP is returned when a connection to a private or internal IP is blocked.
+ ErrPrivateIP = errors.New("connection to private or internal IP address is not allowed")
+ // ErrInvalidScheme is returned when a URL scheme is not http or https.
+ ErrInvalidScheme = errors.New("URL must use http or https scheme")
+ // ErrInvalidURL is returned when a URL is invalid.
+ ErrInvalidURL = errors.New("invalid URL")
+)
+
+// NewSecureClient returns an HTTP client with SSRF protection.
+// It validates resolved IPs at dial time to block connections to private
+// and internal networks. Since validation uses the already-resolved IP
+// from the Transport's DNS lookup, there is no TOCTOU gap between
+// resolution and connection. Redirects are disabled to match the
+// webhook client convention and prevent redirect-based SSRF.
+func NewSecureClient() *http.Client {
+ return &http.Client{
+ Timeout: 30 * time.Second,
+ Transport: &http.Transport{
+ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ host, _, err := net.SplitHostPort(addr)
+ if err != nil {
+ return nil, err //nolint:wrapcheck
+ }
+
+ ip := net.ParseIP(host)
+ if ip == nil {
+ return nil, fmt.Errorf("unexpected non-IP address in dial: %s", host)
+ }
+ if isPrivateOrInternal(ip) {
+ return nil, fmt.Errorf("%w", ErrPrivateIP)
+ }
+
+ dialer := &net.Dialer{
+ Timeout: 10 * time.Second,
+ KeepAlive: 30 * time.Second,
+ }
+ return dialer.DialContext(ctx, network, addr)
+ },
+ MaxIdleConns: 100,
+ IdleConnTimeout: 90 * time.Second,
+ TLSHandshakeTimeout: 10 * time.Second,
+ ExpectContinueTimeout: 1 * time.Second,
+ },
+ CheckRedirect: func(*http.Request, []*http.Request) error {
+ return http.ErrUseLastResponse
+ },
+ }
+}
+
+// isPrivateOrInternal checks if an IP address is private, internal, or reserved.
+func isPrivateOrInternal(ip net.IP) bool {
+ // Normalize IPv6-mapped IPv4 (e.g. ::ffff:127.0.0.1) to IPv4 form
+ // so all checks apply consistently.
+ if ip4 := ip.To4(); ip4 != nil {
+ ip = ip4
+ }
+
+ if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() ||
+ ip.IsPrivate() || ip.IsUnspecified() || ip.IsMulticast() {
+ return true
+ }
+
+ if ip4 := ip.To4(); ip4 != nil {
+ // 0.0.0.0/8
+ if ip4[0] == 0 {
+ return true
+ }
+ // 100.64.0.0/10 (Shared Address Space / CGNAT)
+ if ip4[0] == 100 && ip4[1] >= 64 && ip4[1] <= 127 {
+ return true
+ }
+ // 192.0.0.0/24 (IETF Protocol Assignments)
+ if ip4[0] == 192 && ip4[1] == 0 && ip4[2] == 0 {
+ return true
+ }
+ // 192.0.2.0/24 (TEST-NET-1)
+ if ip4[0] == 192 && ip4[1] == 0 && ip4[2] == 2 {
+ return true
+ }
+ // 198.18.0.0/15 (benchmarking)
+ if ip4[0] == 198 && (ip4[1] == 18 || ip4[1] == 19) {
+ return true
+ }
+ // 198.51.100.0/24 (TEST-NET-2)
+ if ip4[0] == 198 && ip4[1] == 51 && ip4[2] == 100 {
+ return true
+ }
+ // 203.0.113.0/24 (TEST-NET-3)
+ if ip4[0] == 203 && ip4[1] == 0 && ip4[2] == 113 {
+ return true
+ }
+ // 240.0.0.0/4 (Reserved, includes 255.255.255.255 broadcast)
+ if ip4[0] >= 240 {
+ return true
+ }
+ }
+
+ return false
+}
+
+// ValidateURL validates that a URL is safe to make requests to.
+// It checks that the scheme is http/https, the hostname is not localhost,
+// and all resolved IPs are public.
+func ValidateURL(rawURL string) error {
+ if rawURL == "" {
+ return ErrInvalidURL
+ }
+
+ u, err := url.Parse(rawURL)
+ if err != nil {
+ return fmt.Errorf("%w: %v", ErrInvalidURL, err)
+ }
+
+ if u.Scheme != "http" && u.Scheme != "https" {
+ return ErrInvalidScheme
+ }
+
+ hostname := u.Hostname()
+ if hostname == "" {
+ return fmt.Errorf("%w: missing hostname", ErrInvalidURL)
+ }
+
+ if isLocalhost(hostname) {
+ return ErrPrivateIP
+ }
+
+ if ip := net.ParseIP(hostname); ip != nil {
+ if isPrivateOrInternal(ip) {
+ return ErrPrivateIP
+ }
+ return nil
+ }
+
+ ips, err := net.DefaultResolver.LookupIPAddr(context.Background(), hostname)
+ if err != nil {
+ return fmt.Errorf("%w: cannot resolve hostname: %v", ErrInvalidURL, err)
+ }
+
+ if slices.ContainsFunc(ips, func(addr net.IPAddr) bool {
+ return isPrivateOrInternal(addr.IP)
+ }) {
+ return ErrPrivateIP
+ }
+
+ return nil
+}
+
+// ValidateIPBeforeDial validates an IP address before establishing a connection.
+// This prevents DNS rebinding attacks by checking the resolved IP at dial time.
+func ValidateIPBeforeDial(ip net.IP) error {
+ if isPrivateOrInternal(ip) {
+ return ErrPrivateIP
+ }
+ return nil
+}
+
+// isLocalhost checks if the hostname is localhost or similar.
+func isLocalhost(hostname string) bool {
+ hostname = strings.ToLower(hostname)
+ return hostname == "localhost" ||
+ hostname == "localhost.localdomain" ||
+ strings.HasSuffix(hostname, ".localhost")
+}
@@ -0,0 +1,208 @@
+package ssrf
+
+import (
+ "context"
+ "errors"
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+)
+
+func TestNewSecureClientBlocksPrivateIPs(t *testing.T) {
+ client := NewSecureClient()
+ transport := client.Transport.(*http.Transport)
+
+ tests := []struct {
+ name string
+ addr string
+ wantErr bool
+ }{
+ {"block loopback", "127.0.0.1:80", true},
+ {"block private 10.x", "10.0.0.1:80", true},
+ {"block link-local", "169.254.169.254:80", true},
+ {"block CGNAT", "100.64.0.1:80", true},
+ {"allow public IP", "8.8.8.8:80", false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
+ defer cancel()
+
+ conn, err := transport.DialContext(ctx, "tcp", tt.addr)
+ if conn != nil {
+ conn.Close()
+ }
+
+ if tt.wantErr {
+ if err == nil {
+ t.Errorf("expected error for %s, got none", tt.addr)
+ }
+ } else {
+ if err != nil && errors.Is(err, ErrPrivateIP) {
+ t.Errorf("should not block %s with SSRF error, got: %v", tt.addr, err)
+ }
+ }
+ })
+ }
+}
+
+func TestNewSecureClientNilIPNotErrPrivateIP(t *testing.T) {
+ client := NewSecureClient()
+ transport := client.Transport.(*http.Transport)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
+ defer cancel()
+
+ conn, err := transport.DialContext(ctx, "tcp", "not-an-ip:80")
+ if conn != nil {
+ conn.Close()
+ }
+ if err == nil {
+ t.Fatal("expected error for non-IP address, got none")
+ }
+ if errors.Is(err, ErrPrivateIP) {
+ t.Errorf("nil-IP path should not wrap ErrPrivateIP, got: %v", err)
+ }
+}
+
+func TestNewSecureClientBlocksRedirects(t *testing.T) {
+ redirectServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ http.Redirect(w, r, "http://8.8.8.8:8080/safe", http.StatusFound)
+ }))
+ defer redirectServer.Close()
+
+ client := NewSecureClient()
+ req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, redirectServer.URL, nil)
+ if err != nil {
+ t.Fatalf("Failed to create request: %v", err)
+ }
+
+ resp, err := client.Do(req)
+ if err != nil {
+ // httptest uses 127.0.0.1, blocked by SSRF protection
+ if !errors.Is(err, ErrPrivateIP) {
+ t.Fatalf("Request failed with non-SSRF error: %v", err)
+ }
+ return
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusFound {
+ t.Errorf("Expected redirect response (302), got %d", resp.StatusCode)
+ }
+}
+
+func TestIsPrivateOrInternal(t *testing.T) {
+ tests := []struct {
+ ip string
+ want bool
+ }{
+ // Public
+ {"8.8.8.8", false},
+ {"2001:4860:4860::8888", false},
+
+ // Loopback
+ {"127.0.0.1", true},
+ {"::1", true},
+
+ // Private ranges
+ {"10.0.0.1", true},
+ {"192.168.1.1", true},
+ {"172.16.0.1", true},
+
+ // Link-local (cloud metadata)
+ {"169.254.169.254", true},
+
+ // CGNAT boundaries
+ {"100.64.0.1", true},
+ {"100.127.255.255", true},
+
+ // IPv6-mapped IPv4 (bypass vector the old webhook code missed)
+ {"::ffff:127.0.0.1", true},
+ {"::ffff:169.254.169.254", true},
+ {"::ffff:8.8.8.8", false},
+
+ // Reserved
+ {"0.0.0.0", true},
+ {"240.0.0.1", true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.ip, func(t *testing.T) {
+ ip := net.ParseIP(tt.ip)
+ if ip == nil {
+ t.Fatalf("failed to parse IP: %s", tt.ip)
+ }
+ if got := isPrivateOrInternal(ip); got != tt.want {
+ t.Errorf("isPrivateOrInternal(%s) = %v, want %v", tt.ip, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestValidateURL(t *testing.T) {
+ tests := []struct {
+ name string
+ url string
+ wantErr bool
+ errType error
+ }{
+ // Valid
+ {"valid https", "https://1.1.1.1/webhook", false, nil},
+
+ // Scheme validation
+ {"ftp scheme", "ftp://example.com/webhook", true, ErrInvalidScheme},
+ {"no scheme", "example.com/webhook", true, ErrInvalidScheme},
+
+ // Localhost
+ {"localhost", "http://localhost/webhook", true, ErrPrivateIP},
+ {"subdomain.localhost", "http://test.localhost/webhook", true, ErrPrivateIP},
+
+ // IP-based blocking (one per category -- range coverage is in TestIsPrivateOrInternal)
+ {"loopback IP", "http://127.0.0.1/webhook", true, ErrPrivateIP},
+ {"metadata IP", "http://169.254.169.254/latest/meta-data/", true, ErrPrivateIP},
+
+ // Invalid URLs
+ {"empty", "", true, ErrInvalidURL},
+ {"missing hostname", "http:///webhook", true, ErrInvalidURL},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := ValidateURL(tt.url)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("ValidateURL(%q) error = %v, wantErr %v", tt.url, err, tt.wantErr)
+ return
+ }
+ if tt.wantErr && tt.errType != nil {
+ if !errors.Is(err, tt.errType) {
+ t.Errorf("ValidateURL(%q) error = %v, want error type %v", tt.url, err, tt.errType)
+ }
+ }
+ })
+ }
+}
+
+func TestIsLocalhost(t *testing.T) {
+ tests := []struct {
+ hostname string
+ want bool
+ }{
+ {"localhost", true},
+ {"LOCALHOST", true},
+ {"test.localhost", true},
+ {"example.com", false},
+ {"localhost.com", false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.hostname, func(t *testing.T) {
+ if got := isLocalhost(tt.hostname); got != tt.want {
+ t.Errorf("isLocalhost(%s) = %v, want %v", tt.hostname, got, tt.want)
+ }
+ })
+ }
+}
@@ -2,217 +2,57 @@ package webhook
import (
"context"
+ "errors"
"net/http"
- "net/http/httptest"
"testing"
"time"
"github.com/charmbracelet/soft-serve/pkg/db/models"
+ "github.com/charmbracelet/soft-serve/pkg/ssrf"
)
-// TestSSRFProtection tests that the webhook system blocks SSRF attempts.
+// TestSSRFProtection is an integration test verifying the webhook send path
+// blocks private IPs end-to-end (models.Webhook -> secureHTTPClient -> ssrf).
func TestSSRFProtection(t *testing.T) {
tests := []struct {
name string
webhookURL string
shouldBlock bool
- description string
}{
- {
- name: "block localhost",
- webhookURL: "http://localhost:8080/webhook",
- shouldBlock: true,
- description: "should block localhost addresses",
- },
- {
- name: "block 127.0.0.1",
- webhookURL: "http://127.0.0.1:8080/webhook",
- shouldBlock: true,
- description: "should block loopback addresses",
- },
- {
- name: "block 169.254.169.254",
- webhookURL: "http://169.254.169.254/latest/meta-data/",
- shouldBlock: true,
- description: "should block cloud metadata service",
- },
- {
- name: "block private network",
- webhookURL: "http://192.168.1.1/webhook",
- shouldBlock: true,
- description: "should block private networks",
- },
- {
- name: "allow public IP",
- webhookURL: "http://8.8.8.8/webhook",
- shouldBlock: false,
- description: "should allow public IP addresses",
- },
+ {"block loopback", "http://127.0.0.1:8080/webhook", true},
+ {"block metadata", "http://169.254.169.254/latest/meta-data/", true},
+ {"allow public IP", "http://8.8.8.8/webhook", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- // Create a test webhook
- webhook := models.Webhook{
+ w := models.Webhook{
URL: tt.webhookURL,
ContentType: int(ContentTypeJSON),
- Secret: "",
}
- // Try to send a webhook
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
- // Create a simple payload
- payload := map[string]string{"test": "data"}
-
- err := sendWebhookWithContext(ctx, webhook, EventPush, payload)
-
- if tt.shouldBlock {
- if err == nil {
- t.Errorf("%s: expected error but got none", tt.description)
- }
- } else {
- // For public IPs, we expect a connection error (since 8.8.8.8 won't be listening)
- // but NOT an SSRF blocking error
- if err != nil && isSSRFError(err) {
- t.Errorf("%s: should not block public IPs, got: %v", tt.description, err)
- }
+ req, err := http.NewRequestWithContext(ctx, "POST", w.URL, nil)
+ if err != nil {
+ t.Fatalf("failed to create request: %v", err)
}
- })
- }
-}
-
-// TestSecureHTTPClientBlocksRedirects tests that redirects are not followed.
-func TestSecureHTTPClientBlocksRedirects(t *testing.T) {
- // Create a test server on a public-looking address that redirects
- redirectServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- http.Redirect(w, r, "http://8.8.8.8:8080/safe", http.StatusFound)
- }))
- defer redirectServer.Close()
-
- // Try to make a request that would redirect
- req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, redirectServer.URL, nil)
- if err != nil {
- t.Fatalf("Failed to create request: %v", err)
- }
- resp, err := secureHTTPClient.Do(req)
- if err != nil {
- // httptest.NewServer uses 127.0.0.1, which will be blocked by our SSRF protection
- // This is actually correct behavior - we're blocking the initial connection
- if !isSSRFError(err) {
- t.Fatalf("Request failed with non-SSRF error: %v", err)
- }
- // Test passed - we blocked the loopback connection
- return
- }
- defer resp.Body.Close()
-
- // If we got here, check that we got the redirect response (not followed)
- if resp.StatusCode != http.StatusFound {
- t.Errorf("Expected redirect response (302), got %d", resp.StatusCode)
- }
-}
-
-// TestDialContextBlocksPrivateIPs tests the DialContext function directly.
-func TestDialContextBlocksPrivateIPs(t *testing.T) {
- transport := secureHTTPClient.Transport.(*http.Transport)
-
- tests := []struct {
- name string
- addr string
- wantErr bool
- }{
- {"block loopback", "127.0.0.1:80", true},
- {"block private 10.x", "10.0.0.1:80", true},
- {"block private 192.168.x", "192.168.1.1:80", true},
- {"block link-local", "169.254.169.254:80", true},
- {"allow public IP", "8.8.8.8:80", false},
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
- defer cancel()
-
- conn, err := transport.DialContext(ctx, "tcp", tt.addr)
- if conn != nil {
- conn.Close()
+ resp, err := secureHTTPClient.Do(req)
+ if resp != nil {
+ resp.Body.Close()
}
- if tt.wantErr {
+ if tt.shouldBlock {
if err == nil {
- t.Errorf("Expected error for %s, got none", tt.addr)
+ t.Errorf("%s: expected error but got none", tt.name)
}
} else {
- // For public IPs, we expect a connection timeout/refused (not an SSRF block)
- if err != nil && isSSRFError(err) {
- t.Errorf("Should not block %s with SSRF error, got: %v", tt.addr, err)
+ if err != nil && errors.Is(err, ssrf.ErrPrivateIP) {
+ t.Errorf("%s: should not block public IPs, got: %v", tt.name, err)
}
}
})
}
}
-
-// sendWebhookWithContext is a test helper that doesn't require database.
-func sendWebhookWithContext(ctx context.Context, w models.Webhook, _ Event, _ any) error {
- // This is a simplified version for testing that just attempts the HTTP connection
- req, err := http.NewRequestWithContext(ctx, "POST", w.URL, nil)
- if err != nil {
- return err //nolint:wrapcheck
- }
- req = req.WithContext(ctx)
-
- resp, err := secureHTTPClient.Do(req)
- if resp != nil {
- resp.Body.Close()
- }
- return err //nolint:wrapcheck
-}
-
-// isSSRFError checks if an error is related to SSRF blocking.
-func isSSRFError(err error) bool {
- if err == nil {
- return false
- }
- errMsg := err.Error()
- return contains(errMsg, "private IP") ||
- contains(errMsg, "blocked connection") ||
- err == ErrPrivateIP
-}
-
-func contains(s, substr string) bool {
- return len(s) >= len(substr) && (s == substr || len(substr) == 0 || indexOfSubstring(s, substr) >= 0)
-}
-
-func indexOfSubstring(s, substr string) int {
- for i := 0; i <= len(s)-len(substr); i++ {
- if s[i:i+len(substr)] == substr {
- return i
- }
- }
- return -1
-}
-
-// TestPrivateIPResolution tests that hostnames resolving to private IPs are blocked.
-func TestPrivateIPResolution(t *testing.T) {
- // This test verifies that even if a hostname looks public, if it resolves to a private IP, it's blocked
- webhook := models.Webhook{
- URL: "http://127.0.0.1:9999/webhook",
- ContentType: int(ContentTypeJSON),
- }
-
- ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
- defer cancel()
-
- err := sendWebhookWithContext(ctx, webhook, EventPush, map[string]string{"test": "data"})
- if err == nil {
- t.Error("Expected error when connecting to loopback address")
- return
- }
-
- if !isSSRFError(err) {
- t.Errorf("Expected SSRF blocking error, got: %v", err)
- }
-}
@@ -1,169 +1,20 @@
package webhook
import (
- "context"
- "errors"
- "fmt"
- "net"
- "net/url"
- "slices"
- "strings"
+ "github.com/charmbracelet/soft-serve/pkg/ssrf"
)
+// Error aliases for backward compatibility.
var (
- // ErrInvalidScheme is returned when the webhook URL scheme is not http or https.
- ErrInvalidScheme = errors.New("webhook URL must use http or https scheme")
- // ErrPrivateIP is returned when the webhook URL resolves to a private IP address.
- ErrPrivateIP = errors.New("webhook URL cannot resolve to private or internal IP addresses")
- // ErrInvalidURL is returned when the webhook URL is invalid.
- ErrInvalidURL = errors.New("invalid webhook URL")
+ ErrInvalidScheme = ssrf.ErrInvalidScheme
+ ErrPrivateIP = ssrf.ErrPrivateIP
+ ErrInvalidURL = ssrf.ErrInvalidURL
)
// ValidateWebhookURL validates that a webhook URL is safe to use.
-// It checks:
-// - URL is properly formatted
-// - Scheme is http or https
-// - Hostname does not resolve to private/internal IP addresses
-// - Hostname is not localhost or similar.
func ValidateWebhookURL(rawURL string) error {
- if rawURL == "" {
- return ErrInvalidURL
- }
-
- // Parse the URL
- u, err := url.Parse(rawURL)
- if err != nil {
- return fmt.Errorf("%w: %v", ErrInvalidURL, err)
- }
-
- // Check scheme
- if u.Scheme != "http" && u.Scheme != "https" {
- return ErrInvalidScheme
- }
-
- // Extract hostname (without port)
- hostname := u.Hostname()
- if hostname == "" {
- return fmt.Errorf("%w: missing hostname", ErrInvalidURL)
- }
-
- // Check for localhost variations
- if isLocalhost(hostname) {
- return ErrPrivateIP
- }
-
- // If it's an IP address, validate it directly
- if ip := net.ParseIP(hostname); ip != nil {
- if isPrivateOrInternalIP(ip) {
- return ErrPrivateIP
- }
- return nil
- }
-
- // Resolve hostname to IP addresses
- ips, err := net.DefaultResolver.LookupIPAddr(context.Background(), hostname)
- if err != nil {
- return fmt.Errorf("%w: cannot resolve hostname: %v", ErrInvalidURL, err)
- }
-
- // Check all resolved IPs
- if slices.ContainsFunc(ips, isPrivateOrInternalIPAddr) {
- return ErrPrivateIP
- }
-
- return nil
-}
-
-// isLocalhost checks if the hostname is localhost or similar.
-func isLocalhost(hostname string) bool {
- hostname = strings.ToLower(hostname)
- return hostname == "localhost" ||
- hostname == "localhost.localdomain" ||
- strings.HasSuffix(hostname, ".localhost")
-}
-
-// isPrivateOrInternalIPAddr is a helper function that users net.IPAddr instead of net.IP.
-func isPrivateOrInternalIPAddr(ipAddr net.IPAddr) bool {
- return isPrivateOrInternalIP(ipAddr.IP)
-}
-
-// isPrivateOrInternalIP checks if an IP address is private, internal, or reserved.
-func isPrivateOrInternalIP(ip net.IP) bool {
- // Loopback addresses (127.0.0.0/8, ::1)
- if ip.IsLoopback() {
- return true
- }
-
- // Link-local addresses (169.254.0.0/16, fe80::/10)
- // This blocks AWS/GCP/Azure metadata services
- if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
- return true
- }
-
- // Private addresses (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, fc00::/7)
- if ip.IsPrivate() {
- return true
- }
-
- // Unspecified addresses (0.0.0.0, ::)
- if ip.IsUnspecified() {
- return true
- }
-
- // Multicast addresses
- if ip.IsMulticast() {
- return true
- }
-
- // Additional checks for IPv4
- if ip4 := ip.To4(); ip4 != nil {
- // 0.0.0.0/8 (current network)
- if ip4[0] == 0 {
- return true
- }
- // 100.64.0.0/10 (Shared Address Space)
- if ip4[0] == 100 && ip4[1] >= 64 && ip4[1] <= 127 {
- return true
- }
- // 192.0.0.0/24 (IETF Protocol Assignments)
- if ip4[0] == 192 && ip4[1] == 0 && ip4[2] == 0 {
- return true
- }
- // 192.0.2.0/24 (TEST-NET-1)
- if ip4[0] == 192 && ip4[1] == 0 && ip4[2] == 2 {
- return true
- }
- // 198.18.0.0/15 (benchmarking)
- if ip4[0] == 198 && (ip4[1] == 18 || ip4[1] == 19) {
- return true
- }
- // 198.51.100.0/24 (TEST-NET-2)
- if ip4[0] == 198 && ip4[1] == 51 && ip4[2] == 100 {
- return true
- }
- // 203.0.113.0/24 (TEST-NET-3)
- if ip4[0] == 203 && ip4[1] == 0 && ip4[2] == 113 {
- return true
- }
- // 224.0.0.0/4 (Multicast - already handled by IsMulticast)
- // 240.0.0.0/4 (Reserved for future use)
- if ip4[0] >= 240 {
- return true
- }
- // 255.255.255.255/32 (Broadcast)
- if ip4[0] == 255 && ip4[1] == 255 && ip4[2] == 255 && ip4[3] == 255 {
- return true
- }
- }
-
- return false
+ return ssrf.ValidateURL(rawURL) //nolint:wrapcheck
}
// ValidateIPBeforeDial validates an IP address before establishing a connection.
-// This is used to prevent DNS rebinding attacks.
-func ValidateIPBeforeDial(ip net.IP) error {
- if isPrivateOrInternalIP(ip) {
- return ErrPrivateIP
- }
- return nil
-}
+var ValidateIPBeforeDial = ssrf.ValidateIPBeforeDial
@@ -1,315 +1,52 @@
package webhook
import (
- "net"
+ "errors"
"testing"
+
+ "github.com/charmbracelet/soft-serve/pkg/ssrf"
)
+// TestValidateWebhookURL verifies the wrapper delegates correctly and
+// error aliases work across the package boundary. IP range coverage
+// is in pkg/ssrf/ssrf_test.go -- here we just confirm the plumbing.
func TestValidateWebhookURL(t *testing.T) {
tests := []struct {
name string
url string
wantErr bool
errType error
- skip string
}{
- // Valid URLs (these will perform DNS lookups, so may fail in some environments)
- {
- name: "valid https URL",
- url: "https://1.1.1.1/webhook",
- wantErr: false,
- },
- {
- name: "valid http URL",
- url: "http://8.8.8.8/webhook",
- wantErr: false,
- },
- {
- name: "valid URL with port",
- url: "https://1.1.1.1:8080/webhook",
- wantErr: false,
- },
- {
- name: "valid URL with path and query",
- url: "https://8.8.8.8/webhook?token=abc123",
- wantErr: false,
- },
-
- // Invalid schemes
- {
- name: "ftp scheme",
- url: "ftp://example.com/webhook",
- wantErr: true,
- errType: ErrInvalidScheme,
- },
- {
- name: "file scheme",
- url: "file:///etc/passwd",
- wantErr: true,
- errType: ErrInvalidScheme,
- },
- {
- name: "gopher scheme",
- url: "gopher://example.com",
- wantErr: true,
- errType: ErrInvalidScheme,
- },
- {
- name: "no scheme",
- url: "example.com/webhook",
- wantErr: true,
- errType: ErrInvalidScheme,
- },
-
- // Localhost variations
- {
- name: "localhost",
- url: "http://localhost/webhook",
- wantErr: true,
- errType: ErrPrivateIP,
- },
- {
- name: "localhost with port",
- url: "http://localhost:8080/webhook",
- wantErr: true,
- errType: ErrPrivateIP,
- },
- {
- name: "localhost.localdomain",
- url: "http://localhost.localdomain/webhook",
- wantErr: true,
- errType: ErrPrivateIP,
- },
-
- // Loopback IPs
- {
- name: "127.0.0.1",
- url: "http://127.0.0.1/webhook",
- wantErr: true,
- errType: ErrPrivateIP,
- },
- {
- name: "127.0.0.1 with port",
- url: "http://127.0.0.1:8080/webhook",
- wantErr: true,
- errType: ErrPrivateIP,
- },
- {
- name: "127.1.2.3",
- url: "http://127.1.2.3/webhook",
- wantErr: true,
- errType: ErrPrivateIP,
- },
- {
- name: "IPv6 loopback",
- url: "http://[::1]/webhook",
- wantErr: true,
- errType: ErrPrivateIP,
- },
-
- // Private IPv4 ranges
- {
- name: "10.0.0.0",
- url: "http://10.0.0.1/webhook",
- wantErr: true,
- errType: ErrPrivateIP,
- },
- {
- name: "192.168.0.0",
- url: "http://192.168.1.1/webhook",
- wantErr: true,
- errType: ErrPrivateIP,
- },
- {
- name: "172.16.0.0",
- url: "http://172.16.0.1/webhook",
- wantErr: true,
- errType: ErrPrivateIP,
- },
- {
- name: "172.31.255.255",
- url: "http://172.31.255.255/webhook",
- wantErr: true,
- errType: ErrPrivateIP,
- },
-
- // Link-local (AWS/GCP/Azure metadata)
- {
- name: "AWS metadata service",
- url: "http://169.254.169.254/latest/meta-data/",
- wantErr: true,
- errType: ErrPrivateIP,
- },
- {
- name: "link-local",
- url: "http://169.254.1.1/webhook",
- wantErr: true,
- errType: ErrPrivateIP,
- },
-
- // Other reserved ranges
- {
- name: "0.0.0.0",
- url: "http://0.0.0.0/webhook",
- wantErr: true,
- errType: ErrPrivateIP,
- },
- {
- name: "broadcast",
- url: "http://255.255.255.255/webhook",
- wantErr: true,
- errType: ErrPrivateIP,
- },
-
- // Invalid URLs
- {
- name: "empty URL",
- url: "",
- wantErr: true,
- errType: ErrInvalidURL,
- },
- {
- name: "missing hostname",
- url: "http:///webhook",
- wantErr: true,
- errType: ErrInvalidURL,
- },
+ {"valid", "https://1.1.1.1/webhook", false, nil},
+ {"bad scheme", "ftp://example.com", true, ErrInvalidScheme},
+ {"private IP", "http://127.0.0.1/webhook", true, ErrPrivateIP},
+ {"empty", "", true, ErrInvalidURL},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- if tt.skip != "" {
- t.Skip(tt.skip)
- }
err := ValidateWebhookURL(tt.url)
if (err != nil) != tt.wantErr {
- t.Errorf("ValidateWebhookURL() error = %v, wantErr %v", err, tt.wantErr)
+ t.Errorf("ValidateWebhookURL(%q) error = %v, wantErr %v", tt.url, err, tt.wantErr)
return
}
if tt.wantErr && tt.errType != nil {
- if !isErrorType(err, tt.errType) {
- t.Errorf("ValidateWebhookURL() error = %v, want error type %v", err, tt.errType)
+ if !errors.Is(err, tt.errType) {
+ t.Errorf("ValidateWebhookURL(%q) error = %v, want %v", tt.url, err, tt.errType)
}
}
})
}
}
-func TestIsPrivateOrInternalIP(t *testing.T) {
- tests := []struct {
- name string
- ip string
- isPriv bool
- }{
- // Public IPs
- {"Google DNS", "8.8.8.8", false},
- {"Cloudflare DNS", "1.1.1.1", false},
- {"Public IPv6", "2001:4860:4860::8888", false},
-
- // Loopback
- {"127.0.0.1", "127.0.0.1", true},
- {"127.1.2.3", "127.1.2.3", true},
- {"::1", "::1", true},
-
- // Private ranges
- {"10.0.0.1", "10.0.0.1", true},
- {"192.168.1.1", "192.168.1.1", true},
- {"172.16.0.1", "172.16.0.1", true},
- {"172.31.255.255", "172.31.255.255", true},
-
- // Link-local
- {"169.254.169.254", "169.254.169.254", true},
- {"169.254.1.1", "169.254.1.1", true},
- {"fe80::1", "fe80::1", true},
-
- // Other reserved
- {"0.0.0.0", "0.0.0.0", true},
- {"255.255.255.255", "255.255.255.255", true},
- {"240.0.0.1", "240.0.0.1", true},
-
- // Shared address space
- {"100.64.0.1", "100.64.0.1", true},
- {"100.127.255.255", "100.127.255.255", true},
+func TestErrorAliases(t *testing.T) {
+ if ErrPrivateIP != ssrf.ErrPrivateIP {
+ t.Error("ErrPrivateIP should alias ssrf.ErrPrivateIP")
}
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- ip := net.ParseIP(tt.ip)
- if ip == nil {
- t.Fatalf("Failed to parse IP: %s", tt.ip)
- }
- if got := isPrivateOrInternalIP(ip); got != tt.isPriv {
- t.Errorf("isPrivateOrInternalIP(%s) = %v, want %v", tt.ip, got, tt.isPriv)
- }
- })
- }
-}
-
-func TestIsLocalhost(t *testing.T) {
- tests := []struct {
- name string
- hostname string
- want bool
- }{
- {"localhost", "localhost", true},
- {"LOCALHOST", "LOCALHOST", true},
- {"localhost.localdomain", "localhost.localdomain", true},
- {"test.localhost", "test.localhost", true},
- {"example.com", "example.com", false},
- {"localhos", "localhos", false},
- {"localhost.com", "localhost.com", false},
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := isLocalhost(tt.hostname); got != tt.want {
- t.Errorf("isLocalhost(%s) = %v, want %v", tt.hostname, got, tt.want)
- }
- })
- }
-}
-
-func TestValidateIPBeforeDial(t *testing.T) {
- tests := []struct {
- name string
- ip string
- wantErr bool
- }{
- {"public IP", "8.8.8.8", false},
- {"private IP", "192.168.1.1", true},
- {"loopback", "127.0.0.1", true},
- {"link-local", "169.254.169.254", true},
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- ip := net.ParseIP(tt.ip)
- if ip == nil {
- t.Fatalf("Failed to parse IP: %s", tt.ip)
- }
- err := ValidateIPBeforeDial(ip)
- if (err != nil) != tt.wantErr {
- t.Errorf("ValidateIPBeforeDial(%s) error = %v, wantErr %v", tt.ip, err, tt.wantErr)
- }
- })
- }
-}
-
-// isErrorType checks if err is or wraps errType.
-func isErrorType(err, errType error) bool {
- if err == errType {
- return true
+ if ErrInvalidScheme != ssrf.ErrInvalidScheme {
+ t.Error("ErrInvalidScheme should alias ssrf.ErrInvalidScheme")
}
- // Check if err wraps errType
- for err != nil {
- if err == errType {
- return true
- }
- unwrapped, ok := err.(interface{ Unwrap() error })
- if !ok {
- break
- }
- err = unwrapped.Unwrap()
+ if ErrInvalidURL != ssrf.ErrInvalidURL {
+ t.Error("ErrInvalidURL should alias ssrf.ErrInvalidURL")
}
- return false
}
@@ -10,14 +10,13 @@ import (
"errors"
"fmt"
"io"
- "net"
"net/http"
- "time"
"github.com/charmbracelet/soft-serve/git"
"github.com/charmbracelet/soft-serve/pkg/db"
"github.com/charmbracelet/soft-serve/pkg/db/models"
"github.com/charmbracelet/soft-serve/pkg/proto"
+ "github.com/charmbracelet/soft-serve/pkg/ssrf"
"github.com/charmbracelet/soft-serve/pkg/store"
"github.com/charmbracelet/soft-serve/pkg/utils"
"github.com/charmbracelet/soft-serve/pkg/version"
@@ -38,42 +37,8 @@ type Delivery struct {
Event Event
}
-// secureHTTPClient creates an HTTP client with SSRF protection.
-var secureHTTPClient = &http.Client{
- Timeout: 30 * time.Second,
- Transport: &http.Transport{
- DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
- // Parse the address to get the IP
- host, _, err := net.SplitHostPort(addr)
- if err != nil {
- return nil, err //nolint:wrapcheck
- }
-
- // Validate the resolved IP before connecting
- ip := net.ParseIP(host)
- if ip != nil {
- if err := ValidateIPBeforeDial(ip); err != nil {
- return nil, fmt.Errorf("blocked connection to private IP: %w", err)
- }
- }
-
- // Use standard dialer with timeout
- dialer := &net.Dialer{
- Timeout: 10 * time.Second,
- KeepAlive: 30 * time.Second,
- }
- return dialer.DialContext(ctx, network, addr)
- },
- MaxIdleConns: 100,
- IdleConnTimeout: 90 * time.Second,
- TLSHandshakeTimeout: 10 * time.Second,
- ExpectContinueTimeout: 1 * time.Second,
- },
- // Don't follow redirects to prevent bypassing IP validation
- CheckRedirect: func(*http.Request, []*http.Request) error {
- return http.ErrUseLastResponse
- },
-}
+// secureHTTPClient is an HTTP client with SSRF protection.
+var secureHTTPClient = ssrf.NewSecureClient()
// do sends a webhook.
// Caller must close the returned body.