diff --git a/.gitignore b/.gitignore index b28e5a0c727163e8f3585522e680d1df2ad6e621..2f16f744432d89e0a72fd6ea8e359678a64b6d42 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,7 @@ # Go workspace file go.work +go.work.sum # IDE specific files .idea/ diff --git a/cmd/root.go b/cmd/root.go index e27bc46adcf38ae4b36cfba8d0f518690091242f..9ae26b993dd1be7374907305ae4cc90036cb05d6 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -12,7 +12,6 @@ import ( "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/db" - "github.com/charmbracelet/crush/internal/format" "github.com/charmbracelet/crush/internal/llm/agent" "github.com/charmbracelet/crush/internal/log" "github.com/charmbracelet/crush/internal/tui" @@ -52,14 +51,8 @@ to assist developers in writing, debugging, and understanding code directly from debug, _ := cmd.Flags().GetBool("debug") cwd, _ := cmd.Flags().GetString("cwd") prompt, _ := cmd.Flags().GetString("prompt") - outputFormat, _ := cmd.Flags().GetString("output-format") quiet, _ := cmd.Flags().GetBool("quiet") - // Validate format option - if !format.IsValid(outputFormat) { - return fmt.Errorf("invalid format option: %s\n%s", outputFormat, format.GetHelpText()) - } - if cwd != "" { err := os.Chdir(cwd) if err != nil { @@ -79,9 +72,7 @@ to assist developers in writing, debugging, and understanding code directly from return err } - // Create main context for the application - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := cmd.Context() // Connect DB, this will also run migrations conn, err := db.Connect(ctx, cfg.Options.DataDirectory) @@ -109,7 +100,7 @@ to assist developers in writing, debugging, and understanding code directly from // Non-interactive mode if prompt != "" { // Run non-interactive flow using the App method - return app.RunNonInteractive(ctx, prompt, outputFormat, quiet) + return app.RunNonInteractive(ctx, prompt, quiet) } // Set up the TUI @@ -152,6 +143,7 @@ func Execute() { context.Background(), rootCmd, fang.WithVersion(version.Version), + fang.WithNotifySignal(os.Interrupt), ); err != nil { os.Exit(1) } @@ -164,17 +156,8 @@ func init() { rootCmd.Flags().BoolP("debug", "d", false, "Debug") rootCmd.Flags().StringP("prompt", "p", "", "Prompt to run in non-interactive mode") - // Add format flag with validation logic - rootCmd.Flags().StringP("output-format", "f", format.Text.String(), - "Output format for non-interactive mode (text, json)") - // Add quiet flag to hide spinner in non-interactive mode rootCmd.Flags().BoolP("quiet", "q", false, "Hide spinner in non-interactive mode") - - // Register custom validation for the format flag - rootCmd.RegisterFlagCompletionFunc("output-format", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - return format.SupportedFormats, cobra.ShellCompDirectiveNoFileComp - }) } func maybePrependStdin(prompt string) (string, error) { diff --git a/go.mod b/go.mod index 35907121af5791acc5cfc5f3aa07f10df9eba763..d510a774a03c27ceca623400257228763cc2e9a1 100644 --- a/go.mod +++ b/go.mod @@ -18,9 +18,9 @@ require ( github.com/charlievieth/fastwalk v1.0.11 github.com/charmbracelet/bubbles/v2 v2.0.0-beta.1.0.20250710161907-a4c42b579198 github.com/charmbracelet/bubbletea/v2 v2.0.0-beta.1 - github.com/charmbracelet/fang v0.1.0 + github.com/charmbracelet/fang v0.3.1-0.20250711140230-d5ebb8c1d674 github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe - github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.2.0.20250703152125-8e1c474f8a71 + github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.3 github.com/charmbracelet/log/v2 v2.0.0-20250226163916-c379e29ff706 github.com/charmbracelet/x/ansi v0.9.3 github.com/charmbracelet/x/exp/charmtone v0.0.0-20250708181618-a60a724ba6c3 @@ -29,7 +29,7 @@ require ( github.com/fsnotify/fsnotify v1.8.0 github.com/google/uuid v1.6.0 github.com/joho/godotenv v1.5.1 - github.com/mark3labs/mcp-go v0.32.0 + github.com/mark3labs/mcp-go v0.33.0 github.com/muesli/termenv v0.16.0 github.com/ncruces/go-sqlite3 v0.25.0 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 diff --git a/go.sum b/go.sum index 50e30a46d4a47cb210add9c3fe61f0c9fb8e6c26..d7004401154b86ce0658162c06bfc610a0c77126 100644 --- a/go.sum +++ b/go.sum @@ -74,8 +74,8 @@ github.com/charmbracelet/bubbletea-internal/v2 v2.0.0-20250710185017-3c0ffd25e59 github.com/charmbracelet/bubbletea-internal/v2 v2.0.0-20250710185017-3c0ffd25e595/go.mod h1:+Tl7rePElw6OKt382t04zXwtPFoPXxAaJzNrYmtsLds= github.com/charmbracelet/colorprofile v0.3.1 h1:k8dTHMd7fgw4bnFd7jXTLZrSU/CQrKnL3m+AxCzDz40= github.com/charmbracelet/colorprofile v0.3.1/go.mod h1:/GkGusxNs8VB/RSOh3fu0TJmQ4ICMMPApIIVn0KszZ0= -github.com/charmbracelet/fang v0.1.0 h1:SlZS2crf3/zQh7Mr4+W+7QR1k+L08rrPX5rm5z3d7Wg= -github.com/charmbracelet/fang v0.1.0/go.mod h1:Zl/zeUQ8EtQuGyiV0ZKZlZPDowKRTzu8s/367EpN/fc= +github.com/charmbracelet/fang v0.3.1-0.20250711140230-d5ebb8c1d674 h1:+Cz+VfxD5DO+JT1LlswXWhre0HYLj6l2HW8HVGfMuC0= +github.com/charmbracelet/fang v0.3.1-0.20250711140230-d5ebb8c1d674/go.mod h1:9gCUAHmVx5BwSafeyNr3GI0GgvlB1WYjL21SkPp1jyU= github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe h1:i6ce4CcAlPpTj2ER69m1DBeLZ3RRcHnKExuwhKa3GfY= github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe/go.mod h1:p3Q+aN4eQKeM5jhrmXPMgPrlKbmc59rWSnMsSA3udhk= github.com/charmbracelet/lipgloss-internal/v2 v2.0.0-20250710185058-03664cb9cecb h1:lswj7CYZVYbLn2OhYJsXOMRQQGdRIfyuSnh5FdVSMr0= @@ -165,8 +165,8 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0 github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= -github.com/mark3labs/mcp-go v0.32.0 h1:fgwmbfL2gbd67obg57OfV2Dnrhs1HtSdlY/i5fn7MU8= -github.com/mark3labs/mcp-go v0.32.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= +github.com/mark3labs/mcp-go v0.33.0 h1:naxhjnTIs/tyPZmWUZFuG0lDmdA6sUyYGGf3gsHvTCc= +github.com/mark3labs/mcp-go v0.33.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= diff --git a/internal/app/app.go b/internal/app/app.go index 099b092089c4a4e4e0ddcc9ccf79c36ca66acdce..c3dae3d88a2be7c4cd5491e089b97695b08a7a23 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -92,16 +92,26 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { } // RunNonInteractive handles the execution flow when a prompt is provided via CLI flag. -func (a *App) RunNonInteractive(ctx context.Context, prompt string, outputFormat string, quiet bool) error { +func (a *App) RunNonInteractive(ctx context.Context, prompt string, quiet bool) error { slog.Info("Running in non-interactive mode") + ctx, cancel := context.WithCancel(ctx) + defer cancel() + // Start spinner if not in quiet mode var spinner *format.Spinner if !quiet { - spinner = format.NewSpinner(ctx, "Generating") + spinner = format.NewSpinner(ctx, cancel, "Generating") spinner.Start() - defer spinner.Stop() } + // Helper function to stop spinner once + stopSpinner := func() { + if !quiet && spinner != nil { + spinner.Stop() + spinner = nil + } + } + defer stopSpinner() const maxPromptLengthForTitle = 100 titlePrefix := "Non-interactive: " @@ -128,35 +138,42 @@ func (a *App) RunNonInteractive(ctx context.Context, prompt string, outputFormat return fmt.Errorf("failed to start agent processing stream: %w", err) } - result := <-done + messageEvents := a.Messages.Subscribe(ctx) + readBts := 0 - // Stop spinner before printing output - if !quiet && spinner != nil { - spinner.Stop() - } + for { + select { + case result := <-done: + stopSpinner() + + if result.Error != nil { + if errors.Is(result.Error, context.Canceled) || errors.Is(result.Error, agent.ErrRequestCancelled) { + slog.Info("Agent processing cancelled", "session_id", sess.ID) + return nil + } + return fmt.Errorf("agent processing failed: %w", result.Error) + } + + part := result.Message.Content().String()[readBts:] + fmt.Println(part) - if result.Error != nil { - if errors.Is(result.Error, context.Canceled) || errors.Is(result.Error, agent.ErrRequestCancelled) { - slog.Info("Agent processing cancelled", "session_id", sess.ID) + slog.Info("Non-interactive run completed", "session_id", sess.ID) return nil - } - return fmt.Errorf("agent processing failed: %w", result.Error) - } - // Get the text content from the response - content := "No content available" - if result.Message.Content().String() != "" { - content = result.Message.Content().String() - } + case event := <-messageEvents: + msg := event.Payload + if msg.SessionID == sess.ID && msg.Role == message.Assistant && len(msg.Parts) > 0 { + stopSpinner() + part := msg.Content().String()[readBts:] + fmt.Print(part) + readBts += len(part) + } - out, err := format.FormatOutput(content, outputFormat) - if err != nil { - return err + case <-ctx.Done(): + stopSpinner() + return ctx.Err() + } } - - fmt.Println(out) - slog.Info("Non-interactive run completed", "session_id", sess.ID) - return nil } func (app *App) UpdateAgentModel() error { diff --git a/internal/app/lsp.go b/internal/app/lsp.go index ba98d4b3a074c2e9abcef87eb3030a21be669eab..33506016690645dd714c682ddd2e65e992d2d1f9 100644 --- a/internal/app/lsp.go +++ b/internal/app/lsp.go @@ -59,11 +59,8 @@ func (app *App) createAndStartLSPClient(ctx context.Context, name string, comman // Create a child context that can be canceled when the app is shutting down watchCtx, cancelFunc := context.WithCancel(ctx) - // Create a context with the server name for better identification - watchCtx = context.WithValue(watchCtx, "serverName", name) - // Create the workspace watcher - workspaceWatcher := watcher.NewWorkspaceWatcher(lspClient) + workspaceWatcher := watcher.NewWorkspaceWatcher(name, lspClient) // Store the cancel function to be called during cleanup app.cancelFuncsMutex.Lock() diff --git a/internal/config/config.go b/internal/config/config.go index 5c978106bc49f7b5956ea1d1d6e4d994f53eae58..5108a5cbee1684b92f779243b35aa3a50f162e60 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -72,7 +72,7 @@ type ProviderConfig struct { Disable bool `json:"disable,omitempty"` // Extra headers to send with each request to the provider. - ExtraHeaders map[string]string + ExtraHeaders map[string]string `json:"extra_headers,omitempty"` // Used to pass extra parameters to the provider. ExtraParams map[string]string `json:"-"` @@ -371,3 +371,34 @@ func (c *Config) SetProviderAPIKey(providerID, apiKey string) error { c.Providers[providerID] = providerConfig return nil } + +func (c *Config) SetupAgents() { + agents := map[string]Agent{ + "coder": { + ID: "coder", + Name: "Coder", + Description: "An agent that helps with executing coding tasks.", + Model: SelectedModelTypeLarge, + ContextPaths: c.Options.ContextPaths, + // All tools allowed + }, + "task": { + ID: "task", + Name: "Task", + Description: "An agent that helps with searching for context and finding implementation details.", + Model: SelectedModelTypeLarge, + ContextPaths: c.Options.ContextPaths, + AllowedTools: []string{ + "glob", + "grep", + "ls", + "sourcegraph", + "view", + }, + // NO MCPs or LSPs by default + AllowedMCP: map[string][]string{}, + AllowedLSP: []string{}, + }, + } + c.Agents = agents +} diff --git a/internal/config/load.go b/internal/config/load.go index 9f2b5e55f1ccc0a687d46083b67e81d6e5fa212a..81cb4398e5b3a7a2147ab5388b37088788ea041b 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -83,37 +83,7 @@ func Load(workingDir string, debug bool) (*Config, error) { if err := cfg.configureSelectedModels(providers); err != nil { return nil, fmt.Errorf("failed to configure selected models: %w", err) } - - // TODO: remove the agents concept from the config - agents := map[string]Agent{ - "coder": { - ID: "coder", - Name: "Coder", - Description: "An agent that helps with executing coding tasks.", - Model: SelectedModelTypeLarge, - ContextPaths: cfg.Options.ContextPaths, - // All tools allowed - }, - "task": { - ID: "task", - Name: "Task", - Description: "An agent that helps with searching for context and finding implementation details.", - Model: SelectedModelTypeLarge, - ContextPaths: cfg.Options.ContextPaths, - AllowedTools: []string{ - "glob", - "grep", - "ls", - "sourcegraph", - "view", - }, - // NO MCPs or LSPs by default - AllowedMCP: map[string][]string{}, - AllowedLSP: []string{}, - }, - } - cfg.Agents = agents - + cfg.SetupAgents() return cfg, nil } @@ -331,6 +301,7 @@ func (cfg *Config) defaultModelSelection(knownProviders []provider.Provider) (la defaultSmallModel := cfg.GetModel(string(p.ID), p.DefaultSmallModelID) if defaultSmallModel == nil { err = fmt.Errorf("default small model %s not found for provider %s", p.DefaultSmallModelID, p.ID) + return } smallModel = SelectedModel{ Provider: string(p.ID), @@ -387,8 +358,6 @@ func (cfg *Config) configureSelectedModels(knownProviders []provider.Provider) e large.Provider = largeModelSelected.Provider } model := cfg.GetModel(large.Provider, large.Model) - slog.Info("Configuring selected large model", "provider", large.Provider, "model", large.Model) - slog.Info("Model configured", "model", model) if model == nil { large = defaultLarge // override the model type to large diff --git a/internal/config/resolve.go b/internal/config/resolve.go index 9c9116661814fe7abee91e2821829442bc65080d..3c97a6456cf7fe5968311746d62b2772b21d6aaa 100644 --- a/internal/config/resolve.go +++ b/internal/config/resolve.go @@ -44,7 +44,7 @@ func (r *shellVariableResolver) ResolveValue(value string) (string, error) { if strings.HasPrefix(value, "$(") && strings.HasSuffix(value, ")") { command := strings.TrimSuffix(strings.TrimPrefix(value, "$("), ")") - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() stdout, _, err := r.shell.Exec(ctx, command) diff --git a/internal/format/format.go b/internal/format/format.go deleted file mode 100644 index 9f5a98910cafa41b924ff516da54ab751eb7f058..0000000000000000000000000000000000000000 --- a/internal/format/format.go +++ /dev/null @@ -1,91 +0,0 @@ -package format - -import ( - "encoding/json" - "fmt" - "strings" -) - -// OutputFormat represents the output format type for non-interactive mode -type OutputFormat string - -const ( - // Text format outputs the AI response as plain text. - Text OutputFormat = "text" - - // JSON format outputs the AI response wrapped in a JSON object. - JSON OutputFormat = "json" -) - -// String returns the string representation of the OutputFormat -func (f OutputFormat) String() string { - return string(f) -} - -// SupportedFormats is a list of all supported output formats as strings -var SupportedFormats = []string{ - string(Text), - string(JSON), -} - -// Parse converts a string to an OutputFormat -func Parse(s string) (OutputFormat, error) { - s = strings.ToLower(strings.TrimSpace(s)) - - switch s { - case string(Text): - return Text, nil - case string(JSON): - return JSON, nil - default: - return "", fmt.Errorf("invalid format: %s", s) - } -} - -// IsValid checks if the provided format string is supported -func IsValid(s string) bool { - _, err := Parse(s) - return err == nil -} - -// GetHelpText returns a formatted string describing all supported formats -func GetHelpText() string { - return fmt.Sprintf(`Supported output formats: -- %s: Plain text output (default) -- %s: Output wrapped in a JSON object`, - Text, JSON) -} - -// FormatOutput formats the AI response according to the specified format -func FormatOutput(content string, formatStr string) (string, error) { - format, err := Parse(formatStr) - if err != nil { - format = Text - } - - switch format { - case JSON: - return formatAsJSON(content) - case Text: - fallthrough - default: - return content, nil - } -} - -// formatAsJSON wraps the content in a simple JSON object -func formatAsJSON(content string) (string, error) { - // Use the JSON package to properly escape the content - response := struct { - Response string `json:"response"` - }{ - Response: content, - } - - jsonBytes, err := json.MarshalIndent(response, "", " ") - if err != nil { - return "", fmt.Errorf("failed to marshal output into JSON: %w", err) - } - - return string(jsonBytes), nil -} diff --git a/internal/format/spinner.go b/internal/format/spinner.go index 9377bd3b4c145fc6866ac1e0f4e63dff8ab51619..da64fb93ce262e04a0b5fb9da8c4aea8403d10d8 100644 --- a/internal/format/spinner.go +++ b/internal/format/spinner.go @@ -18,24 +18,48 @@ type Spinner struct { prog *tea.Program } +type model struct { + cancel context.CancelFunc + anim anim.Anim +} + +func (m model) Init() tea.Cmd { return m.anim.Init() } +func (m model) View() string { return m.anim.View() } + +// Update implements tea.Model. +func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.KeyPressMsg: + switch msg.String() { + case "ctrl+c", "esc": + m.cancel() + return m, tea.Quit + } + } + mm, cmd := m.anim.Update(msg) + m.anim = mm.(anim.Anim) + return m, cmd +} + // NewSpinner creates a new spinner with the given message -func NewSpinner(ctx context.Context, message string) *Spinner { +func NewSpinner(ctx context.Context, cancel context.CancelFunc, message string) *Spinner { t := styles.CurrentTheme() - model := anim.New(anim.Settings{ - Size: 10, - Label: message, - LabelColor: t.FgBase, - GradColorA: t.Primary, - GradColorB: t.Secondary, - CycleColors: true, - }) + model := model{ + anim: anim.New(anim.Settings{ + Size: 10, + Label: message, + LabelColor: t.FgBase, + GradColorA: t.Primary, + GradColorB: t.Secondary, + CycleColors: true, + }), + cancel: cancel, + } prog := tea.NewProgram( model, - tea.WithInput(nil), tea.WithOutput(os.Stderr), tea.WithContext(ctx), - tea.WithoutCatchPanics(), ) return &Spinner{ @@ -47,13 +71,13 @@ func NewSpinner(ctx context.Context, message string) *Spinner { // Start begins the spinner animation func (s *Spinner) Start() { go func() { + defer close(s.done) _, err := s.prog.Run() // ensures line is cleared fmt.Fprint(os.Stderr, ansi.EraseEntireLine) - if err != nil && !errors.Is(err, context.Canceled) { + if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, tea.ErrInterrupted) { fmt.Fprintf(os.Stderr, "Error running spinner: %v\n", err) } - close(s.done) }() } diff --git a/internal/fur/client/client.go b/internal/fur/client/client.go index 5f0ddeaeee708d4b5475403ce1874591f7e9bb2c..d007c9aee18f77c8b03fe804726b4196e474d0b4 100644 --- a/internal/fur/client/client.go +++ b/internal/fur/client/client.go @@ -10,7 +10,7 @@ import ( "github.com/charmbracelet/crush/internal/fur/provider" ) -const defaultURL = "https://fur.charmcli.dev" +const defaultURL = "https://fur.charm.sh" // Client represents a client for the fur service. type Client struct { diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 313b83c0448d8a668e2390368c6797c82dd22452..fbb5b4fd8c6390ff0dfad0e072af35342355ba41 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -149,7 +149,7 @@ func NewAgent( } opts := []provider.ProviderClientOption{ provider.WithModel(agentCfg.Model), - provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)), + provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)), } agentProvider, err := provider.NewProvider(*providerCfg, opts...) if err != nil { @@ -827,7 +827,7 @@ func (a *agent) UpdateModel() error { opts := []provider.ProviderClientOption{ provider.WithModel(a.agentCfg.Model), - provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)), + provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)), } newProvider, err := provider.NewProvider(*currentProviderCfg, opts...) diff --git a/internal/llm/prompt/prompt.go b/internal/llm/prompt/prompt.go index 835279b4f4c0e08e46aaad271b7cb7f2a59b467f..4a2661bb9f663d9f93cf0371ac5d71dd513392c7 100644 --- a/internal/llm/prompt/prompt.go +++ b/internal/llm/prompt/prompt.go @@ -5,6 +5,9 @@ import ( "path/filepath" "strings" "sync" + + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/env" ) type PromptID string @@ -21,7 +24,7 @@ func GetPrompt(promptID PromptID, provider string, contextPaths ...string) strin basePrompt := "" switch promptID { case PromptCoder: - basePrompt = CoderPrompt(provider) + basePrompt = CoderPrompt(provider, contextPaths...) case PromptTitle: basePrompt = TitlePrompt() case PromptTask: @@ -38,6 +41,32 @@ func getContextFromPaths(workingDir string, contextPaths []string) string { return processContextPaths(workingDir, contextPaths) } +// expandPath expands ~ and environment variables in file paths +func expandPath(path string) string { + // Handle tilde expansion + if strings.HasPrefix(path, "~/") { + homeDir, err := os.UserHomeDir() + if err == nil { + path = filepath.Join(homeDir, path[2:]) + } + } else if path == "~" { + homeDir, err := os.UserHomeDir() + if err == nil { + path = homeDir + } + } + + // Handle environment variable expansion using the same pattern as config + if strings.HasPrefix(path, "$") { + resolver := config.NewEnvironmentVariableResolver(env.New()) + if expanded, err := resolver.ResolveValue(path); err == nil { + path = expanded + } + } + + return path +} + func processContextPaths(workDir string, paths []string) string { var ( wg sync.WaitGroup @@ -53,8 +82,23 @@ func processContextPaths(workDir string, paths []string) string { go func(p string) { defer wg.Done() - if strings.HasSuffix(p, "/") { - filepath.WalkDir(filepath.Join(workDir, p), func(path string, d os.DirEntry, err error) error { + // Expand ~ and environment variables before processing + p = expandPath(p) + + // Use absolute path if provided, otherwise join with workDir + fullPath := p + if !filepath.IsAbs(p) { + fullPath = filepath.Join(workDir, p) + } + + // Check if the path is a directory using os.Stat + info, err := os.Stat(fullPath) + if err != nil { + return // Skip if path doesn't exist or can't be accessed + } + + if info.IsDir() { + filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error { if err != nil { return err } @@ -78,8 +122,7 @@ func processContextPaths(workDir string, paths []string) string { return nil }) } else { - fullPath := filepath.Join(workDir, p) - + // It's a file, process it directly // Check if we've already processed this file (case-insensitive) lowerPath := strings.ToLower(fullPath) diff --git a/internal/llm/prompt/prompt_test.go b/internal/llm/prompt/prompt_test.go new file mode 100644 index 0000000000000000000000000000000000000000..ce7fa0fb35cfdf021b886a96a828202001588a7f --- /dev/null +++ b/internal/llm/prompt/prompt_test.go @@ -0,0 +1,113 @@ +package prompt + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestExpandPath(t *testing.T) { + tests := []struct { + name string + input string + expected func() string + }{ + { + name: "regular path unchanged", + input: "/absolute/path", + expected: func() string { + return "/absolute/path" + }, + }, + { + name: "tilde expansion", + input: "~/documents", + expected: func() string { + home, _ := os.UserHomeDir() + return filepath.Join(home, "documents") + }, + }, + { + name: "tilde only", + input: "~", + expected: func() string { + home, _ := os.UserHomeDir() + return home + }, + }, + { + name: "environment variable expansion", + input: "$HOME", + expected: func() string { + return os.Getenv("HOME") + }, + }, + { + name: "relative path unchanged", + input: "relative/path", + expected: func() string { + return "relative/path" + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := expandPath(tt.input) + expected := tt.expected() + + // Skip test if environment variable is not set + if strings.HasPrefix(tt.input, "$") && expected == "" { + t.Skip("Environment variable not set") + } + + if result != expected { + t.Errorf("expandPath(%q) = %q, want %q", tt.input, result, expected) + } + }) + } +} + +func TestProcessContextPaths(t *testing.T) { + // Create a temporary directory and file for testing + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + testContent := "test content" + + err := os.WriteFile(testFile, []byte(testContent), 0o644) + if err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + + // Test with absolute path to file + result := processContextPaths("", []string{testFile}) + expected := "# From:" + testFile + "\n" + testContent + + if result != expected { + t.Errorf("processContextPaths with absolute path failed.\nGot: %q\nWant: %q", result, expected) + } + + // Test with directory path (should process all files in directory) + result = processContextPaths("", []string{tmpDir}) + if !strings.Contains(result, testContent) { + t.Errorf("processContextPaths with directory path failed to include file content") + } + + // Test with tilde expansion (if we can create a file in home directory) + tmpDir = t.TempDir() + t.Setenv("HOME", tmpDir) + homeTestFile := filepath.Join(tmpDir, "crush_test_file.txt") + err = os.WriteFile(homeTestFile, []byte(testContent), 0o644) + if err == nil { + defer os.Remove(homeTestFile) // Clean up + + tildeFile := "~/crush_test_file.txt" + result = processContextPaths("", []string{tildeFile}) + expected = "# From:" + homeTestFile + "\n" + testContent + + if result != expected { + t.Errorf("processContextPaths with tilde expansion failed.\nGot: %q\nWant: %q", result, expected) + } + } +} diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index 0a10568a39315f6c4077385b8ca83f6b3e52691c..6d7a9a32b3829da02021be80e6e41e28888efd83 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "log/slog" "runtime" "strings" "time" @@ -41,9 +42,74 @@ const ( ) var bannedCommands = []string{ - "alias", "curl", "curlie", "wget", "axel", "aria2c", - "nc", "telnet", "lynx", "w3m", "links", "httpie", "xh", - "http-prompt", "chrome", "firefox", "safari", + // Network/Download tools + "alias", + "aria2c", + "axel", + "chrome", + "curl", + "curlie", + "firefox", + "http-prompt", + "httpie", + "links", + "lynx", + "nc", + "safari", + "telnet", + "w3m", + "wget", + "xh", + + // System administration + "doas", + "su", + "sudo", + + // Package managers + "apk", + "apt", + "apt-cache", + "apt-get", + "dnf", + "dpkg", + "emerge", + "home-manager", + "makepkg", + "opkg", + "pacman", + "paru", + "pkg", + "pkg_add", + "pkg_delete", + "portage", + "rpm", + "yay", + "yum", + "zypper", + + // System modification + "at", + "batch", + "chkconfig", + "crontab", + "fdisk", + "mkfs", + "mount", + "parted", + "service", + "systemctl", + "umount", + + // Network configuration + "firewall-cmd", + "ifconfig", + "ip", + "iptables", + "netstat", + "pfctl", + "route", + "ufw", } // getSafeReadOnlyCommands returns platform-appropriate safe commands @@ -244,7 +310,42 @@ Important: - Never update git config`, bannedCommandsStr, MaxOutputLength) } +func blockFuncs() []shell.BlockFunc { + return []shell.BlockFunc{ + shell.CommandsBlocker(bannedCommands), + shell.ArgumentsBlocker([][]string{ + // System package managers + {"apk", "add"}, + {"apt", "install"}, + {"apt-get", "install"}, + {"dnf", "install"}, + {"emerge"}, + {"pacman", "-S"}, + {"pkg", "install"}, + {"yum", "install"}, + {"zypper", "install"}, + + // Language-specific package managers + {"brew", "install"}, + {"cargo", "install"}, + {"gem", "install"}, + {"go", "install"}, + {"npm", "install", "-g"}, + {"npm", "install", "--global"}, + {"pip", "install", "--user"}, + {"pip3", "install", "--user"}, + {"pnpm", "add", "-g"}, + {"pnpm", "add", "--global"}, + {"yarn", "global", "add"}, + }), + } +} + func NewBashTool(permission permission.Service, workingDir string) BaseTool { + // Set up command blocking on the persistent shell + persistentShell := shell.GetPersistentShell(workingDir) + persistentShell.SetBlockFuncs(blockFuncs()) + return &bashTool{ permissions: permission, workingDir: workingDir, @@ -289,13 +390,6 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return NewTextErrorResponse("missing command"), nil } - baseCmd := strings.Fields(params.Command)[0] - for _, banned := range bannedCommands { - if strings.EqualFold(baseCmd, banned) { - return NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", baseCmd)), nil - } - } - isSafeReadOnly := false cmdLower := strings.ToLower(params.Command) @@ -349,7 +443,20 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) stdout = truncateOutput(stdout) stderr = truncateOutput(stderr) + slog.Info("Bash command executed", + "command", params.Command, + "stdout", stdout, + "stderr", stderr, + "exit_code", exitCode, + "interrupted", interrupted, + "err", err, + ) + errorMessage := stderr + if errorMessage == "" && err != nil { + errorMessage = err.Error() + } + if interrupted { if errorMessage != "" { errorMessage += "\n" diff --git a/internal/lsp/transport.go b/internal/lsp/transport.go index 431a099fa1cda5e5035de7ce6c10ef3761e397ea..9a3dfd261fb68b1afdd17f614daab761f9294327 100644 --- a/internal/lsp/transport.go +++ b/internal/lsp/transport.go @@ -222,29 +222,32 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any } // Wait for response - resp := <-ch - - if cfg.Options.DebugLSP { - slog.Debug("Received response", "id", id) - } - - if resp.Error != nil { - return fmt.Errorf("request failed: %s (code: %d)", resp.Error.Message, resp.Error.Code) - } + select { + case <-ctx.Done(): + return ctx.Err() + case resp := <-ch: + if cfg.Options.DebugLSP { + slog.Debug("Received response", "id", id) + } - if result != nil { - // If result is a json.RawMessage, just copy the raw bytes - if rawMsg, ok := result.(*json.RawMessage); ok { - *rawMsg = resp.Result - return nil + if resp.Error != nil { + return fmt.Errorf("request failed: %s (code: %d)", resp.Error.Message, resp.Error.Code) } - // Otherwise unmarshal into the provided type - if err := json.Unmarshal(resp.Result, result); err != nil { - return fmt.Errorf("failed to unmarshal result: %w", err) + + if result != nil { + // If result is a json.RawMessage, just copy the raw bytes + if rawMsg, ok := result.(*json.RawMessage); ok { + *rawMsg = resp.Result + return nil + } + // Otherwise unmarshal into the provided type + if err := json.Unmarshal(resp.Result, result); err != nil { + return fmt.Errorf("failed to unmarshal result: %w", err) + } } - } - return nil + return nil + } } // Notify sends a notification (a request without an ID that doesn't expect a response) diff --git a/internal/lsp/watcher/watcher.go b/internal/lsp/watcher/watcher.go index a6d27f057e06ea7026a6eed0308979991a44fb9d..5bd016eebe413a17acca29ef628612825d40b923 100644 --- a/internal/lsp/watcher/watcher.go +++ b/internal/lsp/watcher/watcher.go @@ -21,6 +21,7 @@ import ( // WorkspaceWatcher manages LSP file watching type WorkspaceWatcher struct { client *lsp.Client + name string workspacePath string debounceTime time.Duration @@ -33,8 +34,9 @@ type WorkspaceWatcher struct { } // NewWorkspaceWatcher creates a new workspace watcher -func NewWorkspaceWatcher(client *lsp.Client) *WorkspaceWatcher { +func NewWorkspaceWatcher(name string, client *lsp.Client) *WorkspaceWatcher { return &WorkspaceWatcher{ + name: name, client: client, debounceTime: 300 * time.Millisecond, debounceMap: make(map[string]*time.Timer), @@ -95,7 +97,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc } // Determine server type for specialized handling - serverName := getServerNameFromContext(ctx) + serverName := w.name slog.Debug("Server type detected", "serverName", serverName) // Check if this server has sent file watchers @@ -325,17 +327,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str cfg := config.Get() w.workspacePath = workspacePath - // Store the watcher in the context for later use - ctx = context.WithValue(ctx, "workspaceWatcher", w) - - // If the server name isn't already in the context, try to detect it - if _, ok := ctx.Value("serverName").(string); !ok { - serverName := getServerNameFromContext(ctx) - ctx = context.WithValue(ctx, "serverName", serverName) - } - - serverName := getServerNameFromContext(ctx) - slog.Debug("Starting workspace watcher", "workspacePath", workspacePath, "serverName", serverName) + slog.Debug("Starting workspace watcher", "workspacePath", workspacePath, "serverName", w.name) // Register handler for file watcher registrations from the server lsp.RegisterFileWatchHandler(func(id string, watchers []protocol.FileSystemWatcher) { @@ -697,40 +689,6 @@ func (w *WorkspaceWatcher) notifyFileEvent(ctx context.Context, uri string, chan return w.client.DidChangeWatchedFiles(ctx, params) } -// getServerNameFromContext extracts the server name from the context -// This is a best-effort function that tries to identify which LSP server we're dealing with -func getServerNameFromContext(ctx context.Context) string { - // First check if the server name is directly stored in the context - if serverName, ok := ctx.Value("serverName").(string); ok && serverName != "" { - return strings.ToLower(serverName) - } - - // Otherwise, try to extract server name from the client command path - if w, ok := ctx.Value("workspaceWatcher").(*WorkspaceWatcher); ok && w != nil && w.client != nil && w.client.Cmd != nil { - path := strings.ToLower(w.client.Cmd.Path) - - // Extract server name from path - if strings.Contains(path, "typescript") || strings.Contains(path, "tsserver") || strings.Contains(path, "vtsls") { - return "typescript" - } else if strings.Contains(path, "gopls") { - return "gopls" - } else if strings.Contains(path, "rust-analyzer") { - return "rust-analyzer" - } else if strings.Contains(path, "pyright") || strings.Contains(path, "pylsp") || strings.Contains(path, "python") { - return "python" - } else if strings.Contains(path, "clangd") { - return "clangd" - } else if strings.Contains(path, "jdtls") || strings.Contains(path, "java") { - return "java" - } - - // Return the base name as fallback - return filepath.Base(path) - } - - return "unknown" -} - // shouldPreloadFiles determines if we should preload files for a specific language server // Some servers work better with preloaded files, others don't need it func shouldPreloadFiles(serverName string) bool { @@ -884,64 +842,63 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) { } // Check if this path should be watched according to server registrations - if watched, _ := w.isPathWatched(path); watched { - // Get server name for specialized handling - serverName := getServerNameFromContext(ctx) + if watched, _ := w.isPathWatched(path); !watched { + return + } - // Check if the file is a high-priority file that should be opened immediately - // This helps with project initialization for certain language servers - if isHighPriorityFile(path, serverName) { - if cfg.Options.DebugLSP { - slog.Debug("Opening high-priority file", "path", path, "serverName", serverName) - } - if err := w.client.OpenFile(ctx, path); err != nil && cfg.Options.DebugLSP { - slog.Error("Error opening high-priority file", "path", path, "error", err) - } - return - } + serverName := w.name - // For non-high-priority files, we'll use different strategies based on server type - if shouldPreloadFiles(serverName) { - // For servers that benefit from preloading, open files but with limits + // Get server name for specialized handling + // Check if the file is a high-priority file that should be opened immediately + // This helps with project initialization for certain language servers + if isHighPriorityFile(path, serverName) { + if cfg.Options.DebugLSP { + slog.Debug("Opening high-priority file", "path", path, "serverName", serverName) + } + if err := w.client.OpenFile(ctx, path); err != nil && cfg.Options.DebugLSP { + slog.Error("Error opening high-priority file", "path", path, "error", err) + } + return + } - // Check file size - for preloading we're more conservative - if info.Size() > (1 * 1024 * 1024) { // 1MB limit for preloaded files - if cfg.Options.DebugLSP { - slog.Debug("Skipping large file for preloading", "path", path, "size", info.Size()) - } - return - } + // For non-high-priority files, we'll use different strategies based on server type + if !shouldPreloadFiles(serverName) { + return + } + // For servers that benefit from preloading, open files but with limits - // Check file extension for common source files - ext := strings.ToLower(filepath.Ext(path)) + // Check file size - for preloading we're more conservative + if info.Size() > (1 * 1024 * 1024) { // 1MB limit for preloaded files + if cfg.Options.DebugLSP { + slog.Debug("Skipping large file for preloading", "path", path, "size", info.Size()) + } + return + } - // Only preload source files for the specific language - shouldOpen := false + // Check file extension for common source files + ext := strings.ToLower(filepath.Ext(path)) - switch serverName { - case "typescript", "typescript-language-server", "tsserver", "vtsls": - shouldOpen = ext == ".ts" || ext == ".js" || ext == ".tsx" || ext == ".jsx" - case "gopls": - shouldOpen = ext == ".go" - case "rust-analyzer": - shouldOpen = ext == ".rs" - case "python", "pyright", "pylsp": - shouldOpen = ext == ".py" - case "clangd": - shouldOpen = ext == ".c" || ext == ".cpp" || ext == ".h" || ext == ".hpp" - case "java", "jdtls": - shouldOpen = ext == ".java" - default: - // For unknown servers, be conservative - shouldOpen = false - } + // Only preload source files for the specific language + var shouldOpen bool + switch serverName { + case "typescript", "typescript-language-server", "tsserver", "vtsls": + shouldOpen = ext == ".ts" || ext == ".js" || ext == ".tsx" || ext == ".jsx" + case "gopls": + shouldOpen = ext == ".go" + case "rust-analyzer": + shouldOpen = ext == ".rs" + case "python", "pyright", "pylsp": + shouldOpen = ext == ".py" + case "clangd": + shouldOpen = ext == ".c" || ext == ".cpp" || ext == ".h" || ext == ".hpp" + case "java", "jdtls": + shouldOpen = ext == ".java" + } - if shouldOpen { - // Don't need to check if it's already open - the client.OpenFile handles that - if err := w.client.OpenFile(ctx, path); err != nil && cfg.Options.DebugLSP { - slog.Error("Error opening file", "path", path, "error", err) - } - } + if shouldOpen { + // Don't need to check if it's already open - the client.OpenFile handles that + if err := w.client.OpenFile(ctx, path); err != nil && cfg.Options.DebugLSP { + slog.Error("Error opening file", "path", path, "error", err) } } } diff --git a/internal/shell/command_block_test.go b/internal/shell/command_block_test.go new file mode 100644 index 0000000000000000000000000000000000000000..fd7c46bcd98e54f44abbe982e834f3cbb04cbfa4 --- /dev/null +++ b/internal/shell/command_block_test.go @@ -0,0 +1,123 @@ +package shell + +import ( + "context" + "os" + "strings" + "testing" +) + +func TestCommandBlocking(t *testing.T) { + tests := []struct { + name string + blockFuncs []BlockFunc + command string + shouldBlock bool + }{ + { + name: "block simple command", + blockFuncs: []BlockFunc{ + func(args []string) bool { + return len(args) > 0 && args[0] == "curl" + }, + }, + command: "curl https://example.com", + shouldBlock: true, + }, + { + name: "allow non-blocked command", + blockFuncs: []BlockFunc{ + func(args []string) bool { + return len(args) > 0 && args[0] == "curl" + }, + }, + command: "echo hello", + shouldBlock: false, + }, + { + name: "block subcommand", + blockFuncs: []BlockFunc{ + func(args []string) bool { + return len(args) >= 2 && args[0] == "brew" && args[1] == "install" + }, + }, + command: "brew install wget", + shouldBlock: true, + }, + { + name: "allow different subcommand", + blockFuncs: []BlockFunc{ + func(args []string) bool { + return len(args) >= 2 && args[0] == "brew" && args[1] == "install" + }, + }, + command: "brew list", + shouldBlock: false, + }, + { + name: "block npm global install with -g", + blockFuncs: []BlockFunc{ + ArgumentsBlocker([][]string{ + {"npm", "install", "-g"}, + {"npm", "install", "--global"}, + }), + }, + command: "npm install -g typescript", + shouldBlock: true, + }, + { + name: "block npm global install with --global", + blockFuncs: []BlockFunc{ + ArgumentsBlocker([][]string{ + {"npm", "install", "-g"}, + {"npm", "install", "--global"}, + }), + }, + command: "npm install --global typescript", + shouldBlock: true, + }, + { + name: "allow npm local install", + blockFuncs: []BlockFunc{ + ArgumentsBlocker([][]string{ + {"npm", "install", "-g"}, + {"npm", "install", "--global"}, + }), + }, + command: "npm install typescript", + shouldBlock: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a temporary directory for each test + tmpDir, err := os.MkdirTemp("", "shell-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + shell := NewShell(&Options{ + WorkingDir: tmpDir, + BlockFuncs: tt.blockFuncs, + }) + + _, _, err = shell.Exec(context.Background(), tt.command) + + if tt.shouldBlock { + if err == nil { + t.Errorf("Expected command to be blocked, but it was allowed") + } else if !strings.Contains(err.Error(), "not allowed for security reasons") { + t.Errorf("Expected security error, got: %v", err) + } + } else { + // For non-blocked commands, we might get other errors (like command not found) + // but we shouldn't get the security error + if err != nil && strings.Contains(err.Error(), "not allowed for security reasons") { + t.Errorf("Command was unexpectedly blocked: %v", err) + } + } + }) + } +} diff --git a/internal/shell/shell.go b/internal/shell/shell.go index 0467c9072c5111e4b4ea9a5439519e4edf76af46..b655c5dbecf5b69c7ad102c53108733515138771 100644 --- a/internal/shell/shell.go +++ b/internal/shell/shell.go @@ -44,12 +44,16 @@ type noopLogger struct{} func (noopLogger) InfoPersist(msg string, keysAndValues ...interface{}) {} +// BlockFunc is a function that determines if a command should be blocked +type BlockFunc func(args []string) bool + // Shell provides cross-platform shell execution with optional state persistence type Shell struct { - env []string - cwd string - mu sync.Mutex - logger Logger + env []string + cwd string + mu sync.Mutex + logger Logger + blockFuncs []BlockFunc } // Options for creating a new shell @@ -57,6 +61,7 @@ type Options struct { WorkingDir string Env []string Logger Logger + BlockFuncs []BlockFunc } // NewShell creates a new shell instance with the given options @@ -81,9 +86,10 @@ func NewShell(opts *Options) *Shell { } return &Shell{ - cwd: cwd, - env: env, - logger: logger, + cwd: cwd, + env: env, + logger: logger, + blockFuncs: opts.BlockFuncs, } } @@ -152,6 +158,13 @@ func (s *Shell) SetEnv(key, value string) { s.env = append(s.env, keyPrefix+value) } +// SetBlockFuncs sets the command block functions for the shell +func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) { + s.mu.Lock() + defer s.mu.Unlock() + s.blockFuncs = blockFuncs +} + // Windows-specific commands that should use native shell var windowsNativeCommands = map[string]bool{ "dir": true, @@ -203,6 +216,60 @@ func (s *Shell) determineShellType(command string) ShellType { return ShellTypePOSIX } +// CommandsBlocker creates a BlockFunc that blocks exact command matches +func CommandsBlocker(bannedCommands []string) BlockFunc { + bannedSet := make(map[string]bool) + for _, cmd := range bannedCommands { + bannedSet[cmd] = true + } + + return func(args []string) bool { + if len(args) == 0 { + return false + } + return bannedSet[args[0]] + } +} + +// ArgumentsBlocker creates a BlockFunc that blocks specific subcommands +func ArgumentsBlocker(blockedSubCommands [][]string) BlockFunc { + return func(args []string) bool { + for _, blocked := range blockedSubCommands { + if len(args) >= len(blocked) { + match := true + for i, part := range blocked { + if args[i] != part { + match = false + break + } + } + if match { + return true + } + } + } + return false + } +} + +func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { + return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { + return func(ctx context.Context, args []string) error { + if len(args) == 0 { + return next(ctx, args) + } + + for _, blockFunc := range s.blockFuncs { + if blockFunc(args) { + return fmt.Errorf("command is not allowed for security reasons: %s", strings.Join(args, " ")) + } + } + + return next(ctx, args) + } + } +} + // execWindows executes commands using native Windows shells (cmd.exe or PowerShell) func (s *Shell) execWindows(ctx context.Context, command string, shell string) (string, string, error) { var cmd *exec.Cmd @@ -291,6 +358,7 @@ func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, interp.Interactive(false), interp.Env(expand.ListEnviron(s.env...)), interp.Dir(s.cwd), + interp.ExecHandlers(s.blockHandler()), ) if err != nil { return "", "", fmt.Errorf("could not run command: %w", err) diff --git a/internal/tui/components/chat/editor/editor.go b/internal/tui/components/chat/editor/editor.go index 2185715c813dbdcb288bddde0fe70d63046cf731..67ba67f5e6c40f16a89f7bc4fe1b6932c9989754 100644 --- a/internal/tui/components/chat/editor/editor.go +++ b/internal/tui/components/chat/editor/editor.go @@ -361,8 +361,9 @@ func (m *editorCmp) startCompletions() tea.Msg { }) } - x := m.textarea.Cursor().X + m.x + 1 - y := m.textarea.Cursor().Y + m.y + 1 + cur := m.textarea.Cursor() + x := cur.X + m.x // adjust for padding + y := cur.Y + m.y + 1 return completions.OpenCompletionsMsg{ Completions: completionItems, X: x, diff --git a/internal/tui/components/chat/messages/renderer.go b/internal/tui/components/chat/messages/renderer.go index cad86659e04c6eb77e957e2fef4885000214a953..87eb2c8476655fe7d11fc8c787e73b32d4584de4 100644 --- a/internal/tui/components/chat/messages/renderer.go +++ b/internal/tui/components/chat/messages/renderer.go @@ -207,6 +207,7 @@ func (br bashRenderer) Render(v *toolCallCmp) string { } cmd := strings.ReplaceAll(params.Command, "\n", " ") + cmd = strings.ReplaceAll(cmd, "\t", " ") args := newParamBuilder().addMain(cmd).build() return br.renderWithParams(v, "Bash", args, func() string { @@ -578,8 +579,8 @@ func renderParamList(nested bool, paramsWidth int, params ...string) string { return "" } mainParam := params[0] - if paramsWidth-3 >= 0 && len(mainParam) > paramsWidth { - mainParam = mainParam[:paramsWidth-3] + "…" + if paramsWidth >= 0 && lipgloss.Width(mainParam) > paramsWidth { + mainParam = ansi.Truncate(mainParam, paramsWidth, "…") } if len(params) == 1 { @@ -649,7 +650,7 @@ func joinHeaderBody(header, body string) string { return header } body = t.S().Base.PaddingLeft(2).Render(body) - return lipgloss.JoinVertical(lipgloss.Left, header, "", body, "") + return lipgloss.JoinVertical(lipgloss.Left, header, "", body) } func renderPlainContent(v *toolCallCmp, content string) string { diff --git a/internal/tui/components/chat/splash/keys.go b/internal/tui/components/chat/splash/keys.go index 9cf2e3124daa87b0fc62c2ea404fb1c6c86ec649..675c608a94af4aa72b701376f3983506166ac7d7 100644 --- a/internal/tui/components/chat/splash/keys.go +++ b/internal/tui/components/chat/splash/keys.go @@ -30,11 +30,11 @@ func DefaultKeyMap() KeyMap { key.WithHelp("↑", "previous item"), ), Yes: key.NewBinding( - key.WithKeys("y"), + key.WithKeys("y", "Y"), key.WithHelp("y", "yes"), ), No: key.NewBinding( - key.WithKeys("n"), + key.WithKeys("n", "N"), key.WithHelp("n", "no"), ), Tab: key.NewBinding( diff --git a/internal/tui/components/chat/splash/splash.go b/internal/tui/components/chat/splash/splash.go index 722aaea6f75c6ef0bef7e0a9ec2de319c6d71bfb..5b343e6c5538cc17b476e521e6f2bfaf6b3490cb 100644 --- a/internal/tui/components/chat/splash/splash.go +++ b/internal/tui/components/chat/splash/splash.go @@ -313,6 +313,7 @@ func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd { return util.ReportError(err) } } + cfg.SetupAgents() return nil } diff --git a/internal/tui/components/completions/completions.go b/internal/tui/components/completions/completions.go index 29ea86365e9f1532eab3aa1a61214ef74b7f4a05..5a6bcfe92e23f38c3f40c84770a0dcc9893e59d5 100644 --- a/internal/tui/components/completions/completions.go +++ b/internal/tui/components/completions/completions.go @@ -9,6 +9,8 @@ import ( "github.com/charmbracelet/lipgloss/v2" ) +const maxCompletionsHeight = 10 + type Completion struct { Title string // The title of the completion item Value any // The value of the completion item @@ -43,7 +45,7 @@ type Completions interface { type completionsCmp struct { width int height int // Height of the completions component` - x int // X position for the completions popup\ + x int // X position for the completions popup y int // Y position for the completions popup open bool // Indicates if the completions are open keyMap KeyMap @@ -150,18 +152,25 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if !c.open { return c, nil // If completions are not open, do nothing } - cmd := c.list.Filter(msg.Query) - c.height = max(min(10, len(c.list.Items())), 1) - return c, tea.Batch( - cmd, - c.list.SetSize(c.width, c.height), - ) + var cmds []tea.Cmd + cmds = append(cmds, c.list.Filter(msg.Query)) + itemsLen := len(c.list.Items()) + c.height = max(min(maxCompletionsHeight, itemsLen), 1) + cmds = append(cmds, c.list.SetSize(c.width, c.height)) + if itemsLen == 0 { + // Close completions if no items match the query + cmds = append(cmds, util.CmdHandler(CloseCompletionsMsg{})) + } + return c, tea.Batch(cmds...) } return c, nil } // View implements Completions. func (c *completionsCmp) View() string { + if !c.open { + return "" + } if len(c.list.Items()) == 0 { return c.style().Render("No completions found") } diff --git a/internal/tui/components/dialogs/permissions/keys.go b/internal/tui/components/dialogs/permissions/keys.go index 052c5222bc1ff7d7de1eb7e8f8a1378a7c79c1bc..9edc368d275d90d670eeb8f03346184d3edea800 100644 --- a/internal/tui/components/dialogs/permissions/keys.go +++ b/internal/tui/components/dialogs/permissions/keys.go @@ -109,9 +109,5 @@ func (k KeyMap) ShortHelp() []key.Binding { key.WithKeys("shift+left", "shift+down", "shift+up", "shift+right"), key.WithHelp("shift+←↓↑→", "scroll"), ), - key.NewBinding( - key.WithKeys("shift+h", "shift+j", "shift+k", "shift+l"), - key.WithHelp("shift+hjkl", "scroll"), - ), } } diff --git a/internal/tui/exp/diffview/diffview.go b/internal/tui/exp/diffview/diffview.go index bb51a7e505666e67fd9e914a135a0dd7632bb184..1cb56a678f51d0809c584edc1bedd73befc59966 100644 --- a/internal/tui/exp/diffview/diffview.go +++ b/internal/tui/exp/diffview/diffview.go @@ -365,7 +365,8 @@ func (dv *DiffView) renderUnified() string { shouldWrite := func() bool { return printedLines >= 0 } getContent := func(in string, ls LineStyle) (content string, leadingEllipsis bool) { - content = strings.TrimSuffix(in, "\n") + content = strings.ReplaceAll(in, "\r\n", "\n") + content = strings.TrimSuffix(content, "\n") content = dv.hightlightCode(content, ls.Code.GetBackground()) content = ansi.GraphemeWidth.Cut(content, dv.xOffset, len(content)) content = ansi.Truncate(content, dv.codeWidth, "…") @@ -488,7 +489,8 @@ func (dv *DiffView) renderSplit() string { shouldWrite := func() bool { return printedLines >= 0 } getContent := func(in string, ls LineStyle) (content string, leadingEllipsis bool) { - content = strings.TrimSuffix(in, "\n") + content = strings.ReplaceAll(in, "\r\n", "\n") + content = strings.TrimSuffix(content, "\n") content = dv.hightlightCode(content, ls.Code.GetBackground()) content = ansi.GraphemeWidth.Cut(content, dv.xOffset, len(content)) content = ansi.Truncate(content, dv.codeWidth, "…") diff --git a/internal/tui/page/chat/chat.go b/internal/tui/page/chat/chat.go index 33267772e96662f14934a8417149259c7d22541a..5c4b7738580db046920ac7812c7a493c21e996ee 100644 --- a/internal/tui/page/chat/chat.go +++ b/internal/tui/page/chat/chat.go @@ -2,7 +2,6 @@ package chat import ( "context" - "runtime" "time" "github.com/charmbracelet/bubbles/v2/help" @@ -615,26 +614,12 @@ func (a *chatPage) Help() help.KeyMap { fullList = append(fullList, []key.Binding{v}) } case a.isOnboarding && a.splash.IsShowingAPIKey(): - var pasteKey key.Binding - if runtime.GOOS != "darwin" { - pasteKey = key.NewBinding( - key.WithKeys("ctrl+v"), - key.WithHelp("ctrl+v", "paste API key"), - ) - } else { - pasteKey = key.NewBinding( - key.WithKeys("cmd+v"), - key.WithHelp("cmd+v", "paste API key"), - ) - } shortList = append(shortList, // Go back key.NewBinding( key.WithKeys("esc"), key.WithHelp("esc", "back"), ), - // Paste - pasteKey, // Quit key.NewBinding( key.WithKeys("ctrl+c"), diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 633766a1d80bf8b0056e8d856b71df04613e1101..365db72299865897feb94879f837baa93bff5e43 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -327,6 +327,9 @@ func (a *appModel) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd { // If the commands dialog is already open, close it return util.CmdHandler(dialogs.CloseDialogMsg{}) } + if a.dialog.HasDialogs() { + return nil // Don't open commands dialog if another dialog is active + } return util.CmdHandler(dialogs.OpenDialogMsg{ Model: commands.NewCommandDialog(a.selectedSessionID), }) diff --git a/internal/tui/util/util.go b/internal/tui/util/util.go index d737acb3f06a155ab52cc7eed7d32a634d85d582..1f4ea30c49c8fb0517a5068d3b7f05970638743a 100644 --- a/internal/tui/util/util.go +++ b/internal/tui/util/util.go @@ -1,6 +1,7 @@ package util import ( + "log/slog" "time" tea "github.com/charmbracelet/bubbletea/v2" @@ -22,6 +23,7 @@ func CmdHandler(msg tea.Msg) tea.Cmd { } func ReportError(err error) tea.Cmd { + slog.Error("Error reported", "error", err) return CmdHandler(InfoMsg{ Type: InfoTypeError, Msg: err.Error(), diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azidentity/go.work.sum b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azidentity/go.work.sum deleted file mode 100644 index c592f283b6bdb1cb2b13aa4b0769b94811a1cfe9..0000000000000000000000000000000000000000 --- a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azidentity/go.work.sum +++ /dev/null @@ -1,60 +0,0 @@ -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.0-beta.1 h1:ODs3brnqQM99Tq1PffODpAViYv3Bf8zOg464MU7p5ew= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.0-beta.1/go.mod h1:3Ug6Qzto9anB6mGlEdgYMDF5zHQ+wwhEaYR4s17PHMw= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.0 h1:fb8kj/Dh4CSwgsOzHeZY4Xh68cFVbzXx+ONXGMY//4w= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.0/go.mod h1:uReU2sSxZExRPBAg3qKzmAucSi51+SP1OhohieR821Q= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0/go.mod h1:okt5dMMTOFjX/aovMlrjvvXoPMBVSPzk9185BT0+eZM= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2/go.mod h1:yInRyqWXAuaPrgI7p70+lDDgh3mlBohis29jGMISnmc= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI= -github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/keybase/dbus v0.0.0-20220506165403-5aa21ea2c23a/go.mod h1:YPNKjjE7Ubp9dTbnWvsP3HT+hYnY6TfXzubYTBeUxc8= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= -github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= -github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= -github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= -golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= -golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= -golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= -golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= -golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= -golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= -golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= -golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o= -golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= -golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= -golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= -golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= -golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= -golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk= -golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= -golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= -golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/vendor/github.com/charmbracelet/fang/README.md b/vendor/github.com/charmbracelet/fang/README.md index 88a225cfd6e698d15dd29a9af0a5dca74b61ecf7..575b40ce13fa57eb0e41082943a3c21e05c82777 100644 --- a/vendor/github.com/charmbracelet/fang/README.md +++ b/vendor/github.com/charmbracelet/fang/README.md @@ -1,7 +1,7 @@ # Fang

- Charm Fang + Charm Fang

Latest Release @@ -12,7 +12,7 @@ The CLI starter kit. A small, experimental library for batteries-included [Cobra][cobra] applications.

- fang-02 + The Charm Fang mascot and title treatment

## Features @@ -45,6 +45,7 @@ To use it, invoke `fang.Execute` passing your root `*cobra.Command`: package main import ( + "context" "os" "github.com/charmbracelet/fang" @@ -56,7 +57,7 @@ func main() { Use: "example", Short: "A simple example program!", } - if err := fang.Execute(context.TODO(), cmd); err != nil { + if err := fang.Execute(context.Background(), cmd); err != nil { os.Exit(1) } } diff --git a/vendor/github.com/charmbracelet/fang/fang.go b/vendor/github.com/charmbracelet/fang/fang.go index c1f9bc06a5299c991bac569aa6868e3d08fcd37c..6a6ab99a63fc4debf404694473d23e0a576d2fab 100644 --- a/vendor/github.com/charmbracelet/fang/fang.go +++ b/vendor/github.com/charmbracelet/fang/fang.go @@ -4,11 +4,14 @@ package fang import ( "context" "fmt" + "io" "os" + "os/signal" "runtime/debug" "github.com/charmbracelet/colorprofile" "github.com/charmbracelet/lipgloss/v2" + "github.com/charmbracelet/x/term" mango "github.com/muesli/mango-cobra" "github.com/muesli/roff" "github.com/spf13/cobra" @@ -16,12 +19,24 @@ import ( const shaLen = 7 +// ErrorHandler handles an error, printing them to the given [io.Writer]. +// +// Note that this will only be used if the STDERR is a terminal, and should +// be used for styling only. +type ErrorHandler = func(w io.Writer, styles Styles, err error) + +// ColorSchemeFunc gets a [lipgloss.LightDarkFunc] and returns a [ColorScheme]. +type ColorSchemeFunc = func(lipgloss.LightDarkFunc) ColorScheme + type settings struct { completions bool manpages bool + skipVersion bool version string commit string - theme *ColorScheme + colorscheme ColorSchemeFunc + errHandler ErrorHandler + signals []os.Signal } // Option changes fang settings. @@ -41,10 +56,21 @@ func WithoutManpage() Option { } } +// WithColorSchemeFunc sets a function that return colorscheme. +func WithColorSchemeFunc(cs ColorSchemeFunc) Option { + return func(s *settings) { + s.colorscheme = cs + } +} + // WithTheme sets the colorscheme. +// +// Deprecated: use [WithColorSchemeFunc] instead. func WithTheme(theme ColorScheme) Option { return func(s *settings) { - s.theme = &theme + s.colorscheme = func(lipgloss.LightDarkFunc) ColorScheme { + return theme + } } } @@ -55,6 +81,13 @@ func WithVersion(version string) Option { } } +// WithoutVersion skips the `-v`/`--version` functionality. +func WithoutVersion() Option { + return func(s *settings) { + s.skipVersion = true + } +} + // WithCommit sets the commit SHA. func WithCommit(commit string) Option { return func(s *settings) { @@ -62,30 +95,45 @@ func WithCommit(commit string) Option { } } +// WithErrorHandler sets the error handler. +func WithErrorHandler(handler ErrorHandler) Option { + return func(s *settings) { + s.errHandler = handler + } +} + +// WithNotifySignal sets the signals that should interrupt the execution of the +// program. +func WithNotifySignal(signals ...os.Signal) Option { + return func(s *settings) { + s.signals = signals + } +} + // Execute applies fang to the command and executes it. func Execute(ctx context.Context, root *cobra.Command, options ...Option) error { opts := settings{ manpages: true, completions: true, + colorscheme: DefaultColorScheme, + errHandler: DefaultErrorHandler, } + for _, option := range options { option(&opts) } - if opts.theme == nil { - isDark := lipgloss.HasDarkBackground(os.Stdin, os.Stderr) - t := DefaultTheme(isDark) - opts.theme = &t + helpFunc := func(c *cobra.Command, _ []string) { + w := colorprofile.NewWriter(c.OutOrStdout(), os.Environ()) + helpFn(c, w, makeStyles(mustColorscheme(opts.colorscheme))) } - styles := makeStyles(*opts.theme) - - root.SetHelpFunc(func(c *cobra.Command, _ []string) { - w := colorprofile.NewWriter(c.OutOrStdout(), os.Environ()) - helpFn(c, w, styles) - }) root.SilenceUsage = true root.SilenceErrors = true + if !opts.skipVersion { + root.Version = buildVersion(opts) + } + root.SetHelpFunc(helpFunc) if opts.manpages { root.AddCommand(&cobra.Command{ @@ -108,34 +156,49 @@ func Execute(ctx context.Context, root *cobra.Command, options ...Option) error }) } - if opts.completions { - root.InitDefaultCompletionCmd() - } else { + if !opts.completions { root.CompletionOptions.DisableDefaultCmd = true } - if opts.version == "" { - if info, ok := debug.ReadBuildInfo(); ok && info.Main.Sum != "" { - opts.version = info.Main.Version - opts.commit = getKey(info, "vcs.revision") - } else { - opts.version = "unknown (built from source)" - } - } - if len(opts.commit) >= shaLen { - opts.version += " (" + opts.commit[:shaLen] + ")" + if len(opts.signals) > 0 { + var cancel context.CancelFunc + ctx, cancel = signal.NotifyContext(ctx, opts.signals...) + defer cancel() } - root.Version = opts.version - if err := root.ExecuteContext(ctx); err != nil { + if w, ok := root.ErrOrStderr().(term.File); ok { + // if stderr is not a tty, simply print the error without any + // styling or going through an [ErrorHandler]: + if !term.IsTerminal(w.Fd()) { + _, _ = fmt.Fprintln(w, err.Error()) + return err //nolint:wrapcheck + } + } w := colorprofile.NewWriter(root.ErrOrStderr(), os.Environ()) - writeError(w, styles, err) + opts.errHandler(w, makeStyles(mustColorscheme(opts.colorscheme)), err) return err //nolint:wrapcheck } return nil } +func buildVersion(opts settings) string { + commit := opts.commit + version := opts.version + if version == "" { + if info, ok := debug.ReadBuildInfo(); ok && info.Main.Sum != "" { + version = info.Main.Version + commit = getKey(info, "vcs.revision") + } else { + version = "unknown (built from source)" + } + } + if len(commit) >= shaLen { + version += " (" + commit[:shaLen] + ")" + } + return version +} + func getKey(info *debug.BuildInfo, key string) string { if info == nil { return "" diff --git a/vendor/github.com/charmbracelet/fang/help.go b/vendor/github.com/charmbracelet/fang/help.go index 340090eadf1f779c0e702b03440d7e7efb29b62b..ba2a6185844787e83753c51c3415d5ccc06e36ec 100644 --- a/vendor/github.com/charmbracelet/fang/help.go +++ b/vendor/github.com/charmbracelet/fang/help.go @@ -3,7 +3,10 @@ package fang import ( "cmp" "fmt" + "io" + "iter" "os" + "reflect" "regexp" "strconv" "strings" @@ -20,6 +23,7 @@ import ( const ( minSpace = 10 shortPad = 2 + longPad = 4 ) var width = sync.OnceValue(func() int { @@ -45,65 +49,95 @@ func helpFn(c *cobra.Command, w *colorprofile.Writer, styles Styles) { blockWidth = max(blockWidth, lipgloss.Width(ex)) } blockWidth = min(width()-padding, blockWidth+padding) + blockStyle := styles.Codeblock.Base.Width(blockWidth) - styles.Codeblock.Base = styles.Codeblock.Base.Width(blockWidth) + // if the color profile is ascii or notty, or if the block has no + // background color set, remove the vertical padding. + if w.Profile <= colorprofile.Ascii || reflect.DeepEqual(blockStyle.GetBackground(), lipgloss.NoColor{}) { + blockStyle = blockStyle.PaddingTop(0).PaddingBottom(0) + } _, _ = fmt.Fprintln(w, styles.Title.Render("usage")) - _, _ = fmt.Fprintln(w, styles.Codeblock.Base.Render(usage)) + _, _ = fmt.Fprintln(w, blockStyle.Render(usage)) if len(examples) > 0 { - cw := styles.Codeblock.Base.GetWidth() - styles.Codeblock.Base.GetHorizontalPadding() + cw := blockStyle.GetWidth() - blockStyle.GetHorizontalPadding() _, _ = fmt.Fprintln(w, styles.Title.Render("examples")) for i, example := range examples { if lipgloss.Width(example) > cw { examples[i] = ansi.Truncate(example, cw, "…") } } - _, _ = fmt.Fprintln(w, styles.Codeblock.Base.Render(strings.Join(examples, "\n"))) + _, _ = fmt.Fprintln(w, blockStyle.Render(strings.Join(examples, "\n"))) } + groups, groupKeys := evalGroups(c) cmds, cmdKeys := evalCmds(c, styles) flags, flagKeys := evalFlags(c, styles) space := calculateSpace(cmdKeys, flagKeys) - leftPadding := 4 - if len(cmds) > 0 { - _, _ = fmt.Fprintln(w, styles.Title.Render("commands")) - for _, k := range cmdKeys { - _, _ = fmt.Fprintln(w, lipgloss.JoinHorizontal( - lipgloss.Left, - lipgloss.NewStyle().PaddingLeft(leftPadding).Render(k), - strings.Repeat(" ", space-lipgloss.Width(k)), - cmds[k], - )) + for _, groupID := range groupKeys { + group := cmds[groupID] + if len(group) == 0 { + continue } + renderGroup(w, styles, space, groups[groupID], func(yield func(string, string) bool) { + for _, k := range cmdKeys { + cmds, ok := group[k] + if !ok { + continue + } + if !yield(k, cmds) { + return + } + } + }) } if len(flags) > 0 { - _, _ = fmt.Fprintln(w, styles.Title.Render("flags")) - for _, k := range flagKeys { - _, _ = fmt.Fprintln(w, lipgloss.JoinHorizontal( - lipgloss.Left, - lipgloss.NewStyle().PaddingLeft(leftPadding).Render(k), - strings.Repeat(" ", space-lipgloss.Width(k)), - flags[k], - )) - } + renderGroup(w, styles, space, "flags", func(yield func(string, string) bool) { + for _, k := range flagKeys { + if !yield(k, flags[k]) { + return + } + } + }) } _, _ = fmt.Fprintln(w) } -func writeError(w *colorprofile.Writer, styles Styles, err error) { +// DefaultErrorHandler is the default [ErrorHandler] implementation. +func DefaultErrorHandler(w io.Writer, styles Styles, err error) { _, _ = fmt.Fprintln(w, styles.ErrorHeader.String()) _, _ = fmt.Fprintln(w, styles.ErrorText.Render(err.Error()+".")) _, _ = fmt.Fprintln(w) - _, _ = fmt.Fprintln(w, lipgloss.JoinHorizontal( - lipgloss.Left, - styles.ErrorText.UnsetWidth().Render("Try"), - styles.Program.Flag.Render("--help"), - styles.ErrorText.UnsetWidth().UnsetMargins().UnsetTransform().PaddingLeft(1).Render("for usage."), - )) - _, _ = fmt.Fprintln(w) + if isUsageError(err) { + _, _ = fmt.Fprintln(w, lipgloss.JoinHorizontal( + lipgloss.Left, + styles.ErrorText.UnsetWidth().Render("Try"), + styles.Program.Flag.Render(" --help "), + styles.ErrorText.UnsetWidth().UnsetMargins().UnsetTransform().Render("for usage."), + )) + _, _ = fmt.Fprintln(w) + } +} + +// XXX: this is a hack to detect usage errors. +// See: https://github.com/spf13/cobra/pull/2266 +func isUsageError(err error) bool { + s := err.Error() + for _, prefix := range []string{ + "flag needs an argument:", + "unknown flag:", + "unknown shorthand flag:", + "unknown command", + "invalid argument", + } { + if strings.HasPrefix(s, prefix) { + return true + } + } + return false } func writeLongShort(w *colorprofile.Writer, styles Styles, longShort string) { @@ -118,8 +152,10 @@ var otherArgsRe = regexp.MustCompile(`(\[.*\])`) // styleUsage stylized styleUsage line for a given command. func styleUsage(c *cobra.Command, styles Program, complete bool) string { - // XXX: maybe use c.UseLine() here? u := c.Use + if complete { + u = c.UseLine() + } hasArgs := strings.Contains(u, "[args]") hasFlags := strings.Contains(u, "[flags]") || strings.Contains(u, "[--flags]") || c.HasFlags() || c.HasPersistentFlags() || c.HasAvailableFlags() hasCommands := strings.Contains(u, "[command]") || c.HasAvailableSubCommands() @@ -139,34 +175,38 @@ func styleUsage(c *cobra.Command, styles Program, complete bool) string { u = strings.TrimSpace(u) - useLine := []string{ - styles.Name.Render(u), - } - if !complete { - useLine[0] = styles.Command.Render(u) + useLine := []string{} + if complete { + parts := strings.Fields(u) + useLine = append(useLine, styles.Name.Render(parts[0])) + if len(parts) > 1 { + useLine = append(useLine, styles.Command.Render(" "+strings.Join(parts[1:], " "))) + } + } else { + useLine = append(useLine, styles.Command.Render(u)) } if hasCommands { useLine = append( useLine, - styles.DimmedArgument.Render("[command]"), + styles.DimmedArgument.Render(" [command]"), ) } if hasArgs { useLine = append( useLine, - styles.DimmedArgument.Render("[args]"), + styles.DimmedArgument.Render(" [args]"), ) } for _, arg := range otherArgs { useLine = append( useLine, - styles.DimmedArgument.Render(arg), + styles.DimmedArgument.Render(" "+arg), ) } if hasFlags { useLine = append( useLine, - styles.DimmedArgument.Render("[--flags]"), + styles.DimmedArgument.Render(" [--flags]"), ) } return lipgloss.JoinHorizontal(lipgloss.Left, useLine...) @@ -180,19 +220,21 @@ func styleExamples(c *cobra.Command, styles Styles) []string { } usage := []string{} examples := strings.Split(c.Example, "\n") + var indent bool for i, line := range examples { line = strings.TrimSpace(line) if (i == 0 || i == len(examples)-1) && line == "" { continue } - s := styleExample(c, line, styles.Codeblock) + s := styleExample(c, line, indent, styles.Codeblock) usage = append(usage, s) + indent = len(line) > 1 && (line[len(line)-1] == '\\' || line[len(line)-1] == '|') } return usage } -func styleExample(c *cobra.Command, line string, styles Codeblock) string { +func styleExample(c *cobra.Command, line string, indent bool, styles Codeblock) string { if strings.HasPrefix(line, "# ") { return lipgloss.JoinHorizontal( lipgloss.Left, @@ -200,66 +242,110 @@ func styleExample(c *cobra.Command, line string, styles Codeblock) string { ) } - args := strings.Fields(line) - var nextIsFlag bool var isQuotedString bool + var foundProgramName bool + var isRedirecting bool + programName := c.Root().Name() + args := strings.Fields(line) + var cleanArgs []string for i, arg := range args { - if i == 0 { - args[i] = styles.Program.Name.Render(arg) - continue + isQuoteStart := arg[0] == '"' || arg[0] == '\'' + isQuoteEnd := arg[len(arg)-1] == '"' || arg[len(arg)-1] == '\'' + isFlag := arg[0] == '-' + + switch i { + case 0: + args[i] = "" + if indent { + args[i] = styles.Program.DimmedArgument.Render(" ") + indent = false + } + default: + args[i] = styles.Program.DimmedArgument.Render(" ") } - quoteStart := arg[0] == '"' - quoteEnd := arg[len(arg)-1] == '"' - flagStart := arg[0] == '-' - if i == 1 && !quoteStart && !flagStart { - args[i] = styles.Program.Command.Render(arg) + if isRedirecting { + args[i] += styles.Program.DimmedArgument.Render(arg) + isRedirecting = false continue } - if quoteStart { - isQuotedString = true - } - if isQuotedString { - args[i] = styles.Program.QuotedString.Render(arg) - if quoteEnd { - isQuotedString = false + + switch arg { + case "\\": + if i == len(args)-1 { + args[i] += styles.Program.DimmedArgument.Render(arg) + continue } + case "|", "||", "-", "&", "&&": + args[i] += styles.Program.DimmedArgument.Render(arg) continue } - if nextIsFlag { - args[i] = styles.Program.Flag.Render(arg) + + if isRedirect(arg) { + args[i] += styles.Program.DimmedArgument.Render(arg) + isRedirecting = true continue } - var dashes string - if strings.HasPrefix(arg, "-") { - dashes = "-" + + if !foundProgramName { //nolint:nestif + if isQuotedString { + args[i] += styles.Program.QuotedString.Render(arg) + isQuotedString = !isQuoteEnd + continue + } + if left, right, ok := strings.Cut(arg, "="); ok { + args[i] += styles.Program.Flag.Render(left + "=") + if right[0] == '"' { + isQuotedString = true + args[i] += styles.Program.QuotedString.Render(right) + continue + } + args[i] += styles.Program.Argument.Render(right) + continue + } + + if arg == programName { + args[i] += styles.Program.Name.Render(arg) + foundProgramName = true + continue + } } - if strings.HasPrefix(arg, "--") { - dashes = "--" + + if !isQuoteStart && !isQuotedString && !isFlag { + cleanArgs = append(cleanArgs, arg) + } + + if !isQuoteStart && !isFlag && isSubCommand(c, cleanArgs, arg) { + args[i] += styles.Program.Command.Render(arg) + continue + } + isQuotedString = isQuotedString || isQuoteStart + if isQuotedString { + args[i] += styles.Program.QuotedString.Render(arg) + isQuotedString = !isQuoteEnd + continue } // handle a flag - if dashes != "" { + if isFlag { name, value, ok := strings.Cut(arg, "=") - name = strings.TrimPrefix(name, dashes) // it is --flag=value if ok { - args[i] = lipgloss.JoinHorizontal( + args[i] += lipgloss.JoinHorizontal( lipgloss.Left, - styles.Program.Flag.Render(dashes+name+"="), - styles.Program.Argument.UnsetPadding().Render(value), + styles.Program.Flag.Render(name+"="), + styles.Program.Argument.Render(value), ) continue } // it is either --bool-flag or --flag value - args[i] = lipgloss.JoinHorizontal( + args[i] += lipgloss.JoinHorizontal( lipgloss.Left, - styles.Program.Flag.Render(dashes+name), + styles.Program.Flag.Render(name), ) - // if the flag is not a bool flag, next arg continues current flag - nextIsFlag = !isFlagBool(c, name) continue } - args[i] = styles.Program.Argument.Render(arg) + + args[i] += styles.Program.Argument.Render(arg) } return lipgloss.JoinHorizontal( @@ -284,8 +370,7 @@ func evalFlags(c *cobra.Command, styles Styles) (map[string]string, []string) { } else { parts = append( parts, - styles.Program.Flag.Render("-"+f.Shorthand), - styles.Program.Flag.Render("--"+f.Name), + styles.Program.Flag.Render("-"+f.Shorthand+" --"+f.Name), ) } key := lipgloss.JoinHorizontal(lipgloss.Left, parts...) @@ -303,22 +388,50 @@ func evalFlags(c *cobra.Command, styles Styles) (map[string]string, []string) { return flags, keys } -func evalCmds(c *cobra.Command, styles Styles) (map[string]string, []string) { +// result is map[groupID]map[styled cmd name]styled cmd help, and the keys in +// the order they are defined. +func evalCmds(c *cobra.Command, styles Styles) (map[string](map[string]string), []string) { padStyle := lipgloss.NewStyle().PaddingLeft(0) //nolint:mnd keys := []string{} - cmds := map[string]string{} + cmds := map[string]map[string]string{} for _, sc := range c.Commands() { if sc.Hidden { continue } + if _, ok := cmds[sc.GroupID]; !ok { + cmds[sc.GroupID] = map[string]string{} + } key := padStyle.Render(styleUsage(sc, styles.Program, false)) help := styles.FlagDescription.Render(sc.Short) - cmds[key] = help + cmds[sc.GroupID][key] = help keys = append(keys, key) } return cmds, keys } +func evalGroups(c *cobra.Command) (map[string]string, []string) { + // make sure the default group is the first + ids := []string{""} + groups := map[string]string{"": "commands"} + for _, g := range c.Groups() { + groups[g.ID] = g.Title + ids = append(ids, g.ID) + } + return groups, ids +} + +func renderGroup(w io.Writer, styles Styles, space int, name string, items iter.Seq2[string, string]) { + _, _ = fmt.Fprintln(w, styles.Title.Render(name)) + for key, help := range items { + _, _ = fmt.Fprintln(w, lipgloss.JoinHorizontal( + lipgloss.Left, + lipgloss.NewStyle().PaddingLeft(longPad).Render(key), + strings.Repeat(" ", space-lipgloss.Width(key)), + help, + )) + } +} + func calculateSpace(k1, k2 []string) int { const spaceBetween = 2 space := minSpace @@ -328,13 +441,18 @@ func calculateSpace(k1, k2 []string) int { return space } -func isFlagBool(c *cobra.Command, name string) bool { - flag := c.Flags().Lookup(name) - if flag == nil && len(name) == 1 { - flag = c.Flags().ShorthandLookup(name) - } - if flag == nil { - return false +func isSubCommand(c *cobra.Command, args []string, word string) bool { + cmd, _, _ := c.Root().Traverse(args) + return cmd != nil && cmd.Name() == word +} + +var redirectPrefixes = []string{">", "<", "&>", "2>", "1>", ">>", "2>>"} + +func isRedirect(s string) bool { + for _, p := range redirectPrefixes { + if strings.HasPrefix(s, p) { + return true + } } - return flag.Value.Type() == "bool" + return false } diff --git a/vendor/github.com/charmbracelet/fang/theme.go b/vendor/github.com/charmbracelet/fang/theme.go index 8e3389f6e84b4cc66ed0369f2425c4cc7c27d1b4..12cc868089d475d397691e757f55614a4614e44d 100644 --- a/vendor/github.com/charmbracelet/fang/theme.go +++ b/vendor/github.com/charmbracelet/fang/theme.go @@ -2,10 +2,12 @@ package fang import ( "image/color" + "os" "strings" "github.com/charmbracelet/lipgloss/v2" "github.com/charmbracelet/x/exp/charmtone" + "github.com/charmbracelet/x/term" "golang.org/x/text/cases" "golang.org/x/text/language" ) @@ -31,8 +33,14 @@ type ColorScheme struct { } // DefaultTheme is the default colorscheme. +// +// Deprecated: use [DefaultColorScheme] instead. func DefaultTheme(isDark bool) ColorScheme { - c := lipgloss.LightDark(isDark) + return DefaultColorScheme(lipgloss.LightDark(isDark)) +} + +// DefaultColorScheme is the default colorscheme. +func DefaultColorScheme(c lipgloss.LightDarkFunc) ColorScheme { return ColorScheme{ Base: c(charmtone.Charcoal, charmtone.Ash), Title: charmtone.Charple, @@ -45,7 +53,7 @@ func DefaultTheme(isDark bool) ColorScheme { Argument: c(charmtone.Charcoal, charmtone.Ash), Description: c(charmtone.Charcoal, charmtone.Ash), // flag and command descriptions FlagDefault: c(charmtone.Smoke, charmtone.Squid), // flag default values in descriptions - QuotedString: c(charmtone.Charcoal, charmtone.Ash), + QuotedString: c(charmtone.Coral, charmtone.Salmon), ErrorHeader: [2]color.Color{ charmtone.Butter, charmtone.Cherry, @@ -53,6 +61,26 @@ func DefaultTheme(isDark bool) ColorScheme { } } +// AnsiColorScheme is a ANSI colorscheme. +func AnsiColorScheme(c lipgloss.LightDarkFunc) ColorScheme { + base := c(lipgloss.Black, lipgloss.White) + return ColorScheme{ + Base: base, + Title: lipgloss.Blue, + Description: base, + Comment: c(lipgloss.BrightWhite, lipgloss.BrightBlack), + Flag: lipgloss.Magenta, + FlagDefault: lipgloss.BrightMagenta, + Command: lipgloss.Cyan, + QuotedString: lipgloss.Green, + Argument: base, + Help: base, + Dash: base, + ErrorHeader: [2]color.Color{lipgloss.Black, lipgloss.Red}, + ErrorDetails: lipgloss.Red, + } +} + // Styles represents all the styles used. type Styles struct { Text lipgloss.Style @@ -84,6 +112,14 @@ type Program struct { QuotedString lipgloss.Style } +func mustColorscheme(cs func(lipgloss.LightDarkFunc) ColorScheme) ColorScheme { + var isDark bool + if term.IsTerminal(os.Stdout.Fd()) { + isDark = lipgloss.HasDarkBackground(os.Stdin, os.Stdout) + } + return cs(lipgloss.LightDark(isDark)) +} + func makeStyles(cs ColorScheme) Styles { //nolint:mnd return Styles{ @@ -98,8 +134,7 @@ func makeStyles(cs ColorScheme) Styles { Foreground(cs.Description). Transform(titleFirstWord), FlagDefault: lipgloss.NewStyle(). - Foreground(cs.FlagDefault). - PaddingLeft(1), + Foreground(cs.FlagDefault), Codeblock: Codeblock{ Base: lipgloss.NewStyle(). Background(cs.Codeblock). @@ -116,23 +151,18 @@ func makeStyles(cs ColorScheme) Styles { Background(cs.Codeblock). Foreground(cs.Program), Flag: lipgloss.NewStyle(). - PaddingLeft(1). Background(cs.Codeblock). Foreground(cs.Flag), Argument: lipgloss.NewStyle(). - PaddingLeft(1). Background(cs.Codeblock). Foreground(cs.Argument), DimmedArgument: lipgloss.NewStyle(). - PaddingLeft(1). Background(cs.Codeblock). Foreground(cs.DimmedArgument), Command: lipgloss.NewStyle(). - PaddingLeft(1). Background(cs.Codeblock). Foreground(cs.Command), QuotedString: lipgloss.NewStyle(). - PaddingLeft(1). Background(cs.Codeblock). Foreground(cs.QuotedString), }, @@ -141,18 +171,14 @@ func makeStyles(cs ColorScheme) Styles { Name: lipgloss.NewStyle(). Foreground(cs.Program), Argument: lipgloss.NewStyle(). - PaddingLeft(1). Foreground(cs.Argument), DimmedArgument: lipgloss.NewStyle(). - PaddingLeft(1). Foreground(cs.DimmedArgument), Flag: lipgloss.NewStyle(). - PaddingLeft(1). Foreground(cs.Flag), Command: lipgloss.NewStyle(). Foreground(cs.Command), QuotedString: lipgloss.NewStyle(). - PaddingLeft(1). Foreground(cs.QuotedString), }, Span: lipgloss.NewStyle(). diff --git a/vendor/github.com/mark3labs/mcp-go/client/client.go b/vendor/github.com/mark3labs/mcp-go/client/client.go index dd0e31a013595ccbb900a10fe413e02d1ed9d0ad..e2c466586050cf69e2015e83056fdaf6eda949f6 100644 --- a/vendor/github.com/mark3labs/mcp-go/client/client.go +++ b/vendor/github.com/mark3labs/mcp-go/client/client.go @@ -22,6 +22,7 @@ type Client struct { requestID atomic.Int64 clientCapabilities mcp.ClientCapabilities serverCapabilities mcp.ServerCapabilities + samplingHandler SamplingHandler } type ClientOption func(*Client) @@ -33,6 +34,21 @@ func WithClientCapabilities(capabilities mcp.ClientCapabilities) ClientOption { } } +// WithSamplingHandler sets the sampling handler for the client. +// When set, the client will declare sampling capability during initialization. +func WithSamplingHandler(handler SamplingHandler) ClientOption { + return func(c *Client) { + c.samplingHandler = handler + } +} + +// WithSession assumes a MCP Session has already been initialized +func WithSession() ClientOption { + return func(c *Client) { + c.initialized = true + } +} + // NewClient creates a new MCP client with the given transport. // Usage: // @@ -71,6 +87,12 @@ func (c *Client) Start(ctx context.Context) error { handler(notification) } }) + + // Set up request handler for bidirectional communication (e.g., sampling) + if bidirectional, ok := c.transport.(transport.BidirectionalInterface); ok { + bidirectional.SetRequestHandler(c.handleIncomingRequest) + } + return nil } @@ -127,6 +149,12 @@ func (c *Client) Initialize( ctx context.Context, request mcp.InitializeRequest, ) (*mcp.InitializeResult, error) { + // Merge client capabilities with sampling capability if handler is configured + capabilities := request.Params.Capabilities + if c.samplingHandler != nil { + capabilities.Sampling = &struct{}{} + } + // Ensure we send a params object with all required fields params := struct { ProtocolVersion string `json:"protocolVersion"` @@ -135,7 +163,7 @@ func (c *Client) Initialize( }{ ProtocolVersion: request.Params.ProtocolVersion, ClientInfo: request.Params.ClientInfo, - Capabilities: request.Params.Capabilities, // Will be empty struct if not set + Capabilities: capabilities, } response, err := c.sendRequest(ctx, "initialize", params) @@ -398,6 +426,64 @@ func (c *Client) Complete( return &result, nil } +// handleIncomingRequest processes incoming requests from the server. +// This is the main entry point for server-to-client requests like sampling. +func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { + switch request.Method { + case string(mcp.MethodSamplingCreateMessage): + return c.handleSamplingRequestTransport(ctx, request) + default: + return nil, fmt.Errorf("unsupported request method: %s", request.Method) + } +} + +// handleSamplingRequestTransport handles sampling requests at the transport level. +func (c *Client) handleSamplingRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { + if c.samplingHandler == nil { + return nil, fmt.Errorf("no sampling handler configured") + } + + // Parse the request parameters + var params mcp.CreateMessageParams + if request.Params != nil { + paramsBytes, err := json.Marshal(request.Params) + if err != nil { + return nil, fmt.Errorf("failed to marshal params: %w", err) + } + if err := json.Unmarshal(paramsBytes, ¶ms); err != nil { + return nil, fmt.Errorf("failed to unmarshal params: %w", err) + } + } + + // Create the MCP request + mcpRequest := mcp.CreateMessageRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodSamplingCreateMessage), + }, + CreateMessageParams: params, + } + + // Call the sampling handler + result, err := c.samplingHandler.CreateMessage(ctx, mcpRequest) + if err != nil { + return nil, err + } + + // Marshal the result + resultBytes, err := json.Marshal(result) + if err != nil { + return nil, fmt.Errorf("failed to marshal result: %w", err) + } + + // Create the transport response + response := &transport.JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: request.ID, + Result: json.RawMessage(resultBytes), + } + + return response, nil +} func listByPage[T any]( ctx context.Context, client *Client, @@ -432,3 +518,17 @@ func (c *Client) GetServerCapabilities() mcp.ServerCapabilities { func (c *Client) GetClientCapabilities() mcp.ClientCapabilities { return c.clientCapabilities } + +// GetSessionId returns the session ID of the transport. +// If the transport does not support sessions, it returns an empty string. +func (c *Client) GetSessionId() string { + if c.transport == nil { + return "" + } + return c.transport.GetSessionId() +} + +// IsInitialized returns true if the client has been initialized. +func (c *Client) IsInitialized() bool { + return c.initialized +} diff --git a/vendor/github.com/mark3labs/mcp-go/client/http.go b/vendor/github.com/mark3labs/mcp-go/client/http.go index cb3be35d64cfc731efe2cef0c268a018c53a9538..d001a1e63d08e42d7457adbf5c497d93d029e203 100644 --- a/vendor/github.com/mark3labs/mcp-go/client/http.go +++ b/vendor/github.com/mark3labs/mcp-go/client/http.go @@ -13,5 +13,10 @@ func NewStreamableHttpClient(baseURL string, options ...transport.StreamableHTTP if err != nil { return nil, fmt.Errorf("failed to create SSE transport: %w", err) } - return NewClient(trans), nil + clientOptions := make([]ClientOption, 0) + sessionID := trans.GetSessionId() + if sessionID != "" { + clientOptions = append(clientOptions, WithSession()) + } + return NewClient(trans, clientOptions...), nil } diff --git a/vendor/github.com/mark3labs/mcp-go/client/sampling.go b/vendor/github.com/mark3labs/mcp-go/client/sampling.go new file mode 100644 index 0000000000000000000000000000000000000000..245e2c1f7f305ddb75658a345eddcaba5e2898e3 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/client/sampling.go @@ -0,0 +1,20 @@ +package client + +import ( + "context" + + "github.com/mark3labs/mcp-go/mcp" +) + +// SamplingHandler defines the interface for handling sampling requests from servers. +// Clients can implement this interface to provide LLM sampling capabilities to servers. +type SamplingHandler interface { + // CreateMessage handles a sampling request from the server and returns the generated message. + // The implementation should: + // 1. Validate the request parameters + // 2. Optionally prompt the user for approval (human-in-the-loop) + // 3. Select an appropriate model based on preferences + // 4. Generate the response using the selected model + // 5. Return the result with model information and stop reason + CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) +} diff --git a/vendor/github.com/mark3labs/mcp-go/client/stdio.go b/vendor/github.com/mark3labs/mcp-go/client/stdio.go index 100c08a7cc0529ca30ca1386f74f6ea4f9be4654..199ec14c381b57c12d691495179cf0c45029d29e 100644 --- a/vendor/github.com/mark3labs/mcp-go/client/stdio.go +++ b/vendor/github.com/mark3labs/mcp-go/client/stdio.go @@ -19,10 +19,26 @@ func NewStdioMCPClient( env []string, args ...string, ) (*Client, error) { + return NewStdioMCPClientWithOptions(command, env, args) +} + +// NewStdioMCPClientWithOptions creates a new stdio-based MCP client that communicates with a subprocess. +// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication. +// Optional configuration functions can be provided to customize the transport before it starts, +// such as setting a custom command function. +// +// NOTICE: NewStdioMCPClientWithOptions automatically starts the underlying transport. +// Don't call the Start method manually. +// This is for backward compatibility. +func NewStdioMCPClientWithOptions( + command string, + env []string, + args []string, + opts ...transport.StdioOption, +) (*Client, error) { + stdioTransport := transport.NewStdioWithOptions(command, env, args, opts...) - stdioTransport := transport.NewStdio(command, env, args...) - err := stdioTransport.Start(context.Background()) - if err != nil { + if err := stdioTransport.Start(context.Background()); err != nil { return nil, fmt.Errorf("failed to start stdio transport: %w", err) } diff --git a/vendor/github.com/mark3labs/mcp-go/client/transport/inprocess.go b/vendor/github.com/mark3labs/mcp-go/client/transport/inprocess.go index 90fc2fae1f05ebf635b46d0fc415e0260348d3a0..0e2393f0731bcf361d5544da99304d1f07e08706 100644 --- a/vendor/github.com/mark3labs/mcp-go/client/transport/inprocess.go +++ b/vendor/github.com/mark3labs/mcp-go/client/transport/inprocess.go @@ -68,3 +68,7 @@ func (c *InProcessTransport) SetNotificationHandler(handler func(notification mc func (*InProcessTransport) Close() error { return nil } + +func (c *InProcessTransport) GetSessionId() string { + return "" +} diff --git a/vendor/github.com/mark3labs/mcp-go/client/transport/interface.go b/vendor/github.com/mark3labs/mcp-go/client/transport/interface.go index c83c7c65a3a8b0c7a301564144516242919fe2a5..5f8ed6180b6404a1a0f4085c5557aeb1789e3485 100644 --- a/vendor/github.com/mark3labs/mcp-go/client/transport/interface.go +++ b/vendor/github.com/mark3labs/mcp-go/client/transport/interface.go @@ -29,6 +29,22 @@ type Interface interface { // Close the connection. Close() error + + // GetSessionId returns the session ID of the transport. + GetSessionId() string +} + +// RequestHandler defines a function that handles incoming requests from the server. +type RequestHandler func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) + +// BidirectionalInterface extends Interface to support incoming requests from the server. +// This is used for features like sampling where the server can send requests to the client. +type BidirectionalInterface interface { + Interface + + // SetRequestHandler sets the handler for incoming requests from the server. + // The handler should process the request and return a response. + SetRequestHandler(handler RequestHandler) } type JSONRPCRequest struct { @@ -41,10 +57,10 @@ type JSONRPCRequest struct { type JSONRPCResponse struct { JSONRPC string `json:"jsonrpc"` ID mcp.RequestId `json:"id"` - Result json.RawMessage `json:"result"` + Result json.RawMessage `json:"result,omitempty"` Error *struct { Code int `json:"code"` Message string `json:"message"` Data json.RawMessage `json:"data"` - } `json:"error"` + } `json:"error,omitempty"` } diff --git a/vendor/github.com/mark3labs/mcp-go/client/transport/sse.go b/vendor/github.com/mark3labs/mcp-go/client/transport/sse.go index b22ff62d40124b765b633d2b1700c7407d92d041..ffe3247f0ecd87a4e9c68df72b726a8cc44a7736 100644 --- a/vendor/github.com/mark3labs/mcp-go/client/transport/sse.go +++ b/vendor/github.com/mark3labs/mcp-go/client/transport/sse.go @@ -428,6 +428,12 @@ func (c *SSE) Close() error { return nil } +// GetSessionId returns the session ID of the transport. +// Since SSE does not maintain a session ID, it returns an empty string. +func (c *SSE) GetSessionId() string { + return "" +} + // SendNotification sends a JSON-RPC notification to the server without expecting a response. func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error { if c.endpoint == nil { diff --git a/vendor/github.com/mark3labs/mcp-go/client/transport/stdio.go b/vendor/github.com/mark3labs/mcp-go/client/transport/stdio.go index c300c405f7e3880f0b94e1e09e3ee5ca7def732a..c36dc2d37737d71d9028ba11485932c23bb09f9f 100644 --- a/vendor/github.com/mark3labs/mcp-go/client/transport/stdio.go +++ b/vendor/github.com/mark3labs/mcp-go/client/transport/stdio.go @@ -23,6 +23,7 @@ type Stdio struct { env []string cmd *exec.Cmd + cmdFunc CommandFunc stdin io.WriteCloser stdout *bufio.Reader stderr io.ReadCloser @@ -31,6 +32,28 @@ type Stdio struct { done chan struct{} onNotification func(mcp.JSONRPCNotification) notifyMu sync.RWMutex + onRequest RequestHandler + requestMu sync.RWMutex + ctx context.Context + ctxMu sync.RWMutex +} + +// StdioOption defines a function that configures a Stdio transport instance. +// Options can be used to customize the behavior of the transport before it starts, +// such as setting a custom command function. +type StdioOption func(*Stdio) + +// CommandFunc is a factory function that returns a custom exec.Cmd used to launch the MCP subprocess. +// It can be used to apply sandboxing, custom environment control, working directories, etc. +type CommandFunc func(ctx context.Context, command string, env []string, args []string) (*exec.Cmd, error) + +// WithCommandFunc sets a custom command factory function for the stdio transport. +// The CommandFunc is responsible for constructing the exec.Cmd used to launch the subprocess, +// allowing control over attributes like environment, working directory, and system-level sandboxing. +func WithCommandFunc(f CommandFunc) StdioOption { + return func(s *Stdio) { + s.cmdFunc = f + } } // NewIO returns a new stdio-based transport using existing input, output, and @@ -44,6 +67,7 @@ func NewIO(input io.Reader, output io.WriteCloser, logging io.ReadCloser) *Stdio responses: make(map[string]chan *JSONRPCResponse), done: make(chan struct{}), + ctx: context.Background(), } } @@ -55,20 +79,43 @@ func NewStdio( env []string, args ...string, ) *Stdio { + return NewStdioWithOptions(command, env, args) +} - client := &Stdio{ +// NewStdioWithOptions creates a new stdio transport to communicate with a subprocess. +// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication. +// Returns an error if the subprocess cannot be started or the pipes cannot be created. +// Optional configuration functions can be provided to customize the transport before it starts, +// such as setting a custom command factory. +func NewStdioWithOptions( + command string, + env []string, + args []string, + opts ...StdioOption, +) *Stdio { + s := &Stdio{ command: command, args: args, env: env, responses: make(map[string]chan *JSONRPCResponse), done: make(chan struct{}), + ctx: context.Background(), + } + + for _, opt := range opts { + opt(s) } - return client + return s } func (c *Stdio) Start(ctx context.Context) error { + // Store the context for use in request handling + c.ctxMu.Lock() + c.ctx = ctx + c.ctxMu.Unlock() + if err := c.spawnCommand(ctx); err != nil { return err } @@ -83,18 +130,25 @@ func (c *Stdio) Start(ctx context.Context) error { return nil } -// spawnCommand spawns a new process running c.command. +// spawnCommand spawns a new process running the configured command, args, and env. +// If an (optional) cmdFunc custom command factory function was configured, it will be used to construct the subprocess; +// otherwise, the default behavior uses exec.CommandContext with the merged environment. +// Initializes stdin, stdout, and stderr pipes for JSON-RPC communication. func (c *Stdio) spawnCommand(ctx context.Context) error { if c.command == "" { return nil } - cmd := exec.CommandContext(ctx, c.command, c.args...) - - mergedEnv := os.Environ() - mergedEnv = append(mergedEnv, c.env...) + var cmd *exec.Cmd + var err error - cmd.Env = mergedEnv + // Standard behavior if no command func present. + if c.cmdFunc == nil { + cmd = exec.CommandContext(ctx, c.command, c.args...) + cmd.Env = append(os.Environ(), c.env...) + } else if cmd, err = c.cmdFunc(ctx, c.command, c.env, c.args); err != nil { + return err + } stdin, err := cmd.StdinPipe() if err != nil { @@ -148,6 +202,12 @@ func (c *Stdio) Close() error { return nil } +// GetSessionId returns the session ID of the transport. +// Since stdio does not maintain a session ID, it returns an empty string. +func (c *Stdio) GetSessionId() string { + return "" +} + // SetNotificationHandler sets the handler function to be called when a notification is received. // Only one handler can be set at a time; setting a new one replaces the previous handler. func (c *Stdio) SetNotificationHandler( @@ -158,6 +218,14 @@ func (c *Stdio) SetNotificationHandler( c.onNotification = handler } +// SetRequestHandler sets the handler function to be called when a request is received from the server. +// This enables bidirectional communication for features like sampling. +func (c *Stdio) SetRequestHandler(handler RequestHandler) { + c.requestMu.Lock() + defer c.requestMu.Unlock() + c.onRequest = handler +} + // readResponses continuously reads and processes responses from the server's stdout. // It handles both responses to requests and notifications, routing them appropriately. // Runs until the done channel is closed or an error occurs reading from stdout. @@ -175,13 +243,18 @@ func (c *Stdio) readResponses() { return } - var baseMessage JSONRPCResponse + // First try to parse as a generic message to check for ID field + var baseMessage struct { + JSONRPC string `json:"jsonrpc"` + ID *mcp.RequestId `json:"id,omitempty"` + Method string `json:"method,omitempty"` + } if err := json.Unmarshal([]byte(line), &baseMessage); err != nil { continue } - // Handle notification - if baseMessage.ID.IsNil() { + // If it has a method but no ID, it's a notification + if baseMessage.Method != "" && baseMessage.ID == nil { var notification mcp.JSONRPCNotification if err := json.Unmarshal([]byte(line), ¬ification); err != nil { continue @@ -194,15 +267,30 @@ func (c *Stdio) readResponses() { continue } + // If it has a method and an ID, it's an incoming request + if baseMessage.Method != "" && baseMessage.ID != nil { + var request JSONRPCRequest + if err := json.Unmarshal([]byte(line), &request); err == nil { + c.handleIncomingRequest(request) + continue + } + } + + // Otherwise, it's a response to our request + var response JSONRPCResponse + if err := json.Unmarshal([]byte(line), &response); err != nil { + continue + } + // Create string key for map lookup - idKey := baseMessage.ID.String() + idKey := response.ID.String() c.mu.RLock() ch, exists := c.responses[idKey] c.mu.RUnlock() if exists { - ch <- &baseMessage + ch <- &response c.mu.Lock() delete(c.responses, idKey) c.mu.Unlock() @@ -281,6 +369,96 @@ func (c *Stdio) SendNotification( return nil } +// handleIncomingRequest processes incoming requests from the server. +// It calls the registered request handler and sends the response back to the server. +func (c *Stdio) handleIncomingRequest(request JSONRPCRequest) { + c.requestMu.RLock() + handler := c.onRequest + c.requestMu.RUnlock() + + if handler == nil { + // Send error response if no handler is configured + errorResponse := JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: request.ID, + Error: &struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` + }{ + Code: mcp.METHOD_NOT_FOUND, + Message: "No request handler configured", + }, + } + c.sendResponse(errorResponse) + return + } + + // Handle the request in a goroutine to avoid blocking + go func() { + c.ctxMu.RLock() + ctx := c.ctx + c.ctxMu.RUnlock() + + // Check if context is already cancelled before processing + select { + case <-ctx.Done(): + errorResponse := JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: request.ID, + Error: &struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` + }{ + Code: mcp.INTERNAL_ERROR, + Message: ctx.Err().Error(), + }, + } + c.sendResponse(errorResponse) + return + default: + } + + response, err := handler(ctx, request) + + if err != nil { + errorResponse := JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: request.ID, + Error: &struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` + }{ + Code: mcp.INTERNAL_ERROR, + Message: err.Error(), + }, + } + c.sendResponse(errorResponse) + return + } + + if response != nil { + c.sendResponse(*response) + } + }() +} + +// sendResponse sends a response back to the server. +func (c *Stdio) sendResponse(response JSONRPCResponse) { + responseBytes, err := json.Marshal(response) + if err != nil { + fmt.Printf("Error marshaling response: %v\n", err) + return + } + responseBytes = append(responseBytes, '\n') + + if _, err := c.stdin.Write(responseBytes); err != nil { + fmt.Printf("Error writing response: %v\n", err) + } +} + // Stderr returns a reader for the stderr output of the subprocess. // This can be used to capture error messages or logs from the subprocess. func (c *Stdio) Stderr() io.Reader { diff --git a/vendor/github.com/mark3labs/mcp-go/client/transport/streamable_http.go b/vendor/github.com/mark3labs/mcp-go/client/transport/streamable_http.go index 50bde9c288d39e32e00bc5691cbaf75addd740f5..e358751b3344c3783be539cc5daa3e09ffa81020 100644 --- a/vendor/github.com/mark3labs/mcp-go/client/transport/streamable_http.go +++ b/vendor/github.com/mark3labs/mcp-go/client/transport/streamable_http.go @@ -17,10 +17,24 @@ import ( "time" "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/util" ) type StreamableHTTPCOption func(*StreamableHTTP) +// WithContinuousListening enables receiving server-to-client notifications when no request is in flight. +// In particular, if you want to receive global notifications from the server (like ToolListChangedNotification), +// you should enable this option. +// +// It will establish a standalone long-live GET HTTP connection to the server. +// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server +// NOTICE: Even enabled, the server may not support this feature. +func WithContinuousListening() StreamableHTTPCOption { + return func(sc *StreamableHTTP) { + sc.getListeningEnabled = true + } +} + // WithHTTPClient sets a custom HTTP client on the StreamableHTTP transport. func WithHTTPBasicClient(client *http.Client) StreamableHTTPCOption { return func(sc *StreamableHTTP) { @@ -54,6 +68,19 @@ func WithHTTPOAuth(config OAuthConfig) StreamableHTTPCOption { } } +func WithLogger(logger util.Logger) StreamableHTTPCOption { + return func(sc *StreamableHTTP) { + sc.logger = logger + } +} + +// WithSession creates a client with a pre-configured session +func WithSession(sessionID string) StreamableHTTPCOption { + return func(sc *StreamableHTTP) { + sc.sessionID.Store(sessionID) + } +} + // StreamableHTTP implements Streamable HTTP transport. // // It transmits JSON-RPC messages over individual HTTP requests. One message per request. @@ -64,19 +91,22 @@ func WithHTTPOAuth(config OAuthConfig) StreamableHTTPCOption { // // The current implementation does not support the following features: // - batching -// - continuously listening for server notifications when no request is in flight -// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server) // - resuming stream // (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery) // - server -> client request type StreamableHTTP struct { - serverURL *url.URL - httpClient *http.Client - headers map[string]string - headerFunc HTTPHeaderFunc + serverURL *url.URL + httpClient *http.Client + headers map[string]string + headerFunc HTTPHeaderFunc + logger util.Logger + getListeningEnabled bool sessionID atomic.Value // string + initialized chan struct{} + initializedOnce sync.Once + notificationHandler func(mcp.JSONRPCNotification) notifyMu sync.RWMutex @@ -95,15 +125,19 @@ func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*Str } smc := &StreamableHTTP{ - serverURL: parsedURL, - httpClient: &http.Client{}, - headers: make(map[string]string), - closed: make(chan struct{}), + serverURL: parsedURL, + httpClient: &http.Client{}, + headers: make(map[string]string), + closed: make(chan struct{}), + logger: util.DefaultLogger(), + initialized: make(chan struct{}), } smc.sessionID.Store("") // set initial value to simplify later usage for _, opt := range options { - opt(smc) + if opt != nil { + opt(smc) + } } // If OAuth is configured, set the base URL for metadata discovery @@ -118,7 +152,20 @@ func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*Str // Start initiates the HTTP connection to the server. func (c *StreamableHTTP) Start(ctx context.Context) error { - // For Streamable HTTP, we don't need to establish a persistent connection + // For Streamable HTTP, we don't need to establish a persistent connection by default + if c.getListeningEnabled { + go func() { + select { + case <-c.initialized: + ctx, cancel := c.contextAwareOfClientClose(ctx) + defer cancel() + c.listenForever(ctx) + case <-c.closed: + return + } + }() + } + return nil } @@ -142,13 +189,13 @@ func (c *StreamableHTTP) Close() error { defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.serverURL.String(), nil) if err != nil { - fmt.Printf("failed to create close request\n: %v", err) + c.logger.Errorf("failed to create close request: %v", err) return } req.Header.Set(headerKeySessionID, sessionId) res, err := c.httpClient.Do(req) if err != nil { - fmt.Printf("failed to send close request\n: %v", err) + c.logger.Errorf("failed to send close request: %v", err) return } res.Body.Close() @@ -185,77 +232,29 @@ func (c *StreamableHTTP) SendRequest( request JSONRPCRequest, ) (*JSONRPCResponse, error) { - // Create a combined context that could be canceled when the client is closed - newCtx, cancel := context.WithCancel(ctx) - defer cancel() - go func() { - select { - case <-c.closed: - cancel() - case <-newCtx.Done(): - // The original context was canceled, no need to do anything - } - }() - ctx = newCtx - // Marshal request requestBody, err := json.Marshal(request) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) } - // Create HTTP request - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody)) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - // Set headers - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json, text/event-stream") - sessionID := c.sessionID.Load() - if sessionID != "" { - req.Header.Set(headerKeySessionID, sessionID.(string)) - } - for k, v := range c.headers { - req.Header.Set(k, v) - } - - // Add OAuth authorization if configured - if c.oauthHandler != nil { - authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx) - if err != nil { - // If we get an authorization error, return a specific error that can be handled by the client - if err.Error() == "no valid token available, authorization required" { - return nil, &OAuthAuthorizationRequiredError{ - Handler: c.oauthHandler, - } - } - return nil, fmt.Errorf("failed to get authorization header: %w", err) - } - req.Header.Set("Authorization", authHeader) - } - - if c.headerFunc != nil { - for k, v := range c.headerFunc(ctx) { - req.Header.Set(k, v) - } - } + ctx, cancel := c.contextAwareOfClientClose(ctx) + defer cancel() - // Send request - resp, err := c.httpClient.Do(req) + resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream") if err != nil { - return nil, fmt.Errorf("failed to send request: %w", err) + if errors.Is(err, ErrSessionTerminated) && request.Method == string(mcp.MethodInitialize) { + // If the request is initialize, should not return a SessionTerminated error + // It should be a genuine endpoint-routing issue. + // ( Fall through to return StatusCode checking. ) + } else { + return nil, fmt.Errorf("failed to send request: %w", err) + } } defer resp.Body.Close() // Check if we got an error response if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { - // handle session closed - if resp.StatusCode == http.StatusNotFound { - c.sessionID.CompareAndSwap(sessionID, "") - return nil, fmt.Errorf("session terminated (404). need to re-initialize") - } // Handle OAuth unauthorized error if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil { @@ -279,6 +278,10 @@ func (c *StreamableHTTP) SendRequest( if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" { c.sessionID.Store(sessionID) } + + c.initializedOnce.Do(func() { + close(c.initialized) + }) } // Handle different response types @@ -300,16 +303,77 @@ func (c *StreamableHTTP) SendRequest( case "text/event-stream": // Server is using SSE for streaming responses - return c.handleSSEResponse(ctx, resp.Body) + return c.handleSSEResponse(ctx, resp.Body, false) default: return nil, fmt.Errorf("unexpected content type: %s", resp.Header.Get("Content-Type")) } } +func (c *StreamableHTTP) sendHTTP( + ctx context.Context, + method string, + body io.Reader, + acceptType string, +) (resp *http.Response, err error) { + + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, method, c.serverURL.String(), body) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", acceptType) + sessionID := c.sessionID.Load().(string) + if sessionID != "" { + req.Header.Set(headerKeySessionID, sessionID) + } + for k, v := range c.headers { + req.Header.Set(k, v) + } + + // Add OAuth authorization if configured + if c.oauthHandler != nil { + authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx) + if err != nil { + // If we get an authorization error, return a specific error that can be handled by the client + if err.Error() == "no valid token available, authorization required" { + return nil, &OAuthAuthorizationRequiredError{ + Handler: c.oauthHandler, + } + } + return nil, fmt.Errorf("failed to get authorization header: %w", err) + } + req.Header.Set("Authorization", authHeader) + } + + if c.headerFunc != nil { + for k, v := range c.headerFunc(ctx) { + req.Header.Set(k, v) + } + } + + // Send request + resp, err = c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + + // universal handling for session terminated + if resp.StatusCode == http.StatusNotFound { + c.sessionID.CompareAndSwap(sessionID, "") + return nil, ErrSessionTerminated + } + + return resp, nil +} + // handleSSEResponse processes an SSE stream for a specific request. // It returns the final result for the request once received, or an error. -func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser) (*JSONRPCResponse, error) { +// If ignoreResponse is true, it won't return when a response messge is received. This is for continuous listening. +func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser, ignoreResponse bool) (*JSONRPCResponse, error) { // Create a channel for this specific request responseChan := make(chan *JSONRPCResponse, 1) @@ -328,7 +392,7 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl var message JSONRPCResponse if err := json.Unmarshal([]byte(data), &message); err != nil { - fmt.Printf("failed to unmarshal message: %v\n", err) + c.logger.Errorf("failed to unmarshal message: %v", err) return } @@ -336,7 +400,7 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl if message.ID.IsNil() { var notification mcp.JSONRPCNotification if err := json.Unmarshal([]byte(data), ¬ification); err != nil { - fmt.Printf("failed to unmarshal notification: %v\n", err) + c.logger.Errorf("failed to unmarshal notification: %v", err) return } c.notifyMu.RLock() @@ -347,7 +411,9 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl return } - responseChan <- &message + if !ignoreResponse { + responseChan <- &message + } }) }() @@ -393,7 +459,7 @@ func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, hand case <-ctx.Done(): return default: - fmt.Printf("SSE stream error: %v\n", err) + c.logger.Errorf("SSE stream error: %v", err) return } } @@ -432,44 +498,10 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp. } // Create HTTP request - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody)) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - - // Set headers - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json, text/event-stream") - if sessionID := c.sessionID.Load(); sessionID != "" { - req.Header.Set(headerKeySessionID, sessionID.(string)) - } - for k, v := range c.headers { - req.Header.Set(k, v) - } - - // Add OAuth authorization if configured - if c.oauthHandler != nil { - authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx) - if err != nil { - // If we get an authorization error, return a specific error that can be handled by the client - if errors.Is(err, ErrOAuthAuthorizationRequired) { - return &OAuthAuthorizationRequiredError{ - Handler: c.oauthHandler, - } - } - return fmt.Errorf("failed to get authorization header: %w", err) - } - req.Header.Set("Authorization", authHeader) - } - - if c.headerFunc != nil { - for k, v := range c.headerFunc(ctx) { - req.Header.Set(k, v) - } - } + ctx, cancel := c.contextAwareOfClientClose(ctx) + defer cancel() - // Send request - resp, err := c.httpClient.Do(req) + resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream") if err != nil { return fmt.Errorf("failed to send request: %w", err) } @@ -513,3 +545,84 @@ func (c *StreamableHTTP) GetOAuthHandler() *OAuthHandler { func (c *StreamableHTTP) IsOAuthEnabled() bool { return c.oauthHandler != nil } + +func (c *StreamableHTTP) listenForever(ctx context.Context) { + c.logger.Infof("listening to server forever") + for { + err := c.createGETConnectionToServer(ctx) + if errors.Is(err, ErrGetMethodNotAllowed) { + // server does not support listening + c.logger.Errorf("server does not support listening") + return + } + + select { + case <-ctx.Done(): + return + default: + } + + if err != nil { + c.logger.Errorf("failed to listen to server. retry in 1 second: %v", err) + } + time.Sleep(retryInterval) + } +} + +var ( + ErrSessionTerminated = fmt.Errorf("session terminated (404). need to re-initialize") + ErrGetMethodNotAllowed = fmt.Errorf("GET method not allowed") + + retryInterval = 1 * time.Second // a variable is convenient for testing +) + +func (c *StreamableHTTP) createGETConnectionToServer(ctx context.Context) error { + + resp, err := c.sendHTTP(ctx, http.MethodGet, nil, "text/event-stream") + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + // Check if we got an error response + if resp.StatusCode == http.StatusMethodNotAllowed { + return ErrGetMethodNotAllowed + } + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body) + } + + // handle SSE response + contentType := resp.Header.Get("Content-Type") + if contentType != "text/event-stream" { + return fmt.Errorf("unexpected content type: %s", contentType) + } + + // When ignoreResponse is true, the function will never return expect context is done. + // NOTICE: Due to the ambiguity of the specification, other SDKs may use the GET connection to transfer the response + // messages. To be more compatible, we should handle this response, however, as the transport layer is message-based, + // currently, there is no convenient way to handle this response. + // So we ignore the response here. It's not a bug, but may be not compatible with other SDKs. + _, err = c.handleSSEResponse(ctx, resp.Body, true) + if err != nil { + return fmt.Errorf("failed to handle SSE response: %w", err) + } + + return nil +} + +func (c *StreamableHTTP) contextAwareOfClientClose(ctx context.Context) (context.Context, context.CancelFunc) { + newCtx, cancel := context.WithCancel(ctx) + go func() { + select { + case <-c.closed: + cancel() + case <-newCtx.Done(): + // The original context was canceled + cancel() + } + }() + return newCtx, cancel +} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/tools.go b/vendor/github.com/mark3labs/mcp-go/mcp/tools.go index 5f3524b0212d12a89b14fd1e2f4b2e6ba4dbd806..3e3931b09c9aedfce1f6e58a80be180e107b3116 100644 --- a/vendor/github.com/mark3labs/mcp-go/mcp/tools.go +++ b/vendor/github.com/mark3labs/mcp-go/mcp/tools.go @@ -945,7 +945,20 @@ func PropertyNames(schema map[string]any) PropertyOption { } } -// Items defines the schema for array items +// Items defines the schema for array items. +// Accepts any schema definition for maximum flexibility. +// +// Example: +// +// Items(map[string]any{ +// "type": "object", +// "properties": map[string]any{ +// "name": map[string]any{"type": "string"}, +// "age": map[string]any{"type": "number"}, +// }, +// }) +// +// For simple types, use ItemsString(), ItemsNumber(), ItemsBoolean() instead. func Items(schema any) PropertyOption { return func(schemaMap map[string]any) { schemaMap["items"] = schema @@ -972,3 +985,94 @@ func UniqueItems(unique bool) PropertyOption { schema["uniqueItems"] = unique } } + +// WithStringItems configures an array's items to be of type string. +// +// Supported options: Description(), DefaultString(), Enum(), MaxLength(), MinLength(), Pattern() +// Note: Options like Required() are not valid for item schemas and will be ignored. +// +// Examples: +// +// mcp.WithArray("tags", mcp.WithStringItems()) +// mcp.WithArray("colors", mcp.WithStringItems(mcp.Enum("red", "green", "blue"))) +// mcp.WithArray("names", mcp.WithStringItems(mcp.MinLength(1), mcp.MaxLength(50))) +// +// Limitations: Only supports simple string arrays. Use Items() for complex objects. +func WithStringItems(opts ...PropertyOption) PropertyOption { + return func(schema map[string]any) { + itemSchema := map[string]any{ + "type": "string", + } + + for _, opt := range opts { + opt(itemSchema) + } + + schema["items"] = itemSchema + } +} + +// WithStringEnumItems configures an array's items to be of type string with a specified enum. +// Example: +// +// mcp.WithArray("priority", mcp.WithStringEnumItems([]string{"low", "medium", "high"})) +// +// Limitations: Only supports string enums. Use WithStringItems(Enum(...)) for more flexibility. +func WithStringEnumItems(values []string) PropertyOption { + return func(schema map[string]any) { + schema["items"] = map[string]any{ + "type": "string", + "enum": values, + } + } +} + +// WithNumberItems configures an array's items to be of type number. +// +// Supported options: Description(), DefaultNumber(), Min(), Max(), MultipleOf() +// Note: Options like Required() are not valid for item schemas and will be ignored. +// +// Examples: +// +// mcp.WithArray("scores", mcp.WithNumberItems(mcp.Min(0), mcp.Max(100))) +// mcp.WithArray("prices", mcp.WithNumberItems(mcp.Min(0))) +// +// Limitations: Only supports simple number arrays. Use Items() for complex objects. +func WithNumberItems(opts ...PropertyOption) PropertyOption { + return func(schema map[string]any) { + itemSchema := map[string]any{ + "type": "number", + } + + for _, opt := range opts { + opt(itemSchema) + } + + schema["items"] = itemSchema + } +} + +// WithBooleanItems configures an array's items to be of type boolean. +// +// Supported options: Description(), DefaultBool() +// Note: Options like Required() are not valid for item schemas and will be ignored. +// +// Examples: +// +// mcp.WithArray("flags", mcp.WithBooleanItems()) +// mcp.WithArray("permissions", mcp.WithBooleanItems(mcp.Description("User permissions"))) +// +// Limitations: Only supports simple boolean arrays. Use Items() for complex objects. +func WithBooleanItems(opts ...PropertyOption) PropertyOption { + return func(schema map[string]any) { + itemSchema := map[string]any{ + "type": "boolean", + } + + for _, opt := range opts { + opt(itemSchema) + } + + schema["items"] = itemSchema + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/types.go b/vendor/github.com/mark3labs/mcp-go/mcp/types.go index 0091d2e42d380253ee03d0d1b5cde8597775be8f..241b55ce9b549941d764a2ca5b4ba11e551d301d 100644 --- a/vendor/github.com/mark3labs/mcp-go/mcp/types.go +++ b/vendor/github.com/mark3labs/mcp-go/mcp/types.go @@ -763,6 +763,11 @@ const ( /* Sampling */ +const ( + // MethodSamplingCreateMessage allows servers to request LLM completions from clients + MethodSamplingCreateMessage MCPMethod = "sampling/createMessage" +) + // CreateMessageRequest is a request from the server to sample an LLM via the // client. The client has full discretion over which model to select. The client // should also inform the user before beginning sampling, to allow them to inspect @@ -865,6 +870,22 @@ type AudioContent struct { func (AudioContent) isContent() {} +// ResourceLink represents a link to a resource that the client can access. +type ResourceLink struct { + Annotated + Type string `json:"type"` // Must be "resource_link" + // The URI of the resource. + URI string `json:"uri"` + // The name of the resource. + Name string `json:"name"` + // The description of the resource. + Description string `json:"description"` + // The MIME type of the resource. + MIMEType string `json:"mimeType"` +} + +func (ResourceLink) isContent() {} + // EmbeddedResource represents the contents of a resource, embedded into a prompt or tool call result. // // It is up to the client how best to render embedded resources for the diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/utils.go b/vendor/github.com/mark3labs/mcp-go/mcp/utils.go index 55bef7a997e2a406f111b2fb399812ca1941ab96..3e652efd7e842d24bc6ab13fa119d21f272a8ba7 100644 --- a/vendor/github.com/mark3labs/mcp-go/mcp/utils.go +++ b/vendor/github.com/mark3labs/mcp-go/mcp/utils.go @@ -222,6 +222,17 @@ func NewAudioContent(data, mimeType string) AudioContent { } } +// Helper function to create a new ResourceLink +func NewResourceLink(uri, name, description, mimeType string) ResourceLink { + return ResourceLink{ + Type: "resource_link", + URI: uri, + Name: name, + Description: description, + MIMEType: mimeType, + } +} + // Helper function to create a new EmbeddedResource func NewEmbeddedResource(resource ResourceContents) EmbeddedResource { return EmbeddedResource{ @@ -476,6 +487,16 @@ func ParseContent(contentMap map[string]any) (Content, error) { } return NewAudioContent(data, mimeType), nil + case "resource_link": + uri := ExtractString(contentMap, "uri") + name := ExtractString(contentMap, "name") + description := ExtractString(contentMap, "description") + mimeType := ExtractString(contentMap, "mimeType") + if uri == "" || name == "" { + return nil, fmt.Errorf("resource_link uri or name is missing") + } + return NewResourceLink(uri, name, description, mimeType), nil + case "resource": resourceMap := ExtractMap(contentMap, "resource") if resourceMap == nil { diff --git a/vendor/github.com/mark3labs/mcp-go/server/sampling.go b/vendor/github.com/mark3labs/mcp-go/server/sampling.go new file mode 100644 index 0000000000000000000000000000000000000000..b633b24d07ebfeeedd9b49468d7aadb411c87b4c --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/sampling.go @@ -0,0 +1,37 @@ +package server + +import ( + "context" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" +) + +// EnableSampling enables sampling capabilities for the server. +// This allows the server to send sampling requests to clients that support it. +func (s *MCPServer) EnableSampling() { + s.capabilitiesMu.Lock() + defer s.capabilitiesMu.Unlock() +} + +// RequestSampling sends a sampling request to the client. +// The client must have declared sampling capability during initialization. +func (s *MCPServer) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + session := ClientSessionFromContext(ctx) + if session == nil { + return nil, fmt.Errorf("no active session") + } + + // Check if the session supports sampling requests + if samplingSession, ok := session.(SessionWithSampling); ok { + return samplingSession.RequestSampling(ctx, request) + } + + return nil, fmt.Errorf("session does not support sampling") +} + +// SessionWithSampling extends ClientSession to support sampling requests. +type SessionWithSampling interface { + ClientSession + RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/stdio.go b/vendor/github.com/mark3labs/mcp-go/server/stdio.go index 746a7d96f6c3635ec05c6bc2d7b92820824a8e20..33ac9bb8854527db09ea31a0be4d109521fa0c37 100644 --- a/vendor/github.com/mark3labs/mcp-go/server/stdio.go +++ b/vendor/github.com/mark3labs/mcp-go/server/stdio.go @@ -9,6 +9,7 @@ import ( "log" "os" "os/signal" + "sync" "sync/atomic" "syscall" @@ -51,10 +52,21 @@ func WithStdioContextFunc(fn StdioContextFunc) StdioOption { // stdioSession is a static client session, since stdio has only one client. type stdioSession struct { - notifications chan mcp.JSONRPCNotification - initialized atomic.Bool - loggingLevel atomic.Value - clientInfo atomic.Value // stores session-specific client info + notifications chan mcp.JSONRPCNotification + initialized atomic.Bool + loggingLevel atomic.Value + clientInfo atomic.Value // stores session-specific client info + writer io.Writer // for sending requests to client + requestID atomic.Int64 // for generating unique request IDs + mu sync.RWMutex // protects writer + pendingRequests map[int64]chan *samplingResponse // for tracking pending sampling requests + pendingMu sync.RWMutex // protects pendingRequests +} + +// samplingResponse represents a response to a sampling request +type samplingResponse struct { + result *mcp.CreateMessageResult + err error } func (s *stdioSession) SessionID() string { @@ -100,14 +112,86 @@ func (s *stdioSession) GetLogLevel() mcp.LoggingLevel { return level.(mcp.LoggingLevel) } +// RequestSampling sends a sampling request to the client and waits for the response. +func (s *stdioSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + s.mu.RLock() + writer := s.writer + s.mu.RUnlock() + + if writer == nil { + return nil, fmt.Errorf("no writer available for sending requests") + } + + // Generate a unique request ID + id := s.requestID.Add(1) + + // Create a response channel for this request + responseChan := make(chan *samplingResponse, 1) + s.pendingMu.Lock() + s.pendingRequests[id] = responseChan + s.pendingMu.Unlock() + + // Cleanup function to remove the pending request + cleanup := func() { + s.pendingMu.Lock() + delete(s.pendingRequests, id) + s.pendingMu.Unlock() + } + defer cleanup() + + // Create the JSON-RPC request + jsonRPCRequest := struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params mcp.CreateMessageParams `json:"params"` + }{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: id, + Method: string(mcp.MethodSamplingCreateMessage), + Params: request.CreateMessageParams, + } + + // Marshal and send the request + requestBytes, err := json.Marshal(jsonRPCRequest) + if err != nil { + return nil, fmt.Errorf("failed to marshal sampling request: %w", err) + } + requestBytes = append(requestBytes, '\n') + + if _, err := writer.Write(requestBytes); err != nil { + return nil, fmt.Errorf("failed to write sampling request: %w", err) + } + + // Wait for the response or context cancellation + select { + case <-ctx.Done(): + return nil, ctx.Err() + case response := <-responseChan: + if response.err != nil { + return nil, response.err + } + return response.result, nil + } +} + +// SetWriter sets the writer for sending requests to the client. +func (s *stdioSession) SetWriter(writer io.Writer) { + s.mu.Lock() + defer s.mu.Unlock() + s.writer = writer +} + var ( _ ClientSession = (*stdioSession)(nil) _ SessionWithLogging = (*stdioSession)(nil) _ SessionWithClientInfo = (*stdioSession)(nil) + _ SessionWithSampling = (*stdioSession)(nil) ) var stdioSessionInstance = stdioSession{ - notifications: make(chan mcp.JSONRPCNotification, 100), + notifications: make(chan mcp.JSONRPCNotification, 100), + pendingRequests: make(map[int64]chan *samplingResponse), } // NewStdioServer creates a new stdio server wrapper around an MCPServer. @@ -224,6 +308,9 @@ func (s *StdioServer) Listen( defer s.server.UnregisterSession(ctx, stdioSessionInstance.SessionID()) ctx = s.server.WithContext(ctx, &stdioSessionInstance) + // Set the writer for sending requests to the client + stdioSessionInstance.SetWriter(stdout) + // Add in any custom context. if s.contextFunc != nil { ctx = s.contextFunc(ctx) @@ -256,7 +343,29 @@ func (s *StdioServer) processMessage( return s.writeResponse(response, writer) } - // Handle the message using the wrapped server + // Check if this is a response to a sampling request + if s.handleSamplingResponse(rawMessage) { + return nil + } + + // Check if this is a tool call that might need sampling (and thus should be processed concurrently) + var baseMessage struct { + Method string `json:"method"` + } + if json.Unmarshal(rawMessage, &baseMessage) == nil && baseMessage.Method == "tools/call" { + // Process tool calls concurrently to avoid blocking on sampling requests + go func() { + response := s.server.HandleMessage(ctx, rawMessage) + if response != nil { + if err := s.writeResponse(response, writer); err != nil { + s.errLogger.Printf("Error writing tool response: %v", err) + } + } + }() + return nil + } + + // Handle other messages synchronously response := s.server.HandleMessage(ctx, rawMessage) // Only write response if there is one (not for notifications) @@ -269,6 +378,65 @@ func (s *StdioServer) processMessage( return nil } +// handleSamplingResponse checks if the message is a response to a sampling request +// and routes it to the appropriate pending request channel. +func (s *StdioServer) handleSamplingResponse(rawMessage json.RawMessage) bool { + return stdioSessionInstance.handleSamplingResponse(rawMessage) +} + +// handleSamplingResponse handles incoming sampling responses for this session +func (s *stdioSession) handleSamplingResponse(rawMessage json.RawMessage) bool { + // Try to parse as a JSON-RPC response + var response struct { + JSONRPC string `json:"jsonrpc"` + ID json.Number `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error *struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error,omitempty"` + } + + if err := json.Unmarshal(rawMessage, &response); err != nil { + return false + } + // Parse the ID as int64 + idInt64, err := response.ID.Int64() + if err != nil || (response.Result == nil && response.Error == nil) { + return false + } + + // Look for a pending request with this ID + s.pendingMu.RLock() + responseChan, exists := s.pendingRequests[idInt64] + s.pendingMu.RUnlock() + + if !exists { + return false + } // Parse and send the response + samplingResp := &samplingResponse{} + + if response.Error != nil { + samplingResp.err = fmt.Errorf("sampling request failed: %s", response.Error.Message) + } else { + var result mcp.CreateMessageResult + if err := json.Unmarshal(response.Result, &result); err != nil { + samplingResp.err = fmt.Errorf("failed to unmarshal sampling response: %w", err) + } else { + samplingResp.result = &result + } + } + + // Send the response (non-blocking) + select { + case responseChan <- samplingResp: + default: + // Channel is full or closed, ignore + } + + return true +} + // writeResponse marshals and writes a JSON-RPC response message followed by a newline. // Returns an error if marshaling or writing fails. func (s *StdioServer) writeResponse( diff --git a/vendor/github.com/mark3labs/mcp-go/server/streamable_http.go b/vendor/github.com/mark3labs/mcp-go/server/streamable_http.go index e9a011fb1c31b46771e3baeffe666ef9a71ef1a1..1312c9753a5ddc2d37a2d3c9f6266cc80d517e2e 100644 --- a/vendor/github.com/mark3labs/mcp-go/server/streamable_http.go +++ b/vendor/github.com/mark3labs/mcp-go/server/streamable_http.go @@ -40,7 +40,9 @@ func WithEndpointPath(endpointPath string) StreamableHTTPOption { // to StatelessSessionIdManager. func WithStateLess(stateLess bool) StreamableHTTPOption { return func(s *StreamableHTTPServer) { - s.sessionIdManager = &StatelessSessionIdManager{} + if stateLess { + s.sessionIdManager = &StatelessSessionIdManager{} + } } } @@ -374,7 +376,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") - w.WriteHeader(http.StatusAccepted) + w.WriteHeader(http.StatusOK) flusher, ok := w.(http.Flusher) if !ok { diff --git a/vendor/modules.txt b/vendor/modules.txt index 33d95285eebb41a1038aa2d95233bbcc96a87151..ebdc8318f987500b38cb989a7a0de6bea45caf5f 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -260,8 +260,8 @@ github.com/charmbracelet/bubbletea/v2 # github.com/charmbracelet/colorprofile v0.3.1 ## explicit; go 1.23.0 github.com/charmbracelet/colorprofile -# github.com/charmbracelet/fang v0.1.0 -## explicit; go 1.23.0 +# github.com/charmbracelet/fang v0.3.1-0.20250711140230-d5ebb8c1d674 +## explicit; go 1.24.0 github.com/charmbracelet/fang # github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe ## explicit; go 1.23.0 @@ -269,7 +269,7 @@ github.com/charmbracelet/glamour/v2 github.com/charmbracelet/glamour/v2/ansi github.com/charmbracelet/glamour/v2/internal/autolink github.com/charmbracelet/glamour/v2/styles -# github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.2.0.20250703152125-8e1c474f8a71 => github.com/charmbracelet/lipgloss-internal/v2 v2.0.0-20250710185058-03664cb9cecb +# github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.3 => github.com/charmbracelet/lipgloss-internal/v2 v2.0.0-20250710185058-03664cb9cecb ## explicit; go 1.24.2 github.com/charmbracelet/lipgloss/v2 github.com/charmbracelet/lipgloss/v2/table @@ -403,7 +403,7 @@ github.com/kylelemons/godebug/pretty # github.com/lucasb-eyer/go-colorful v1.2.0 ## explicit; go 1.12 github.com/lucasb-eyer/go-colorful -# github.com/mark3labs/mcp-go v0.32.0 +# github.com/mark3labs/mcp-go v0.33.0 ## explicit; go 1.23 github.com/mark3labs/mcp-go/client github.com/mark3labs/mcp-go/client/transport