From aeadffde3bbb407bec508f7c8b95acf0991b1f3c Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Sat, 3 Jan 2026 18:36:33 +0100 Subject: [PATCH] fix: copilot quota handling (#1738) Co-authored-by: Gustave-241021 <2909789120@qq.com> Co-authored-by: Christian Rocha --- internal/agent/coordinator.go | 17 +++++-- internal/oauth/copilot/client.go | 77 ++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 4 deletions(-) create mode 100644 internal/oauth/copilot/client.go diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index 363f5690c8868ebb95726d1f66f628f301abef91..37cd960f16f0f7641288768a6a173e8ae92e9cc2 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -26,6 +26,7 @@ import ( "github.com/charmbracelet/crush/internal/log" "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/oauth/copilot" "github.com/charmbracelet/crush/internal/permission" "github.com/charmbracelet/crush/internal/session" "golang.org/x/sync/errgroup" @@ -582,15 +583,23 @@ func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[stri return openrouter.New(opts...) } -func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any) (fantasy.Provider, error) { +func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string) (fantasy.Provider, error) { opts := []openaicompat.Option{ openaicompat.WithBaseURL(baseURL), openaicompat.WithAPIKey(apiKey), } - if c.cfg.Options.Debug { - httpClient := log.NewHTTPClient() + + // Set HTTP client based on provider and debug mode. + var httpClient *http.Client + if providerID == string(catwalk.InferenceProviderCopilot) { + httpClient = copilot.NewClient(c.cfg.Options.Debug) + } else if c.cfg.Options.Debug { + httpClient = log.NewHTTPClient() + } + if httpClient != nil { opts = append(opts, openaicompat.WithHTTPClient(httpClient)) } + if len(headers) > 0 { opts = append(opts, openaicompat.WithHeaders(headers)) } @@ -745,7 +754,7 @@ func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model con } providerCfg.ExtraBody["tool_stream"] = true } - return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody) + return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID) case hyper.Name: return c.buildHyperProvider(baseURL, apiKey) default: diff --git a/internal/oauth/copilot/client.go b/internal/oauth/copilot/client.go new file mode 100644 index 0000000000000000000000000000000000000000..2b59a1a4bf379edd3c6d02b24b2d241f5fd4d9e1 --- /dev/null +++ b/internal/oauth/copilot/client.go @@ -0,0 +1,77 @@ +// Package copilot provides GitHub Copilot integration. +package copilot + +import ( + "bytes" + "fmt" + "io" + "log/slog" + "net/http" + "regexp" + + "github.com/charmbracelet/crush/internal/log" +) + +// NewClient creates a new HTTP client with a custom transport that adds the +// X-Initiator header based on message history in the request body. +func NewClient(debug bool) *http.Client { + return &http.Client{ + Transport: &initiatorTransport{debug: debug}, + } +} + +type initiatorTransport struct { + debug bool +} + +func (t *initiatorTransport) RoundTrip(req *http.Request) (*http.Response, error) { + const ( + xInitiatorHeader = "X-Initiator" + userInitiator = "user" + agentInitiator = "agent" + ) + + if req == nil { + return nil, fmt.Errorf("HTTP request is nil") + } + if req.Body == http.NoBody { + // No body to inspect; default to user. + req.Header.Set(xInitiatorHeader, userInitiator) + slog.Debug("Setting X-Initiator header to user (no request body)") + return t.roundTrip(req) + } + + // Clone request to avoid modifying the original. + req = req.Clone(req.Context()) + + // Read the original body into bytes so we can examine it. + bodyBytes, err := io.ReadAll(req.Body) + if err != nil { + return nil, fmt.Errorf("failed to read request body: %w", err) + } + defer req.Body.Close() + + // Restore the original body using the preserved bytes. + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + + // Check for assistant messages using regex to handle whitespace + // variations in the JSON while avoiding full unmarshalling overhead. + initiator := userInitiator + assistantRolePattern := regexp.MustCompile(`"role"\s*:\s*"assistant"`) + if assistantRolePattern.Match(bodyBytes) { + slog.Debug("Setting X-Initiator header to agent (found assistant messages in history)") + initiator = agentInitiator + } else { + slog.Debug("Setting X-Initiator header to user (no assistant messages)") + } + req.Header.Set(xInitiatorHeader, initiator) + + return t.roundTrip(req) +} + +func (t *initiatorTransport) roundTrip(req *http.Request) (*http.Response, error) { + if t.debug { + return log.NewHTTPClient().Transport.RoundTrip(req) + } + return http.DefaultTransport.RoundTrip(req) +}