1// Package copilot provides GitHub Copilot integration.
2package copilot
3
4import (
5 "bytes"
6 "fmt"
7 "io"
8 "log/slog"
9 "net/http"
10 "regexp"
11
12 "github.com/charmbracelet/crush/internal/log"
13)
14
15var assistantRolePattern = regexp.MustCompile(`"role"\s*:\s*"assistant"`)
16
17// NewClient creates a new HTTP client with a custom transport that adds the
18// X-Initiator header based on message history in the request body.
19func NewClient(isSubAgent, debug bool) *http.Client {
20 return &http.Client{
21 Transport: &initiatorTransport{debug: debug, isSubAgent: isSubAgent},
22 }
23}
24
25type initiatorTransport struct {
26 debug bool
27 isSubAgent bool
28}
29
30func (t *initiatorTransport) RoundTrip(req *http.Request) (*http.Response, error) {
31 const (
32 xInitiatorHeader = "X-Initiator"
33 userInitiator = "user"
34 agentInitiator = "agent"
35 )
36
37 if req == nil {
38 return nil, fmt.Errorf("HTTP request is nil")
39 }
40 if req.Body == http.NoBody {
41 // No body to inspect; default to user.
42 req.Header.Set(xInitiatorHeader, userInitiator)
43 slog.Debug("Setting X-Initiator header to user (no request body)")
44 return t.roundTrip(req)
45 }
46
47 // Clone request to avoid modifying the original.
48 req = req.Clone(req.Context())
49
50 // Read the original body into bytes so we can examine it.
51 bodyBytes, err := io.ReadAll(req.Body)
52 if err != nil {
53 return nil, fmt.Errorf("failed to read request body: %w", err)
54 }
55 defer req.Body.Close()
56
57 // Restore the original body using the preserved bytes.
58 req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
59
60 // Check for assistant messages using regex to handle whitespace
61 // variations in the JSON while avoiding full unmarshalling overhead.
62 initiator := userInitiator
63 if assistantRolePattern.Match(bodyBytes) || t.isSubAgent {
64 slog.Debug("Setting X-Initiator header to agent (found assistant messages in history)")
65 initiator = agentInitiator
66 } else {
67 slog.Debug("Setting X-Initiator header to user (no assistant messages)")
68 }
69 req.Header.Set(xInitiatorHeader, initiator)
70
71 return t.roundTrip(req)
72}
73
74func (t *initiatorTransport) roundTrip(req *http.Request) (*http.Response, error) {
75 if t.debug {
76 return log.NewHTTPClient().Transport.RoundTrip(req)
77 }
78 return http.DefaultTransport.RoundTrip(req)
79}