fix: quota for subagents in copilot

Kujtim Hoxha created

Change summary

internal/agent/agentic_fetch_tool.go |  2 +-
internal/agent/coordinator.go        | 18 +++++++++---------
internal/oauth/copilot/client.go     |  9 +++++----
3 files changed, 15 insertions(+), 14 deletions(-)

Detailed changes

internal/agent/agentic_fetch_tool.go 🔗

@@ -143,7 +143,7 @@ func (c *coordinator) agenticFetchTool(_ context.Context, client *http.Client) (
 				return fantasy.ToolResponse{}, fmt.Errorf("error creating prompt: %s", err)
 			}
 
-			_, small, err := c.buildAgentModels(ctx)
+			_, small, err := c.buildAgentModels(ctx, true)
 			if err != nil {
 				return fantasy.ToolResponse{}, fmt.Errorf("error building models: %s", err)
 			}

internal/agent/coordinator.go 🔗

@@ -317,7 +317,7 @@ func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderO
 }
 
 func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
-	large, small, err := c.buildAgentModels(ctx)
+	large, small, err := c.buildAgentModels(ctx, isSubAgent)
 	if err != nil {
 		return nil, err
 	}
@@ -441,7 +441,7 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fan
 }
 
 // TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
-func (c *coordinator) buildAgentModels(ctx context.Context) (Model, Model, error) {
+func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
 	largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
 	if !ok {
 		return Model{}, Model{}, errors.New("large model not selected")
@@ -456,7 +456,7 @@ func (c *coordinator) buildAgentModels(ctx context.Context) (Model, Model, error
 		return Model{}, Model{}, errors.New("large model provider not configured")
 	}
 
-	largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg)
+	largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent)
 	if err != nil {
 		return Model{}, Model{}, err
 	}
@@ -466,7 +466,7 @@ func (c *coordinator) buildAgentModels(ctx context.Context) (Model, Model, error
 		return Model{}, Model{}, errors.New("large model provider not configured")
 	}
 
-	smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg)
+	smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg, isSubAgent)
 	if err != nil {
 		return Model{}, Model{}, err
 	}
@@ -583,7 +583,7 @@ func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[stri
 	return openrouter.New(opts...)
 }
 
-func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string) (fantasy.Provider, error) {
+func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) {
 	opts := []openaicompat.Option{
 		openaicompat.WithBaseURL(baseURL),
 		openaicompat.WithAPIKey(apiKey),
@@ -592,7 +592,7 @@ func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers
 	// Set HTTP client based on provider and debug mode.
 	var httpClient *http.Client
 	if providerID == string(catwalk.InferenceProviderCopilot) {
-		httpClient = copilot.NewClient(c.cfg.Options.Debug)
+		httpClient = copilot.NewClient(isSubAgent, c.cfg.Options.Debug)
 	} else if c.cfg.Options.Debug {
 		httpClient = log.NewHTTPClient()
 	}
@@ -714,7 +714,7 @@ func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
 	return false
 }
 
-func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel) (fantasy.Provider, error) {
+func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) {
 	headers := maps.Clone(providerCfg.ExtraHeaders)
 	if headers == nil {
 		headers = make(map[string]string)
@@ -754,7 +754,7 @@ func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model con
 			}
 			providerCfg.ExtraBody["tool_stream"] = true
 		}
-		return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID)
+		return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
 	case hyper.Name:
 		return c.buildHyperProvider(baseURL, apiKey)
 	default:
@@ -799,7 +799,7 @@ func (c *coordinator) Model() Model {
 
 func (c *coordinator) UpdateModels(ctx context.Context) error {
 	// build the models again so we make sure we get the latest config
-	large, small, err := c.buildAgentModels(ctx)
+	large, small, err := c.buildAgentModels(ctx, false)
 	if err != nil {
 		return err
 	}

internal/oauth/copilot/client.go 🔗

@@ -14,14 +14,15 @@ import (
 
 // NewClient creates a new HTTP client with a custom transport that adds the
 // X-Initiator header based on message history in the request body.
-func NewClient(debug bool) *http.Client {
+func NewClient(isSubAgent, debug bool) *http.Client {
 	return &http.Client{
-		Transport: &initiatorTransport{debug: debug},
+		Transport: &initiatorTransport{debug: debug, isSubAgent: isSubAgent},
 	}
 }
 
 type initiatorTransport struct {
-	debug bool
+	debug      bool
+	isSubAgent bool
 }
 
 func (t *initiatorTransport) RoundTrip(req *http.Request) (*http.Response, error) {
@@ -58,7 +59,7 @@ func (t *initiatorTransport) RoundTrip(req *http.Request) (*http.Response, error
 	// variations in the JSON while avoiding full unmarshalling overhead.
 	initiator := userInitiator
 	assistantRolePattern := regexp.MustCompile(`"role"\s*:\s*"assistant"`)
-	if assistantRolePattern.Match(bodyBytes) {
+	if assistantRolePattern.Match(bodyBytes) || t.isSubAgent {
 		slog.Debug("Setting X-Initiator header to agent (found assistant messages in history)")
 		initiator = agentInitiator
 	} else {