webhook.go

  1package webhook
  2
  3import (
  4	"bytes"
  5	"context"
  6	"crypto/hmac"
  7	"crypto/sha256"
  8	"encoding/hex"
  9	"encoding/json"
 10	"errors"
 11	"fmt"
 12	"io"
 13	"net"
 14	"net/http"
 15	"time"
 16
 17	"github.com/charmbracelet/soft-serve/git"
 18	"github.com/charmbracelet/soft-serve/pkg/db"
 19	"github.com/charmbracelet/soft-serve/pkg/db/models"
 20	"github.com/charmbracelet/soft-serve/pkg/proto"
 21	"github.com/charmbracelet/soft-serve/pkg/store"
 22	"github.com/charmbracelet/soft-serve/pkg/utils"
 23	"github.com/charmbracelet/soft-serve/pkg/version"
 24	"github.com/google/go-querystring/query"
 25	"github.com/google/uuid"
 26)
 27
 28// Hook is a repository webhook.
 29type Hook struct {
 30	models.Webhook
 31	ContentType ContentType
 32	Events      []Event
 33}
 34
 35// Delivery is a webhook delivery.
 36type Delivery struct {
 37	models.WebhookDelivery
 38	Event Event
 39}
 40
 41// secureHTTPClient creates an HTTP client with SSRF protection.
 42var secureHTTPClient = &http.Client{
 43	Timeout: 30 * time.Second,
 44	Transport: &http.Transport{
 45		DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
 46			// Parse the address to get the IP
 47			host, _, err := net.SplitHostPort(addr)
 48			if err != nil {
 49				return nil, err //nolint:wrapcheck
 50			}
 51
 52			// Validate the resolved IP before connecting
 53			ip := net.ParseIP(host)
 54			if ip != nil {
 55				if err := ValidateIPBeforeDial(ip); err != nil {
 56					return nil, fmt.Errorf("blocked connection to private IP: %w", err)
 57				}
 58			}
 59
 60			// Use standard dialer with timeout
 61			dialer := &net.Dialer{
 62				Timeout:   10 * time.Second,
 63				KeepAlive: 30 * time.Second,
 64			}
 65			return dialer.DialContext(ctx, network, addr)
 66		},
 67		MaxIdleConns:          100,
 68		IdleConnTimeout:       90 * time.Second,
 69		TLSHandshakeTimeout:   10 * time.Second,
 70		ExpectContinueTimeout: 1 * time.Second,
 71	},
 72	// Don't follow redirects to prevent bypassing IP validation
 73	CheckRedirect: func(*http.Request, []*http.Request) error {
 74		return http.ErrUseLastResponse
 75	},
 76}
 77
 78// do sends a webhook.
 79// Caller must close the returned body.
 80func do(ctx context.Context, url string, method string, headers http.Header, body io.Reader) (*http.Response, error) {
 81	req, err := http.NewRequestWithContext(ctx, method, url, body)
 82	if err != nil {
 83		return nil, err
 84	}
 85
 86	req.Header = headers
 87	res, err := secureHTTPClient.Do(req)
 88	if err != nil {
 89		return nil, err
 90	}
 91
 92	return res, nil
 93}
 94
 95// SendWebhook sends a webhook event.
 96func SendWebhook(ctx context.Context, w models.Webhook, event Event, payload interface{}) error {
 97	var buf bytes.Buffer
 98	dbx := db.FromContext(ctx)
 99	datastore := store.FromContext(ctx)
100
101	contentType := ContentType(w.ContentType) //nolint:gosec
102	switch contentType {
103	case ContentTypeJSON:
104		if err := json.NewEncoder(&buf).Encode(payload); err != nil {
105			return err
106		}
107	case ContentTypeForm:
108		v, err := query.Values(payload)
109		if err != nil {
110			return err
111		}
112		buf.WriteString(v.Encode()) // nolint: errcheck
113	default:
114		return ErrInvalidContentType
115	}
116
117	headers := http.Header{}
118	headers.Add("Content-Type", contentType.String())
119	headers.Add("User-Agent", "SoftServe/"+version.Version)
120	headers.Add("X-SoftServe-Event", event.String())
121
122	id, err := uuid.NewUUID()
123	if err != nil {
124		return err
125	}
126
127	headers.Add("X-SoftServe-Delivery", id.String())
128
129	reqBody := buf.String()
130	if w.Secret != "" {
131		sig := hmac.New(sha256.New, []byte(w.Secret))
132		sig.Write([]byte(reqBody)) // nolint: errcheck
133		headers.Add("X-SoftServe-Signature", "sha256="+hex.EncodeToString(sig.Sum(nil)))
134	}
135
136	res, reqErr := do(ctx, w.URL, http.MethodPost, headers, &buf)
137	var reqHeaders string
138	for k, v := range headers {
139		reqHeaders += k + ": " + v[0] + "\n"
140	}
141
142	resStatus := 0
143	resHeaders := ""
144	resBody := ""
145
146	if res != nil {
147		resStatus = res.StatusCode
148		for k, v := range res.Header {
149			resHeaders += k + ": " + v[0] + "\n"
150		}
151
152		if res.Body != nil {
153			defer res.Body.Close() // nolint: errcheck
154			b, err := io.ReadAll(res.Body)
155			if err != nil {
156				return err
157			}
158
159			resBody = string(b)
160		}
161	}
162
163	return db.WrapError(datastore.CreateWebhookDelivery(ctx, dbx, id, w.ID, int(event), w.URL, http.MethodPost, reqErr, reqHeaders, reqBody, resStatus, resHeaders, resBody))
164}
165
166// SendEvent sends a webhook event.
167func SendEvent(ctx context.Context, payload EventPayload) error {
168	dbx := db.FromContext(ctx)
169	datastore := store.FromContext(ctx)
170	webhooks, err := datastore.GetWebhooksByRepoIDWhereEvent(ctx, dbx, payload.RepositoryID(), []int{int(payload.Event())})
171	if err != nil {
172		return db.WrapError(err)
173	}
174
175	for _, w := range webhooks {
176		if err := SendWebhook(ctx, w, payload.Event(), payload); err != nil {
177			return err
178		}
179	}
180
181	return nil
182}
183
184func repoURL(publicURL string, repo string) string {
185	return fmt.Sprintf("%s/%s.git", publicURL, utils.SanitizeRepo(repo))
186}
187
188func getDefaultBranch(repo proto.Repository) (string, error) {
189	branch, err := proto.RepositoryDefaultBranch(repo)
190	// XXX: we check for ErrReferenceNotExist here because we don't want to
191	// return an error if the repo is an empty repo.
192	// This means that the repo doesn't have a default branch yet and this is
193	// the first push to it.
194	if err != nil && !errors.Is(err, git.ErrReferenceNotExist) {
195		return "", err
196	}
197
198	return branch, nil
199}