diff --git a/internal/app/app.go b/internal/app/app.go index 75042e89648779cf50a4376aa01aa3b6ac8e72a0..b096c1b4f5612901a1cedeaa2ee758b666cda517 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -9,7 +9,7 @@ import ( "sync" "time" - configv2 "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/db" "github.com/charmbracelet/crush/internal/format" "github.com/charmbracelet/crush/internal/history" @@ -55,9 +55,9 @@ func New(ctx context.Context, conn *sql.DB) (*App, error) { // Initialize LSP clients in the background go app.initLSPClients(ctx) - cfg := configv2.Get() + cfg := config.Get() - coderAgentCfg := cfg.Agents[configv2.AgentCoder] + coderAgentCfg := cfg.Agents[config.AgentCoder] if coderAgentCfg.ID == "" { return nil, fmt.Errorf("coder agent configuration is missing") } diff --git a/internal/config/config.go b/internal/config/config.go index 13444a5ccc8e99bdaa57a6156151b45a40176c09..8ebc1ce6cf5226fb7ad43601eb95d346bfebc0ef 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -47,6 +47,13 @@ const ( AgentTask AgentID = "task" ) +type ModelType string + +const ( + LargeModel ModelType = "large" + SmallModel ModelType = "small" +) + type Model struct { ID string `json:"id"` Name string `json:"model"` @@ -90,8 +97,7 @@ type Agent struct { // This is the id of the system prompt used by the agent Disabled bool `json:"disabled"` - Provider provider.InferenceProvider `json:"provider"` - Model string `json:"model"` + Model ModelType `json:"model"` // The available tools for the agent // if this is nil, all tools are available @@ -291,8 +297,7 @@ func loadConfig(cwd string, debug bool) (*Config, error) { ID: AgentCoder, Name: "Coder", Description: "An agent that helps with executing coding tasks.", - Provider: cfg.Models.Large.Provider, - Model: cfg.Models.Large.ModelID, + Model: LargeModel, ContextPaths: cfg.Options.ContextPaths, // All tools allowed }, @@ -300,8 +305,7 @@ func loadConfig(cwd string, debug bool) (*Config, error) { ID: AgentTask, Name: "Task", Description: "An agent that helps with searching for context and finding implementation details.", - Provider: cfg.Models.Large.Provider, - Model: cfg.Models.Large.ModelID, + Model: LargeModel, ContextPaths: cfg.Options.ContextPaths, AllowedTools: []string{ "glob", @@ -490,9 +494,8 @@ func mergeAgents(base, global, local *Config) { switch agentID { case AgentCoder: baseAgent := base.Agents[agentID] - if newAgent.Model != "" && newAgent.Provider != "" { + if newAgent.Model != "" { baseAgent.Model = newAgent.Model - baseAgent.Provider = newAgent.Provider } baseAgent.AllowedMCP = newAgent.AllowedMCP baseAgent.AllowedLSP = newAgent.AllowedLSP @@ -502,9 +505,8 @@ func mergeAgents(base, global, local *Config) { baseAgent.Name = newAgent.Name baseAgent.Description = newAgent.Description baseAgent.Disabled = newAgent.Disabled - if newAgent.Model == "" || newAgent.Provider == "" { - baseAgent.Provider = base.Models.Large.Provider - baseAgent.Model = base.Models.Large.ModelID + if newAgent.Model == "" { + baseAgent.Model = LargeModel } baseAgent.AllowedTools = newAgent.AllowedTools baseAgent.AllowedMCP = newAgent.AllowedMCP @@ -709,6 +711,8 @@ func WorkingDirectory() string { return cwd } +// TODO: Handle error state + func GetAgentModel(agentID AgentID) Model { cfg := Get() agent, ok := cfg.Agents[agentID] @@ -717,15 +721,25 @@ func GetAgentModel(agentID AgentID) Model { return Model{} } - providerConfig, ok := cfg.Providers[agent.Provider] + var model PreferredModel + switch agent.Model { + case LargeModel: + model = cfg.Models.Large + case SmallModel: + model = cfg.Models.Small + default: + logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model) + model = cfg.Models.Large // Fallback to large model + } + providerConfig, ok := cfg.Providers[model.Provider] if !ok { - logging.Error("Provider not found for agent", "agent_id", agentID, "provider", agent.Provider) + logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider) return Model{} } - for _, model := range providerConfig.Models { - if model.ID == agent.Model { - return model + for _, m := range providerConfig.Models { + if m.ID == model.ModelID { + return m } } @@ -733,6 +747,34 @@ func GetAgentModel(agentID AgentID) Model { return Model{} } +func GetAgentProvider(agentID AgentID) ProviderConfig { + cfg := Get() + agent, ok := cfg.Agents[agentID] + if !ok { + logging.Error("Agent not found", "agent_id", agentID) + return ProviderConfig{} + } + + var model PreferredModel + switch agent.Model { + case LargeModel: + model = cfg.Models.Large + case SmallModel: + model = cfg.Models.Small + default: + logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model) + model = cfg.Models.Large // Fallback to large model + } + + providerConfig, ok := cfg.Providers[model.Provider] + if !ok { + logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider) + return ProviderConfig{} + } + + return providerConfig +} + func GetProviderModel(provider provider.InferenceProvider, modelID string) Model { cfg := Get() providerConfig, ok := cfg.Providers[provider] @@ -750,3 +792,40 @@ func GetProviderModel(provider provider.InferenceProvider, modelID string) Model logging.Error("Model not found for provider", "provider", provider, "model_id", modelID) return Model{} } + +func GetModel(modelType ModelType) Model { + cfg := Get() + var model PreferredModel + switch modelType { + case LargeModel: + model = cfg.Models.Large + case SmallModel: + model = cfg.Models.Small + default: + model = cfg.Models.Large // Fallback to large model + } + providerConfig, ok := cfg.Providers[model.Provider] + if !ok { + return Model{} + } + + for _, m := range providerConfig.Models { + if m.ID == model.ModelID { + return m + } + } + return Model{} +} + +func UpdatePreferredModel(modelType ModelType, model PreferredModel) error { + cfg := Get() + switch modelType { + case LargeModel: + cfg.Models.Large = model + case SmallModel: + cfg.Models.Small = model + default: + return fmt.Errorf("unknown model type: %s", modelType) + } + return nil +} diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index f9e97b164aa98fe1ae76490fdfcf336efb43098f..8c6faf8c4a06bbef5da279847cd14ce2314648cd 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -9,7 +9,7 @@ import ( "sync" "time" - configv2 "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/history" "github.com/charmbracelet/crush/internal/llm/prompt" "github.com/charmbracelet/crush/internal/llm/provider" @@ -49,19 +49,18 @@ type AgentEvent struct { type Service interface { pubsub.Suscriber[AgentEvent] - Model() configv2.Model + Model() config.Model Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) Cancel(sessionID string) CancelAll() IsSessionBusy(sessionID string) bool IsBusy() bool - Update(model configv2.PreferredModel) (configv2.Model, error) Summarize(ctx context.Context, sessionID string) error } type agent struct { *pubsub.Broker[AgentEvent] - agentCfg configv2.Agent + agentCfg config.Agent sessions session.Service messages message.Service @@ -76,13 +75,13 @@ type agent struct { activeRequests sync.Map } -var agentPromptMap = map[configv2.AgentID]prompt.PromptID{ - configv2.AgentCoder: prompt.PromptCoder, - configv2.AgentTask: prompt.PromptTask, +var agentPromptMap = map[config.AgentID]prompt.PromptID{ + config.AgentCoder: prompt.PromptCoder, + config.AgentTask: prompt.PromptTask, } func NewAgent( - agentCfg configv2.Agent, + agentCfg config.Agent, // These services are needed in the tools permissions permission.Service, sessions session.Service, @@ -91,7 +90,7 @@ func NewAgent( lspClients map[string]*lsp.Client, ) (Service, error) { ctx := context.Background() - cfg := configv2.Get() + cfg := config.Get() otherTools := GetMcpTools(ctx, permissions) if len(lspClients) > 0 { otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients)) @@ -109,8 +108,8 @@ func NewAgent( tools.NewWriteTool(lspClients, permissions, history), } - if agentCfg.ID == configv2.AgentCoder { - taskAgentCfg := configv2.Get().Agents[configv2.AgentTask] + if agentCfg.ID == config.AgentCoder { + taskAgentCfg := config.Get().Agents[config.AgentTask] if taskAgentCfg.ID == "" { return nil, fmt.Errorf("task agent not found in config") } @@ -130,26 +129,14 @@ func NewAgent( } allTools = append(allTools, otherTools...) - var providerCfg configv2.ProviderConfig - for _, p := range cfg.Providers { - if p.ID == agentCfg.Provider { - providerCfg = p - break - } - } + providerCfg := config.GetAgentProvider(agentCfg.ID) if providerCfg.ID == "" { - return nil, fmt.Errorf("provider %s not found in config", agentCfg.Provider) + return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name) } + model := config.GetAgentModel(agentCfg.ID) - var model configv2.Model - for _, m := range providerCfg.Models { - if m.ID == agentCfg.Model { - model = m - break - } - } if model.ID == "" { - return nil, fmt.Errorf("model %s not found in provider %s", agentCfg.Model, agentCfg.Provider) + return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name) } promptID := agentPromptMap[agentCfg.ID] @@ -157,7 +144,7 @@ func NewAgent( promptID = prompt.PromptDefault } opts := []provider.ProviderClientOption{ - provider.WithModel(model), + provider.WithModel(agentCfg.Model), provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)), provider.WithMaxTokens(model.DefaultMaxTokens), } @@ -167,9 +154,9 @@ func NewAgent( } smallModelCfg := cfg.Models.Small - var smallModel configv2.Model + var smallModel config.Model - var smallModelProviderCfg configv2.ProviderConfig + var smallModelProviderCfg config.ProviderConfig if smallModelCfg.Provider == providerCfg.ID { smallModelProviderCfg = providerCfg } else { @@ -194,7 +181,7 @@ func NewAgent( } titleOpts := []provider.ProviderClientOption{ - provider.WithModel(smallModel), + provider.WithModel(config.SmallModel), provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)), provider.WithMaxTokens(40), } @@ -203,7 +190,7 @@ func NewAgent( return nil, err } summarizeOpts := []provider.ProviderClientOption{ - provider.WithModel(smallModel), + provider.WithModel(config.SmallModel), provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)), provider.WithMaxTokens(smallModel.DefaultMaxTokens), } @@ -240,8 +227,8 @@ func NewAgent( return agent, nil } -func (a *agent) Model() configv2.Model { - return a.provider.Model() +func (a *agent) Model() config.Model { + return config.GetAgentModel(a.agentCfg.ID) } func (a *agent) Cancel(sessionID string) { @@ -336,7 +323,7 @@ func (a *agent) err(err error) AgentEvent { } func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) { - if !a.provider.Model().SupportsImages && attachments != nil { + if !a.Model().SupportsImages && attachments != nil { attachments = nil } events := make(chan AgentEvent) @@ -458,7 +445,7 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ Role: message.Assistant, Parts: []message.ContentPart{}, - Model: a.provider.Model().ID, + Model: a.Model().ID, Provider: a.providerID, }) if err != nil { @@ -609,13 +596,13 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg if err := a.messages.Update(ctx, *assistantMsg); err != nil { return fmt.Errorf("failed to update message: %w", err) } - return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage) + return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage) } return nil } -func (a *agent) TrackUsage(ctx context.Context, sessionID string, model configv2.Model, usage provider.TokenUsage) error { +func (a *agent) TrackUsage(ctx context.Context, sessionID string, model config.Model, usage provider.TokenUsage) error { sess, err := a.sessions.Get(ctx, sessionID) if err != nil { return fmt.Errorf("failed to get session: %w", err) @@ -637,52 +624,6 @@ func (a *agent) TrackUsage(ctx context.Context, sessionID string, model configv2 return nil } -func (a *agent) Update(modelCfg configv2.PreferredModel) (configv2.Model, error) { - if a.IsBusy() { - return configv2.Model{}, fmt.Errorf("cannot change model while processing requests") - } - - cfg := configv2.Get() - var providerCfg configv2.ProviderConfig - for _, p := range cfg.Providers { - if p.ID == modelCfg.Provider { - providerCfg = p - break - } - } - if providerCfg.ID == "" { - return configv2.Model{}, fmt.Errorf("provider %s not found in config", modelCfg.Provider) - } - - var model configv2.Model - for _, m := range providerCfg.Models { - if m.ID == modelCfg.ModelID { - model = m - break - } - } - if model.ID == "" { - return configv2.Model{}, fmt.Errorf("model %s not found in provider %s", modelCfg.ModelID, modelCfg.Provider) - } - - promptID := agentPromptMap[a.agentCfg.ID] - if promptID == "" { - promptID = prompt.PromptDefault - } - opts := []provider.ProviderClientOption{ - provider.WithModel(model), - provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)), - provider.WithMaxTokens(model.DefaultMaxTokens), - } - agentProvider, err := provider.NewProviderV2(providerCfg, opts...) - if err != nil { - return configv2.Model{}, err - } - a.provider = agentProvider - - return a.provider.Model(), nil -} - func (a *agent) Summarize(ctx context.Context, sessionID string) error { if a.summarizeProvider == nil { return fmt.Errorf("summarize provider not available") diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index aca4d5b7f0adc4977fb349956be1005186e267e6..626882f283c030454477b27b152bd6a717d08476 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -145,6 +145,7 @@ func (a *anthropicClient) finishReason(reason string) message.FinishReason { } func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams { + model := a.providerOptions.model(a.providerOptions.modelType) var thinkingParam anthropic.ThinkingConfigParamUnion // TODO: Implement a proper thinking function // lastMessage := messages[len(messages)-1] @@ -164,7 +165,7 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to // } return anthropic.MessageNewParams{ - Model: anthropic.Model(a.providerOptions.model.ID), + Model: anthropic.Model(model.ID), MaxTokens: a.providerOptions.maxTokens, Temperature: temperature, Messages: messages, @@ -425,6 +426,10 @@ func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage { } } +func (a *anthropicClient) Model() config.Model { + return a.providerOptions.model(a.providerOptions.modelType) +} + // TODO: check if we need func DefaultShouldThinkFn(s string) bool { return strings.Contains(strings.ToLower(s), "think") diff --git a/internal/llm/provider/bedrock.go b/internal/llm/provider/bedrock.go index 6b31c7d7fd6625ad7c2962f409f6c50f01ff726b..1519099b00401e32ad5f19c1f6ed253eb8b7130d 100644 --- a/internal/llm/provider/bedrock.go +++ b/internal/llm/provider/bedrock.go @@ -6,6 +6,7 @@ import ( "fmt" "strings" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/message" ) @@ -30,13 +31,20 @@ func newBedrockClient(opts providerClientOptions) BedrockClient { } } - // Prefix the model name with region - regionPrefix := region[:2] - modelName := opts.model.ID - opts.model.ID = fmt.Sprintf("%s.%s", regionPrefix, modelName) + opts.model = func(modelType config.ModelType) config.Model { + model := config.GetModel(modelType) + + // Prefix the model name with region + regionPrefix := region[:2] + modelName := model.ID + model.ID = fmt.Sprintf("%s.%s", regionPrefix, modelName) + return model + } + + model := opts.model(opts.modelType) // Determine which provider to use based on the model - if strings.Contains(string(opts.model.ID), "anthropic") { + if strings.Contains(string(model.ID), "anthropic") { // Create Anthropic client with Bedrock configuration anthropicOpts := opts // TODO: later find a way to check if the AWS account has caching enabled @@ -78,3 +86,7 @@ func (b *bedrockClient) stream(ctx context.Context, messages []message.Message, return b.childProvider.stream(ctx, messages, tools) } + +func (b *bedrockClient) Model() config.Model { + return b.providerOptions.model(b.providerOptions.modelType) +} diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index a91c1eae2427a7629ee1f4de6d6b9abb5944a972..a5c012861ad9e6b537c0e9bca8e957ef3f38bf2f 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -173,7 +173,8 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too if len(tools) > 0 { config.Tools = g.convertTools(tools) } - chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.ID, config, history) + model := g.providerOptions.model(g.providerOptions.modelType) + chat, _ := g.client.Chats.Create(ctx, model.ID, config, history) attempts := 0 for { @@ -261,7 +262,8 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t if len(tools) > 0 { config.Tools = g.convertTools(tools) } - chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.ID, config, history) + model := g.providerOptions.model(g.providerOptions.modelType) + chat, _ := g.client.Chats.Create(ctx, model.ID, config, history) attempts := 0 eventChan := make(chan ProviderEvent) @@ -439,6 +441,10 @@ func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage { } } +func (g *geminiClient) Model() config.Model { + return g.providerOptions.model(g.providerOptions.modelType) +} + // Helper functions func parseJsonToMap(jsonStr string) (map[string]any, error) { var result map[string]any diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 448ab3674f25053453f51c0f48475db5699ee913..9af060a80f75309e1e314e3c33df72e607c9c77a 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -152,13 +152,13 @@ func (o *openaiClient) finishReason(reason string) message.FinishReason { } func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams { + model := o.providerOptions.model(o.providerOptions.modelType) params := openai.ChatCompletionNewParams{ - Model: openai.ChatModel(o.providerOptions.model.ID), + Model: openai.ChatModel(model.ID), Messages: messages, Tools: tools, } - - if o.providerOptions.model.CanReason { + if model.CanReason { params.MaxCompletionTokens = openai.Int(o.providerOptions.maxTokens) switch o.options.reasoningEffort { case "low": @@ -384,3 +384,7 @@ func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage { CacheReadTokens: cachedTokens, } } + +func (a *openaiClient) Model() config.Model { + return a.providerOptions.model(a.providerOptions.modelType) +} diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 3152cd6a9a7e6fd6a68d0e6b54b6ea6853a38273..9723dc9fe55af414ed415653e3e9e31031395a02 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -4,7 +4,7 @@ import ( "context" "fmt" - configv2 "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/message" @@ -55,13 +55,14 @@ type Provider interface { StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent - Model() configv2.Model + Model() config.Model } type providerClientOptions struct { baseURL string apiKey string - model configv2.Model + modelType config.ModelType + model func(config.ModelType) config.Model disableCache bool maxTokens int64 systemMessage string @@ -74,6 +75,8 @@ type ProviderClientOption func(*providerClientOptions) type ProviderClient interface { send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent + + Model() config.Model } type baseProvider[C ProviderClient] struct { @@ -97,18 +100,18 @@ func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.M return p.client.send(ctx, messages, tools) } -func (p *baseProvider[C]) Model() configv2.Model { - return p.options.model -} - func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { messages = p.cleanMessages(messages) return p.client.stream(ctx, messages, tools) } -func WithModel(model configv2.Model) ProviderClientOption { +func (p *baseProvider[C]) Model() config.Model { + return p.client.Model() +} + +func WithModel(model config.ModelType) ProviderClientOption { return func(options *providerClientOptions) { - options.model = model + options.modelType = model } } @@ -130,11 +133,14 @@ func WithSystemMessage(systemMessage string) ProviderClientOption { } } -func NewProviderV2(cfg configv2.ProviderConfig, opts ...ProviderClientOption) (Provider, error) { +func NewProviderV2(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) { clientOptions := providerClientOptions{ baseURL: cfg.BaseURL, apiKey: cfg.APIKey, extraHeaders: cfg.ExtraHeaders, + model: func(tp config.ModelType) config.Model { + return config.GetModel(tp) + }, } for _, o := range opts { o(&clientOptions) diff --git a/internal/tui/components/dialogs/init/init.go b/internal/tui/components/dialogs/init/init.go index 4e331198f5984f81db87332e3c998d9477810806..74d0dc0b3d9d4630b28c4b240fb17fbe611ba21f 100644 --- a/internal/tui/components/dialogs/init/init.go +++ b/internal/tui/components/dialogs/init/init.go @@ -5,7 +5,7 @@ import ( tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/lipgloss/v2" - configv2 "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/config" cmpChat "github.com/charmbracelet/crush/internal/tui/components/chat" "github.com/charmbracelet/crush/internal/tui/components/core" "github.com/charmbracelet/crush/internal/tui/components/dialogs" @@ -184,7 +184,7 @@ If there are Cursor rules (in .cursor/rules/ or .cursorrules) or Copilot rules ( Add the .crush directory to the .gitignore file if it's not already there.` // Mark the project as initialized - if err := configv2.MarkProjectInitialized(); err != nil { + if err := config.MarkProjectInitialized(); err != nil { return util.ReportError(err) } @@ -196,7 +196,7 @@ Add the .crush directory to the .gitignore file if it's not already there.` ) } else { // Mark the project as initialized without running the command - if err := configv2.MarkProjectInitialized(); err != nil { + if err := config.MarkProjectInitialized(); err != nil { return util.ReportError(err) } } diff --git a/internal/tui/components/dialogs/models/models.go b/internal/tui/components/dialogs/models/models.go index b5f87b16681ea17e2fb303a4b52a3a83ae30eb85..6d5fa155b2371865771b55c16f8fdbf65d3df952 100644 --- a/internal/tui/components/dialogs/models/models.go +++ b/internal/tui/components/dialogs/models/models.go @@ -4,7 +4,7 @@ import ( "github.com/charmbracelet/bubbles/v2/help" "github.com/charmbracelet/bubbles/v2/key" tea "github.com/charmbracelet/bubbletea/v2" - configv2 "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/tui/components/completions" "github.com/charmbracelet/crush/internal/tui/components/core" @@ -24,7 +24,7 @@ const ( // ModelSelectedMsg is sent when a model is selected type ModelSelectedMsg struct { - Model configv2.PreferredModel + Model config.PreferredModel } // CloseModelDialogMsg is sent when a model is selected @@ -84,12 +84,12 @@ func NewModelDialogCmp() ModelDialog { } func (m *modelDialogCmp) Init() tea.Cmd { - providers := configv2.Providers() - cfg := configv2.Get() + providers := config.Providers() - coderAgent := cfg.Agents[configv2.AgentCoder] modelItems := []util.Model{} selectIndex := 0 + agentModel := config.GetAgentModel(config.AgentCoder) + agentProvider := config.GetAgentProvider(config.AgentCoder) for _, provider := range providers { name := provider.Name if name == "" { @@ -97,7 +97,7 @@ func (m *modelDialogCmp) Init() tea.Cmd { } modelItems = append(modelItems, commands.NewItemSection(name)) for _, model := range provider.Models { - if model.ID == coderAgent.Model && provider.ID == coderAgent.Provider { + if model.ID == agentModel.ID && provider.ID == agentProvider.ID { selectIndex = len(modelItems) // Set the selected index to the current model } modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{ @@ -128,7 +128,7 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, tea.Sequence( util.CmdHandler(dialogs.CloseDialogMsg{}), - util.CmdHandler(ModelSelectedMsg{Model: configv2.PreferredModel{ + util.CmdHandler(ModelSelectedMsg{Model: config.PreferredModel{ ModelID: selectedItem.Model.ID, Provider: selectedItem.Provider.ID, }}), diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 54978b53576940e6fa478b7d05af514f66641acf..032b481eeaad75531debe7dc453efe19b866dd8d 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -8,7 +8,6 @@ import ( tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/config" - configv2 "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/llm/agent" "github.com/charmbracelet/crush/internal/logging" "github.com/charmbracelet/crush/internal/permission" @@ -70,7 +69,7 @@ func (a appModel) Init() tea.Cmd { // Check if we should show the init dialog cmds = append(cmds, func() tea.Msg { - shouldShow, err := configv2.ProjectNeedsInitialization() + shouldShow, err := config.ProjectNeedsInitialization() if err != nil { return util.InfoMsg{ Type: util.InfoTypeError, @@ -173,12 +172,8 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // Model Switch case models.ModelSelectedMsg: - model, err := a.app.CoderAgent.Update(msg.Model) - if err != nil { - return a, util.ReportError(err) - } - - return a, util.ReportInfo(fmt.Sprintf("Model changed to %s", model.Name)) + config.UpdatePreferredModel(config.LargeModel, msg.Model) + return a, util.ReportInfo(fmt.Sprintf("Model changed to %s", msg.Model.ModelID)) // File Picker case chat.OpenFilePickerMsg: