fix: copilot quota handling (#1738)

Kujtim Hoxha , Gustave-241021 , and Christian Rocha created

Co-authored-by: Gustave-241021 <2909789120@qq.com>
Co-authored-by: Christian Rocha <christian@rocha.is>

Change summary

internal/agent/coordinator.go    | 17 +++++-
internal/oauth/copilot/client.go | 77 ++++++++++++++++++++++++++++++++++
2 files changed, 90 insertions(+), 4 deletions(-)

Detailed changes

internal/agent/coordinator.go 🔗

@@ -26,6 +26,7 @@ import (
 	"github.com/charmbracelet/crush/internal/log"
 	"github.com/charmbracelet/crush/internal/lsp"
 	"github.com/charmbracelet/crush/internal/message"
+	"github.com/charmbracelet/crush/internal/oauth/copilot"
 	"github.com/charmbracelet/crush/internal/permission"
 	"github.com/charmbracelet/crush/internal/session"
 	"golang.org/x/sync/errgroup"
@@ -582,15 +583,23 @@ func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[stri
 	return openrouter.New(opts...)
 }
 
-func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any) (fantasy.Provider, error) {
+func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string) (fantasy.Provider, error) {
 	opts := []openaicompat.Option{
 		openaicompat.WithBaseURL(baseURL),
 		openaicompat.WithAPIKey(apiKey),
 	}
-	if c.cfg.Options.Debug {
-		httpClient := log.NewHTTPClient()
+
+	// Set HTTP client based on provider and debug mode.
+	var httpClient *http.Client
+	if providerID == string(catwalk.InferenceProviderCopilot) {
+		httpClient = copilot.NewClient(c.cfg.Options.Debug)
+	} else if c.cfg.Options.Debug {
+		httpClient = log.NewHTTPClient()
+	}
+	if httpClient != nil {
 		opts = append(opts, openaicompat.WithHTTPClient(httpClient))
 	}
+
 	if len(headers) > 0 {
 		opts = append(opts, openaicompat.WithHeaders(headers))
 	}
@@ -745,7 +754,7 @@ func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model con
 			}
 			providerCfg.ExtraBody["tool_stream"] = true
 		}
-		return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody)
+		return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID)
 	case hyper.Name:
 		return c.buildHyperProvider(baseURL, apiKey)
 	default:

internal/oauth/copilot/client.go 🔗

@@ -0,0 +1,77 @@
+// Package copilot provides GitHub Copilot integration.
+package copilot
+
+import (
+	"bytes"
+	"fmt"
+	"io"
+	"log/slog"
+	"net/http"
+	"regexp"
+
+	"github.com/charmbracelet/crush/internal/log"
+)
+
+// NewClient creates a new HTTP client with a custom transport that adds the
+// X-Initiator header based on message history in the request body.
+func NewClient(debug bool) *http.Client {
+	return &http.Client{
+		Transport: &initiatorTransport{debug: debug},
+	}
+}
+
+type initiatorTransport struct {
+	debug bool
+}
+
+func (t *initiatorTransport) RoundTrip(req *http.Request) (*http.Response, error) {
+	const (
+		xInitiatorHeader = "X-Initiator"
+		userInitiator    = "user"
+		agentInitiator   = "agent"
+	)
+
+	if req == nil {
+		return nil, fmt.Errorf("HTTP request is nil")
+	}
+	if req.Body == http.NoBody {
+		// No body to inspect; default to user.
+		req.Header.Set(xInitiatorHeader, userInitiator)
+		slog.Debug("Setting X-Initiator header to user (no request body)")
+		return t.roundTrip(req)
+	}
+
+	// Clone request to avoid modifying the original.
+	req = req.Clone(req.Context())
+
+	// Read the original body into bytes so we can examine it.
+	bodyBytes, err := io.ReadAll(req.Body)
+	if err != nil {
+		return nil, fmt.Errorf("failed to read request body: %w", err)
+	}
+	defer req.Body.Close()
+
+	// Restore the original body using the preserved bytes.
+	req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
+
+	// Check for assistant messages using regex to handle whitespace
+	// variations in the JSON while avoiding full unmarshalling overhead.
+	initiator := userInitiator
+	assistantRolePattern := regexp.MustCompile(`"role"\s*:\s*"assistant"`)
+	if assistantRolePattern.Match(bodyBytes) {
+		slog.Debug("Setting X-Initiator header to agent (found assistant messages in history)")
+		initiator = agentInitiator
+	} else {
+		slog.Debug("Setting X-Initiator header to user (no assistant messages)")
+	}
+	req.Header.Set(xInitiatorHeader, initiator)
+
+	return t.roundTrip(req)
+}
+
+func (t *initiatorTransport) roundTrip(req *http.Request) (*http.Response, error) {
+	if t.debug {
+		return log.NewHTTPClient().Transport.RoundTrip(req)
+	}
+	return http.DefaultTransport.RoundTrip(req)
+}