1// Package copilot provides GitHub Copilot integration.
2package copilot
3
4import (
5 "bytes"
6 "fmt"
7 "io"
8 "log/slog"
9 "net/http"
10 "regexp"
11
12 "git.secluded.site/crush/internal/log"
13)
14
15// NewClient creates a new HTTP client with a custom transport that adds the
16// X-Initiator header based on message history in the request body.
17func NewClient(isSubAgent, debug bool) *http.Client {
18 return &http.Client{
19 Transport: &initiatorTransport{debug: debug, isSubAgent: isSubAgent},
20 }
21}
22
23type initiatorTransport struct {
24 debug bool
25 isSubAgent bool
26}
27
28func (t *initiatorTransport) RoundTrip(req *http.Request) (*http.Response, error) {
29 const (
30 xInitiatorHeader = "X-Initiator"
31 userInitiator = "user"
32 agentInitiator = "agent"
33 )
34
35 if req == nil {
36 return nil, fmt.Errorf("HTTP request is nil")
37 }
38 if req.Body == http.NoBody {
39 // No body to inspect; default to user.
40 req.Header.Set(xInitiatorHeader, userInitiator)
41 slog.Debug("Setting X-Initiator header to user (no request body)")
42 return t.roundTrip(req)
43 }
44
45 // Clone request to avoid modifying the original.
46 req = req.Clone(req.Context())
47
48 // Read the original body into bytes so we can examine it.
49 bodyBytes, err := io.ReadAll(req.Body)
50 if err != nil {
51 return nil, fmt.Errorf("failed to read request body: %w", err)
52 }
53 defer req.Body.Close()
54
55 // Restore the original body using the preserved bytes.
56 req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
57
58 // Check for assistant messages using regex to handle whitespace
59 // variations in the JSON while avoiding full unmarshalling overhead.
60 initiator := userInitiator
61 assistantRolePattern := regexp.MustCompile(`"role"\s*:\s*"assistant"`)
62 if assistantRolePattern.Match(bodyBytes) || t.isSubAgent {
63 slog.Debug("Setting X-Initiator header to agent (found assistant messages in history)")
64 initiator = agentInitiator
65 } else {
66 slog.Debug("Setting X-Initiator header to user (no assistant messages)")
67 }
68 req.Header.Set(xInitiatorHeader, initiator)
69
70 return t.roundTrip(req)
71}
72
73func (t *initiatorTransport) roundTrip(req *http.Request) (*http.Response, error) {
74 if t.debug {
75 return log.NewHTTPClient().Transport.RoundTrip(req)
76 }
77 return http.DefaultTransport.RoundTrip(req)
78}