client.go

 1// Package copilot provides GitHub Copilot integration.
 2package copilot
 3
 4import (
 5	"bytes"
 6	"fmt"
 7	"io"
 8	"log/slog"
 9	"net/http"
10	"regexp"
11
12	"git.secluded.site/crush/internal/log"
13)
14
15// NewClient creates a new HTTP client with a custom transport that adds the
16// X-Initiator header based on message history in the request body.
17func NewClient(isSubAgent, debug bool) *http.Client {
18	return &http.Client{
19		Transport: &initiatorTransport{debug: debug, isSubAgent: isSubAgent},
20	}
21}
22
23type initiatorTransport struct {
24	debug      bool
25	isSubAgent bool
26}
27
28func (t *initiatorTransport) RoundTrip(req *http.Request) (*http.Response, error) {
29	const (
30		xInitiatorHeader = "X-Initiator"
31		userInitiator    = "user"
32		agentInitiator   = "agent"
33	)
34
35	if req == nil {
36		return nil, fmt.Errorf("HTTP request is nil")
37	}
38	if req.Body == http.NoBody {
39		// No body to inspect; default to user.
40		req.Header.Set(xInitiatorHeader, userInitiator)
41		slog.Debug("Setting X-Initiator header to user (no request body)")
42		return t.roundTrip(req)
43	}
44
45	// Clone request to avoid modifying the original.
46	req = req.Clone(req.Context())
47
48	// Read the original body into bytes so we can examine it.
49	bodyBytes, err := io.ReadAll(req.Body)
50	if err != nil {
51		return nil, fmt.Errorf("failed to read request body: %w", err)
52	}
53	defer req.Body.Close()
54
55	// Restore the original body using the preserved bytes.
56	req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
57
58	// Check for assistant messages using regex to handle whitespace
59	// variations in the JSON while avoiding full unmarshalling overhead.
60	initiator := userInitiator
61	assistantRolePattern := regexp.MustCompile(`"role"\s*:\s*"assistant"`)
62	if assistantRolePattern.Match(bodyBytes) || t.isSubAgent {
63		slog.Debug("Setting X-Initiator header to agent (found assistant messages in history)")
64		initiator = agentInitiator
65	} else {
66		slog.Debug("Setting X-Initiator header to user (no assistant messages)")
67	}
68	req.Header.Set(xInitiatorHeader, initiator)
69
70	return t.roundTrip(req)
71}
72
73func (t *initiatorTransport) roundTrip(req *http.Request) (*http.Response, error) {
74	if t.debug {
75		return log.NewHTTPClient().Transport.RoundTrip(req)
76	}
77	return http.DefaultTransport.RoundTrip(req)
78}