client.go

 1// Package copilot provides GitHub Copilot integration.
 2package copilot
 3
 4import (
 5	"bytes"
 6	"fmt"
 7	"io"
 8	"log/slog"
 9	"net/http"
10	"regexp"
11
12	"github.com/charmbracelet/crush/internal/log"
13)
14
15var assistantRolePattern = regexp.MustCompile(`"role"\s*:\s*"assistant"`)
16
17// NewClient creates a new HTTP client with a custom transport that adds the
18// X-Initiator header based on message history in the request body.
19func NewClient(isSubAgent, debug bool) *http.Client {
20	return &http.Client{
21		Transport: &initiatorTransport{debug: debug, isSubAgent: isSubAgent},
22	}
23}
24
25type initiatorTransport struct {
26	debug      bool
27	isSubAgent bool
28}
29
30func (t *initiatorTransport) RoundTrip(req *http.Request) (*http.Response, error) {
31	const (
32		xInitiatorHeader = "X-Initiator"
33		userInitiator    = "user"
34		agentInitiator   = "agent"
35	)
36
37	if req == nil {
38		return nil, fmt.Errorf("HTTP request is nil")
39	}
40	if req.Body == http.NoBody {
41		// No body to inspect; default to user.
42		req.Header.Set(xInitiatorHeader, userInitiator)
43		slog.Debug("Setting X-Initiator header to user (no request body)")
44		return t.roundTrip(req)
45	}
46
47	// Clone request to avoid modifying the original.
48	req = req.Clone(req.Context())
49
50	// Read the original body into bytes so we can examine it.
51	bodyBytes, err := io.ReadAll(req.Body)
52	if err != nil {
53		return nil, fmt.Errorf("failed to read request body: %w", err)
54	}
55	defer req.Body.Close()
56
57	// Restore the original body using the preserved bytes.
58	req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
59
60	// Check for assistant messages using regex to handle whitespace
61	// variations in the JSON while avoiding full unmarshalling overhead.
62	initiator := userInitiator
63	if assistantRolePattern.Match(bodyBytes) || t.isSubAgent {
64		slog.Debug("Setting X-Initiator header to agent (found assistant messages in history)")
65		initiator = agentInitiator
66	} else {
67		slog.Debug("Setting X-Initiator header to user (no assistant messages)")
68	}
69	req.Header.Set(xInitiatorHeader, initiator)
70
71	return t.roundTrip(req)
72}
73
74func (t *initiatorTransport) roundTrip(req *http.Request) (*http.Response, error) {
75	if t.debug {
76		return log.NewHTTPClient().Transport.RoundTrip(req)
77	}
78	return http.DefaultTransport.RoundTrip(req)
79}