@@ -26,6 +26,7 @@ import (
"github.com/charmbracelet/crush/internal/log"
"github.com/charmbracelet/crush/internal/lsp"
"github.com/charmbracelet/crush/internal/message"
+ "github.com/charmbracelet/crush/internal/oauth/copilot"
"github.com/charmbracelet/crush/internal/permission"
"github.com/charmbracelet/crush/internal/session"
"golang.org/x/sync/errgroup"
@@ -582,15 +583,23 @@ func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[stri
return openrouter.New(opts...)
}
-func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any) (fantasy.Provider, error) {
+func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string) (fantasy.Provider, error) {
opts := []openaicompat.Option{
openaicompat.WithBaseURL(baseURL),
openaicompat.WithAPIKey(apiKey),
}
- if c.cfg.Options.Debug {
- httpClient := log.NewHTTPClient()
+
+ // Set HTTP client based on provider and debug mode.
+ var httpClient *http.Client
+ if providerID == string(catwalk.InferenceProviderCopilot) {
+ httpClient = copilot.NewClient(c.cfg.Options.Debug)
+ } else if c.cfg.Options.Debug {
+ httpClient = log.NewHTTPClient()
+ }
+ if httpClient != nil {
opts = append(opts, openaicompat.WithHTTPClient(httpClient))
}
+
if len(headers) > 0 {
opts = append(opts, openaicompat.WithHeaders(headers))
}
@@ -745,7 +754,7 @@ func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model con
}
providerCfg.ExtraBody["tool_stream"] = true
}
- return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody)
+ return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID)
case hyper.Name:
return c.buildHyperProvider(baseURL, apiKey)
default:
@@ -0,0 +1,77 @@
+// Package copilot provides GitHub Copilot integration.
+package copilot
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "log/slog"
+ "net/http"
+ "regexp"
+
+ "github.com/charmbracelet/crush/internal/log"
+)
+
+// NewClient creates a new HTTP client with a custom transport that adds the
+// X-Initiator header based on message history in the request body.
+func NewClient(debug bool) *http.Client {
+ return &http.Client{
+ Transport: &initiatorTransport{debug: debug},
+ }
+}
+
+type initiatorTransport struct {
+ debug bool
+}
+
+func (t *initiatorTransport) RoundTrip(req *http.Request) (*http.Response, error) {
+ const (
+ xInitiatorHeader = "X-Initiator"
+ userInitiator = "user"
+ agentInitiator = "agent"
+ )
+
+ if req == nil {
+ return nil, fmt.Errorf("HTTP request is nil")
+ }
+ if req.Body == http.NoBody {
+ // No body to inspect; default to user.
+ req.Header.Set(xInitiatorHeader, userInitiator)
+ slog.Debug("Setting X-Initiator header to user (no request body)")
+ return t.roundTrip(req)
+ }
+
+ // Clone request to avoid modifying the original.
+ req = req.Clone(req.Context())
+
+ // Read the original body into bytes so we can examine it.
+ bodyBytes, err := io.ReadAll(req.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read request body: %w", err)
+ }
+ defer req.Body.Close()
+
+ // Restore the original body using the preserved bytes.
+ req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
+
+ // Check for assistant messages using regex to handle whitespace
+ // variations in the JSON while avoiding full unmarshalling overhead.
+ initiator := userInitiator
+ assistantRolePattern := regexp.MustCompile(`"role"\s*:\s*"assistant"`)
+ if assistantRolePattern.Match(bodyBytes) {
+ slog.Debug("Setting X-Initiator header to agent (found assistant messages in history)")
+ initiator = agentInitiator
+ } else {
+ slog.Debug("Setting X-Initiator header to user (no assistant messages)")
+ }
+ req.Header.Set(xInitiatorHeader, initiator)
+
+ return t.roundTrip(req)
+}
+
+func (t *initiatorTransport) roundTrip(req *http.Request) (*http.Response, error) {
+ if t.debug {
+ return log.NewHTTPClient().Transport.RoundTrip(req)
+ }
+ return http.DefaultTransport.RoundTrip(req)
+}