Detailed changes
@@ -82,8 +82,9 @@ type ProviderConfig struct {
}
type Agent struct {
- Name string `json:"name"`
- Description string `json:"description,omitempty"`
+ ID AgentID `json:"id"`
+ Name string `json:"name"`
+ Description string `json:"description,omitempty"`
// This is the id of the system prompt used by the agent
Disabled bool `json:"disabled"`
@@ -229,6 +230,7 @@ func loadConfig(cwd string) (*Config, error) {
agents := map[AgentID]Agent{
AgentCoder: {
+ ID: AgentCoder,
Name: "Coder",
Description: "An agent that helps with executing coding tasks.",
Provider: preferredProvider.ID,
@@ -237,6 +239,7 @@ func loadConfig(cwd string) (*Config, error) {
// All tools allowed
},
AgentTask: {
+ ID: AgentTask,
Name: "Task",
Description: "An agent that helps with searching for context and finding implementation details.",
Provider: preferredProvider.ID,
@@ -254,6 +257,7 @@ func loadConfig(cwd string) (*Config, error) {
AllowedLSP: []string{},
},
AgentTitle: {
+ ID: AgentTitle,
Name: "Title",
Description: "An agent that helps with generating titles for sessions.",
Provider: preferredProvider.ID,
@@ -265,6 +269,7 @@ func loadConfig(cwd string) (*Config, error) {
AllowedLSP: []string{},
},
AgentSummarize: {
+ ID: AgentSummarize,
Name: "Summarize",
Description: "An agent that helps with summarizing sessions.",
Provider: preferredProvider.ID,
@@ -429,43 +434,44 @@ func mergeAgents(base, global, local *Config) {
if cfg == nil {
continue
}
- for agentID, globalAgent := range cfg.Agents {
+ for agentID, newAgent := range cfg.Agents {
if _, ok := base.Agents[agentID]; !ok {
- base.Agents[agentID] = globalAgent
+ newAgent.ID = agentID // Ensure the ID is set correctly
+ base.Agents[agentID] = newAgent
} else {
switch agentID {
case AgentCoder:
baseAgent := base.Agents[agentID]
- baseAgent.Model = globalAgent.Model
- baseAgent.Provider = globalAgent.Provider
- baseAgent.AllowedMCP = globalAgent.AllowedMCP
- baseAgent.AllowedLSP = globalAgent.AllowedLSP
+ baseAgent.Model = newAgent.Model
+ baseAgent.Provider = newAgent.Provider
+ baseAgent.AllowedMCP = newAgent.AllowedMCP
+ baseAgent.AllowedLSP = newAgent.AllowedLSP
base.Agents[agentID] = baseAgent
case AgentTask:
baseAgent := base.Agents[agentID]
- baseAgent.Model = globalAgent.Model
- baseAgent.Provider = globalAgent.Provider
+ baseAgent.Model = newAgent.Model
+ baseAgent.Provider = newAgent.Provider
base.Agents[agentID] = baseAgent
case AgentTitle:
baseAgent := base.Agents[agentID]
- baseAgent.Model = globalAgent.Model
- baseAgent.Provider = globalAgent.Provider
+ baseAgent.Model = newAgent.Model
+ baseAgent.Provider = newAgent.Provider
base.Agents[agentID] = baseAgent
case AgentSummarize:
baseAgent := base.Agents[agentID]
- baseAgent.Model = globalAgent.Model
- baseAgent.Provider = globalAgent.Provider
+ baseAgent.Model = newAgent.Model
+ baseAgent.Provider = newAgent.Provider
base.Agents[agentID] = baseAgent
default:
baseAgent := base.Agents[agentID]
- baseAgent.Name = globalAgent.Name
- baseAgent.Description = globalAgent.Description
- baseAgent.Disabled = globalAgent.Disabled
- baseAgent.Provider = globalAgent.Provider
- baseAgent.Model = globalAgent.Model
- baseAgent.AllowedTools = globalAgent.AllowedTools
- baseAgent.AllowedMCP = globalAgent.AllowedMCP
- baseAgent.AllowedLSP = globalAgent.AllowedLSP
+ baseAgent.Name = newAgent.Name
+ baseAgent.Description = newAgent.Description
+ baseAgent.Disabled = newAgent.Disabled
+ baseAgent.Provider = newAgent.Provider
+ baseAgent.Model = newAgent.Model
+ baseAgent.AllowedTools = newAgent.AllowedTools
+ baseAgent.AllowedMCP = newAgent.AllowedMCP
+ baseAgent.AllowedLSP = newAgent.AllowedLSP
base.Agents[agentID] = baseAgent
}
@@ -26,6 +26,10 @@ type AgentParams struct {
Prompt string `json:"prompt"`
}
+func (b *agentTool) Name() string {
+ return AgentToolName
+}
+
func (b *agentTool) Info() tools.ToolInfo {
return tools.ToolInfo{
Name: AgentToolName,
@@ -32,6 +32,10 @@ type MCPClient interface {
Close() error
}
+func (b *mcpTool) Name() string {
+ return fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name)
+}
+
func (b *mcpTool) Info() tools.ToolInfo {
return tools.ToolInfo{
Name: fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name),
@@ -3,7 +3,6 @@ package provider
import (
"context"
"fmt"
- "maps"
"os"
"github.com/charmbracelet/crush/internal/llm/models"
@@ -177,18 +176,6 @@ func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message
return p.client.stream(ctx, messages, tools)
}
-func WithBaseURL(baseURL string) ProviderClientOption {
- return func(options *providerClientOptions) {
- options.baseURL = baseURL
- }
-}
-
-func WithAPIKey(apiKey string) ProviderClientOption {
- return func(options *providerClientOptions) {
- options.apiKey = apiKey
- }
-}
-
func WithModel(model models.Model) ProviderClientOption {
return func(options *providerClientOptions) {
options.model = model
@@ -201,15 +188,6 @@ func WithDisableCache(disableCache bool) ProviderClientOption {
}
}
-func WithExtraHeaders(extraHeaders map[string]string) ProviderClientOption {
- return func(options *providerClientOptions) {
- if options.extraHeaders == nil {
- options.extraHeaders = make(map[string]string)
- }
- maps.Copy(options.extraHeaders, extraHeaders)
- }
-}
-
func WithMaxTokens(maxTokens int64) ProviderClientOption {
return func(options *providerClientOptions) {
options.maxTokens = maxTokens
@@ -250,6 +250,10 @@ func NewBashTool(permission permission.Service) BaseTool {
}
}
+func (b *bashTool) Name() string {
+ return BashToolName
+}
+
func (b *bashTool) Info() ToolInfo {
return ToolInfo{
Name: BashToolName,
@@ -51,6 +51,10 @@ func NewDiagnosticsTool(lspClients map[string]*lsp.Client) BaseTool {
}
}
+func (b *diagnosticsTool) Name() string {
+ return DiagnosticsToolName
+}
+
func (b *diagnosticsTool) Info() ToolInfo {
return ToolInfo{
Name: DiagnosticsToolName,
@@ -106,6 +106,10 @@ func NewEditTool(lspClients map[string]*lsp.Client, permissions permission.Servi
}
}
+func (e *editTool) Name() string {
+ return EditToolName
+}
+
func (e *editTool) Info() ToolInfo {
return ToolInfo{
Name: EditToolName,
@@ -79,6 +79,10 @@ func NewFetchTool(permissions permission.Service) BaseTool {
}
}
+func (t *fetchTool) Name() string {
+ return FetchToolName
+}
+
func (t *fetchTool) Info() ToolInfo {
return ToolInfo{
Name: FetchToolName,
@@ -74,6 +74,10 @@ func NewGlobTool() BaseTool {
return &globTool{}
}
+func (g *globTool) Name() string {
+ return GlobToolName
+}
+
func (g *globTool) Info() ToolInfo {
return ToolInfo{
Name: GlobToolName,
@@ -140,6 +140,10 @@ func NewGrepTool() BaseTool {
return &grepTool{}
}
+func (g *grepTool) Name() string {
+ return GrepToolName
+}
+
func (g *grepTool) Info() ToolInfo {
return ToolInfo{
Name: GrepToolName,
@@ -74,6 +74,10 @@ func NewLsTool() BaseTool {
return &lsTool{}
}
+func (l *lsTool) Name() string {
+ return LSToolName
+}
+
func (l *lsTool) Info() ToolInfo {
return ToolInfo{
Name: LSToolName,
@@ -138,6 +138,10 @@ func NewSourcegraphTool() BaseTool {
}
}
+func (t *sourcegraphTool) Name() string {
+ return SourcegraphToolName
+}
+
func (t *sourcegraphTool) Info() ToolInfo {
return ToolInfo{
Name: SourcegraphToolName,
@@ -68,6 +68,7 @@ type ToolCall struct {
type BaseTool interface {
Info() ToolInfo
+ Name() string
Run(ctx context.Context, params ToolCall) (ToolResponse, error)
}
@@ -77,6 +77,10 @@ func NewViewTool(lspClients map[string]*lsp.Client) BaseTool {
}
}
+func (v *viewTool) Name() string {
+ return ViewToolName
+}
+
func (v *viewTool) Info() ToolInfo {
return ToolInfo{
Name: ViewToolName,
@@ -84,6 +84,10 @@ func NewWriteTool(lspClients map[string]*lsp.Client, permissions permission.Serv
}
}
+func (w *writeTool) Name() string {
+ return WriteToolName
+}
+
func (w *writeTool) Info() ToolInfo {
return ToolInfo{
Name: WriteToolName,