@@ -47,43 +47,55 @@ type modelPair struct {
}
func anthropicBuilder(model string) builderFunc {
- return func(_ *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
- provider := anthropic.New(
+ return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
+ provider, err := anthropic.New(
anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
anthropic.WithHTTPClient(&http.Client{Transport: r}),
)
- return provider.LanguageModel(model)
+ if err != nil {
+ return nil, err
+ }
+ return provider.LanguageModel(t.Context(), model)
}
}
func openaiBuilder(model string) builderFunc {
- return func(_ *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
- provider := openai.New(
+ return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
+ provider, err := openai.New(
openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
openai.WithHTTPClient(&http.Client{Transport: r}),
)
- return provider.LanguageModel(model)
+ if err != nil {
+ return nil, err
+ }
+ return provider.LanguageModel(t.Context(), model)
}
}
func openRouterBuilder(model string) builderFunc {
return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
- provider := openrouter.New(
+ provider, err := openrouter.New(
openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
openrouter.WithHTTPClient(&http.Client{Transport: r}),
)
- return provider.LanguageModel(model)
+ if err != nil {
+ return nil, err
+ }
+ return provider.LanguageModel(t.Context(), model)
}
}
func zAIBuilder(model string) builderFunc {
return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
- provider := openaicompat.New(
+ provider, err := openaicompat.New(
openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
openaicompat.WithHTTPClient(&http.Client{Transport: r}),
)
- return provider.LanguageModel(model)
+ if err != nil {
+ return nil, err
+ }
+ return provider.LanguageModel(t.Context(), model)
}
}
@@ -45,7 +45,7 @@ type Coordinator interface {
ClearQueue(sessionID string)
Summarize(context.Context, string) error
Model() Model
- UpdateModels() error
+ UpdateModels(ctx context.Context) error
}
type coordinator struct {
@@ -61,6 +61,7 @@ type coordinator struct {
}
func NewCoordinator(
+ ctx context.Context,
cfg *config.Config,
sessions session.Service,
messages message.Service,
@@ -89,7 +90,7 @@ func NewCoordinator(
return nil, err
}
- agent, err := c.buildAgent(prompt, agentCfg)
+ agent, err := c.buildAgent(ctx, prompt, agentCfg)
if err != nil {
return nil, err
}
@@ -255,8 +256,8 @@ func mergeCallOptions(model Model, tp catwalk.Type) (fantasy.ProviderOptions, *f
return modelOptions, temp, topP, topK, freqPenalty, presPenalty
}
-func (c *coordinator) buildAgent(prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
- large, small, err := c.buildAgentModels()
+func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
+ large, small, err := c.buildAgentModels(ctx)
if err != nil {
return nil, err
}
@@ -266,17 +267,17 @@ func (c *coordinator) buildAgent(prompt *prompt.Prompt, agent config.Agent) (Ses
return nil, err
}
- tools, err := c.buildTools(agent)
+ tools, err := c.buildTools(ctx, agent)
if err != nil {
return nil, err
}
return NewSessionAgent(SessionAgentOptions{large, small, systemPrompt, c.cfg.Options.DisableAutoSummarize, c.sessions, c.messages, tools}), nil
}
-func (c *coordinator) buildTools(agent config.Agent) ([]fantasy.AgentTool, error) {
+func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
var allTools []fantasy.AgentTool
if slices.Contains(agent.AllowedTools, AgentToolName) {
- agentTool, err := c.agentTool()
+ agentTool, err := c.agentTool(ctx)
if err != nil {
return nil, err
}
@@ -334,7 +335,7 @@ func (c *coordinator) buildTools(agent config.Agent) ([]fantasy.AgentTool, error
}
// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
-func (c *coordinator) buildAgentModels() (Model, Model, error) {
+func (c *coordinator) buildAgentModels(ctx context.Context) (Model, Model, error) {
largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
if !ok {
return Model{}, Model{}, errors.New("large model not selected")
@@ -386,11 +387,11 @@ func (c *coordinator) buildAgentModels() (Model, Model, error) {
return Model{}, Model{}, errors.New("snall model not found in provider config")
}
- largeModel, err := largeProvider.LanguageModel(largeModelCfg.Model)
+ largeModel, err := largeProvider.LanguageModel(ctx, largeModelCfg.Model)
if err != nil {
return Model{}, Model{}, err
}
- smallModel, err := smallProvider.LanguageModel(smallModelCfg.Model)
+ smallModel, err := smallProvider.LanguageModel(ctx, smallModelCfg.Model)
if err != nil {
return Model{}, Model{}, err
}
@@ -406,7 +407,7 @@ func (c *coordinator) buildAgentModels() (Model, Model, error) {
}, nil
}
-func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) fantasy.Provider {
+func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
hasBearerAuth := false
for key := range headers {
if strings.ToLower(key) == "authorization" {
@@ -441,7 +442,7 @@ func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map
return anthropic.New(opts...)
}
-func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) fantasy.Provider {
+func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
opts := []openai.Option{
openai.WithAPIKey(apiKey),
openai.WithUseResponsesAPI(),
@@ -459,7 +460,7 @@ func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[st
return openai.New(opts...)
}
-func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) fantasy.Provider {
+func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
opts := []openrouter.Option{
openrouter.WithAPIKey(apiKey),
}
@@ -473,7 +474,7 @@ func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[stri
return openrouter.New(opts...)
}
-func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) fantasy.Provider {
+func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
opts := []openaicompat.Option{
openaicompat.WithBaseURL(baseURL),
openaicompat.WithAPIKey(apiKey),
@@ -489,7 +490,7 @@ func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers
return openaicompat.New(opts...)
}
-func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) fantasy.Provider {
+func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
opts := []azure.Option{
azure.WithBaseURL(baseURL),
azure.WithAPIKey(apiKey),
@@ -511,7 +512,7 @@ func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[str
return azure.New(opts...)
}
-func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) fantasy.Provider {
+func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
opts := []google.Option{
google.WithBaseURL(baseURL),
google.WithGeminiAPIKey(apiKey),
@@ -526,7 +527,7 @@ func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[st
return google.New(opts...)
}
-func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) fantasy.Provider {
+func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
opts := []google.Option{}
if c.cfg.Options.Debug {
httpClient := log.NewHTTPClient()
@@ -574,27 +575,25 @@ func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model con
// TODO: make sure we have
apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
- var provider fantasy.Provider
+
switch providerCfg.Type {
case openai.Name:
- provider = c.buildOpenaiProvider(baseURL, apiKey, headers)
+ return c.buildOpenaiProvider(baseURL, apiKey, headers)
case anthropic.Name:
- provider = c.buildAnthropicProvider(baseURL, apiKey, headers)
+ return c.buildAnthropicProvider(baseURL, apiKey, headers)
case openrouter.Name:
- provider = c.buildOpenrouterProvider(baseURL, apiKey, headers)
+ return c.buildOpenrouterProvider(baseURL, apiKey, headers)
case azure.Name:
- provider = c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
+ return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
case google.Name:
- provider = c.buildGoogleProvider(baseURL, apiKey, headers)
- // this is not in fantasy since its just the google provider with extra stuff
+ return c.buildGoogleProvider(baseURL, apiKey, headers)
case "vertexai":
- provider = c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
+ return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
case openaicompat.Name:
- provider = c.buildOpenaiCompatProvider(baseURL, apiKey, headers)
+ return c.buildOpenaiCompatProvider(baseURL, apiKey, headers)
default:
return nil, errors.New("provider type not supported")
}
- return provider, nil
}
func (c *coordinator) Cancel(sessionID string) {
@@ -621,9 +620,9 @@ func (c *coordinator) Model() Model {
return c.currentAgent.Model()
}
-func (c *coordinator) UpdateModels() error {
+func (c *coordinator) UpdateModels(ctx context.Context) error {
// build the models again so we make sure we get the latest config
- large, small, err := c.buildAgentModels()
+ large, small, err := c.buildAgentModels(ctx)
if err != nil {
return err
}
@@ -634,7 +633,7 @@ func (c *coordinator) UpdateModels() error {
return errors.New("coder agent not configured")
}
- tools, err := c.buildTools(agentCfg)
+ tools, err := c.buildTools(ctx, agentCfg)
if err != nil {
return err
}