Detailed changes
@@ -446,7 +446,7 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fan
}
}
- for _, tool := range tools.GetMCPTools(c.permissions, c.cfg.WorkingDir()) {
+ for _, tool := range tools.GetMCPTools(c.permissions, c.cfg, c.cfg.WorkingDir()) {
if agent.AllowedMCP == nil {
// No MCP restrictions
filteredTools = append(filteredTools, tool)
@@ -6,11 +6,12 @@ import (
"charm.land/fantasy"
"github.com/charmbracelet/crush/internal/agent/tools/mcp"
+ "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/permission"
)
// GetMCPTools gets all the currently available MCP tools.
-func GetMCPTools(permissions permission.Service, wd string) []*Tool {
+func GetMCPTools(permissions permission.Service, cfg *config.Config, wd string) []*Tool {
var result []*Tool
for mcpName, tools := range mcp.Tools() {
for _, tool := range tools {
@@ -19,6 +20,7 @@ func GetMCPTools(permissions permission.Service, wd string) []*Tool {
tool: tool,
permissions: permissions,
workingDir: wd,
+ cfg: cfg,
})
}
}
@@ -29,6 +31,7 @@ func GetMCPTools(permissions permission.Service, wd string) []*Tool {
type Tool struct {
mcpName string
tool *mcp.Tool
+ cfg *config.Config
permissions permission.Service
workingDir string
providerOptions fantasy.ProviderOptions
@@ -107,7 +110,7 @@ func (m *Tool) Run(ctx context.Context, params fantasy.ToolCall) (fantasy.ToolRe
return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
}
- result, err := mcp.RunTool(ctx, m.mcpName, m.tool.Name, params.Input)
+ result, err := mcp.RunTool(ctx, m.cfg, m.mcpName, m.tool.Name, params.Input)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
@@ -189,7 +189,7 @@ func Initialize(ctx context.Context, permissions permission.Service, cfg *config
return
}
- toolCount := updateTools(name, tools)
+ toolCount := updateTools(cfg, name, tools)
updatePrompts(name, prompts)
sessions.Set(name, session)
@@ -214,13 +214,12 @@ func WaitForInit(ctx context.Context) error {
}
}
-func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
+func getOrRenewClient(ctx context.Context, cfg *config.Config, name string) (*mcp.ClientSession, error) {
sess, ok := sessions.Get(name)
if !ok {
return nil, fmt.Errorf("mcp '%s' not available", name)
}
- cfg := config.Get()
m := cfg.MCP[name]
state, _ := states.Get(name)
@@ -5,6 +5,7 @@ import (
"iter"
"log/slog"
+ "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/csync"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
@@ -19,8 +20,8 @@ func Prompts() iter.Seq2[string, []*Prompt] {
}
// GetPromptMessages retrieves the content of an MCP prompt with the given arguments.
-func GetPromptMessages(ctx context.Context, clientName, promptName string, args map[string]string) ([]string, error) {
- c, err := getOrRenewClient(ctx, clientName)
+func GetPromptMessages(ctx context.Context, cfg *config.Config, clientName, promptName string, args map[string]string) ([]string, error) {
+ c, err := getOrRenewClient(ctx, cfg, clientName)
if err != nil {
return nil, err
}
@@ -32,13 +32,13 @@ func Tools() iter.Seq2[string, []*Tool] {
}
// RunTool runs an MCP tool with the given input parameters.
-func RunTool(ctx context.Context, name, toolName string, input string) (ToolResult, error) {
+func RunTool(ctx context.Context, cfg *config.Config, name, toolName string, input string) (ToolResult, error) {
var args map[string]any
if err := json.Unmarshal([]byte(input), &args); err != nil {
return ToolResult{}, fmt.Errorf("error parsing parameters: %s", err)
}
- c, err := getOrRenewClient(ctx, name)
+ c, err := getOrRenewClient(ctx, cfg, name)
if err != nil {
return ToolResult{}, err
}
@@ -108,7 +108,7 @@ func RunTool(ctx context.Context, name, toolName string, input string) (ToolResu
// RefreshTools gets the updated list of tools from the MCP and updates the
// global state.
-func RefreshTools(ctx context.Context, name string) {
+func RefreshTools(ctx context.Context, cfg *config.Config, name string) {
session, ok := sessions.Get(name)
if !ok {
slog.Warn("Refresh tools: no session", "name", name)
@@ -121,7 +121,7 @@ func RefreshTools(ctx context.Context, name string) {
return
}
- toolCount := updateTools(name, tools)
+ toolCount := updateTools(cfg, name, tools)
prev, _ := states.Get(name)
prev.Counts.Tools = toolCount
@@ -139,8 +139,8 @@ func getTools(ctx context.Context, session *mcp.ClientSession) ([]*Tool, error)
return result.Tools, nil
}
-func updateTools(name string, tools []*Tool) int {
- tools = filterDisabledTools(name, tools)
+func updateTools(cfg *config.Config, name string, tools []*Tool) int {
+ tools = filterDisabledTools(cfg, name, tools)
if len(tools) == 0 {
allTools.Del(name)
return 0
@@ -150,8 +150,7 @@ func updateTools(name string, tools []*Tool) int {
}
// filterDisabledTools removes tools that are disabled via config.
-func filterDisabledTools(mcpName string, tools []*Tool) []*Tool {
- cfg := config.Get()
+func filterDisabledTools(cfg *config.Config, mcpName string, tools []*Tool) []*Tool {
mcpCfg, ok := cfg.MCP[mcpName]
if !ok || len(mcpCfg.DisabledTools) == 0 {
return tools
@@ -114,7 +114,7 @@ func (app *App) createAndStartLSPClient(ctx context.Context, name string, config
updateLSPState(name, lsp.StateStarting, nil, nil, 0)
// Create LSP client.
- lspClient, err := lsp.New(ctx, name, config, app.config.Resolver())
+ lspClient, err := lsp.New(ctx, name, config, app.config.Resolver(), app.config.Options.DebugLSP)
if err != nil {
if !userConfigured {
slog.Warn("Default LSP config skipped due to error", "name", name, "error", err)
@@ -52,17 +52,16 @@ crush login copilot
}
switch provider {
case "hyper":
- return loginHyper()
+ return loginHyper(app.Config())
case "copilot", "github", "github-copilot":
- return loginCopilot()
+ return loginCopilot(app.Config())
default:
return fmt.Errorf("unknown platform: %s", args[0])
}
},
}
-func loginHyper() error {
- cfg := config.Get()
+func loginHyper(cfg *config.Config) error {
if !hyperp.Enabled() {
return fmt.Errorf("hyper not enabled")
}
@@ -124,10 +123,9 @@ func loginHyper() error {
return nil
}
-func loginCopilot() error {
+func loginCopilot(cfg *config.Config) error {
ctx := getLoginContext()
- cfg := config.Get()
if cfg.HasConfigField("providers.copilot.oauth") {
fmt.Println("You are already logged in to GitHub Copilot.")
return nil
@@ -227,21 +227,21 @@ func setupApp(cmd *cobra.Command) (*app.App, error) {
return nil, err
}
- if shouldEnableMetrics() {
+ if shouldEnableMetrics(cfg) {
event.Init()
}
return appInstance, nil
}
-func shouldEnableMetrics() bool {
+func shouldEnableMetrics(cfg *config.Config) bool {
if v, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_METRICS")); v {
return false
}
if v, _ := strconv.ParseBool(os.Getenv("DO_NOT_TRACK")); v {
return false
}
- if config.Get().Options.DisableMetrics {
+ if cfg.Options.DisableMetrics {
return false
}
return true
@@ -227,9 +227,9 @@ func isMarkdownFile(name string) bool {
return strings.HasSuffix(strings.ToLower(name), ".md")
}
-func GetMCPPrompt(clientID, promptID string, args map[string]string) (string, error) {
+func GetMCPPrompt(cfg *config.Config, clientID, promptID string, args map[string]string) (string, error) {
// TODO: we should pass the context down
- result, err := mcp.GetPromptMessages(context.Background(), clientID, promptID, args)
+ result, err := mcp.GetPromptMessages(context.Background(), cfg, clientID, promptID, args)
if err != nil {
return "", err
}
@@ -6,7 +6,6 @@ import (
"path/filepath"
"slices"
"strings"
- "sync/atomic"
"github.com/charmbracelet/crush/internal/fsext"
)
@@ -19,25 +18,15 @@ type ProjectInitFlag struct {
Initialized bool `json:"initialized"`
}
-// TODO: we need to remove the global config instance keeping it now just until everything is migrated
-var instance atomic.Pointer[Config]
-
func Init(workingDir, dataDir string, debug bool) (*Config, error) {
cfg, err := Load(workingDir, dataDir, debug)
if err != nil {
return nil, err
}
- instance.Store(cfg)
- return instance.Load(), nil
-}
-
-func Get() *Config {
- cfg := instance.Load()
- return cfg
+ return cfg, nil
}
-func ProjectNeedsInitialization() (bool, error) {
- cfg := Get()
+func ProjectNeedsInitialization(cfg *Config) (bool, error) {
if cfg == nil {
return false, fmt.Errorf("config not loaded")
}
@@ -110,8 +99,7 @@ func dirHasNoVisibleFiles(dir string) (bool, error) {
return len(files) == 0, nil
}
-func MarkProjectInitialized() error {
- cfg := Get()
+func MarkProjectInitialized(cfg *Config) error {
if cfg == nil {
return fmt.Errorf("config not loaded")
}
@@ -126,10 +114,13 @@ func MarkProjectInitialized() error {
return nil
}
-func HasInitialDataConfig() bool {
+func HasInitialDataConfig(cfg *Config) bool {
+ if cfg == nil {
+ return false
+ }
cfgPath := GlobalConfigData()
if _, err := os.Stat(cfgPath); err != nil {
return false
}
- return Get().IsConfigured()
+ return cfg.IsConfigured()
}
@@ -35,6 +35,7 @@ type DiagnosticCounts struct {
type Client struct {
client *powernap.Client
name string
+ debug bool
// Working directory this LSP is scoped to.
workDir string
@@ -68,7 +69,7 @@ type Client struct {
}
// New creates a new LSP client using the powernap implementation.
-func New(ctx context.Context, name string, cfg config.LSPConfig, resolver config.VariableResolver) (*Client, error) {
+func New(ctx context.Context, name string, cfg config.LSPConfig, resolver config.VariableResolver, debug bool) (*Client, error) {
client := &Client{
name: name,
fileTypes: cfg.FileTypes,
@@ -76,6 +77,7 @@ func New(ctx context.Context, name string, cfg config.LSPConfig, resolver config
openFiles: csync.NewMap[string, *OpenFileInfo](),
config: cfg,
ctx: ctx,
+ debug: debug,
resolver: resolver,
}
client.serverState.Store(StateStarting)
@@ -174,7 +176,11 @@ func (c *Client) registerHandlers() {
c.RegisterServerRequestHandler("workspace/applyEdit", HandleApplyEdit)
c.RegisterServerRequestHandler("workspace/configuration", HandleWorkspaceConfiguration)
c.RegisterServerRequestHandler("client/registerCapability", HandleRegisterCapability)
- c.RegisterNotificationHandler("window/showMessage", HandleServerMessage)
+ c.RegisterNotificationHandler("window/showMessage", func(ctx context.Context, method string, params json.RawMessage) {
+ if c.debug {
+ HandleServerMessage(ctx, method, params)
+ }
+ })
c.RegisterNotificationHandler("textDocument/publishDiagnostics", func(_ context.Context, _ string, params json.RawMessage) {
HandleDiagnostics(c, params)
})
@@ -262,8 +268,6 @@ func (c *Client) SetDiagnosticsCallback(callback func(name string, count int)) {
// WaitForServerReady waits for the server to be ready
func (c *Client) WaitForServerReady(ctx context.Context) error {
- cfg := config.Get()
-
// Set initial state
c.SetServerState(StateStarting)
@@ -275,7 +279,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
- if cfg != nil && cfg.Options.DebugLSP {
+ if c.debug {
slog.Debug("Waiting for LSP server to be ready...")
}
@@ -289,7 +293,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error {
case <-ticker.C:
// Check if client is running
if !c.client.IsRunning() {
- if cfg != nil && cfg.Options.DebugLSP {
+ if c.debug {
slog.Debug("LSP server not ready yet", "server", c.name)
}
continue
@@ -297,7 +301,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error {
// Server is ready
c.SetServerState(StateReady)
- if cfg != nil && cfg.Options.DebugLSP {
+ if c.debug {
slog.Debug("LSP server is ready")
}
return nil
@@ -416,10 +420,8 @@ func (c *Client) IsFileOpen(filepath string) bool {
// CloseAllFiles closes all currently open files.
func (c *Client) CloseAllFiles(ctx context.Context) {
- cfg := config.Get()
- debugLSP := cfg != nil && cfg.Options.DebugLSP
for uri := range c.openFiles.Seq2() {
- if debugLSP {
+ if c.debug {
slog.Debug("Closing file", "file", uri)
}
if err := c.client.NotifyDidCloseTextDocument(ctx, uri); err != nil {
@@ -23,7 +23,7 @@ func TestClient(t *testing.T) {
// but we can still test the basic structure
client, err := New(ctx, "test", cfg, config.NewEnvironmentVariableResolver(env.NewFromMap(map[string]string{
"THE_CMD": "echo",
- })))
+ })), false)
if err != nil {
// Expected to fail with echo command, skip the rest
t.Skipf("Powernap client creation failed as expected with dummy command: %v", err)
@@ -5,7 +5,6 @@ import (
"encoding/json"
"log/slog"
- "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/lsp/util"
"github.com/charmbracelet/x/powernap/pkg/lsp/protocol"
)
@@ -80,11 +79,6 @@ func notifyFileWatchRegistration(id string, watchers []protocol.FileSystemWatche
// HandleServerMessage handles server messages
func HandleServerMessage(_ context.Context, method string, params json.RawMessage) {
- cfg := config.Get()
- if !cfg.Options.DebugLSP {
- return
- }
-
var msg protocol.ShowMessageParams
if err := json.Unmarshal(params, &msg); err != nil {
slog.Debug("Server message", "type", msg.Type, "message", msg.Message)
@@ -186,16 +186,18 @@ type AssistantInfoItem struct {
id string
message *message.Message
sty *styles.Styles
+ cfg *config.Config
lastUserMessageTime time.Time
}
// NewAssistantInfoItem creates a new AssistantInfoItem.
-func NewAssistantInfoItem(sty *styles.Styles, message *message.Message, lastUserMessageTime time.Time) MessageItem {
+func NewAssistantInfoItem(sty *styles.Styles, message *message.Message, cfg *config.Config, lastUserMessageTime time.Time) MessageItem {
return &AssistantInfoItem{
cachedMessageItem: &cachedMessageItem{},
id: AssistantInfoID(message.ID),
message: message,
sty: sty,
+ cfg: cfg,
lastUserMessageTime: lastUserMessageTime,
}
}
@@ -231,13 +233,13 @@ func (a *AssistantInfoItem) renderContent(width int) string {
duration := finishTime.Sub(a.lastUserMessageTime)
infoMsg := a.sty.Chat.Message.AssistantInfoDuration.Render(duration.String())
icon := a.sty.Chat.Message.AssistantInfoIcon.Render(styles.ModelIcon)
- model := config.Get().GetModel(a.message.Provider, a.message.Model)
+ model := a.cfg.GetModel(a.message.Provider, a.message.Model)
if model == nil {
model = &catwalk.Model{Name: "Unknown Model"}
}
modelFormatted := a.sty.Chat.Message.AssistantInfoModel.Render(model.Name)
providerName := a.message.Provider
- if providerConfig, ok := config.Get().Providers.Get(a.message.Provider); ok {
+ if providerConfig, ok := a.cfg.Providers.Get(a.message.Provider); ok {
providerName = providerConfig.Name
}
provider := a.sty.Chat.Message.AssistantInfoProvider.Render(fmt.Sprintf("via %s", providerName))
@@ -296,7 +296,7 @@ func (m *APIKeyInput) verifyAPIKey() tea.Msg {
Type: m.provider.Type,
BaseURL: m.provider.APIEndpoint,
}
- err := providerConfig.TestConnection(config.Get().Resolver())
+ err := providerConfig.TestConnection(m.com.Config().Resolver())
// intentionally wait for at least 750ms to make sure the user sees the spinner
elapsed := time.Since(start)
@@ -117,8 +117,8 @@ func renderHeaderDetails(
parts = append(parts, t.LSP.ErrorDiagnostic.Render(fmt.Sprintf("%s%d", styles.LSPErrorIcon, errorCount)))
}
- agentCfg := config.Get().Agents[config.AgentCoder]
- model := config.Get().GetModelByType(agentCfg.Model)
+ agentCfg := com.Config().Agents[config.AgentCoder]
+ model := com.Config().GetModelByType(agentCfg.Model)
percentage := (float64(session.CompletionTokens+session.PromptTokens) / float64(model.ContextWindow)) * 100
formattedPercentage := t.Header.Percentage.Render(fmt.Sprintf("%d%%", int(percentage)))
parts = append(parts, formattedPercentage)
@@ -19,7 +19,7 @@ import (
// markProjectInitialized marks the current project as initialized in the config.
func (m *UI) markProjectInitialized() tea.Msg {
// TODO: handle error so we show it in the tui footer
- err := config.MarkProjectInitialized()
+ err := config.MarkProjectInitialized(m.com.Config())
if err != nil {
slog.Error(err.Error())
}
@@ -298,7 +298,7 @@ func New(com *common.Common) *UI {
desiredFocus := uiFocusEditor
if !com.Config().IsConfigured() {
desiredState = uiOnboarding
- } else if n, _ := config.ProjectNeedsInitialization(); n {
+ } else if n, _ := config.ProjectNeedsInitialization(com.Config()); n {
desiredState = uiInitialize
}
@@ -776,7 +776,7 @@ func (m *UI) setSessionMessages(msgs []message.Message) tea.Cmd {
case message.Assistant:
items = append(items, chat.ExtractMessageItems(m.com.Styles, msg, toolResultMap)...)
if msg.FinishPart() != nil && msg.FinishPart().Reason == message.FinishReasonEndTurn {
- infoItem := chat.NewAssistantInfoItem(m.com.Styles, msg, time.Unix(m.lastUserMessageTime, 0))
+ infoItem := chat.NewAssistantInfoItem(m.com.Styles, msg, m.com.Config(), time.Unix(m.lastUserMessageTime, 0))
items = append(items, infoItem)
}
default:
@@ -906,7 +906,7 @@ func (m *UI) appendSessionMessage(msg message.Message) tea.Cmd {
}
}
if msg.FinishPart() != nil && msg.FinishPart().Reason == message.FinishReasonEndTurn {
- infoItem := chat.NewAssistantInfoItem(m.com.Styles, &msg, time.Unix(m.lastUserMessageTime, 0))
+ infoItem := chat.NewAssistantInfoItem(m.com.Styles, &msg, m.com.Config(), time.Unix(m.lastUserMessageTime, 0))
m.chat.AppendMessages(infoItem)
if atBottom {
if cmd := m.chat.ScrollToBottomAndAnimate(); cmd != nil {
@@ -977,7 +977,7 @@ func (m *UI) updateSessionMessage(msg message.Message) tea.Cmd {
if shouldRenderAssistant && msg.FinishPart() != nil && msg.FinishPart().Reason == message.FinishReasonEndTurn {
if infoItem := m.chat.MessageItem(chat.AssistantInfoID(msg.ID)); infoItem == nil {
- newInfoItem := chat.NewAssistantInfoItem(m.com.Styles, &msg, time.Unix(m.lastUserMessageTime, 0))
+ newInfoItem := chat.NewAssistantInfoItem(m.com.Styles, &msg, m.com.Config(), time.Unix(m.lastUserMessageTime, 0))
m.chat.AppendMessages(newInfoItem)
}
}
@@ -1249,7 +1249,7 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd {
// Attempt to import GitHub Copilot tokens from VSCode if available.
if isCopilot && !isConfigured() {
- config.Get().ImportCopilot()
+ m.com.Config().ImportCopilot()
}
if !isConfigured() {
@@ -3063,7 +3063,7 @@ func (m *UI) drawSessionDetails(scr uv.Screen, area uv.Rectangle) {
func (m *UI) runMCPPrompt(clientID, promptID string, arguments map[string]string) tea.Cmd {
load := func() tea.Msg {
- prompt, err := commands.GetMCPPrompt(clientID, promptID, arguments)
+ prompt, err := commands.GetMCPPrompt(m.com.Config(), clientID, promptID, arguments)
if err != nil {
// TODO: make this better
return util.ReportError(err)()