diff --git a/internal/agent/agentic_fetch_tool.go b/internal/agent/agentic_fetch_tool.go index bebdbbafaab1c3f01d5a700a02166e80dff00c44..333ec7926f80735c3798c524378964a8e41fe3e4 100644 --- a/internal/agent/agentic_fetch_tool.go +++ b/internal/agent/agentic_fetch_tool.go @@ -143,7 +143,7 @@ func (c *coordinator) agenticFetchTool(_ context.Context, client *http.Client) ( return fantasy.ToolResponse{}, fmt.Errorf("error creating prompt: %s", err) } - _, small, err := c.buildAgentModels(ctx) + _, small, err := c.buildAgentModels(ctx, true) if err != nil { return fantasy.ToolResponse{}, fmt.Errorf("error building models: %s", err) } diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index 37cd960f16f0f7641288768a6a173e8ae92e9cc2..6a4293bbdae0c3ad5150d4cfa7c4a064c20e6af4 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -317,7 +317,7 @@ func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderO } func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) { - large, small, err := c.buildAgentModels(ctx) + large, small, err := c.buildAgentModels(ctx, isSubAgent) if err != nil { return nil, err } @@ -441,7 +441,7 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fan } // TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config -func (c *coordinator) buildAgentModels(ctx context.Context) (Model, Model, error) { +func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) { largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge] if !ok { return Model{}, Model{}, errors.New("large model not selected") @@ -456,7 +456,7 @@ func (c *coordinator) buildAgentModels(ctx context.Context) (Model, Model, error return Model{}, Model{}, errors.New("large model provider not configured") } - largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg) + largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent) if err != nil { return Model{}, Model{}, err } @@ -466,7 +466,7 @@ func (c *coordinator) buildAgentModels(ctx context.Context) (Model, Model, error return Model{}, Model{}, errors.New("large model provider not configured") } - smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg) + smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg, isSubAgent) if err != nil { return Model{}, Model{}, err } @@ -583,7 +583,7 @@ func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[stri return openrouter.New(opts...) } -func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string) (fantasy.Provider, error) { +func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) { opts := []openaicompat.Option{ openaicompat.WithBaseURL(baseURL), openaicompat.WithAPIKey(apiKey), @@ -592,7 +592,7 @@ func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers // Set HTTP client based on provider and debug mode. var httpClient *http.Client if providerID == string(catwalk.InferenceProviderCopilot) { - httpClient = copilot.NewClient(c.cfg.Options.Debug) + httpClient = copilot.NewClient(isSubAgent, c.cfg.Options.Debug) } else if c.cfg.Options.Debug { httpClient = log.NewHTTPClient() } @@ -714,7 +714,7 @@ func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool { return false } -func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel) (fantasy.Provider, error) { +func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) { headers := maps.Clone(providerCfg.ExtraHeaders) if headers == nil { headers = make(map[string]string) @@ -754,7 +754,7 @@ func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model con } providerCfg.ExtraBody["tool_stream"] = true } - return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID) + return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent) case hyper.Name: return c.buildHyperProvider(baseURL, apiKey) default: @@ -799,7 +799,7 @@ func (c *coordinator) Model() Model { func (c *coordinator) UpdateModels(ctx context.Context) error { // build the models again so we make sure we get the latest config - large, small, err := c.buildAgentModels(ctx) + large, small, err := c.buildAgentModels(ctx, false) if err != nil { return err } diff --git a/internal/oauth/copilot/client.go b/internal/oauth/copilot/client.go index 2b59a1a4bf379edd3c6d02b24b2d241f5fd4d9e1..f76f3bf640c4331968b4173cf0d48e0dbc69aed2 100644 --- a/internal/oauth/copilot/client.go +++ b/internal/oauth/copilot/client.go @@ -14,14 +14,15 @@ import ( // NewClient creates a new HTTP client with a custom transport that adds the // X-Initiator header based on message history in the request body. -func NewClient(debug bool) *http.Client { +func NewClient(isSubAgent, debug bool) *http.Client { return &http.Client{ - Transport: &initiatorTransport{debug: debug}, + Transport: &initiatorTransport{debug: debug, isSubAgent: isSubAgent}, } } type initiatorTransport struct { - debug bool + debug bool + isSubAgent bool } func (t *initiatorTransport) RoundTrip(req *http.Request) (*http.Response, error) { @@ -58,7 +59,7 @@ func (t *initiatorTransport) RoundTrip(req *http.Request) (*http.Response, error // variations in the JSON while avoiding full unmarshalling overhead. initiator := userInitiator assistantRolePattern := regexp.MustCompile(`"role"\s*:\s*"assistant"`) - if assistantRolePattern.Match(bodyBytes) { + if assistantRolePattern.Match(bodyBytes) || t.isSubAgent { slog.Debug("Setting X-Initiator header to agent (found assistant messages in history)") initiator = agentInitiator } else {