Merge branch 'main' into feature/log-stdout

Tai Groot created

Change summary

.gitignore                                                             |   1 
cmd/root.go                                                            |  23 
go.mod                                                                 |   6 
go.sum                                                                 |   8 
internal/app/app.go                                                    |  69 
internal/app/lsp.go                                                    |   5 
internal/config/config.go                                              |  33 
internal/config/load.go                                                |  35 
internal/config/resolve.go                                             |   2 
internal/format/format.go                                              |  91 
internal/format/spinner.go                                             |  50 
internal/fur/client/client.go                                          |   2 
internal/llm/agent/agent.go                                            |   4 
internal/llm/prompt/prompt.go                                          |  53 
internal/llm/prompt/prompt_test.go                                     | 113 
internal/llm/tools/bash.go                                             | 127 
internal/lsp/transport.go                                              |  41 
internal/lsp/watcher/watcher.go                                        | 153 
internal/shell/command_block_test.go                                   | 123 
internal/shell/shell.go                                                |  82 
internal/tui/components/chat/editor/editor.go                          |   5 
internal/tui/components/chat/messages/renderer.go                      |   7 
internal/tui/components/chat/splash/keys.go                            |   4 
internal/tui/components/chat/splash/splash.go                          |   1 
internal/tui/components/completions/completions.go                     |  23 
internal/tui/components/dialogs/permissions/keys.go                    |   4 
internal/tui/exp/diffview/diffview.go                                  |   6 
internal/tui/page/chat/chat.go                                         |  15 
internal/tui/tui.go                                                    |   3 
internal/tui/util/util.go                                              |   2 
vendor/github.com/Azure/azure-sdk-for-go/sdk/azidentity/go.work.sum    |  60 
vendor/github.com/charmbracelet/fang/README.md                         |   7 
vendor/github.com/charmbracelet/fang/fang.go                           | 119 
vendor/github.com/charmbracelet/fang/help.go                           | 298 
vendor/github.com/charmbracelet/fang/theme.go                          |  52 
vendor/github.com/mark3labs/mcp-go/client/client.go                    | 102 
vendor/github.com/mark3labs/mcp-go/client/http.go                      |   7 
vendor/github.com/mark3labs/mcp-go/client/sampling.go                  |  20 
vendor/github.com/mark3labs/mcp-go/client/stdio.go                     |  22 
vendor/github.com/mark3labs/mcp-go/client/transport/inprocess.go       |   4 
vendor/github.com/mark3labs/mcp-go/client/transport/interface.go       |  20 
vendor/github.com/mark3labs/mcp-go/client/transport/sse.go             |   6 
vendor/github.com/mark3labs/mcp-go/client/transport/stdio.go           | 204 
vendor/github.com/mark3labs/mcp-go/client/transport/streamable_http.go | 343 
vendor/github.com/mark3labs/mcp-go/mcp/tools.go                        | 106 
vendor/github.com/mark3labs/mcp-go/mcp/types.go                        |  21 
vendor/github.com/mark3labs/mcp-go/mcp/utils.go                        |  21 
vendor/github.com/mark3labs/mcp-go/server/sampling.go                  |  37 
vendor/github.com/mark3labs/mcp-go/server/stdio.go                     | 180 
vendor/github.com/mark3labs/mcp-go/server/streamable_http.go           |   6 
vendor/modules.txt                                                     |   8 
51 files changed, 2,019 insertions(+), 715 deletions(-)

Detailed changes

.gitignore 🔗

@@ -16,6 +16,7 @@
 
 # Go workspace file
 go.work
+go.work.sum
 
 # IDE specific files
 .idea/

cmd/root.go 🔗

@@ -12,7 +12,6 @@ import (
 	"github.com/charmbracelet/crush/internal/app"
 	"github.com/charmbracelet/crush/internal/config"
 	"github.com/charmbracelet/crush/internal/db"
-	"github.com/charmbracelet/crush/internal/format"
 	"github.com/charmbracelet/crush/internal/llm/agent"
 	"github.com/charmbracelet/crush/internal/log"
 	"github.com/charmbracelet/crush/internal/tui"
@@ -52,14 +51,8 @@ to assist developers in writing, debugging, and understanding code directly from
 		debug, _ := cmd.Flags().GetBool("debug")
 		cwd, _ := cmd.Flags().GetString("cwd")
 		prompt, _ := cmd.Flags().GetString("prompt")
-		outputFormat, _ := cmd.Flags().GetString("output-format")
 		quiet, _ := cmd.Flags().GetBool("quiet")
 
-		// Validate format option
-		if !format.IsValid(outputFormat) {
-			return fmt.Errorf("invalid format option: %s\n%s", outputFormat, format.GetHelpText())
-		}
-
 		if cwd != "" {
 			err := os.Chdir(cwd)
 			if err != nil {
@@ -79,9 +72,7 @@ to assist developers in writing, debugging, and understanding code directly from
 			return err
 		}
 
-		// Create main context for the application
-		ctx, cancel := context.WithCancel(context.Background())
-		defer cancel()
+		ctx := cmd.Context()
 
 		// Connect DB, this will also run migrations
 		conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
@@ -109,7 +100,7 @@ to assist developers in writing, debugging, and understanding code directly from
 		// Non-interactive mode
 		if prompt != "" {
 			// Run non-interactive flow using the App method
-			return app.RunNonInteractive(ctx, prompt, outputFormat, quiet)
+			return app.RunNonInteractive(ctx, prompt, quiet)
 		}
 
 		// Set up the TUI
@@ -152,6 +143,7 @@ func Execute() {
 		context.Background(),
 		rootCmd,
 		fang.WithVersion(version.Version),
+		fang.WithNotifySignal(os.Interrupt),
 	); err != nil {
 		os.Exit(1)
 	}
@@ -164,17 +156,8 @@ func init() {
 	rootCmd.Flags().BoolP("debug", "d", false, "Debug")
 	rootCmd.Flags().StringP("prompt", "p", "", "Prompt to run in non-interactive mode")
 
-	// Add format flag with validation logic
-	rootCmd.Flags().StringP("output-format", "f", format.Text.String(),
-		"Output format for non-interactive mode (text, json)")
-
 	// Add quiet flag to hide spinner in non-interactive mode
 	rootCmd.Flags().BoolP("quiet", "q", false, "Hide spinner in non-interactive mode")
-
-	// Register custom validation for the format flag
-	rootCmd.RegisterFlagCompletionFunc("output-format", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
-		return format.SupportedFormats, cobra.ShellCompDirectiveNoFileComp
-	})
 }
 
 func maybePrependStdin(prompt string) (string, error) {

go.mod 🔗

@@ -18,9 +18,9 @@ require (
 	github.com/charlievieth/fastwalk v1.0.11
 	github.com/charmbracelet/bubbles/v2 v2.0.0-beta.1.0.20250710161907-a4c42b579198
 	github.com/charmbracelet/bubbletea/v2 v2.0.0-beta.1
-	github.com/charmbracelet/fang v0.1.0
+	github.com/charmbracelet/fang v0.3.1-0.20250711140230-d5ebb8c1d674
 	github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe
-	github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.2.0.20250703152125-8e1c474f8a71
+	github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.3
 	github.com/charmbracelet/log/v2 v2.0.0-20250226163916-c379e29ff706
 	github.com/charmbracelet/x/ansi v0.9.3
 	github.com/charmbracelet/x/exp/charmtone v0.0.0-20250708181618-a60a724ba6c3
@@ -29,7 +29,7 @@ require (
 	github.com/fsnotify/fsnotify v1.8.0
 	github.com/google/uuid v1.6.0
 	github.com/joho/godotenv v1.5.1
-	github.com/mark3labs/mcp-go v0.32.0
+	github.com/mark3labs/mcp-go v0.33.0
 	github.com/muesli/termenv v0.16.0
 	github.com/ncruces/go-sqlite3 v0.25.0
 	github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646

go.sum 🔗

@@ -74,8 +74,8 @@ github.com/charmbracelet/bubbletea-internal/v2 v2.0.0-20250710185017-3c0ffd25e59
 github.com/charmbracelet/bubbletea-internal/v2 v2.0.0-20250710185017-3c0ffd25e595/go.mod h1:+Tl7rePElw6OKt382t04zXwtPFoPXxAaJzNrYmtsLds=
 github.com/charmbracelet/colorprofile v0.3.1 h1:k8dTHMd7fgw4bnFd7jXTLZrSU/CQrKnL3m+AxCzDz40=
 github.com/charmbracelet/colorprofile v0.3.1/go.mod h1:/GkGusxNs8VB/RSOh3fu0TJmQ4ICMMPApIIVn0KszZ0=
-github.com/charmbracelet/fang v0.1.0 h1:SlZS2crf3/zQh7Mr4+W+7QR1k+L08rrPX5rm5z3d7Wg=
-github.com/charmbracelet/fang v0.1.0/go.mod h1:Zl/zeUQ8EtQuGyiV0ZKZlZPDowKRTzu8s/367EpN/fc=
+github.com/charmbracelet/fang v0.3.1-0.20250711140230-d5ebb8c1d674 h1:+Cz+VfxD5DO+JT1LlswXWhre0HYLj6l2HW8HVGfMuC0=
+github.com/charmbracelet/fang v0.3.1-0.20250711140230-d5ebb8c1d674/go.mod h1:9gCUAHmVx5BwSafeyNr3GI0GgvlB1WYjL21SkPp1jyU=
 github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe h1:i6ce4CcAlPpTj2ER69m1DBeLZ3RRcHnKExuwhKa3GfY=
 github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe/go.mod h1:p3Q+aN4eQKeM5jhrmXPMgPrlKbmc59rWSnMsSA3udhk=
 github.com/charmbracelet/lipgloss-internal/v2 v2.0.0-20250710185058-03664cb9cecb h1:lswj7CYZVYbLn2OhYJsXOMRQQGdRIfyuSnh5FdVSMr0=
@@ -165,8 +165,8 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
 github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
 github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
 github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
-github.com/mark3labs/mcp-go v0.32.0 h1:fgwmbfL2gbd67obg57OfV2Dnrhs1HtSdlY/i5fn7MU8=
-github.com/mark3labs/mcp-go v0.32.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4=
+github.com/mark3labs/mcp-go v0.33.0 h1:naxhjnTIs/tyPZmWUZFuG0lDmdA6sUyYGGf3gsHvTCc=
+github.com/mark3labs/mcp-go v0.33.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4=
 github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
 github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
 github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=

internal/app/app.go 🔗

@@ -92,16 +92,26 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
 }
 
 // RunNonInteractive handles the execution flow when a prompt is provided via CLI flag.
-func (a *App) RunNonInteractive(ctx context.Context, prompt string, outputFormat string, quiet bool) error {
+func (a *App) RunNonInteractive(ctx context.Context, prompt string, quiet bool) error {
 	slog.Info("Running in non-interactive mode")
 
+	ctx, cancel := context.WithCancel(ctx)
+	defer cancel()
+
 	// Start spinner if not in quiet mode
 	var spinner *format.Spinner
 	if !quiet {
-		spinner = format.NewSpinner(ctx, "Generating")
+		spinner = format.NewSpinner(ctx, cancel, "Generating")
 		spinner.Start()
-		defer spinner.Stop()
 	}
+	// Helper function to stop spinner once
+	stopSpinner := func() {
+		if !quiet && spinner != nil {
+			spinner.Stop()
+			spinner = nil
+		}
+	}
+	defer stopSpinner()
 
 	const maxPromptLengthForTitle = 100
 	titlePrefix := "Non-interactive: "
@@ -128,35 +138,42 @@ func (a *App) RunNonInteractive(ctx context.Context, prompt string, outputFormat
 		return fmt.Errorf("failed to start agent processing stream: %w", err)
 	}
 
-	result := <-done
+	messageEvents := a.Messages.Subscribe(ctx)
+	readBts := 0
 
-	// Stop spinner before printing output
-	if !quiet && spinner != nil {
-		spinner.Stop()
-	}
+	for {
+		select {
+		case result := <-done:
+			stopSpinner()
+
+			if result.Error != nil {
+				if errors.Is(result.Error, context.Canceled) || errors.Is(result.Error, agent.ErrRequestCancelled) {
+					slog.Info("Agent processing cancelled", "session_id", sess.ID)
+					return nil
+				}
+				return fmt.Errorf("agent processing failed: %w", result.Error)
+			}
+
+			part := result.Message.Content().String()[readBts:]
+			fmt.Println(part)
 
-	if result.Error != nil {
-		if errors.Is(result.Error, context.Canceled) || errors.Is(result.Error, agent.ErrRequestCancelled) {
-			slog.Info("Agent processing cancelled", "session_id", sess.ID)
+			slog.Info("Non-interactive run completed", "session_id", sess.ID)
 			return nil
-		}
-		return fmt.Errorf("agent processing failed: %w", result.Error)
-	}
 
-	// Get the text content from the response
-	content := "No content available"
-	if result.Message.Content().String() != "" {
-		content = result.Message.Content().String()
-	}
+		case event := <-messageEvents:
+			msg := event.Payload
+			if msg.SessionID == sess.ID && msg.Role == message.Assistant && len(msg.Parts) > 0 {
+				stopSpinner()
+				part := msg.Content().String()[readBts:]
+				fmt.Print(part)
+				readBts += len(part)
+			}
 
-	out, err := format.FormatOutput(content, outputFormat)
-	if err != nil {
-		return err
+		case <-ctx.Done():
+			stopSpinner()
+			return ctx.Err()
+		}
 	}
-
-	fmt.Println(out)
-	slog.Info("Non-interactive run completed", "session_id", sess.ID)
-	return nil
 }
 
 func (app *App) UpdateAgentModel() error {

internal/app/lsp.go 🔗

@@ -59,11 +59,8 @@ func (app *App) createAndStartLSPClient(ctx context.Context, name string, comman
 	// Create a child context that can be canceled when the app is shutting down
 	watchCtx, cancelFunc := context.WithCancel(ctx)
 
-	// Create a context with the server name for better identification
-	watchCtx = context.WithValue(watchCtx, "serverName", name)
-
 	// Create the workspace watcher
-	workspaceWatcher := watcher.NewWorkspaceWatcher(lspClient)
+	workspaceWatcher := watcher.NewWorkspaceWatcher(name, lspClient)
 
 	// Store the cancel function to be called during cleanup
 	app.cancelFuncsMutex.Lock()

internal/config/config.go 🔗

@@ -72,7 +72,7 @@ type ProviderConfig struct {
 	Disable bool `json:"disable,omitempty"`
 
 	// Extra headers to send with each request to the provider.
-	ExtraHeaders map[string]string
+	ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
 
 	// Used to pass extra parameters to the provider.
 	ExtraParams map[string]string `json:"-"`
@@ -371,3 +371,34 @@ func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
 	c.Providers[providerID] = providerConfig
 	return nil
 }
+
+func (c *Config) SetupAgents() {
+	agents := map[string]Agent{
+		"coder": {
+			ID:           "coder",
+			Name:         "Coder",
+			Description:  "An agent that helps with executing coding tasks.",
+			Model:        SelectedModelTypeLarge,
+			ContextPaths: c.Options.ContextPaths,
+			// All tools allowed
+		},
+		"task": {
+			ID:           "task",
+			Name:         "Task",
+			Description:  "An agent that helps with searching for context and finding implementation details.",
+			Model:        SelectedModelTypeLarge,
+			ContextPaths: c.Options.ContextPaths,
+			AllowedTools: []string{
+				"glob",
+				"grep",
+				"ls",
+				"sourcegraph",
+				"view",
+			},
+			// NO MCPs or LSPs by default
+			AllowedMCP: map[string][]string{},
+			AllowedLSP: []string{},
+		},
+	}
+	c.Agents = agents
+}

internal/config/load.go 🔗

@@ -83,37 +83,7 @@ func Load(workingDir string, debug bool) (*Config, error) {
 	if err := cfg.configureSelectedModels(providers); err != nil {
 		return nil, fmt.Errorf("failed to configure selected models: %w", err)
 	}
-
-	// TODO: remove the agents concept from the config
-	agents := map[string]Agent{
-		"coder": {
-			ID:           "coder",
-			Name:         "Coder",
-			Description:  "An agent that helps with executing coding tasks.",
-			Model:        SelectedModelTypeLarge,
-			ContextPaths: cfg.Options.ContextPaths,
-			// All tools allowed
-		},
-		"task": {
-			ID:           "task",
-			Name:         "Task",
-			Description:  "An agent that helps with searching for context and finding implementation details.",
-			Model:        SelectedModelTypeLarge,
-			ContextPaths: cfg.Options.ContextPaths,
-			AllowedTools: []string{
-				"glob",
-				"grep",
-				"ls",
-				"sourcegraph",
-				"view",
-			},
-			// NO MCPs or LSPs by default
-			AllowedMCP: map[string][]string{},
-			AllowedLSP: []string{},
-		},
-	}
-	cfg.Agents = agents
-
+	cfg.SetupAgents()
 	return cfg, nil
 }
 
@@ -331,6 +301,7 @@ func (cfg *Config) defaultModelSelection(knownProviders []provider.Provider) (la
 		defaultSmallModel := cfg.GetModel(string(p.ID), p.DefaultSmallModelID)
 		if defaultSmallModel == nil {
 			err = fmt.Errorf("default small model %s not found for provider %s", p.DefaultSmallModelID, p.ID)
+			return
 		}
 		smallModel = SelectedModel{
 			Provider:        string(p.ID),
@@ -387,8 +358,6 @@ func (cfg *Config) configureSelectedModels(knownProviders []provider.Provider) e
 			large.Provider = largeModelSelected.Provider
 		}
 		model := cfg.GetModel(large.Provider, large.Model)
-		slog.Info("Configuring selected large model", "provider", large.Provider, "model", large.Model)
-		slog.Info("Model configured", "model", model)
 		if model == nil {
 			large = defaultLarge
 			// override the model type to large

internal/config/resolve.go 🔗

@@ -44,7 +44,7 @@ func (r *shellVariableResolver) ResolveValue(value string) (string, error) {
 
 	if strings.HasPrefix(value, "$(") && strings.HasSuffix(value, ")") {
 		command := strings.TrimSuffix(strings.TrimPrefix(value, "$("), ")")
-		ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 		defer cancel()
 
 		stdout, _, err := r.shell.Exec(ctx, command)

internal/format/format.go 🔗

@@ -1,91 +0,0 @@
-package format
-
-import (
-	"encoding/json"
-	"fmt"
-	"strings"
-)
-
-// OutputFormat represents the output format type for non-interactive mode
-type OutputFormat string
-
-const (
-	// Text format outputs the AI response as plain text.
-	Text OutputFormat = "text"
-
-	// JSON format outputs the AI response wrapped in a JSON object.
-	JSON OutputFormat = "json"
-)
-
-// String returns the string representation of the OutputFormat
-func (f OutputFormat) String() string {
-	return string(f)
-}
-
-// SupportedFormats is a list of all supported output formats as strings
-var SupportedFormats = []string{
-	string(Text),
-	string(JSON),
-}
-
-// Parse converts a string to an OutputFormat
-func Parse(s string) (OutputFormat, error) {
-	s = strings.ToLower(strings.TrimSpace(s))
-
-	switch s {
-	case string(Text):
-		return Text, nil
-	case string(JSON):
-		return JSON, nil
-	default:
-		return "", fmt.Errorf("invalid format: %s", s)
-	}
-}
-
-// IsValid checks if the provided format string is supported
-func IsValid(s string) bool {
-	_, err := Parse(s)
-	return err == nil
-}
-
-// GetHelpText returns a formatted string describing all supported formats
-func GetHelpText() string {
-	return fmt.Sprintf(`Supported output formats:
-- %s: Plain text output (default)
-- %s: Output wrapped in a JSON object`,
-		Text, JSON)
-}
-
-// FormatOutput formats the AI response according to the specified format
-func FormatOutput(content string, formatStr string) (string, error) {
-	format, err := Parse(formatStr)
-	if err != nil {
-		format = Text
-	}
-
-	switch format {
-	case JSON:
-		return formatAsJSON(content)
-	case Text:
-		fallthrough
-	default:
-		return content, nil
-	}
-}
-
-// formatAsJSON wraps the content in a simple JSON object
-func formatAsJSON(content string) (string, error) {
-	// Use the JSON package to properly escape the content
-	response := struct {
-		Response string `json:"response"`
-	}{
-		Response: content,
-	}
-
-	jsonBytes, err := json.MarshalIndent(response, "", "  ")
-	if err != nil {
-		return "", fmt.Errorf("failed to marshal output into JSON: %w", err)
-	}
-
-	return string(jsonBytes), nil
-}

internal/format/spinner.go 🔗

@@ -18,24 +18,48 @@ type Spinner struct {
 	prog *tea.Program
 }
 
+type model struct {
+	cancel context.CancelFunc
+	anim   anim.Anim
+}
+
+func (m model) Init() tea.Cmd { return m.anim.Init() }
+func (m model) View() string  { return m.anim.View() }
+
+// Update implements tea.Model.
+func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
+	switch msg := msg.(type) {
+	case tea.KeyPressMsg:
+		switch msg.String() {
+		case "ctrl+c", "esc":
+			m.cancel()
+			return m, tea.Quit
+		}
+	}
+	mm, cmd := m.anim.Update(msg)
+	m.anim = mm.(anim.Anim)
+	return m, cmd
+}
+
 // NewSpinner creates a new spinner with the given message
-func NewSpinner(ctx context.Context, message string) *Spinner {
+func NewSpinner(ctx context.Context, cancel context.CancelFunc, message string) *Spinner {
 	t := styles.CurrentTheme()
-	model := anim.New(anim.Settings{
-		Size:        10,
-		Label:       message,
-		LabelColor:  t.FgBase,
-		GradColorA:  t.Primary,
-		GradColorB:  t.Secondary,
-		CycleColors: true,
-	})
+	model := model{
+		anim: anim.New(anim.Settings{
+			Size:        10,
+			Label:       message,
+			LabelColor:  t.FgBase,
+			GradColorA:  t.Primary,
+			GradColorB:  t.Secondary,
+			CycleColors: true,
+		}),
+		cancel: cancel,
+	}
 
 	prog := tea.NewProgram(
 		model,
-		tea.WithInput(nil),
 		tea.WithOutput(os.Stderr),
 		tea.WithContext(ctx),
-		tea.WithoutCatchPanics(),
 	)
 
 	return &Spinner{
@@ -47,13 +71,13 @@ func NewSpinner(ctx context.Context, message string) *Spinner {
 // Start begins the spinner animation
 func (s *Spinner) Start() {
 	go func() {
+		defer close(s.done)
 		_, err := s.prog.Run()
 		// ensures line is cleared
 		fmt.Fprint(os.Stderr, ansi.EraseEntireLine)
-		if err != nil && !errors.Is(err, context.Canceled) {
+		if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, tea.ErrInterrupted) {
 			fmt.Fprintf(os.Stderr, "Error running spinner: %v\n", err)
 		}
-		close(s.done)
 	}()
 }
 

internal/fur/client/client.go 🔗

@@ -10,7 +10,7 @@ import (
 	"github.com/charmbracelet/crush/internal/fur/provider"
 )
 
-const defaultURL = "https://fur.charmcli.dev"
+const defaultURL = "https://fur.charm.sh"
 
 // Client represents a client for the fur service.
 type Client struct {

internal/llm/agent/agent.go 🔗

@@ -149,7 +149,7 @@ func NewAgent(
 	}
 	opts := []provider.ProviderClientOption{
 		provider.WithModel(agentCfg.Model),
-		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
+		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 	}
 	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 	if err != nil {
@@ -827,7 +827,7 @@ func (a *agent) UpdateModel() error {
 
 		opts := []provider.ProviderClientOption{
 			provider.WithModel(a.agentCfg.Model),
-			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
+			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
 		}
 
 		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)

internal/llm/prompt/prompt.go 🔗

@@ -5,6 +5,9 @@ import (
 	"path/filepath"
 	"strings"
 	"sync"
+
+	"github.com/charmbracelet/crush/internal/config"
+	"github.com/charmbracelet/crush/internal/env"
 )
 
 type PromptID string
@@ -21,7 +24,7 @@ func GetPrompt(promptID PromptID, provider string, contextPaths ...string) strin
 	basePrompt := ""
 	switch promptID {
 	case PromptCoder:
-		basePrompt = CoderPrompt(provider)
+		basePrompt = CoderPrompt(provider, contextPaths...)
 	case PromptTitle:
 		basePrompt = TitlePrompt()
 	case PromptTask:
@@ -38,6 +41,32 @@ func getContextFromPaths(workingDir string, contextPaths []string) string {
 	return processContextPaths(workingDir, contextPaths)
 }
 
+// expandPath expands ~ and environment variables in file paths
+func expandPath(path string) string {
+	// Handle tilde expansion
+	if strings.HasPrefix(path, "~/") {
+		homeDir, err := os.UserHomeDir()
+		if err == nil {
+			path = filepath.Join(homeDir, path[2:])
+		}
+	} else if path == "~" {
+		homeDir, err := os.UserHomeDir()
+		if err == nil {
+			path = homeDir
+		}
+	}
+
+	// Handle environment variable expansion using the same pattern as config
+	if strings.HasPrefix(path, "$") {
+		resolver := config.NewEnvironmentVariableResolver(env.New())
+		if expanded, err := resolver.ResolveValue(path); err == nil {
+			path = expanded
+		}
+	}
+
+	return path
+}
+
 func processContextPaths(workDir string, paths []string) string {
 	var (
 		wg       sync.WaitGroup
@@ -53,8 +82,23 @@ func processContextPaths(workDir string, paths []string) string {
 		go func(p string) {
 			defer wg.Done()
 
-			if strings.HasSuffix(p, "/") {
-				filepath.WalkDir(filepath.Join(workDir, p), func(path string, d os.DirEntry, err error) error {
+			// Expand ~ and environment variables before processing
+			p = expandPath(p)
+
+			// Use absolute path if provided, otherwise join with workDir
+			fullPath := p
+			if !filepath.IsAbs(p) {
+				fullPath = filepath.Join(workDir, p)
+			}
+
+			// Check if the path is a directory using os.Stat
+			info, err := os.Stat(fullPath)
+			if err != nil {
+				return // Skip if path doesn't exist or can't be accessed
+			}
+
+			if info.IsDir() {
+				filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
 					if err != nil {
 						return err
 					}
@@ -78,8 +122,7 @@ func processContextPaths(workDir string, paths []string) string {
 					return nil
 				})
 			} else {
-				fullPath := filepath.Join(workDir, p)
-
+				// It's a file, process it directly
 				// Check if we've already processed this file (case-insensitive)
 				lowerPath := strings.ToLower(fullPath)
 

internal/llm/prompt/prompt_test.go 🔗

@@ -0,0 +1,113 @@
+package prompt
+
+import (
+	"os"
+	"path/filepath"
+	"strings"
+	"testing"
+)
+
+func TestExpandPath(t *testing.T) {
+	tests := []struct {
+		name     string
+		input    string
+		expected func() string
+	}{
+		{
+			name:  "regular path unchanged",
+			input: "/absolute/path",
+			expected: func() string {
+				return "/absolute/path"
+			},
+		},
+		{
+			name:  "tilde expansion",
+			input: "~/documents",
+			expected: func() string {
+				home, _ := os.UserHomeDir()
+				return filepath.Join(home, "documents")
+			},
+		},
+		{
+			name:  "tilde only",
+			input: "~",
+			expected: func() string {
+				home, _ := os.UserHomeDir()
+				return home
+			},
+		},
+		{
+			name:  "environment variable expansion",
+			input: "$HOME",
+			expected: func() string {
+				return os.Getenv("HOME")
+			},
+		},
+		{
+			name:  "relative path unchanged",
+			input: "relative/path",
+			expected: func() string {
+				return "relative/path"
+			},
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			result := expandPath(tt.input)
+			expected := tt.expected()
+
+			// Skip test if environment variable is not set
+			if strings.HasPrefix(tt.input, "$") && expected == "" {
+				t.Skip("Environment variable not set")
+			}
+
+			if result != expected {
+				t.Errorf("expandPath(%q) = %q, want %q", tt.input, result, expected)
+			}
+		})
+	}
+}
+
+func TestProcessContextPaths(t *testing.T) {
+	// Create a temporary directory and file for testing
+	tmpDir := t.TempDir()
+	testFile := filepath.Join(tmpDir, "test.txt")
+	testContent := "test content"
+
+	err := os.WriteFile(testFile, []byte(testContent), 0o644)
+	if err != nil {
+		t.Fatalf("Failed to create test file: %v", err)
+	}
+
+	// Test with absolute path to file
+	result := processContextPaths("", []string{testFile})
+	expected := "# From:" + testFile + "\n" + testContent
+
+	if result != expected {
+		t.Errorf("processContextPaths with absolute path failed.\nGot: %q\nWant: %q", result, expected)
+	}
+
+	// Test with directory path (should process all files in directory)
+	result = processContextPaths("", []string{tmpDir})
+	if !strings.Contains(result, testContent) {
+		t.Errorf("processContextPaths with directory path failed to include file content")
+	}
+
+	// Test with tilde expansion (if we can create a file in home directory)
+	tmpDir = t.TempDir()
+	t.Setenv("HOME", tmpDir)
+	homeTestFile := filepath.Join(tmpDir, "crush_test_file.txt")
+	err = os.WriteFile(homeTestFile, []byte(testContent), 0o644)
+	if err == nil {
+		defer os.Remove(homeTestFile) // Clean up
+
+		tildeFile := "~/crush_test_file.txt"
+		result = processContextPaths("", []string{tildeFile})
+		expected = "# From:" + homeTestFile + "\n" + testContent
+
+		if result != expected {
+			t.Errorf("processContextPaths with tilde expansion failed.\nGot: %q\nWant: %q", result, expected)
+		}
+	}
+}

internal/llm/tools/bash.go 🔗

@@ -4,6 +4,7 @@ import (
 	"context"
 	"encoding/json"
 	"fmt"
+	"log/slog"
 	"runtime"
 	"strings"
 	"time"
@@ -41,9 +42,74 @@ const (
 )
 
 var bannedCommands = []string{
-	"alias", "curl", "curlie", "wget", "axel", "aria2c",
-	"nc", "telnet", "lynx", "w3m", "links", "httpie", "xh",
-	"http-prompt", "chrome", "firefox", "safari",
+	// Network/Download tools
+	"alias",
+	"aria2c",
+	"axel",
+	"chrome",
+	"curl",
+	"curlie",
+	"firefox",
+	"http-prompt",
+	"httpie",
+	"links",
+	"lynx",
+	"nc",
+	"safari",
+	"telnet",
+	"w3m",
+	"wget",
+	"xh",
+
+	// System administration
+	"doas",
+	"su",
+	"sudo",
+
+	// Package managers
+	"apk",
+	"apt",
+	"apt-cache",
+	"apt-get",
+	"dnf",
+	"dpkg",
+	"emerge",
+	"home-manager",
+	"makepkg",
+	"opkg",
+	"pacman",
+	"paru",
+	"pkg",
+	"pkg_add",
+	"pkg_delete",
+	"portage",
+	"rpm",
+	"yay",
+	"yum",
+	"zypper",
+
+	// System modification
+	"at",
+	"batch",
+	"chkconfig",
+	"crontab",
+	"fdisk",
+	"mkfs",
+	"mount",
+	"parted",
+	"service",
+	"systemctl",
+	"umount",
+
+	// Network configuration
+	"firewall-cmd",
+	"ifconfig",
+	"ip",
+	"iptables",
+	"netstat",
+	"pfctl",
+	"route",
+	"ufw",
 }
 
 // getSafeReadOnlyCommands returns platform-appropriate safe commands
@@ -244,7 +310,42 @@ Important:
 - Never update git config`, bannedCommandsStr, MaxOutputLength)
 }
 
+func blockFuncs() []shell.BlockFunc {
+	return []shell.BlockFunc{
+		shell.CommandsBlocker(bannedCommands),
+		shell.ArgumentsBlocker([][]string{
+			// System package managers
+			{"apk", "add"},
+			{"apt", "install"},
+			{"apt-get", "install"},
+			{"dnf", "install"},
+			{"emerge"},
+			{"pacman", "-S"},
+			{"pkg", "install"},
+			{"yum", "install"},
+			{"zypper", "install"},
+
+			// Language-specific package managers
+			{"brew", "install"},
+			{"cargo", "install"},
+			{"gem", "install"},
+			{"go", "install"},
+			{"npm", "install", "-g"},
+			{"npm", "install", "--global"},
+			{"pip", "install", "--user"},
+			{"pip3", "install", "--user"},
+			{"pnpm", "add", "-g"},
+			{"pnpm", "add", "--global"},
+			{"yarn", "global", "add"},
+		}),
+	}
+}
+
 func NewBashTool(permission permission.Service, workingDir string) BaseTool {
+	// Set up command blocking on the persistent shell
+	persistentShell := shell.GetPersistentShell(workingDir)
+	persistentShell.SetBlockFuncs(blockFuncs())
+
 	return &bashTool{
 		permissions: permission,
 		workingDir:  workingDir,
@@ -289,13 +390,6 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
 		return NewTextErrorResponse("missing command"), nil
 	}
 
-	baseCmd := strings.Fields(params.Command)[0]
-	for _, banned := range bannedCommands {
-		if strings.EqualFold(baseCmd, banned) {
-			return NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", baseCmd)), nil
-		}
-	}
-
 	isSafeReadOnly := false
 	cmdLower := strings.ToLower(params.Command)
 
@@ -349,7 +443,20 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
 	stdout = truncateOutput(stdout)
 	stderr = truncateOutput(stderr)
 
+	slog.Info("Bash command executed",
+		"command", params.Command,
+		"stdout", stdout,
+		"stderr", stderr,
+		"exit_code", exitCode,
+		"interrupted", interrupted,
+		"err", err,
+	)
+
 	errorMessage := stderr
+	if errorMessage == "" && err != nil {
+		errorMessage = err.Error()
+	}
+
 	if interrupted {
 		if errorMessage != "" {
 			errorMessage += "\n"

internal/lsp/transport.go 🔗

@@ -222,29 +222,32 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any
 	}
 
 	// Wait for response
-	resp := <-ch
-
-	if cfg.Options.DebugLSP {
-		slog.Debug("Received response", "id", id)
-	}
-
-	if resp.Error != nil {
-		return fmt.Errorf("request failed: %s (code: %d)", resp.Error.Message, resp.Error.Code)
-	}
+	select {
+	case <-ctx.Done():
+		return ctx.Err()
+	case resp := <-ch:
+		if cfg.Options.DebugLSP {
+			slog.Debug("Received response", "id", id)
+		}
 
-	if result != nil {
-		// If result is a json.RawMessage, just copy the raw bytes
-		if rawMsg, ok := result.(*json.RawMessage); ok {
-			*rawMsg = resp.Result
-			return nil
+		if resp.Error != nil {
+			return fmt.Errorf("request failed: %s (code: %d)", resp.Error.Message, resp.Error.Code)
 		}
-		// Otherwise unmarshal into the provided type
-		if err := json.Unmarshal(resp.Result, result); err != nil {
-			return fmt.Errorf("failed to unmarshal result: %w", err)
+
+		if result != nil {
+			// If result is a json.RawMessage, just copy the raw bytes
+			if rawMsg, ok := result.(*json.RawMessage); ok {
+				*rawMsg = resp.Result
+				return nil
+			}
+			// Otherwise unmarshal into the provided type
+			if err := json.Unmarshal(resp.Result, result); err != nil {
+				return fmt.Errorf("failed to unmarshal result: %w", err)
+			}
 		}
-	}
 
-	return nil
+		return nil
+	}
 }
 
 // Notify sends a notification (a request without an ID that doesn't expect a response)

internal/lsp/watcher/watcher.go 🔗

@@ -21,6 +21,7 @@ import (
 // WorkspaceWatcher manages LSP file watching
 type WorkspaceWatcher struct {
 	client        *lsp.Client
+	name          string
 	workspacePath string
 
 	debounceTime time.Duration
@@ -33,8 +34,9 @@ type WorkspaceWatcher struct {
 }
 
 // NewWorkspaceWatcher creates a new workspace watcher
-func NewWorkspaceWatcher(client *lsp.Client) *WorkspaceWatcher {
+func NewWorkspaceWatcher(name string, client *lsp.Client) *WorkspaceWatcher {
 	return &WorkspaceWatcher{
+		name:          name,
 		client:        client,
 		debounceTime:  300 * time.Millisecond,
 		debounceMap:   make(map[string]*time.Timer),
@@ -95,7 +97,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
 	}
 
 	// Determine server type for specialized handling
-	serverName := getServerNameFromContext(ctx)
+	serverName := w.name
 	slog.Debug("Server type detected", "serverName", serverName)
 
 	// Check if this server has sent file watchers
@@ -325,17 +327,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
 	cfg := config.Get()
 	w.workspacePath = workspacePath
 
-	// Store the watcher in the context for later use
-	ctx = context.WithValue(ctx, "workspaceWatcher", w)
-
-	// If the server name isn't already in the context, try to detect it
-	if _, ok := ctx.Value("serverName").(string); !ok {
-		serverName := getServerNameFromContext(ctx)
-		ctx = context.WithValue(ctx, "serverName", serverName)
-	}
-
-	serverName := getServerNameFromContext(ctx)
-	slog.Debug("Starting workspace watcher", "workspacePath", workspacePath, "serverName", serverName)
+	slog.Debug("Starting workspace watcher", "workspacePath", workspacePath, "serverName", w.name)
 
 	// Register handler for file watcher registrations from the server
 	lsp.RegisterFileWatchHandler(func(id string, watchers []protocol.FileSystemWatcher) {
@@ -697,40 +689,6 @@ func (w *WorkspaceWatcher) notifyFileEvent(ctx context.Context, uri string, chan
 	return w.client.DidChangeWatchedFiles(ctx, params)
 }
 
-// getServerNameFromContext extracts the server name from the context
-// This is a best-effort function that tries to identify which LSP server we're dealing with
-func getServerNameFromContext(ctx context.Context) string {
-	// First check if the server name is directly stored in the context
-	if serverName, ok := ctx.Value("serverName").(string); ok && serverName != "" {
-		return strings.ToLower(serverName)
-	}
-
-	// Otherwise, try to extract server name from the client command path
-	if w, ok := ctx.Value("workspaceWatcher").(*WorkspaceWatcher); ok && w != nil && w.client != nil && w.client.Cmd != nil {
-		path := strings.ToLower(w.client.Cmd.Path)
-
-		// Extract server name from path
-		if strings.Contains(path, "typescript") || strings.Contains(path, "tsserver") || strings.Contains(path, "vtsls") {
-			return "typescript"
-		} else if strings.Contains(path, "gopls") {
-			return "gopls"
-		} else if strings.Contains(path, "rust-analyzer") {
-			return "rust-analyzer"
-		} else if strings.Contains(path, "pyright") || strings.Contains(path, "pylsp") || strings.Contains(path, "python") {
-			return "python"
-		} else if strings.Contains(path, "clangd") {
-			return "clangd"
-		} else if strings.Contains(path, "jdtls") || strings.Contains(path, "java") {
-			return "java"
-		}
-
-		// Return the base name as fallback
-		return filepath.Base(path)
-	}
-
-	return "unknown"
-}
-
 // shouldPreloadFiles determines if we should preload files for a specific language server
 // Some servers work better with preloaded files, others don't need it
 func shouldPreloadFiles(serverName string) bool {
@@ -884,64 +842,63 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) {
 	}
 
 	// Check if this path should be watched according to server registrations
-	if watched, _ := w.isPathWatched(path); watched {
-		// Get server name for specialized handling
-		serverName := getServerNameFromContext(ctx)
+	if watched, _ := w.isPathWatched(path); !watched {
+		return
+	}
 
-		// Check if the file is a high-priority file that should be opened immediately
-		// This helps with project initialization for certain language servers
-		if isHighPriorityFile(path, serverName) {
-			if cfg.Options.DebugLSP {
-				slog.Debug("Opening high-priority file", "path", path, "serverName", serverName)
-			}
-			if err := w.client.OpenFile(ctx, path); err != nil && cfg.Options.DebugLSP {
-				slog.Error("Error opening high-priority file", "path", path, "error", err)
-			}
-			return
-		}
+	serverName := w.name
 
-		// For non-high-priority files, we'll use different strategies based on server type
-		if shouldPreloadFiles(serverName) {
-			// For servers that benefit from preloading, open files but with limits
+	// Get server name for specialized handling
+	// Check if the file is a high-priority file that should be opened immediately
+	// This helps with project initialization for certain language servers
+	if isHighPriorityFile(path, serverName) {
+		if cfg.Options.DebugLSP {
+			slog.Debug("Opening high-priority file", "path", path, "serverName", serverName)
+		}
+		if err := w.client.OpenFile(ctx, path); err != nil && cfg.Options.DebugLSP {
+			slog.Error("Error opening high-priority file", "path", path, "error", err)
+		}
+		return
+	}
 
-			// Check file size - for preloading we're more conservative
-			if info.Size() > (1 * 1024 * 1024) { // 1MB limit for preloaded files
-				if cfg.Options.DebugLSP {
-					slog.Debug("Skipping large file for preloading", "path", path, "size", info.Size())
-				}
-				return
-			}
+	// For non-high-priority files, we'll use different strategies based on server type
+	if !shouldPreloadFiles(serverName) {
+		return
+	}
+	// For servers that benefit from preloading, open files but with limits
 
-			// Check file extension for common source files
-			ext := strings.ToLower(filepath.Ext(path))
+	// Check file size - for preloading we're more conservative
+	if info.Size() > (1 * 1024 * 1024) { // 1MB limit for preloaded files
+		if cfg.Options.DebugLSP {
+			slog.Debug("Skipping large file for preloading", "path", path, "size", info.Size())
+		}
+		return
+	}
 
-			// Only preload source files for the specific language
-			shouldOpen := false
+	// Check file extension for common source files
+	ext := strings.ToLower(filepath.Ext(path))
 
-			switch serverName {
-			case "typescript", "typescript-language-server", "tsserver", "vtsls":
-				shouldOpen = ext == ".ts" || ext == ".js" || ext == ".tsx" || ext == ".jsx"
-			case "gopls":
-				shouldOpen = ext == ".go"
-			case "rust-analyzer":
-				shouldOpen = ext == ".rs"
-			case "python", "pyright", "pylsp":
-				shouldOpen = ext == ".py"
-			case "clangd":
-				shouldOpen = ext == ".c" || ext == ".cpp" || ext == ".h" || ext == ".hpp"
-			case "java", "jdtls":
-				shouldOpen = ext == ".java"
-			default:
-				// For unknown servers, be conservative
-				shouldOpen = false
-			}
+	// Only preload source files for the specific language
+	var shouldOpen bool
+	switch serverName {
+	case "typescript", "typescript-language-server", "tsserver", "vtsls":
+		shouldOpen = ext == ".ts" || ext == ".js" || ext == ".tsx" || ext == ".jsx"
+	case "gopls":
+		shouldOpen = ext == ".go"
+	case "rust-analyzer":
+		shouldOpen = ext == ".rs"
+	case "python", "pyright", "pylsp":
+		shouldOpen = ext == ".py"
+	case "clangd":
+		shouldOpen = ext == ".c" || ext == ".cpp" || ext == ".h" || ext == ".hpp"
+	case "java", "jdtls":
+		shouldOpen = ext == ".java"
+	}
 
-			if shouldOpen {
-				// Don't need to check if it's already open - the client.OpenFile handles that
-				if err := w.client.OpenFile(ctx, path); err != nil && cfg.Options.DebugLSP {
-					slog.Error("Error opening file", "path", path, "error", err)
-				}
-			}
+	if shouldOpen {
+		// Don't need to check if it's already open - the client.OpenFile handles that
+		if err := w.client.OpenFile(ctx, path); err != nil && cfg.Options.DebugLSP {
+			slog.Error("Error opening file", "path", path, "error", err)
 		}
 	}
 }

internal/shell/command_block_test.go 🔗

@@ -0,0 +1,123 @@
+package shell
+
+import (
+	"context"
+	"os"
+	"strings"
+	"testing"
+)
+
+func TestCommandBlocking(t *testing.T) {
+	tests := []struct {
+		name        string
+		blockFuncs  []BlockFunc
+		command     string
+		shouldBlock bool
+	}{
+		{
+			name: "block simple command",
+			blockFuncs: []BlockFunc{
+				func(args []string) bool {
+					return len(args) > 0 && args[0] == "curl"
+				},
+			},
+			command:     "curl https://example.com",
+			shouldBlock: true,
+		},
+		{
+			name: "allow non-blocked command",
+			blockFuncs: []BlockFunc{
+				func(args []string) bool {
+					return len(args) > 0 && args[0] == "curl"
+				},
+			},
+			command:     "echo hello",
+			shouldBlock: false,
+		},
+		{
+			name: "block subcommand",
+			blockFuncs: []BlockFunc{
+				func(args []string) bool {
+					return len(args) >= 2 && args[0] == "brew" && args[1] == "install"
+				},
+			},
+			command:     "brew install wget",
+			shouldBlock: true,
+		},
+		{
+			name: "allow different subcommand",
+			blockFuncs: []BlockFunc{
+				func(args []string) bool {
+					return len(args) >= 2 && args[0] == "brew" && args[1] == "install"
+				},
+			},
+			command:     "brew list",
+			shouldBlock: false,
+		},
+		{
+			name: "block npm global install with -g",
+			blockFuncs: []BlockFunc{
+				ArgumentsBlocker([][]string{
+					{"npm", "install", "-g"},
+					{"npm", "install", "--global"},
+				}),
+			},
+			command:     "npm install -g typescript",
+			shouldBlock: true,
+		},
+		{
+			name: "block npm global install with --global",
+			blockFuncs: []BlockFunc{
+				ArgumentsBlocker([][]string{
+					{"npm", "install", "-g"},
+					{"npm", "install", "--global"},
+				}),
+			},
+			command:     "npm install --global typescript",
+			shouldBlock: true,
+		},
+		{
+			name: "allow npm local install",
+			blockFuncs: []BlockFunc{
+				ArgumentsBlocker([][]string{
+					{"npm", "install", "-g"},
+					{"npm", "install", "--global"},
+				}),
+			},
+			command:     "npm install typescript",
+			shouldBlock: false,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			// Create a temporary directory for each test
+			tmpDir, err := os.MkdirTemp("", "shell-test-*")
+			if err != nil {
+				t.Fatalf("Failed to create temp dir: %v", err)
+			}
+			defer os.RemoveAll(tmpDir)
+
+			shell := NewShell(&Options{
+				WorkingDir: tmpDir,
+				BlockFuncs: tt.blockFuncs,
+			})
+
+			_, _, err = shell.Exec(context.Background(), tt.command)
+
+			if tt.shouldBlock {
+				if err == nil {
+					t.Errorf("Expected command to be blocked, but it was allowed")
+				} else if !strings.Contains(err.Error(), "not allowed for security reasons") {
+					t.Errorf("Expected security error, got: %v", err)
+				}
+			} else {
+				// For non-blocked commands, we might get other errors (like command not found)
+				// but we shouldn't get the security error
+				if err != nil && strings.Contains(err.Error(), "not allowed for security reasons") {
+					t.Errorf("Command was unexpectedly blocked: %v", err)
+				}
+			}
+		})
+	}
+}

internal/shell/shell.go 🔗

@@ -44,12 +44,16 @@ type noopLogger struct{}
 
 func (noopLogger) InfoPersist(msg string, keysAndValues ...interface{}) {}
 
+// BlockFunc is a function that determines if a command should be blocked
+type BlockFunc func(args []string) bool
+
 // Shell provides cross-platform shell execution with optional state persistence
 type Shell struct {
-	env    []string
-	cwd    string
-	mu     sync.Mutex
-	logger Logger
+	env        []string
+	cwd        string
+	mu         sync.Mutex
+	logger     Logger
+	blockFuncs []BlockFunc
 }
 
 // Options for creating a new shell
@@ -57,6 +61,7 @@ type Options struct {
 	WorkingDir string
 	Env        []string
 	Logger     Logger
+	BlockFuncs []BlockFunc
 }
 
 // NewShell creates a new shell instance with the given options
@@ -81,9 +86,10 @@ func NewShell(opts *Options) *Shell {
 	}
 
 	return &Shell{
-		cwd:    cwd,
-		env:    env,
-		logger: logger,
+		cwd:        cwd,
+		env:        env,
+		logger:     logger,
+		blockFuncs: opts.BlockFuncs,
 	}
 }
 
@@ -152,6 +158,13 @@ func (s *Shell) SetEnv(key, value string) {
 	s.env = append(s.env, keyPrefix+value)
 }
 
+// SetBlockFuncs sets the command block functions for the shell
+func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	s.blockFuncs = blockFuncs
+}
+
 // Windows-specific commands that should use native shell
 var windowsNativeCommands = map[string]bool{
 	"dir":      true,
@@ -203,6 +216,60 @@ func (s *Shell) determineShellType(command string) ShellType {
 	return ShellTypePOSIX
 }
 
+// CommandsBlocker creates a BlockFunc that blocks exact command matches
+func CommandsBlocker(bannedCommands []string) BlockFunc {
+	bannedSet := make(map[string]bool)
+	for _, cmd := range bannedCommands {
+		bannedSet[cmd] = true
+	}
+
+	return func(args []string) bool {
+		if len(args) == 0 {
+			return false
+		}
+		return bannedSet[args[0]]
+	}
+}
+
+// ArgumentsBlocker creates a BlockFunc that blocks specific subcommands
+func ArgumentsBlocker(blockedSubCommands [][]string) BlockFunc {
+	return func(args []string) bool {
+		for _, blocked := range blockedSubCommands {
+			if len(args) >= len(blocked) {
+				match := true
+				for i, part := range blocked {
+					if args[i] != part {
+						match = false
+						break
+					}
+				}
+				if match {
+					return true
+				}
+			}
+		}
+		return false
+	}
+}
+
+func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
+	return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
+		return func(ctx context.Context, args []string) error {
+			if len(args) == 0 {
+				return next(ctx, args)
+			}
+
+			for _, blockFunc := range s.blockFuncs {
+				if blockFunc(args) {
+					return fmt.Errorf("command is not allowed for security reasons: %s", strings.Join(args, " "))
+				}
+			}
+
+			return next(ctx, args)
+		}
+	}
+}
+
 // execWindows executes commands using native Windows shells (cmd.exe or PowerShell)
 func (s *Shell) execWindows(ctx context.Context, command string, shell string) (string, string, error) {
 	var cmd *exec.Cmd
@@ -291,6 +358,7 @@ func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string,
 		interp.Interactive(false),
 		interp.Env(expand.ListEnviron(s.env...)),
 		interp.Dir(s.cwd),
+		interp.ExecHandlers(s.blockHandler()),
 	)
 	if err != nil {
 		return "", "", fmt.Errorf("could not run command: %w", err)

internal/tui/components/chat/editor/editor.go 🔗

@@ -361,8 +361,9 @@ func (m *editorCmp) startCompletions() tea.Msg {
 		})
 	}
 
-	x := m.textarea.Cursor().X + m.x + 1
-	y := m.textarea.Cursor().Y + m.y + 1
+	cur := m.textarea.Cursor()
+	x := cur.X + m.x // adjust for padding
+	y := cur.Y + m.y + 1
 	return completions.OpenCompletionsMsg{
 		Completions: completionItems,
 		X:           x,

internal/tui/components/chat/messages/renderer.go 🔗

@@ -207,6 +207,7 @@ func (br bashRenderer) Render(v *toolCallCmp) string {
 	}
 
 	cmd := strings.ReplaceAll(params.Command, "\n", " ")
+	cmd = strings.ReplaceAll(cmd, "\t", "    ")
 	args := newParamBuilder().addMain(cmd).build()
 
 	return br.renderWithParams(v, "Bash", args, func() string {
@@ -578,8 +579,8 @@ func renderParamList(nested bool, paramsWidth int, params ...string) string {
 		return ""
 	}
 	mainParam := params[0]
-	if paramsWidth-3 >= 0 && len(mainParam) > paramsWidth {
-		mainParam = mainParam[:paramsWidth-3] + "…"
+	if paramsWidth >= 0 && lipgloss.Width(mainParam) > paramsWidth {
+		mainParam = ansi.Truncate(mainParam, paramsWidth, "…")
 	}
 
 	if len(params) == 1 {
@@ -649,7 +650,7 @@ func joinHeaderBody(header, body string) string {
 		return header
 	}
 	body = t.S().Base.PaddingLeft(2).Render(body)
-	return lipgloss.JoinVertical(lipgloss.Left, header, "", body, "")
+	return lipgloss.JoinVertical(lipgloss.Left, header, "", body)
 }
 
 func renderPlainContent(v *toolCallCmp, content string) string {

internal/tui/components/chat/splash/keys.go 🔗

@@ -30,11 +30,11 @@ func DefaultKeyMap() KeyMap {
 			key.WithHelp("↑", "previous item"),
 		),
 		Yes: key.NewBinding(
-			key.WithKeys("y"),
+			key.WithKeys("y", "Y"),
 			key.WithHelp("y", "yes"),
 		),
 		No: key.NewBinding(
-			key.WithKeys("n"),
+			key.WithKeys("n", "N"),
 			key.WithHelp("n", "no"),
 		),
 		Tab: key.NewBinding(

internal/tui/components/completions/completions.go 🔗

@@ -9,6 +9,8 @@ import (
 	"github.com/charmbracelet/lipgloss/v2"
 )
 
+const maxCompletionsHeight = 10
+
 type Completion struct {
 	Title string // The title of the completion item
 	Value any    // The value of the completion item
@@ -43,7 +45,7 @@ type Completions interface {
 type completionsCmp struct {
 	width  int
 	height int  // Height of the completions component`
-	x      int  // X position for the completions popup\
+	x      int  // X position for the completions popup
 	y      int  // Y position for the completions popup
 	open   bool // Indicates if the completions are open
 	keyMap KeyMap
@@ -150,18 +152,25 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 		if !c.open {
 			return c, nil // If completions are not open, do nothing
 		}
-		cmd := c.list.Filter(msg.Query)
-		c.height = max(min(10, len(c.list.Items())), 1)
-		return c, tea.Batch(
-			cmd,
-			c.list.SetSize(c.width, c.height),
-		)
+		var cmds []tea.Cmd
+		cmds = append(cmds, c.list.Filter(msg.Query))
+		itemsLen := len(c.list.Items())
+		c.height = max(min(maxCompletionsHeight, itemsLen), 1)
+		cmds = append(cmds, c.list.SetSize(c.width, c.height))
+		if itemsLen == 0 {
+			// Close completions if no items match the query
+			cmds = append(cmds, util.CmdHandler(CloseCompletionsMsg{}))
+		}
+		return c, tea.Batch(cmds...)
 	}
 	return c, nil
 }
 
 // View implements Completions.
 func (c *completionsCmp) View() string {
+	if !c.open {
+		return ""
+	}
 	if len(c.list.Items()) == 0 {
 		return c.style().Render("No completions found")
 	}

internal/tui/components/dialogs/permissions/keys.go 🔗

@@ -109,9 +109,5 @@ func (k KeyMap) ShortHelp() []key.Binding {
 			key.WithKeys("shift+left", "shift+down", "shift+up", "shift+right"),
 			key.WithHelp("shift+←↓↑→", "scroll"),
 		),
-		key.NewBinding(
-			key.WithKeys("shift+h", "shift+j", "shift+k", "shift+l"),
-			key.WithHelp("shift+hjkl", "scroll"),
-		),
 	}
 }

internal/tui/exp/diffview/diffview.go 🔗

@@ -365,7 +365,8 @@ func (dv *DiffView) renderUnified() string {
 	shouldWrite := func() bool { return printedLines >= 0 }
 
 	getContent := func(in string, ls LineStyle) (content string, leadingEllipsis bool) {
-		content = strings.TrimSuffix(in, "\n")
+		content = strings.ReplaceAll(in, "\r\n", "\n")
+		content = strings.TrimSuffix(content, "\n")
 		content = dv.hightlightCode(content, ls.Code.GetBackground())
 		content = ansi.GraphemeWidth.Cut(content, dv.xOffset, len(content))
 		content = ansi.Truncate(content, dv.codeWidth, "…")
@@ -488,7 +489,8 @@ func (dv *DiffView) renderSplit() string {
 	shouldWrite := func() bool { return printedLines >= 0 }
 
 	getContent := func(in string, ls LineStyle) (content string, leadingEllipsis bool) {
-		content = strings.TrimSuffix(in, "\n")
+		content = strings.ReplaceAll(in, "\r\n", "\n")
+		content = strings.TrimSuffix(content, "\n")
 		content = dv.hightlightCode(content, ls.Code.GetBackground())
 		content = ansi.GraphemeWidth.Cut(content, dv.xOffset, len(content))
 		content = ansi.Truncate(content, dv.codeWidth, "…")

internal/tui/page/chat/chat.go 🔗

@@ -2,7 +2,6 @@ package chat
 
 import (
 	"context"
-	"runtime"
 	"time"
 
 	"github.com/charmbracelet/bubbles/v2/help"
@@ -615,26 +614,12 @@ func (a *chatPage) Help() help.KeyMap {
 			fullList = append(fullList, []key.Binding{v})
 		}
 	case a.isOnboarding && a.splash.IsShowingAPIKey():
-		var pasteKey key.Binding
-		if runtime.GOOS != "darwin" {
-			pasteKey = key.NewBinding(
-				key.WithKeys("ctrl+v"),
-				key.WithHelp("ctrl+v", "paste API key"),
-			)
-		} else {
-			pasteKey = key.NewBinding(
-				key.WithKeys("cmd+v"),
-				key.WithHelp("cmd+v", "paste API key"),
-			)
-		}
 		shortList = append(shortList,
 			// Go back
 			key.NewBinding(
 				key.WithKeys("esc"),
 				key.WithHelp("esc", "back"),
 			),
-			// Paste
-			pasteKey,
 			// Quit
 			key.NewBinding(
 				key.WithKeys("ctrl+c"),

internal/tui/tui.go 🔗

@@ -327,6 +327,9 @@ func (a *appModel) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd {
 			// If the commands dialog is already open, close it
 			return util.CmdHandler(dialogs.CloseDialogMsg{})
 		}
+		if a.dialog.HasDialogs() {
+			return nil // Don't open commands dialog if another dialog is active
+		}
 		return util.CmdHandler(dialogs.OpenDialogMsg{
 			Model: commands.NewCommandDialog(a.selectedSessionID),
 		})

internal/tui/util/util.go 🔗

@@ -1,6 +1,7 @@
 package util
 
 import (
+	"log/slog"
 	"time"
 
 	tea "github.com/charmbracelet/bubbletea/v2"
@@ -22,6 +23,7 @@ func CmdHandler(msg tea.Msg) tea.Cmd {
 }
 
 func ReportError(err error) tea.Cmd {
+	slog.Error("Error reported", "error", err)
 	return CmdHandler(InfoMsg{
 		Type: InfoTypeError,
 		Msg:  err.Error(),

vendor/github.com/Azure/azure-sdk-for-go/sdk/azidentity/go.work.sum 🔗

@@ -1,60 +0,0 @@
-github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.0-beta.1 h1:ODs3brnqQM99Tq1PffODpAViYv3Bf8zOg464MU7p5ew=
-github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.0-beta.1/go.mod h1:3Ug6Qzto9anB6mGlEdgYMDF5zHQ+wwhEaYR4s17PHMw=
-github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.0 h1:fb8kj/Dh4CSwgsOzHeZY4Xh68cFVbzXx+ONXGMY//4w=
-github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.0/go.mod h1:uReU2sSxZExRPBAg3qKzmAucSi51+SP1OhohieR821Q=
-github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0/go.mod h1:okt5dMMTOFjX/aovMlrjvvXoPMBVSPzk9185BT0+eZM=
-github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2/go.mod h1:yInRyqWXAuaPrgI7p70+lDDgh3mlBohis29jGMISnmc=
-github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI=
-github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
-github.com/keybase/dbus v0.0.0-20220506165403-5aa21ea2c23a/go.mod h1:YPNKjjE7Ubp9dTbnWvsP3HT+hYnY6TfXzubYTBeUxc8=
-github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
-github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
-github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
-github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
-github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow=
-github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
-github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
-github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
-github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
-github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
-github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
-github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
-github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
-github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
-github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
-github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
-github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
-golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
-golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
-golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
-golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
-golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
-golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
-golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
-golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
-golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
-golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
-golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
-golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
-golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY=
-golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
-golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o=
-golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
-golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
-golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
-golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0=
-golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
-golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk=
-golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
-golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
-golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
-golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
-golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
-golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
-gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
-gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
-gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

vendor/github.com/charmbracelet/fang/README.md 🔗

@@ -1,7 +1,7 @@
 # Fang
 
 <p>
-    <img width="485" alt="Charm Fang" src="https://github.com/user-attachments/assets/3f34ea01-3750-4760-beb2-a1b700e110f5">   
+    <img width="485" alt="Charm Fang" src="https://github.com/user-attachments/assets/3f34ea01-3750-4760-beb2-a1b700e110f5">
 </p>
 <p>
     <a href="https://github.com/charmbracelet/fang/releases"><img src="https://img.shields.io/github/release/charmbracelet/fang.svg" alt="Latest Release"></a>
@@ -12,7 +12,7 @@
 The CLI starter kit. A small, experimental library for batteries-included [Cobra][cobra] applications.
 
 <p>
-    <img width="865" alt="fang-02" src="https://github.com/user-attachments/assets/7f68ec3f-2b42-4188-a750-7e2808696132" />
+    <img width="859" alt="The Charm Fang mascot and title treatment" src="https://github.com/user-attachments/assets/5c35e1fa-9577-4f81-a879-3ddb4d4a43f0" />
 </p>
 
 ## Features
@@ -45,6 +45,7 @@ To use it, invoke `fang.Execute` passing your root `*cobra.Command`:
 package main
 
 import (
+	"context"
 	"os"
 
 	"github.com/charmbracelet/fang"
@@ -56,7 +57,7 @@ func main() {
 		Use:   "example",
 		Short: "A simple example program!",
 	}
-	if err := fang.Execute(context.TODO(), cmd); err != nil {
+	if err := fang.Execute(context.Background(), cmd); err != nil {
 		os.Exit(1)
 	}
 }

vendor/github.com/charmbracelet/fang/fang.go 🔗

@@ -4,11 +4,14 @@ package fang
 import (
 	"context"
 	"fmt"
+	"io"
 	"os"
+	"os/signal"
 	"runtime/debug"
 
 	"github.com/charmbracelet/colorprofile"
 	"github.com/charmbracelet/lipgloss/v2"
+	"github.com/charmbracelet/x/term"
 	mango "github.com/muesli/mango-cobra"
 	"github.com/muesli/roff"
 	"github.com/spf13/cobra"
@@ -16,12 +19,24 @@ import (
 
 const shaLen = 7
 
+// ErrorHandler handles an error, printing them to the given [io.Writer].
+//
+// Note that this will only be used if the STDERR is a terminal, and should
+// be used for styling only.
+type ErrorHandler = func(w io.Writer, styles Styles, err error)
+
+// ColorSchemeFunc gets a [lipgloss.LightDarkFunc] and returns a [ColorScheme].
+type ColorSchemeFunc = func(lipgloss.LightDarkFunc) ColorScheme
+
 type settings struct {
 	completions bool
 	manpages    bool
+	skipVersion bool
 	version     string
 	commit      string
-	theme       *ColorScheme
+	colorscheme ColorSchemeFunc
+	errHandler  ErrorHandler
+	signals     []os.Signal
 }
 
 // Option changes fang settings.
@@ -41,10 +56,21 @@ func WithoutManpage() Option {
 	}
 }
 
+// WithColorSchemeFunc sets a function that return colorscheme.
+func WithColorSchemeFunc(cs ColorSchemeFunc) Option {
+	return func(s *settings) {
+		s.colorscheme = cs
+	}
+}
+
 // WithTheme sets the colorscheme.
+//
+// Deprecated: use [WithColorSchemeFunc] instead.
 func WithTheme(theme ColorScheme) Option {
 	return func(s *settings) {
-		s.theme = &theme
+		s.colorscheme = func(lipgloss.LightDarkFunc) ColorScheme {
+			return theme
+		}
 	}
 }
 
@@ -55,6 +81,13 @@ func WithVersion(version string) Option {
 	}
 }
 
+// WithoutVersion skips the `-v`/`--version` functionality.
+func WithoutVersion() Option {
+	return func(s *settings) {
+		s.skipVersion = true
+	}
+}
+
 // WithCommit sets the commit SHA.
 func WithCommit(commit string) Option {
 	return func(s *settings) {
@@ -62,30 +95,45 @@ func WithCommit(commit string) Option {
 	}
 }
 
+// WithErrorHandler sets the error handler.
+func WithErrorHandler(handler ErrorHandler) Option {
+	return func(s *settings) {
+		s.errHandler = handler
+	}
+}
+
+// WithNotifySignal sets the signals that should interrupt the execution of the
+// program.
+func WithNotifySignal(signals ...os.Signal) Option {
+	return func(s *settings) {
+		s.signals = signals
+	}
+}
+
 // Execute applies fang to the command and executes it.
 func Execute(ctx context.Context, root *cobra.Command, options ...Option) error {
 	opts := settings{
 		manpages:    true,
 		completions: true,
+		colorscheme: DefaultColorScheme,
+		errHandler:  DefaultErrorHandler,
 	}
+
 	for _, option := range options {
 		option(&opts)
 	}
 
-	if opts.theme == nil {
-		isDark := lipgloss.HasDarkBackground(os.Stdin, os.Stderr)
-		t := DefaultTheme(isDark)
-		opts.theme = &t
+	helpFunc := func(c *cobra.Command, _ []string) {
+		w := colorprofile.NewWriter(c.OutOrStdout(), os.Environ())
+		helpFn(c, w, makeStyles(mustColorscheme(opts.colorscheme)))
 	}
 
-	styles := makeStyles(*opts.theme)
-
-	root.SetHelpFunc(func(c *cobra.Command, _ []string) {
-		w := colorprofile.NewWriter(c.OutOrStdout(), os.Environ())
-		helpFn(c, w, styles)
-	})
 	root.SilenceUsage = true
 	root.SilenceErrors = true
+	if !opts.skipVersion {
+		root.Version = buildVersion(opts)
+	}
+	root.SetHelpFunc(helpFunc)
 
 	if opts.manpages {
 		root.AddCommand(&cobra.Command{
@@ -108,34 +156,49 @@ func Execute(ctx context.Context, root *cobra.Command, options ...Option) error
 		})
 	}
 
-	if opts.completions {
-		root.InitDefaultCompletionCmd()
-	} else {
+	if !opts.completions {
 		root.CompletionOptions.DisableDefaultCmd = true
 	}
 
-	if opts.version == "" {
-		if info, ok := debug.ReadBuildInfo(); ok && info.Main.Sum != "" {
-			opts.version = info.Main.Version
-			opts.commit = getKey(info, "vcs.revision")
-		} else {
-			opts.version = "unknown (built from source)"
-		}
-	}
-	if len(opts.commit) >= shaLen {
-		opts.version += " (" + opts.commit[:shaLen] + ")"
+	if len(opts.signals) > 0 {
+		var cancel context.CancelFunc
+		ctx, cancel = signal.NotifyContext(ctx, opts.signals...)
+		defer cancel()
 	}
 
-	root.Version = opts.version
-
 	if err := root.ExecuteContext(ctx); err != nil {
+		if w, ok := root.ErrOrStderr().(term.File); ok {
+			// if stderr is not a tty, simply print the error without any
+			// styling or going through an [ErrorHandler]:
+			if !term.IsTerminal(w.Fd()) {
+				_, _ = fmt.Fprintln(w, err.Error())
+				return err //nolint:wrapcheck
+			}
+		}
 		w := colorprofile.NewWriter(root.ErrOrStderr(), os.Environ())
-		writeError(w, styles, err)
+		opts.errHandler(w, makeStyles(mustColorscheme(opts.colorscheme)), err)
 		return err //nolint:wrapcheck
 	}
 	return nil
 }
 
+func buildVersion(opts settings) string {
+	commit := opts.commit
+	version := opts.version
+	if version == "" {
+		if info, ok := debug.ReadBuildInfo(); ok && info.Main.Sum != "" {
+			version = info.Main.Version
+			commit = getKey(info, "vcs.revision")
+		} else {
+			version = "unknown (built from source)"
+		}
+	}
+	if len(commit) >= shaLen {
+		version += " (" + commit[:shaLen] + ")"
+	}
+	return version
+}
+
 func getKey(info *debug.BuildInfo, key string) string {
 	if info == nil {
 		return ""

vendor/github.com/charmbracelet/fang/help.go 🔗

@@ -3,7 +3,10 @@ package fang
 import (
 	"cmp"
 	"fmt"
+	"io"
+	"iter"
 	"os"
+	"reflect"
 	"regexp"
 	"strconv"
 	"strings"
@@ -20,6 +23,7 @@ import (
 const (
 	minSpace = 10
 	shortPad = 2
+	longPad  = 4
 )
 
 var width = sync.OnceValue(func() int {
@@ -45,65 +49,95 @@ func helpFn(c *cobra.Command, w *colorprofile.Writer, styles Styles) {
 		blockWidth = max(blockWidth, lipgloss.Width(ex))
 	}
 	blockWidth = min(width()-padding, blockWidth+padding)
+	blockStyle := styles.Codeblock.Base.Width(blockWidth)
 
-	styles.Codeblock.Base = styles.Codeblock.Base.Width(blockWidth)
+	// if the color profile is ascii or notty, or if the block has no
+	// background color set, remove the vertical padding.
+	if w.Profile <= colorprofile.Ascii || reflect.DeepEqual(blockStyle.GetBackground(), lipgloss.NoColor{}) {
+		blockStyle = blockStyle.PaddingTop(0).PaddingBottom(0)
+	}
 
 	_, _ = fmt.Fprintln(w, styles.Title.Render("usage"))
-	_, _ = fmt.Fprintln(w, styles.Codeblock.Base.Render(usage))
+	_, _ = fmt.Fprintln(w, blockStyle.Render(usage))
 	if len(examples) > 0 {
-		cw := styles.Codeblock.Base.GetWidth() - styles.Codeblock.Base.GetHorizontalPadding()
+		cw := blockStyle.GetWidth() - blockStyle.GetHorizontalPadding()
 		_, _ = fmt.Fprintln(w, styles.Title.Render("examples"))
 		for i, example := range examples {
 			if lipgloss.Width(example) > cw {
 				examples[i] = ansi.Truncate(example, cw, "…")
 			}
 		}
-		_, _ = fmt.Fprintln(w, styles.Codeblock.Base.Render(strings.Join(examples, "\n")))
+		_, _ = fmt.Fprintln(w, blockStyle.Render(strings.Join(examples, "\n")))
 	}
 
+	groups, groupKeys := evalGroups(c)
 	cmds, cmdKeys := evalCmds(c, styles)
 	flags, flagKeys := evalFlags(c, styles)
 	space := calculateSpace(cmdKeys, flagKeys)
 
-	leftPadding := 4
-	if len(cmds) > 0 {
-		_, _ = fmt.Fprintln(w, styles.Title.Render("commands"))
-		for _, k := range cmdKeys {
-			_, _ = fmt.Fprintln(w, lipgloss.JoinHorizontal(
-				lipgloss.Left,
-				lipgloss.NewStyle().PaddingLeft(leftPadding).Render(k),
-				strings.Repeat(" ", space-lipgloss.Width(k)),
-				cmds[k],
-			))
+	for _, groupID := range groupKeys {
+		group := cmds[groupID]
+		if len(group) == 0 {
+			continue
 		}
+		renderGroup(w, styles, space, groups[groupID], func(yield func(string, string) bool) {
+			for _, k := range cmdKeys {
+				cmds, ok := group[k]
+				if !ok {
+					continue
+				}
+				if !yield(k, cmds) {
+					return
+				}
+			}
+		})
 	}
 
 	if len(flags) > 0 {
-		_, _ = fmt.Fprintln(w, styles.Title.Render("flags"))
-		for _, k := range flagKeys {
-			_, _ = fmt.Fprintln(w, lipgloss.JoinHorizontal(
-				lipgloss.Left,
-				lipgloss.NewStyle().PaddingLeft(leftPadding).Render(k),
-				strings.Repeat(" ", space-lipgloss.Width(k)),
-				flags[k],
-			))
-		}
+		renderGroup(w, styles, space, "flags", func(yield func(string, string) bool) {
+			for _, k := range flagKeys {
+				if !yield(k, flags[k]) {
+					return
+				}
+			}
+		})
 	}
 
 	_, _ = fmt.Fprintln(w)
 }
 
-func writeError(w *colorprofile.Writer, styles Styles, err error) {
+// DefaultErrorHandler is the default [ErrorHandler] implementation.
+func DefaultErrorHandler(w io.Writer, styles Styles, err error) {
 	_, _ = fmt.Fprintln(w, styles.ErrorHeader.String())
 	_, _ = fmt.Fprintln(w, styles.ErrorText.Render(err.Error()+"."))
 	_, _ = fmt.Fprintln(w)
-	_, _ = fmt.Fprintln(w, lipgloss.JoinHorizontal(
-		lipgloss.Left,
-		styles.ErrorText.UnsetWidth().Render("Try"),
-		styles.Program.Flag.Render("--help"),
-		styles.ErrorText.UnsetWidth().UnsetMargins().UnsetTransform().PaddingLeft(1).Render("for usage."),
-	))
-	_, _ = fmt.Fprintln(w)
+	if isUsageError(err) {
+		_, _ = fmt.Fprintln(w, lipgloss.JoinHorizontal(
+			lipgloss.Left,
+			styles.ErrorText.UnsetWidth().Render("Try"),
+			styles.Program.Flag.Render(" --help "),
+			styles.ErrorText.UnsetWidth().UnsetMargins().UnsetTransform().Render("for usage."),
+		))
+		_, _ = fmt.Fprintln(w)
+	}
+}
+
+// XXX: this is a hack to detect usage errors.
+// See: https://github.com/spf13/cobra/pull/2266
+func isUsageError(err error) bool {
+	s := err.Error()
+	for _, prefix := range []string{
+		"flag needs an argument:",
+		"unknown flag:",
+		"unknown shorthand flag:",
+		"unknown command",
+		"invalid argument",
+	} {
+		if strings.HasPrefix(s, prefix) {
+			return true
+		}
+	}
+	return false
 }
 
 func writeLongShort(w *colorprofile.Writer, styles Styles, longShort string) {
@@ -118,8 +152,10 @@ var otherArgsRe = regexp.MustCompile(`(\[.*\])`)
 
 // styleUsage stylized styleUsage line for a given command.
 func styleUsage(c *cobra.Command, styles Program, complete bool) string {
-	// XXX: maybe use c.UseLine() here?
 	u := c.Use
+	if complete {
+		u = c.UseLine()
+	}
 	hasArgs := strings.Contains(u, "[args]")
 	hasFlags := strings.Contains(u, "[flags]") || strings.Contains(u, "[--flags]") || c.HasFlags() || c.HasPersistentFlags() || c.HasAvailableFlags()
 	hasCommands := strings.Contains(u, "[command]") || c.HasAvailableSubCommands()
@@ -139,34 +175,38 @@ func styleUsage(c *cobra.Command, styles Program, complete bool) string {
 
 	u = strings.TrimSpace(u)
 
-	useLine := []string{
-		styles.Name.Render(u),
-	}
-	if !complete {
-		useLine[0] = styles.Command.Render(u)
+	useLine := []string{}
+	if complete {
+		parts := strings.Fields(u)
+		useLine = append(useLine, styles.Name.Render(parts[0]))
+		if len(parts) > 1 {
+			useLine = append(useLine, styles.Command.Render(" "+strings.Join(parts[1:], " ")))
+		}
+	} else {
+		useLine = append(useLine, styles.Command.Render(u))
 	}
 	if hasCommands {
 		useLine = append(
 			useLine,
-			styles.DimmedArgument.Render("[command]"),
+			styles.DimmedArgument.Render(" [command]"),
 		)
 	}
 	if hasArgs {
 		useLine = append(
 			useLine,
-			styles.DimmedArgument.Render("[args]"),
+			styles.DimmedArgument.Render(" [args]"),
 		)
 	}
 	for _, arg := range otherArgs {
 		useLine = append(
 			useLine,
-			styles.DimmedArgument.Render(arg),
+			styles.DimmedArgument.Render(" "+arg),
 		)
 	}
 	if hasFlags {
 		useLine = append(
 			useLine,
-			styles.DimmedArgument.Render("[--flags]"),
+			styles.DimmedArgument.Render(" [--flags]"),
 		)
 	}
 	return lipgloss.JoinHorizontal(lipgloss.Left, useLine...)
@@ -180,19 +220,21 @@ func styleExamples(c *cobra.Command, styles Styles) []string {
 	}
 	usage := []string{}
 	examples := strings.Split(c.Example, "\n")
+	var indent bool
 	for i, line := range examples {
 		line = strings.TrimSpace(line)
 		if (i == 0 || i == len(examples)-1) && line == "" {
 			continue
 		}
-		s := styleExample(c, line, styles.Codeblock)
+		s := styleExample(c, line, indent, styles.Codeblock)
 		usage = append(usage, s)
+		indent = len(line) > 1 && (line[len(line)-1] == '\\' || line[len(line)-1] == '|')
 	}
 
 	return usage
 }
 
-func styleExample(c *cobra.Command, line string, styles Codeblock) string {
+func styleExample(c *cobra.Command, line string, indent bool, styles Codeblock) string {
 	if strings.HasPrefix(line, "# ") {
 		return lipgloss.JoinHorizontal(
 			lipgloss.Left,
@@ -200,66 +242,110 @@ func styleExample(c *cobra.Command, line string, styles Codeblock) string {
 		)
 	}
 
-	args := strings.Fields(line)
-	var nextIsFlag bool
 	var isQuotedString bool
+	var foundProgramName bool
+	var isRedirecting bool
+	programName := c.Root().Name()
+	args := strings.Fields(line)
+	var cleanArgs []string
 	for i, arg := range args {
-		if i == 0 {
-			args[i] = styles.Program.Name.Render(arg)
-			continue
+		isQuoteStart := arg[0] == '"' || arg[0] == '\''
+		isQuoteEnd := arg[len(arg)-1] == '"' || arg[len(arg)-1] == '\''
+		isFlag := arg[0] == '-'
+
+		switch i {
+		case 0:
+			args[i] = ""
+			if indent {
+				args[i] = styles.Program.DimmedArgument.Render("  ")
+				indent = false
+			}
+		default:
+			args[i] = styles.Program.DimmedArgument.Render(" ")
 		}
 
-		quoteStart := arg[0] == '"'
-		quoteEnd := arg[len(arg)-1] == '"'
-		flagStart := arg[0] == '-'
-		if i == 1 && !quoteStart && !flagStart {
-			args[i] = styles.Program.Command.Render(arg)
+		if isRedirecting {
+			args[i] += styles.Program.DimmedArgument.Render(arg)
+			isRedirecting = false
 			continue
 		}
-		if quoteStart {
-			isQuotedString = true
-		}
-		if isQuotedString {
-			args[i] = styles.Program.QuotedString.Render(arg)
-			if quoteEnd {
-				isQuotedString = false
+
+		switch arg {
+		case "\\":
+			if i == len(args)-1 {
+				args[i] += styles.Program.DimmedArgument.Render(arg)
+				continue
 			}
+		case "|", "||", "-", "&", "&&":
+			args[i] += styles.Program.DimmedArgument.Render(arg)
 			continue
 		}
-		if nextIsFlag {
-			args[i] = styles.Program.Flag.Render(arg)
+
+		if isRedirect(arg) {
+			args[i] += styles.Program.DimmedArgument.Render(arg)
+			isRedirecting = true
 			continue
 		}
-		var dashes string
-		if strings.HasPrefix(arg, "-") {
-			dashes = "-"
+
+		if !foundProgramName { //nolint:nestif
+			if isQuotedString {
+				args[i] += styles.Program.QuotedString.Render(arg)
+				isQuotedString = !isQuoteEnd
+				continue
+			}
+			if left, right, ok := strings.Cut(arg, "="); ok {
+				args[i] += styles.Program.Flag.Render(left + "=")
+				if right[0] == '"' {
+					isQuotedString = true
+					args[i] += styles.Program.QuotedString.Render(right)
+					continue
+				}
+				args[i] += styles.Program.Argument.Render(right)
+				continue
+			}
+
+			if arg == programName {
+				args[i] += styles.Program.Name.Render(arg)
+				foundProgramName = true
+				continue
+			}
 		}
-		if strings.HasPrefix(arg, "--") {
-			dashes = "--"
+
+		if !isQuoteStart && !isQuotedString && !isFlag {
+			cleanArgs = append(cleanArgs, arg)
+		}
+
+		if !isQuoteStart && !isFlag && isSubCommand(c, cleanArgs, arg) {
+			args[i] += styles.Program.Command.Render(arg)
+			continue
+		}
+		isQuotedString = isQuotedString || isQuoteStart
+		if isQuotedString {
+			args[i] += styles.Program.QuotedString.Render(arg)
+			isQuotedString = !isQuoteEnd
+			continue
 		}
 		// handle a flag
-		if dashes != "" {
+		if isFlag {
 			name, value, ok := strings.Cut(arg, "=")
-			name = strings.TrimPrefix(name, dashes)
 			// it is --flag=value
 			if ok {
-				args[i] = lipgloss.JoinHorizontal(
+				args[i] += lipgloss.JoinHorizontal(
 					lipgloss.Left,
-					styles.Program.Flag.Render(dashes+name+"="),
-					styles.Program.Argument.UnsetPadding().Render(value),
+					styles.Program.Flag.Render(name+"="),
+					styles.Program.Argument.Render(value),
 				)
 				continue
 			}
 			// it is either --bool-flag or --flag value
-			args[i] = lipgloss.JoinHorizontal(
+			args[i] += lipgloss.JoinHorizontal(
 				lipgloss.Left,
-				styles.Program.Flag.Render(dashes+name),
+				styles.Program.Flag.Render(name),
 			)
-			// if the flag is not a bool flag, next arg continues current flag
-			nextIsFlag = !isFlagBool(c, name)
 			continue
 		}
-		args[i] = styles.Program.Argument.Render(arg)
+
+		args[i] += styles.Program.Argument.Render(arg)
 	}
 
 	return lipgloss.JoinHorizontal(
@@ -284,8 +370,7 @@ func evalFlags(c *cobra.Command, styles Styles) (map[string]string, []string) {
 		} else {
 			parts = append(
 				parts,
-				styles.Program.Flag.Render("-"+f.Shorthand),
-				styles.Program.Flag.Render("--"+f.Name),
+				styles.Program.Flag.Render("-"+f.Shorthand+" --"+f.Name),
 			)
 		}
 		key := lipgloss.JoinHorizontal(lipgloss.Left, parts...)
@@ -303,22 +388,50 @@ func evalFlags(c *cobra.Command, styles Styles) (map[string]string, []string) {
 	return flags, keys
 }
 
-func evalCmds(c *cobra.Command, styles Styles) (map[string]string, []string) {
+// result is map[groupID]map[styled cmd name]styled cmd help, and the keys in
+// the order they are defined.
+func evalCmds(c *cobra.Command, styles Styles) (map[string](map[string]string), []string) {
 	padStyle := lipgloss.NewStyle().PaddingLeft(0) //nolint:mnd
 	keys := []string{}
-	cmds := map[string]string{}
+	cmds := map[string]map[string]string{}
 	for _, sc := range c.Commands() {
 		if sc.Hidden {
 			continue
 		}
+		if _, ok := cmds[sc.GroupID]; !ok {
+			cmds[sc.GroupID] = map[string]string{}
+		}
 		key := padStyle.Render(styleUsage(sc, styles.Program, false))
 		help := styles.FlagDescription.Render(sc.Short)
-		cmds[key] = help
+		cmds[sc.GroupID][key] = help
 		keys = append(keys, key)
 	}
 	return cmds, keys
 }
 
+func evalGroups(c *cobra.Command) (map[string]string, []string) {
+	// make sure the default group is the first
+	ids := []string{""}
+	groups := map[string]string{"": "commands"}
+	for _, g := range c.Groups() {
+		groups[g.ID] = g.Title
+		ids = append(ids, g.ID)
+	}
+	return groups, ids
+}
+
+func renderGroup(w io.Writer, styles Styles, space int, name string, items iter.Seq2[string, string]) {
+	_, _ = fmt.Fprintln(w, styles.Title.Render(name))
+	for key, help := range items {
+		_, _ = fmt.Fprintln(w, lipgloss.JoinHorizontal(
+			lipgloss.Left,
+			lipgloss.NewStyle().PaddingLeft(longPad).Render(key),
+			strings.Repeat(" ", space-lipgloss.Width(key)),
+			help,
+		))
+	}
+}
+
 func calculateSpace(k1, k2 []string) int {
 	const spaceBetween = 2
 	space := minSpace
@@ -328,13 +441,18 @@ func calculateSpace(k1, k2 []string) int {
 	return space
 }
 
-func isFlagBool(c *cobra.Command, name string) bool {
-	flag := c.Flags().Lookup(name)
-	if flag == nil && len(name) == 1 {
-		flag = c.Flags().ShorthandLookup(name)
-	}
-	if flag == nil {
-		return false
+func isSubCommand(c *cobra.Command, args []string, word string) bool {
+	cmd, _, _ := c.Root().Traverse(args)
+	return cmd != nil && cmd.Name() == word
+}
+
+var redirectPrefixes = []string{">", "<", "&>", "2>", "1>", ">>", "2>>"}
+
+func isRedirect(s string) bool {
+	for _, p := range redirectPrefixes {
+		if strings.HasPrefix(s, p) {
+			return true
+		}
 	}
-	return flag.Value.Type() == "bool"
+	return false
 }

vendor/github.com/charmbracelet/fang/theme.go 🔗

@@ -2,10 +2,12 @@ package fang
 
 import (
 	"image/color"
+	"os"
 	"strings"
 
 	"github.com/charmbracelet/lipgloss/v2"
 	"github.com/charmbracelet/x/exp/charmtone"
+	"github.com/charmbracelet/x/term"
 	"golang.org/x/text/cases"
 	"golang.org/x/text/language"
 )
@@ -31,8 +33,14 @@ type ColorScheme struct {
 }
 
 // DefaultTheme is the default colorscheme.
+//
+// Deprecated: use [DefaultColorScheme] instead.
 func DefaultTheme(isDark bool) ColorScheme {
-	c := lipgloss.LightDark(isDark)
+	return DefaultColorScheme(lipgloss.LightDark(isDark))
+}
+
+// DefaultColorScheme is the default colorscheme.
+func DefaultColorScheme(c lipgloss.LightDarkFunc) ColorScheme {
 	return ColorScheme{
 		Base:           c(charmtone.Charcoal, charmtone.Ash),
 		Title:          charmtone.Charple,
@@ -45,7 +53,7 @@ func DefaultTheme(isDark bool) ColorScheme {
 		Argument:       c(charmtone.Charcoal, charmtone.Ash),
 		Description:    c(charmtone.Charcoal, charmtone.Ash), // flag and command descriptions
 		FlagDefault:    c(charmtone.Smoke, charmtone.Squid),  // flag default values in descriptions
-		QuotedString:   c(charmtone.Charcoal, charmtone.Ash),
+		QuotedString:   c(charmtone.Coral, charmtone.Salmon),
 		ErrorHeader: [2]color.Color{
 			charmtone.Butter,
 			charmtone.Cherry,
@@ -53,6 +61,26 @@ func DefaultTheme(isDark bool) ColorScheme {
 	}
 }
 
+// AnsiColorScheme is a ANSI colorscheme.
+func AnsiColorScheme(c lipgloss.LightDarkFunc) ColorScheme {
+	base := c(lipgloss.Black, lipgloss.White)
+	return ColorScheme{
+		Base:         base,
+		Title:        lipgloss.Blue,
+		Description:  base,
+		Comment:      c(lipgloss.BrightWhite, lipgloss.BrightBlack),
+		Flag:         lipgloss.Magenta,
+		FlagDefault:  lipgloss.BrightMagenta,
+		Command:      lipgloss.Cyan,
+		QuotedString: lipgloss.Green,
+		Argument:     base,
+		Help:         base,
+		Dash:         base,
+		ErrorHeader:  [2]color.Color{lipgloss.Black, lipgloss.Red},
+		ErrorDetails: lipgloss.Red,
+	}
+}
+
 // Styles represents all the styles used.
 type Styles struct {
 	Text            lipgloss.Style
@@ -84,6 +112,14 @@ type Program struct {
 	QuotedString   lipgloss.Style
 }
 
+func mustColorscheme(cs func(lipgloss.LightDarkFunc) ColorScheme) ColorScheme {
+	var isDark bool
+	if term.IsTerminal(os.Stdout.Fd()) {
+		isDark = lipgloss.HasDarkBackground(os.Stdin, os.Stdout)
+	}
+	return cs(lipgloss.LightDark(isDark))
+}
+
 func makeStyles(cs ColorScheme) Styles {
 	//nolint:mnd
 	return Styles{
@@ -98,8 +134,7 @@ func makeStyles(cs ColorScheme) Styles {
 			Foreground(cs.Description).
 			Transform(titleFirstWord),
 		FlagDefault: lipgloss.NewStyle().
-			Foreground(cs.FlagDefault).
-			PaddingLeft(1),
+			Foreground(cs.FlagDefault),
 		Codeblock: Codeblock{
 			Base: lipgloss.NewStyle().
 				Background(cs.Codeblock).
@@ -116,23 +151,18 @@ func makeStyles(cs ColorScheme) Styles {
 					Background(cs.Codeblock).
 					Foreground(cs.Program),
 				Flag: lipgloss.NewStyle().
-					PaddingLeft(1).
 					Background(cs.Codeblock).
 					Foreground(cs.Flag),
 				Argument: lipgloss.NewStyle().
-					PaddingLeft(1).
 					Background(cs.Codeblock).
 					Foreground(cs.Argument),
 				DimmedArgument: lipgloss.NewStyle().
-					PaddingLeft(1).
 					Background(cs.Codeblock).
 					Foreground(cs.DimmedArgument),
 				Command: lipgloss.NewStyle().
-					PaddingLeft(1).
 					Background(cs.Codeblock).
 					Foreground(cs.Command),
 				QuotedString: lipgloss.NewStyle().
-					PaddingLeft(1).
 					Background(cs.Codeblock).
 					Foreground(cs.QuotedString),
 			},
@@ -141,18 +171,14 @@ func makeStyles(cs ColorScheme) Styles {
 			Name: lipgloss.NewStyle().
 				Foreground(cs.Program),
 			Argument: lipgloss.NewStyle().
-				PaddingLeft(1).
 				Foreground(cs.Argument),
 			DimmedArgument: lipgloss.NewStyle().
-				PaddingLeft(1).
 				Foreground(cs.DimmedArgument),
 			Flag: lipgloss.NewStyle().
-				PaddingLeft(1).
 				Foreground(cs.Flag),
 			Command: lipgloss.NewStyle().
 				Foreground(cs.Command),
 			QuotedString: lipgloss.NewStyle().
-				PaddingLeft(1).
 				Foreground(cs.QuotedString),
 		},
 		Span: lipgloss.NewStyle().

vendor/github.com/mark3labs/mcp-go/client/client.go 🔗

@@ -22,6 +22,7 @@ type Client struct {
 	requestID          atomic.Int64
 	clientCapabilities mcp.ClientCapabilities
 	serverCapabilities mcp.ServerCapabilities
+	samplingHandler    SamplingHandler
 }
 
 type ClientOption func(*Client)
@@ -33,6 +34,21 @@ func WithClientCapabilities(capabilities mcp.ClientCapabilities) ClientOption {
 	}
 }
 
+// WithSamplingHandler sets the sampling handler for the client.
+// When set, the client will declare sampling capability during initialization.
+func WithSamplingHandler(handler SamplingHandler) ClientOption {
+	return func(c *Client) {
+		c.samplingHandler = handler
+  }
+}
+
+// WithSession assumes a MCP Session has already been initialized
+func WithSession() ClientOption {
+	return func(c *Client) {
+		c.initialized = true
+	}
+}
+
 // NewClient creates a new MCP client with the given transport.
 // Usage:
 //
@@ -71,6 +87,12 @@ func (c *Client) Start(ctx context.Context) error {
 			handler(notification)
 		}
 	})
+
+	// Set up request handler for bidirectional communication (e.g., sampling)
+	if bidirectional, ok := c.transport.(transport.BidirectionalInterface); ok {
+		bidirectional.SetRequestHandler(c.handleIncomingRequest)
+	}
+
 	return nil
 }
 
@@ -127,6 +149,12 @@ func (c *Client) Initialize(
 	ctx context.Context,
 	request mcp.InitializeRequest,
 ) (*mcp.InitializeResult, error) {
+	// Merge client capabilities with sampling capability if handler is configured
+	capabilities := request.Params.Capabilities
+	if c.samplingHandler != nil {
+		capabilities.Sampling = &struct{}{}
+	}
+
 	// Ensure we send a params object with all required fields
 	params := struct {
 		ProtocolVersion string                 `json:"protocolVersion"`
@@ -135,7 +163,7 @@ func (c *Client) Initialize(
 	}{
 		ProtocolVersion: request.Params.ProtocolVersion,
 		ClientInfo:      request.Params.ClientInfo,
-		Capabilities:    request.Params.Capabilities, // Will be empty struct if not set
+		Capabilities:    capabilities,
 	}
 
 	response, err := c.sendRequest(ctx, "initialize", params)
@@ -398,6 +426,64 @@ func (c *Client) Complete(
 	return &result, nil
 }
 
+// handleIncomingRequest processes incoming requests from the server.
+// This is the main entry point for server-to-client requests like sampling.
+func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
+	switch request.Method {
+	case string(mcp.MethodSamplingCreateMessage):
+		return c.handleSamplingRequestTransport(ctx, request)
+	default:
+		return nil, fmt.Errorf("unsupported request method: %s", request.Method)
+	}
+}
+
+// handleSamplingRequestTransport handles sampling requests at the transport level.
+func (c *Client) handleSamplingRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
+	if c.samplingHandler == nil {
+		return nil, fmt.Errorf("no sampling handler configured")
+	}
+
+	// Parse the request parameters
+	var params mcp.CreateMessageParams
+	if request.Params != nil {
+		paramsBytes, err := json.Marshal(request.Params)
+		if err != nil {
+			return nil, fmt.Errorf("failed to marshal params: %w", err)
+		}
+		if err := json.Unmarshal(paramsBytes, &params); err != nil {
+			return nil, fmt.Errorf("failed to unmarshal params: %w", err)
+		}
+	}
+
+	// Create the MCP request
+	mcpRequest := mcp.CreateMessageRequest{
+		Request: mcp.Request{
+			Method: string(mcp.MethodSamplingCreateMessage),
+		},
+		CreateMessageParams: params,
+	}
+
+	// Call the sampling handler
+	result, err := c.samplingHandler.CreateMessage(ctx, mcpRequest)
+	if err != nil {
+		return nil, err
+	}
+
+	// Marshal the result
+	resultBytes, err := json.Marshal(result)
+	if err != nil {
+		return nil, fmt.Errorf("failed to marshal result: %w", err)
+	}
+
+	// Create the transport response
+	response := &transport.JSONRPCResponse{
+		JSONRPC: mcp.JSONRPC_VERSION,
+		ID:      request.ID,
+		Result:  json.RawMessage(resultBytes),
+	}
+
+	return response, nil
+}
 func listByPage[T any](
 	ctx context.Context,
 	client *Client,
@@ -432,3 +518,17 @@ func (c *Client) GetServerCapabilities() mcp.ServerCapabilities {
 func (c *Client) GetClientCapabilities() mcp.ClientCapabilities {
 	return c.clientCapabilities
 }
+
+// GetSessionId returns the session ID of the transport.
+// If the transport does not support sessions, it returns an empty string.
+func (c *Client) GetSessionId() string {
+	if c.transport == nil {
+		return ""
+	}
+	return c.transport.GetSessionId()
+}
+
+// IsInitialized returns true if the client has been initialized.
+func (c *Client) IsInitialized() bool {
+	return c.initialized
+}

vendor/github.com/mark3labs/mcp-go/client/http.go 🔗

@@ -13,5 +13,10 @@ func NewStreamableHttpClient(baseURL string, options ...transport.StreamableHTTP
 	if err != nil {
 		return nil, fmt.Errorf("failed to create SSE transport: %w", err)
 	}
-	return NewClient(trans), nil
+	clientOptions := make([]ClientOption, 0)
+	sessionID := trans.GetSessionId()
+	if sessionID != "" {
+		clientOptions = append(clientOptions, WithSession())
+	}
+	return NewClient(trans, clientOptions...), nil
 }

vendor/github.com/mark3labs/mcp-go/client/sampling.go 🔗

@@ -0,0 +1,20 @@
+package client
+
+import (
+	"context"
+
+	"github.com/mark3labs/mcp-go/mcp"
+)
+
+// SamplingHandler defines the interface for handling sampling requests from servers.
+// Clients can implement this interface to provide LLM sampling capabilities to servers.
+type SamplingHandler interface {
+	// CreateMessage handles a sampling request from the server and returns the generated message.
+	// The implementation should:
+	// 1. Validate the request parameters
+	// 2. Optionally prompt the user for approval (human-in-the-loop)
+	// 3. Select an appropriate model based on preferences
+	// 4. Generate the response using the selected model
+	// 5. Return the result with model information and stop reason
+	CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error)
+}

vendor/github.com/mark3labs/mcp-go/client/stdio.go 🔗

@@ -19,10 +19,26 @@ func NewStdioMCPClient(
 	env []string,
 	args ...string,
 ) (*Client, error) {
+	return NewStdioMCPClientWithOptions(command, env, args)
+}
+
+// NewStdioMCPClientWithOptions creates a new stdio-based MCP client that communicates with a subprocess.
+// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication.
+// Optional configuration functions can be provided to customize the transport before it starts,
+// such as setting a custom command function.
+//
+// NOTICE: NewStdioMCPClientWithOptions automatically starts the underlying transport.
+// Don't call the Start method manually.
+// This is for backward compatibility.
+func NewStdioMCPClientWithOptions(
+	command string,
+	env []string,
+	args []string,
+	opts ...transport.StdioOption,
+) (*Client, error) {
+	stdioTransport := transport.NewStdioWithOptions(command, env, args, opts...)
 
-	stdioTransport := transport.NewStdio(command, env, args...)
-	err := stdioTransport.Start(context.Background())
-	if err != nil {
+	if err := stdioTransport.Start(context.Background()); err != nil {
 		return nil, fmt.Errorf("failed to start stdio transport: %w", err)
 	}
 

vendor/github.com/mark3labs/mcp-go/client/transport/interface.go 🔗

@@ -29,6 +29,22 @@ type Interface interface {
 
 	// Close the connection.
 	Close() error
+
+	// GetSessionId returns the session ID of the transport.
+	GetSessionId() string
+}
+
+// RequestHandler defines a function that handles incoming requests from the server.
+type RequestHandler func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error)
+
+// BidirectionalInterface extends Interface to support incoming requests from the server.
+// This is used for features like sampling where the server can send requests to the client.
+type BidirectionalInterface interface {
+	Interface
+
+	// SetRequestHandler sets the handler for incoming requests from the server.
+	// The handler should process the request and return a response.
+	SetRequestHandler(handler RequestHandler)
 }
 
 type JSONRPCRequest struct {
@@ -41,10 +57,10 @@ type JSONRPCRequest struct {
 type JSONRPCResponse struct {
 	JSONRPC string          `json:"jsonrpc"`
 	ID      mcp.RequestId   `json:"id"`
-	Result  json.RawMessage `json:"result"`
+	Result  json.RawMessage `json:"result,omitempty"`
 	Error   *struct {
 		Code    int             `json:"code"`
 		Message string          `json:"message"`
 		Data    json.RawMessage `json:"data"`
-	} `json:"error"`
+	} `json:"error,omitempty"`
 }

vendor/github.com/mark3labs/mcp-go/client/transport/sse.go 🔗

@@ -428,6 +428,12 @@ func (c *SSE) Close() error {
 	return nil
 }
 
+// GetSessionId returns the session ID of the transport.
+// Since SSE does not maintain a session ID, it returns an empty string.
+func (c *SSE) GetSessionId() string {
+	return ""
+}
+
 // SendNotification sends a JSON-RPC notification to the server without expecting a response.
 func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
 	if c.endpoint == nil {

vendor/github.com/mark3labs/mcp-go/client/transport/stdio.go 🔗

@@ -23,6 +23,7 @@ type Stdio struct {
 	env     []string
 
 	cmd            *exec.Cmd
+	cmdFunc        CommandFunc
 	stdin          io.WriteCloser
 	stdout         *bufio.Reader
 	stderr         io.ReadCloser
@@ -31,6 +32,28 @@ type Stdio struct {
 	done           chan struct{}
 	onNotification func(mcp.JSONRPCNotification)
 	notifyMu       sync.RWMutex
+	onRequest      RequestHandler
+	requestMu      sync.RWMutex
+	ctx            context.Context
+	ctxMu          sync.RWMutex
+}
+
+// StdioOption defines a function that configures a Stdio transport instance.
+// Options can be used to customize the behavior of the transport before it starts,
+// such as setting a custom command function.
+type StdioOption func(*Stdio)
+
+// CommandFunc is a factory function that returns a custom exec.Cmd used to launch the MCP subprocess.
+// It can be used to apply sandboxing, custom environment control, working directories, etc.
+type CommandFunc func(ctx context.Context, command string, env []string, args []string) (*exec.Cmd, error)
+
+// WithCommandFunc sets a custom command factory function for the stdio transport.
+// The CommandFunc is responsible for constructing the exec.Cmd used to launch the subprocess,
+// allowing control over attributes like environment, working directory, and system-level sandboxing.
+func WithCommandFunc(f CommandFunc) StdioOption {
+	return func(s *Stdio) {
+		s.cmdFunc = f
+	}
 }
 
 // NewIO returns a new stdio-based transport using existing input, output, and
@@ -44,6 +67,7 @@ func NewIO(input io.Reader, output io.WriteCloser, logging io.ReadCloser) *Stdio
 
 		responses: make(map[string]chan *JSONRPCResponse),
 		done:      make(chan struct{}),
+		ctx:       context.Background(),
 	}
 }
 
@@ -55,20 +79,43 @@ func NewStdio(
 	env []string,
 	args ...string,
 ) *Stdio {
+	return NewStdioWithOptions(command, env, args)
+}
 
-	client := &Stdio{
+// NewStdioWithOptions creates a new stdio transport to communicate with a subprocess.
+// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication.
+// Returns an error if the subprocess cannot be started or the pipes cannot be created.
+// Optional configuration functions can be provided to customize the transport before it starts,
+// such as setting a custom command factory.
+func NewStdioWithOptions(
+	command string,
+	env []string,
+	args []string,
+	opts ...StdioOption,
+) *Stdio {
+	s := &Stdio{
 		command: command,
 		args:    args,
 		env:     env,
 
 		responses: make(map[string]chan *JSONRPCResponse),
 		done:      make(chan struct{}),
+		ctx:       context.Background(),
+	}
+
+	for _, opt := range opts {
+		opt(s)
 	}
 
-	return client
+	return s
 }
 
 func (c *Stdio) Start(ctx context.Context) error {
+	// Store the context for use in request handling
+	c.ctxMu.Lock()
+	c.ctx = ctx
+	c.ctxMu.Unlock()
+
 	if err := c.spawnCommand(ctx); err != nil {
 		return err
 	}
@@ -83,18 +130,25 @@ func (c *Stdio) Start(ctx context.Context) error {
 	return nil
 }
 
-// spawnCommand spawns a new process running c.command.
+// spawnCommand spawns a new process running the configured command, args, and env.
+// If an (optional) cmdFunc custom command factory function was configured, it will be used to construct the subprocess;
+// otherwise, the default behavior uses exec.CommandContext with the merged environment.
+// Initializes stdin, stdout, and stderr pipes for JSON-RPC communication.
 func (c *Stdio) spawnCommand(ctx context.Context) error {
 	if c.command == "" {
 		return nil
 	}
 
-	cmd := exec.CommandContext(ctx, c.command, c.args...)
-
-	mergedEnv := os.Environ()
-	mergedEnv = append(mergedEnv, c.env...)
+	var cmd *exec.Cmd
+	var err error
 
-	cmd.Env = mergedEnv
+	// Standard behavior if no command func present.
+	if c.cmdFunc == nil {
+		cmd = exec.CommandContext(ctx, c.command, c.args...)
+		cmd.Env = append(os.Environ(), c.env...)
+	} else if cmd, err = c.cmdFunc(ctx, c.command, c.env, c.args); err != nil {
+		return err
+	}
 
 	stdin, err := cmd.StdinPipe()
 	if err != nil {
@@ -148,6 +202,12 @@ func (c *Stdio) Close() error {
 	return nil
 }
 
+// GetSessionId returns the session ID of the transport.
+// Since stdio does not maintain a session ID, it returns an empty string.
+func (c *Stdio) GetSessionId() string {
+	return ""
+}
+
 // SetNotificationHandler sets the handler function to be called when a notification is received.
 // Only one handler can be set at a time; setting a new one replaces the previous handler.
 func (c *Stdio) SetNotificationHandler(
@@ -158,6 +218,14 @@ func (c *Stdio) SetNotificationHandler(
 	c.onNotification = handler
 }
 
+// SetRequestHandler sets the handler function to be called when a request is received from the server.
+// This enables bidirectional communication for features like sampling.
+func (c *Stdio) SetRequestHandler(handler RequestHandler) {
+	c.requestMu.Lock()
+	defer c.requestMu.Unlock()
+	c.onRequest = handler
+}
+
 // readResponses continuously reads and processes responses from the server's stdout.
 // It handles both responses to requests and notifications, routing them appropriately.
 // Runs until the done channel is closed or an error occurs reading from stdout.
@@ -175,13 +243,18 @@ func (c *Stdio) readResponses() {
 				return
 			}
 
-			var baseMessage JSONRPCResponse
+			// First try to parse as a generic message to check for ID field
+			var baseMessage struct {
+				JSONRPC string         `json:"jsonrpc"`
+				ID      *mcp.RequestId `json:"id,omitempty"`
+				Method  string         `json:"method,omitempty"`
+			}
 			if err := json.Unmarshal([]byte(line), &baseMessage); err != nil {
 				continue
 			}
 
-			// Handle notification
-			if baseMessage.ID.IsNil() {
+			// If it has a method but no ID, it's a notification
+			if baseMessage.Method != "" && baseMessage.ID == nil {
 				var notification mcp.JSONRPCNotification
 				if err := json.Unmarshal([]byte(line), &notification); err != nil {
 					continue
@@ -194,15 +267,30 @@ func (c *Stdio) readResponses() {
 				continue
 			}
 
+			// If it has a method and an ID, it's an incoming request
+			if baseMessage.Method != "" && baseMessage.ID != nil {
+				var request JSONRPCRequest
+				if err := json.Unmarshal([]byte(line), &request); err == nil {
+					c.handleIncomingRequest(request)
+					continue
+				}
+			}
+
+			// Otherwise, it's a response to our request
+			var response JSONRPCResponse
+			if err := json.Unmarshal([]byte(line), &response); err != nil {
+				continue
+			}
+
 			// Create string key for map lookup
-			idKey := baseMessage.ID.String()
+			idKey := response.ID.String()
 
 			c.mu.RLock()
 			ch, exists := c.responses[idKey]
 			c.mu.RUnlock()
 
 			if exists {
-				ch <- &baseMessage
+				ch <- &response
 				c.mu.Lock()
 				delete(c.responses, idKey)
 				c.mu.Unlock()
@@ -281,6 +369,96 @@ func (c *Stdio) SendNotification(
 	return nil
 }
 
+// handleIncomingRequest processes incoming requests from the server.
+// It calls the registered request handler and sends the response back to the server.
+func (c *Stdio) handleIncomingRequest(request JSONRPCRequest) {
+	c.requestMu.RLock()
+	handler := c.onRequest
+	c.requestMu.RUnlock()
+
+	if handler == nil {
+		// Send error response if no handler is configured
+		errorResponse := JSONRPCResponse{
+			JSONRPC: mcp.JSONRPC_VERSION,
+			ID:      request.ID,
+			Error: &struct {
+				Code    int             `json:"code"`
+				Message string          `json:"message"`
+				Data    json.RawMessage `json:"data"`
+			}{
+				Code:    mcp.METHOD_NOT_FOUND,
+				Message: "No request handler configured",
+			},
+		}
+		c.sendResponse(errorResponse)
+		return
+	}
+
+	// Handle the request in a goroutine to avoid blocking
+	go func() {
+		c.ctxMu.RLock()
+		ctx := c.ctx
+		c.ctxMu.RUnlock()
+
+		// Check if context is already cancelled before processing
+		select {
+		case <-ctx.Done():
+			errorResponse := JSONRPCResponse{
+				JSONRPC: mcp.JSONRPC_VERSION,
+				ID:      request.ID,
+				Error: &struct {
+					Code    int             `json:"code"`
+					Message string          `json:"message"`
+					Data    json.RawMessage `json:"data"`
+				}{
+					Code:    mcp.INTERNAL_ERROR,
+					Message: ctx.Err().Error(),
+				},
+			}
+			c.sendResponse(errorResponse)
+			return
+		default:
+		}
+
+		response, err := handler(ctx, request)
+
+		if err != nil {
+			errorResponse := JSONRPCResponse{
+				JSONRPC: mcp.JSONRPC_VERSION,
+				ID:      request.ID,
+				Error: &struct {
+					Code    int             `json:"code"`
+					Message string          `json:"message"`
+					Data    json.RawMessage `json:"data"`
+				}{
+					Code:    mcp.INTERNAL_ERROR,
+					Message: err.Error(),
+				},
+			}
+			c.sendResponse(errorResponse)
+			return
+		}
+
+		if response != nil {
+			c.sendResponse(*response)
+		}
+	}()
+}
+
+// sendResponse sends a response back to the server.
+func (c *Stdio) sendResponse(response JSONRPCResponse) {
+	responseBytes, err := json.Marshal(response)
+	if err != nil {
+		fmt.Printf("Error marshaling response: %v\n", err)
+		return
+	}
+	responseBytes = append(responseBytes, '\n')
+
+	if _, err := c.stdin.Write(responseBytes); err != nil {
+		fmt.Printf("Error writing response: %v\n", err)
+	}
+}
+
 // Stderr returns a reader for the stderr output of the subprocess.
 // This can be used to capture error messages or logs from the subprocess.
 func (c *Stdio) Stderr() io.Reader {

vendor/github.com/mark3labs/mcp-go/client/transport/streamable_http.go 🔗

@@ -17,10 +17,24 @@ import (
 	"time"
 
 	"github.com/mark3labs/mcp-go/mcp"
+	"github.com/mark3labs/mcp-go/util"
 )
 
 type StreamableHTTPCOption func(*StreamableHTTP)
 
+// WithContinuousListening enables receiving server-to-client notifications when no request is in flight.
+// In particular, if you want to receive global notifications from the server (like ToolListChangedNotification),
+// you should enable this option.
+//
+// It will establish a standalone long-live GET HTTP connection to the server.
+// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server
+// NOTICE: Even enabled, the server may not support this feature.
+func WithContinuousListening() StreamableHTTPCOption {
+	return func(sc *StreamableHTTP) {
+		sc.getListeningEnabled = true
+	}
+}
+
 // WithHTTPClient sets a custom HTTP client on the StreamableHTTP transport.
 func WithHTTPBasicClient(client *http.Client) StreamableHTTPCOption {
 	return func(sc *StreamableHTTP) {
@@ -54,6 +68,19 @@ func WithHTTPOAuth(config OAuthConfig) StreamableHTTPCOption {
 	}
 }
 
+func WithLogger(logger util.Logger) StreamableHTTPCOption {
+	return func(sc *StreamableHTTP) {
+		sc.logger = logger
+	}
+}
+
+// WithSession creates a client with a pre-configured session
+func WithSession(sessionID string) StreamableHTTPCOption {
+	return func(sc *StreamableHTTP) {
+		sc.sessionID.Store(sessionID)
+	}
+}
+
 // StreamableHTTP implements Streamable HTTP transport.
 //
 // It transmits JSON-RPC messages over individual HTTP requests. One message per request.
@@ -64,19 +91,22 @@ func WithHTTPOAuth(config OAuthConfig) StreamableHTTPCOption {
 //
 // The current implementation does not support the following features:
 //   - batching
-//   - continuously listening for server notifications when no request is in flight
-//     (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server)
 //   - resuming stream
 //     (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery)
 //   - server -> client request
 type StreamableHTTP struct {
-	serverURL  *url.URL
-	httpClient *http.Client
-	headers    map[string]string
-	headerFunc HTTPHeaderFunc
+	serverURL           *url.URL
+	httpClient          *http.Client
+	headers             map[string]string
+	headerFunc          HTTPHeaderFunc
+	logger              util.Logger
+	getListeningEnabled bool
 
 	sessionID atomic.Value // string
 
+	initialized     chan struct{}
+	initializedOnce sync.Once
+
 	notificationHandler func(mcp.JSONRPCNotification)
 	notifyMu            sync.RWMutex
 
@@ -95,15 +125,19 @@ func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*Str
 	}
 
 	smc := &StreamableHTTP{
-		serverURL:  parsedURL,
-		httpClient: &http.Client{},
-		headers:    make(map[string]string),
-		closed:     make(chan struct{}),
+		serverURL:   parsedURL,
+		httpClient:  &http.Client{},
+		headers:     make(map[string]string),
+		closed:      make(chan struct{}),
+		logger:      util.DefaultLogger(),
+		initialized: make(chan struct{}),
 	}
 	smc.sessionID.Store("") // set initial value to simplify later usage
 
 	for _, opt := range options {
-		opt(smc)
+		if opt != nil {
+			opt(smc)
+		}
 	}
 
 	// If OAuth is configured, set the base URL for metadata discovery
@@ -118,7 +152,20 @@ func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*Str
 
 // Start initiates the HTTP connection to the server.
 func (c *StreamableHTTP) Start(ctx context.Context) error {
-	// For Streamable HTTP, we don't need to establish a persistent connection
+	// For Streamable HTTP, we don't need to establish a persistent connection by default
+	if c.getListeningEnabled {
+		go func() {
+			select {
+			case <-c.initialized:
+				ctx, cancel := c.contextAwareOfClientClose(ctx)
+				defer cancel()
+				c.listenForever(ctx)
+			case <-c.closed:
+				return
+			}
+		}()
+	}
+
 	return nil
 }
 
@@ -142,13 +189,13 @@ func (c *StreamableHTTP) Close() error {
 			defer cancel()
 			req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.serverURL.String(), nil)
 			if err != nil {
-				fmt.Printf("failed to create close request\n: %v", err)
+				c.logger.Errorf("failed to create close request: %v", err)
 				return
 			}
 			req.Header.Set(headerKeySessionID, sessionId)
 			res, err := c.httpClient.Do(req)
 			if err != nil {
-				fmt.Printf("failed to send close request\n: %v", err)
+				c.logger.Errorf("failed to send close request: %v", err)
 				return
 			}
 			res.Body.Close()
@@ -185,77 +232,29 @@ func (c *StreamableHTTP) SendRequest(
 	request JSONRPCRequest,
 ) (*JSONRPCResponse, error) {
 
-	// Create a combined context that could be canceled when the client is closed
-	newCtx, cancel := context.WithCancel(ctx)
-	defer cancel()
-	go func() {
-		select {
-		case <-c.closed:
-			cancel()
-		case <-newCtx.Done():
-			// The original context was canceled, no need to do anything
-		}
-	}()
-	ctx = newCtx
-
 	// Marshal request
 	requestBody, err := json.Marshal(request)
 	if err != nil {
 		return nil, fmt.Errorf("failed to marshal request: %w", err)
 	}
 
-	// Create HTTP request
-	req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody))
-	if err != nil {
-		return nil, fmt.Errorf("failed to create request: %w", err)
-	}
-
-	// Set headers
-	req.Header.Set("Content-Type", "application/json")
-	req.Header.Set("Accept", "application/json, text/event-stream")
-	sessionID := c.sessionID.Load()
-	if sessionID != "" {
-		req.Header.Set(headerKeySessionID, sessionID.(string))
-	}
-	for k, v := range c.headers {
-		req.Header.Set(k, v)
-	}
-
-	// Add OAuth authorization if configured
-	if c.oauthHandler != nil {
-		authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
-		if err != nil {
-			// If we get an authorization error, return a specific error that can be handled by the client
-			if err.Error() == "no valid token available, authorization required" {
-				return nil, &OAuthAuthorizationRequiredError{
-					Handler: c.oauthHandler,
-				}
-			}
-			return nil, fmt.Errorf("failed to get authorization header: %w", err)
-		}
-		req.Header.Set("Authorization", authHeader)
-	}
-
-	if c.headerFunc != nil {
-		for k, v := range c.headerFunc(ctx) {
-			req.Header.Set(k, v)
-		}
-	}
+	ctx, cancel := c.contextAwareOfClientClose(ctx)
+	defer cancel()
 
-	// Send request
-	resp, err := c.httpClient.Do(req)
+	resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream")
 	if err != nil {
-		return nil, fmt.Errorf("failed to send request: %w", err)
+		if errors.Is(err, ErrSessionTerminated) && request.Method == string(mcp.MethodInitialize) {
+			// If the request is initialize, should not return a SessionTerminated error
+			// It should be a genuine endpoint-routing issue.
+			// ( Fall through to return StatusCode checking. )
+		} else {
+			return nil, fmt.Errorf("failed to send request: %w", err)
+		}
 	}
 	defer resp.Body.Close()
 
 	// Check if we got an error response
 	if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
-		// handle session closed
-		if resp.StatusCode == http.StatusNotFound {
-			c.sessionID.CompareAndSwap(sessionID, "")
-			return nil, fmt.Errorf("session terminated (404). need to re-initialize")
-		}
 
 		// Handle OAuth unauthorized error
 		if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
@@ -279,6 +278,10 @@ func (c *StreamableHTTP) SendRequest(
 		if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" {
 			c.sessionID.Store(sessionID)
 		}
+
+		c.initializedOnce.Do(func() {
+			close(c.initialized)
+		})
 	}
 
 	// Handle different response types
@@ -300,16 +303,77 @@ func (c *StreamableHTTP) SendRequest(
 
 	case "text/event-stream":
 		// Server is using SSE for streaming responses
-		return c.handleSSEResponse(ctx, resp.Body)
+		return c.handleSSEResponse(ctx, resp.Body, false)
 
 	default:
 		return nil, fmt.Errorf("unexpected content type: %s", resp.Header.Get("Content-Type"))
 	}
 }
 
+func (c *StreamableHTTP) sendHTTP(
+	ctx context.Context,
+	method string,
+	body io.Reader,
+	acceptType string,
+) (resp *http.Response, err error) {
+
+	// Create HTTP request
+	req, err := http.NewRequestWithContext(ctx, method, c.serverURL.String(), body)
+	if err != nil {
+		return nil, fmt.Errorf("failed to create request: %w", err)
+	}
+
+	// Set headers
+	req.Header.Set("Content-Type", "application/json")
+	req.Header.Set("Accept", acceptType)
+	sessionID := c.sessionID.Load().(string)
+	if sessionID != "" {
+		req.Header.Set(headerKeySessionID, sessionID)
+	}
+	for k, v := range c.headers {
+		req.Header.Set(k, v)
+	}
+
+	// Add OAuth authorization if configured
+	if c.oauthHandler != nil {
+		authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
+		if err != nil {
+			// If we get an authorization error, return a specific error that can be handled by the client
+			if err.Error() == "no valid token available, authorization required" {
+				return nil, &OAuthAuthorizationRequiredError{
+					Handler: c.oauthHandler,
+				}
+			}
+			return nil, fmt.Errorf("failed to get authorization header: %w", err)
+		}
+		req.Header.Set("Authorization", authHeader)
+	}
+
+	if c.headerFunc != nil {
+		for k, v := range c.headerFunc(ctx) {
+			req.Header.Set(k, v)
+		}
+	}
+
+	// Send request
+	resp, err = c.httpClient.Do(req)
+	if err != nil {
+		return nil, fmt.Errorf("failed to send request: %w", err)
+	}
+
+	// universal handling for session terminated
+	if resp.StatusCode == http.StatusNotFound {
+		c.sessionID.CompareAndSwap(sessionID, "")
+		return nil, ErrSessionTerminated
+	}
+
+	return resp, nil
+}
+
 // handleSSEResponse processes an SSE stream for a specific request.
 // It returns the final result for the request once received, or an error.
-func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser) (*JSONRPCResponse, error) {
+// If ignoreResponse is true, it won't return when a response messge is received. This is for continuous listening.
+func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser, ignoreResponse bool) (*JSONRPCResponse, error) {
 
 	// Create a channel for this specific request
 	responseChan := make(chan *JSONRPCResponse, 1)
@@ -328,7 +392,7 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
 
 			var message JSONRPCResponse
 			if err := json.Unmarshal([]byte(data), &message); err != nil {
-				fmt.Printf("failed to unmarshal message: %v\n", err)
+				c.logger.Errorf("failed to unmarshal message: %v", err)
 				return
 			}
 
@@ -336,7 +400,7 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
 			if message.ID.IsNil() {
 				var notification mcp.JSONRPCNotification
 				if err := json.Unmarshal([]byte(data), &notification); err != nil {
-					fmt.Printf("failed to unmarshal notification: %v\n", err)
+					c.logger.Errorf("failed to unmarshal notification: %v", err)
 					return
 				}
 				c.notifyMu.RLock()
@@ -347,7 +411,9 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
 				return
 			}
 
-			responseChan <- &message
+			if !ignoreResponse {
+				responseChan <- &message
+			}
 		})
 	}()
 
@@ -393,7 +459,7 @@ func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, hand
 				case <-ctx.Done():
 					return
 				default:
-					fmt.Printf("SSE stream error: %v\n", err)
+					c.logger.Errorf("SSE stream error: %v", err)
 					return
 				}
 			}
@@ -432,44 +498,10 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp.
 	}
 
 	// Create HTTP request
-	req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody))
-	if err != nil {
-		return fmt.Errorf("failed to create request: %w", err)
-	}
-
-	// Set headers
-	req.Header.Set("Content-Type", "application/json")
-	req.Header.Set("Accept", "application/json, text/event-stream")
-	if sessionID := c.sessionID.Load(); sessionID != "" {
-		req.Header.Set(headerKeySessionID, sessionID.(string))
-	}
-	for k, v := range c.headers {
-		req.Header.Set(k, v)
-	}
-
-	// Add OAuth authorization if configured
-	if c.oauthHandler != nil {
-		authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
-		if err != nil {
-			// If we get an authorization error, return a specific error that can be handled by the client
-			if errors.Is(err, ErrOAuthAuthorizationRequired) {
-				return &OAuthAuthorizationRequiredError{
-					Handler: c.oauthHandler,
-				}
-			}
-			return fmt.Errorf("failed to get authorization header: %w", err)
-		}
-		req.Header.Set("Authorization", authHeader)
-	}
-
-	if c.headerFunc != nil {
-		for k, v := range c.headerFunc(ctx) {
-			req.Header.Set(k, v)
-		}
-	}
+	ctx, cancel := c.contextAwareOfClientClose(ctx)
+	defer cancel()
 
-	// Send request
-	resp, err := c.httpClient.Do(req)
+	resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream")
 	if err != nil {
 		return fmt.Errorf("failed to send request: %w", err)
 	}
@@ -513,3 +545,84 @@ func (c *StreamableHTTP) GetOAuthHandler() *OAuthHandler {
 func (c *StreamableHTTP) IsOAuthEnabled() bool {
 	return c.oauthHandler != nil
 }
+
+func (c *StreamableHTTP) listenForever(ctx context.Context) {
+	c.logger.Infof("listening to server forever")
+	for {
+		err := c.createGETConnectionToServer(ctx)
+		if errors.Is(err, ErrGetMethodNotAllowed) {
+			// server does not support listening
+			c.logger.Errorf("server does not support listening")
+			return
+		}
+
+		select {
+		case <-ctx.Done():
+			return
+		default:
+		}
+
+		if err != nil {
+			c.logger.Errorf("failed to listen to server. retry in 1 second: %v", err)
+		}
+		time.Sleep(retryInterval)
+	}
+}
+
+var (
+	ErrSessionTerminated   = fmt.Errorf("session terminated (404). need to re-initialize")
+	ErrGetMethodNotAllowed = fmt.Errorf("GET method not allowed")
+
+	retryInterval = 1 * time.Second // a variable is convenient for testing
+)
+
+func (c *StreamableHTTP) createGETConnectionToServer(ctx context.Context) error {
+
+	resp, err := c.sendHTTP(ctx, http.MethodGet, nil, "text/event-stream")
+	if err != nil {
+		return fmt.Errorf("failed to send request: %w", err)
+	}
+	defer resp.Body.Close()
+
+	// Check if we got an error response
+	if resp.StatusCode == http.StatusMethodNotAllowed {
+		return ErrGetMethodNotAllowed
+	}
+
+	if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
+		body, _ := io.ReadAll(resp.Body)
+		return fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body)
+	}
+
+	// handle SSE response
+	contentType := resp.Header.Get("Content-Type")
+	if contentType != "text/event-stream" {
+		return fmt.Errorf("unexpected content type: %s", contentType)
+	}
+
+	// When ignoreResponse is true, the function will never return expect context is done.
+	// NOTICE: Due to the ambiguity of the specification, other SDKs may use the GET connection to transfer the response
+	// messages. To be more compatible, we should handle this response, however, as the transport layer is message-based,
+	// currently, there is no convenient way to handle this response.
+	// So we ignore the response here. It's not a bug, but may be not compatible with other SDKs.
+	_, err = c.handleSSEResponse(ctx, resp.Body, true)
+	if err != nil {
+		return fmt.Errorf("failed to handle SSE response: %w", err)
+	}
+
+	return nil
+}
+
+func (c *StreamableHTTP) contextAwareOfClientClose(ctx context.Context) (context.Context, context.CancelFunc) {
+	newCtx, cancel := context.WithCancel(ctx)
+	go func() {
+		select {
+		case <-c.closed:
+			cancel()
+		case <-newCtx.Done():
+			// The original context was canceled
+			cancel()
+		}
+	}()
+	return newCtx, cancel
+}

vendor/github.com/mark3labs/mcp-go/mcp/tools.go 🔗

@@ -945,7 +945,20 @@ func PropertyNames(schema map[string]any) PropertyOption {
 	}
 }
 
-// Items defines the schema for array items
+// Items defines the schema for array items.
+// Accepts any schema definition for maximum flexibility.
+//
+// Example:
+//
+//	Items(map[string]any{
+//	    "type": "object",
+//	    "properties": map[string]any{
+//	        "name": map[string]any{"type": "string"},
+//	        "age": map[string]any{"type": "number"},
+//	    },
+//	})
+//
+// For simple types, use ItemsString(), ItemsNumber(), ItemsBoolean() instead.
 func Items(schema any) PropertyOption {
 	return func(schemaMap map[string]any) {
 		schemaMap["items"] = schema
@@ -972,3 +985,94 @@ func UniqueItems(unique bool) PropertyOption {
 		schema["uniqueItems"] = unique
 	}
 }
+
+// WithStringItems configures an array's items to be of type string.
+//
+// Supported options: Description(), DefaultString(), Enum(), MaxLength(), MinLength(), Pattern()
+// Note: Options like Required() are not valid for item schemas and will be ignored.
+//
+// Examples:
+//
+//	mcp.WithArray("tags", mcp.WithStringItems())
+//	mcp.WithArray("colors", mcp.WithStringItems(mcp.Enum("red", "green", "blue")))
+//	mcp.WithArray("names", mcp.WithStringItems(mcp.MinLength(1), mcp.MaxLength(50)))
+//
+// Limitations: Only supports simple string arrays. Use Items() for complex objects.
+func WithStringItems(opts ...PropertyOption) PropertyOption {
+	return func(schema map[string]any) {
+		itemSchema := map[string]any{
+			"type": "string",
+		}
+
+		for _, opt := range opts {
+			opt(itemSchema)
+		}
+
+		schema["items"] = itemSchema
+	}
+}
+
+// WithStringEnumItems configures an array's items to be of type string with a specified enum.
+// Example:
+//
+//	mcp.WithArray("priority", mcp.WithStringEnumItems([]string{"low", "medium", "high"}))
+//
+// Limitations: Only supports string enums. Use WithStringItems(Enum(...)) for more flexibility.
+func WithStringEnumItems(values []string) PropertyOption {
+	return func(schema map[string]any) {
+		schema["items"] = map[string]any{
+			"type": "string",
+			"enum": values,
+		}
+	}
+}
+
+// WithNumberItems configures an array's items to be of type number.
+//
+// Supported options: Description(), DefaultNumber(), Min(), Max(), MultipleOf()
+// Note: Options like Required() are not valid for item schemas and will be ignored.
+//
+// Examples:
+//
+//	mcp.WithArray("scores", mcp.WithNumberItems(mcp.Min(0), mcp.Max(100)))
+//	mcp.WithArray("prices", mcp.WithNumberItems(mcp.Min(0)))
+//
+// Limitations: Only supports simple number arrays. Use Items() for complex objects.
+func WithNumberItems(opts ...PropertyOption) PropertyOption {
+	return func(schema map[string]any) {
+		itemSchema := map[string]any{
+			"type": "number",
+		}
+
+		for _, opt := range opts {
+			opt(itemSchema)
+		}
+
+		schema["items"] = itemSchema
+	}
+}
+
+// WithBooleanItems configures an array's items to be of type boolean.
+//
+// Supported options: Description(), DefaultBool()
+// Note: Options like Required() are not valid for item schemas and will be ignored.
+//
+// Examples:
+//
+//	mcp.WithArray("flags", mcp.WithBooleanItems())
+//	mcp.WithArray("permissions", mcp.WithBooleanItems(mcp.Description("User permissions")))
+//
+// Limitations: Only supports simple boolean arrays. Use Items() for complex objects.
+func WithBooleanItems(opts ...PropertyOption) PropertyOption {
+	return func(schema map[string]any) {
+		itemSchema := map[string]any{
+			"type": "boolean",
+		}
+
+		for _, opt := range opts {
+			opt(itemSchema)
+		}
+
+		schema["items"] = itemSchema
+	}
+}

vendor/github.com/mark3labs/mcp-go/mcp/types.go 🔗

@@ -763,6 +763,11 @@ const (
 
 /* Sampling */
 
+const (
+	// MethodSamplingCreateMessage allows servers to request LLM completions from clients
+	MethodSamplingCreateMessage MCPMethod = "sampling/createMessage"
+)
+
 // CreateMessageRequest is a request from the server to sample an LLM via the
 // client. The client has full discretion over which model to select. The client
 // should also inform the user before beginning sampling, to allow them to inspect
@@ -865,6 +870,22 @@ type AudioContent struct {
 
 func (AudioContent) isContent() {}
 
+// ResourceLink represents a link to a resource that the client can access.
+type ResourceLink struct {
+	Annotated
+	Type string `json:"type"` // Must be "resource_link"
+	// The URI of the resource.
+	URI string `json:"uri"`
+	// The name of the resource.
+	Name string `json:"name"`
+	// The description of the resource.
+	Description string `json:"description"`
+	// The MIME type of the resource.
+	MIMEType string `json:"mimeType"`
+}
+
+func (ResourceLink) isContent() {}
+
 // EmbeddedResource represents the contents of a resource, embedded into a prompt or tool call result.
 //
 // It is up to the client how best to render embedded resources for the

vendor/github.com/mark3labs/mcp-go/mcp/utils.go 🔗

@@ -222,6 +222,17 @@ func NewAudioContent(data, mimeType string) AudioContent {
 	}
 }
 
+// Helper function to create a new ResourceLink
+func NewResourceLink(uri, name, description, mimeType string) ResourceLink {
+	return ResourceLink{
+		Type:        "resource_link",
+		URI:         uri,
+		Name:        name,
+		Description: description,
+		MIMEType:    mimeType,
+	}
+}
+
 // Helper function to create a new EmbeddedResource
 func NewEmbeddedResource(resource ResourceContents) EmbeddedResource {
 	return EmbeddedResource{
@@ -476,6 +487,16 @@ func ParseContent(contentMap map[string]any) (Content, error) {
 		}
 		return NewAudioContent(data, mimeType), nil
 
+	case "resource_link":
+		uri := ExtractString(contentMap, "uri")
+		name := ExtractString(contentMap, "name")
+		description := ExtractString(contentMap, "description")
+		mimeType := ExtractString(contentMap, "mimeType")
+		if uri == "" || name == "" {
+			return nil, fmt.Errorf("resource_link uri or name is missing")
+		}
+		return NewResourceLink(uri, name, description, mimeType), nil
+
 	case "resource":
 		resourceMap := ExtractMap(contentMap, "resource")
 		if resourceMap == nil {

vendor/github.com/mark3labs/mcp-go/server/sampling.go 🔗

@@ -0,0 +1,37 @@
+package server
+
+import (
+	"context"
+	"fmt"
+
+	"github.com/mark3labs/mcp-go/mcp"
+)
+
+// EnableSampling enables sampling capabilities for the server.
+// This allows the server to send sampling requests to clients that support it.
+func (s *MCPServer) EnableSampling() {
+	s.capabilitiesMu.Lock()
+	defer s.capabilitiesMu.Unlock()
+}
+
+// RequestSampling sends a sampling request to the client.
+// The client must have declared sampling capability during initialization.
+func (s *MCPServer) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
+	session := ClientSessionFromContext(ctx)
+	if session == nil {
+		return nil, fmt.Errorf("no active session")
+	}
+
+	// Check if the session supports sampling requests
+	if samplingSession, ok := session.(SessionWithSampling); ok {
+		return samplingSession.RequestSampling(ctx, request)
+	}
+
+	return nil, fmt.Errorf("session does not support sampling")
+}
+
+// SessionWithSampling extends ClientSession to support sampling requests.
+type SessionWithSampling interface {
+	ClientSession
+	RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error)
+}

vendor/github.com/mark3labs/mcp-go/server/stdio.go 🔗

@@ -9,6 +9,7 @@ import (
 	"log"
 	"os"
 	"os/signal"
+	"sync"
 	"sync/atomic"
 	"syscall"
 
@@ -51,10 +52,21 @@ func WithStdioContextFunc(fn StdioContextFunc) StdioOption {
 
 // stdioSession is a static client session, since stdio has only one client.
 type stdioSession struct {
-	notifications chan mcp.JSONRPCNotification
-	initialized   atomic.Bool
-	loggingLevel  atomic.Value
-	clientInfo    atomic.Value // stores session-specific client info
+	notifications   chan mcp.JSONRPCNotification
+	initialized     atomic.Bool
+	loggingLevel    atomic.Value
+	clientInfo      atomic.Value                     // stores session-specific client info
+	writer          io.Writer                        // for sending requests to client
+	requestID       atomic.Int64                     // for generating unique request IDs
+	mu              sync.RWMutex                     // protects writer
+	pendingRequests map[int64]chan *samplingResponse // for tracking pending sampling requests
+	pendingMu       sync.RWMutex                     // protects pendingRequests
+}
+
+// samplingResponse represents a response to a sampling request
+type samplingResponse struct {
+	result *mcp.CreateMessageResult
+	err    error
 }
 
 func (s *stdioSession) SessionID() string {
@@ -100,14 +112,86 @@ func (s *stdioSession) GetLogLevel() mcp.LoggingLevel {
 	return level.(mcp.LoggingLevel)
 }
 
+// RequestSampling sends a sampling request to the client and waits for the response.
+func (s *stdioSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
+	s.mu.RLock()
+	writer := s.writer
+	s.mu.RUnlock()
+
+	if writer == nil {
+		return nil, fmt.Errorf("no writer available for sending requests")
+	}
+
+	// Generate a unique request ID
+	id := s.requestID.Add(1)
+
+	// Create a response channel for this request
+	responseChan := make(chan *samplingResponse, 1)
+	s.pendingMu.Lock()
+	s.pendingRequests[id] = responseChan
+	s.pendingMu.Unlock()
+
+	// Cleanup function to remove the pending request
+	cleanup := func() {
+		s.pendingMu.Lock()
+		delete(s.pendingRequests, id)
+		s.pendingMu.Unlock()
+	}
+	defer cleanup()
+
+	// Create the JSON-RPC request
+	jsonRPCRequest := struct {
+		JSONRPC string                  `json:"jsonrpc"`
+		ID      int64                   `json:"id"`
+		Method  string                  `json:"method"`
+		Params  mcp.CreateMessageParams `json:"params"`
+	}{
+		JSONRPC: mcp.JSONRPC_VERSION,
+		ID:      id,
+		Method:  string(mcp.MethodSamplingCreateMessage),
+		Params:  request.CreateMessageParams,
+	}
+
+	// Marshal and send the request
+	requestBytes, err := json.Marshal(jsonRPCRequest)
+	if err != nil {
+		return nil, fmt.Errorf("failed to marshal sampling request: %w", err)
+	}
+	requestBytes = append(requestBytes, '\n')
+
+	if _, err := writer.Write(requestBytes); err != nil {
+		return nil, fmt.Errorf("failed to write sampling request: %w", err)
+	}
+
+	// Wait for the response or context cancellation
+	select {
+	case <-ctx.Done():
+		return nil, ctx.Err()
+	case response := <-responseChan:
+		if response.err != nil {
+			return nil, response.err
+		}
+		return response.result, nil
+	}
+}
+
+// SetWriter sets the writer for sending requests to the client.
+func (s *stdioSession) SetWriter(writer io.Writer) {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	s.writer = writer
+}
+
 var (
 	_ ClientSession         = (*stdioSession)(nil)
 	_ SessionWithLogging    = (*stdioSession)(nil)
 	_ SessionWithClientInfo = (*stdioSession)(nil)
+	_ SessionWithSampling   = (*stdioSession)(nil)
 )
 
 var stdioSessionInstance = stdioSession{
-	notifications: make(chan mcp.JSONRPCNotification, 100),
+	notifications:   make(chan mcp.JSONRPCNotification, 100),
+	pendingRequests: make(map[int64]chan *samplingResponse),
 }
 
 // NewStdioServer creates a new stdio server wrapper around an MCPServer.
@@ -224,6 +308,9 @@ func (s *StdioServer) Listen(
 	defer s.server.UnregisterSession(ctx, stdioSessionInstance.SessionID())
 	ctx = s.server.WithContext(ctx, &stdioSessionInstance)
 
+	// Set the writer for sending requests to the client
+	stdioSessionInstance.SetWriter(stdout)
+
 	// Add in any custom context.
 	if s.contextFunc != nil {
 		ctx = s.contextFunc(ctx)
@@ -256,7 +343,29 @@ func (s *StdioServer) processMessage(
 		return s.writeResponse(response, writer)
 	}
 
-	// Handle the message using the wrapped server
+	// Check if this is a response to a sampling request
+	if s.handleSamplingResponse(rawMessage) {
+		return nil
+	}
+
+	// Check if this is a tool call that might need sampling (and thus should be processed concurrently)
+	var baseMessage struct {
+		Method string `json:"method"`
+	}
+	if json.Unmarshal(rawMessage, &baseMessage) == nil && baseMessage.Method == "tools/call" {
+		// Process tool calls concurrently to avoid blocking on sampling requests
+		go func() {
+			response := s.server.HandleMessage(ctx, rawMessage)
+			if response != nil {
+				if err := s.writeResponse(response, writer); err != nil {
+					s.errLogger.Printf("Error writing tool response: %v", err)
+				}
+			}
+		}()
+		return nil
+	}
+
+	// Handle other messages synchronously
 	response := s.server.HandleMessage(ctx, rawMessage)
 
 	// Only write response if there is one (not for notifications)
@@ -269,6 +378,65 @@ func (s *StdioServer) processMessage(
 	return nil
 }
 
+// handleSamplingResponse checks if the message is a response to a sampling request
+// and routes it to the appropriate pending request channel.
+func (s *StdioServer) handleSamplingResponse(rawMessage json.RawMessage) bool {
+	return stdioSessionInstance.handleSamplingResponse(rawMessage)
+}
+
+// handleSamplingResponse handles incoming sampling responses for this session
+func (s *stdioSession) handleSamplingResponse(rawMessage json.RawMessage) bool {
+	// Try to parse as a JSON-RPC response
+	var response struct {
+		JSONRPC string          `json:"jsonrpc"`
+		ID      json.Number     `json:"id"`
+		Result  json.RawMessage `json:"result,omitempty"`
+		Error   *struct {
+			Code    int    `json:"code"`
+			Message string `json:"message"`
+		} `json:"error,omitempty"`
+	}
+
+	if err := json.Unmarshal(rawMessage, &response); err != nil {
+		return false
+	}
+	// Parse the ID as int64
+	idInt64, err := response.ID.Int64()
+	if err != nil || (response.Result == nil && response.Error == nil) {
+		return false
+	}
+
+	// Look for a pending request with this ID
+	s.pendingMu.RLock()
+	responseChan, exists := s.pendingRequests[idInt64]
+	s.pendingMu.RUnlock()
+
+	if !exists {
+		return false
+	} // Parse and send the response
+	samplingResp := &samplingResponse{}
+
+	if response.Error != nil {
+		samplingResp.err = fmt.Errorf("sampling request failed: %s", response.Error.Message)
+	} else {
+		var result mcp.CreateMessageResult
+		if err := json.Unmarshal(response.Result, &result); err != nil {
+			samplingResp.err = fmt.Errorf("failed to unmarshal sampling response: %w", err)
+		} else {
+			samplingResp.result = &result
+		}
+	}
+
+	// Send the response (non-blocking)
+	select {
+	case responseChan <- samplingResp:
+	default:
+		// Channel is full or closed, ignore
+	}
+
+	return true
+}
+
 // writeResponse marshals and writes a JSON-RPC response message followed by a newline.
 // Returns an error if marshaling or writing fails.
 func (s *StdioServer) writeResponse(

vendor/github.com/mark3labs/mcp-go/server/streamable_http.go 🔗

@@ -40,7 +40,9 @@ func WithEndpointPath(endpointPath string) StreamableHTTPOption {
 // to StatelessSessionIdManager.
 func WithStateLess(stateLess bool) StreamableHTTPOption {
 	return func(s *StreamableHTTPServer) {
-		s.sessionIdManager = &StatelessSessionIdManager{}
+		if stateLess {
+			s.sessionIdManager = &StatelessSessionIdManager{}
+		}
 	}
 }
 
@@ -374,7 +376,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
 	w.Header().Set("Content-Type", "text/event-stream")
 	w.Header().Set("Cache-Control", "no-cache")
 	w.Header().Set("Connection", "keep-alive")
-	w.WriteHeader(http.StatusAccepted)
+	w.WriteHeader(http.StatusOK)
 
 	flusher, ok := w.(http.Flusher)
 	if !ok {

vendor/modules.txt 🔗

@@ -260,8 +260,8 @@ github.com/charmbracelet/bubbletea/v2
 # github.com/charmbracelet/colorprofile v0.3.1
 ## explicit; go 1.23.0
 github.com/charmbracelet/colorprofile
-# github.com/charmbracelet/fang v0.1.0
-## explicit; go 1.23.0
+# github.com/charmbracelet/fang v0.3.1-0.20250711140230-d5ebb8c1d674
+## explicit; go 1.24.0
 github.com/charmbracelet/fang
 # github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe
 ## explicit; go 1.23.0
@@ -269,7 +269,7 @@ github.com/charmbracelet/glamour/v2
 github.com/charmbracelet/glamour/v2/ansi
 github.com/charmbracelet/glamour/v2/internal/autolink
 github.com/charmbracelet/glamour/v2/styles
-# github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.2.0.20250703152125-8e1c474f8a71 => github.com/charmbracelet/lipgloss-internal/v2 v2.0.0-20250710185058-03664cb9cecb
+# github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.3 => github.com/charmbracelet/lipgloss-internal/v2 v2.0.0-20250710185058-03664cb9cecb
 ## explicit; go 1.24.2
 github.com/charmbracelet/lipgloss/v2
 github.com/charmbracelet/lipgloss/v2/table
@@ -403,7 +403,7 @@ github.com/kylelemons/godebug/pretty
 # github.com/lucasb-eyer/go-colorful v1.2.0
 ## explicit; go 1.12
 github.com/lucasb-eyer/go-colorful
-# github.com/mark3labs/mcp-go v0.32.0
+# github.com/mark3labs/mcp-go v0.33.0
 ## explicit; go 1.23
 github.com/mark3labs/mcp-go/client
 github.com/mark3labs/mcp-go/client/transport