.gitignore 🔗
@@ -16,6 +16,7 @@
# Go workspace file
go.work
+go.work.sum
# IDE specific files
.idea/
Tai Groot created
.gitignore | 1
cmd/root.go | 23
go.mod | 6
go.sum | 8
internal/app/app.go | 69
internal/app/lsp.go | 5
internal/config/config.go | 33
internal/config/load.go | 35
internal/config/resolve.go | 2
internal/format/format.go | 91
internal/format/spinner.go | 50
internal/fur/client/client.go | 2
internal/llm/agent/agent.go | 4
internal/llm/prompt/prompt.go | 53
internal/llm/prompt/prompt_test.go | 113
internal/llm/tools/bash.go | 127
internal/lsp/transport.go | 41
internal/lsp/watcher/watcher.go | 153
internal/shell/command_block_test.go | 123
internal/shell/shell.go | 82
internal/tui/components/chat/editor/editor.go | 5
internal/tui/components/chat/messages/renderer.go | 7
internal/tui/components/chat/splash/keys.go | 4
internal/tui/components/chat/splash/splash.go | 1
internal/tui/components/completions/completions.go | 23
internal/tui/components/dialogs/permissions/keys.go | 4
internal/tui/exp/diffview/diffview.go | 6
internal/tui/page/chat/chat.go | 15
internal/tui/tui.go | 3
internal/tui/util/util.go | 2
vendor/github.com/Azure/azure-sdk-for-go/sdk/azidentity/go.work.sum | 60
vendor/github.com/charmbracelet/fang/README.md | 7
vendor/github.com/charmbracelet/fang/fang.go | 119
vendor/github.com/charmbracelet/fang/help.go | 298
vendor/github.com/charmbracelet/fang/theme.go | 52
vendor/github.com/mark3labs/mcp-go/client/client.go | 102
vendor/github.com/mark3labs/mcp-go/client/http.go | 7
vendor/github.com/mark3labs/mcp-go/client/sampling.go | 20
vendor/github.com/mark3labs/mcp-go/client/stdio.go | 22
vendor/github.com/mark3labs/mcp-go/client/transport/inprocess.go | 4
vendor/github.com/mark3labs/mcp-go/client/transport/interface.go | 20
vendor/github.com/mark3labs/mcp-go/client/transport/sse.go | 6
vendor/github.com/mark3labs/mcp-go/client/transport/stdio.go | 204
vendor/github.com/mark3labs/mcp-go/client/transport/streamable_http.go | 343
vendor/github.com/mark3labs/mcp-go/mcp/tools.go | 106
vendor/github.com/mark3labs/mcp-go/mcp/types.go | 21
vendor/github.com/mark3labs/mcp-go/mcp/utils.go | 21
vendor/github.com/mark3labs/mcp-go/server/sampling.go | 37
vendor/github.com/mark3labs/mcp-go/server/stdio.go | 180
vendor/github.com/mark3labs/mcp-go/server/streamable_http.go | 6
vendor/modules.txt | 8
51 files changed, 2,019 insertions(+), 715 deletions(-)
@@ -16,6 +16,7 @@
# Go workspace file
go.work
+go.work.sum
# IDE specific files
.idea/
@@ -12,7 +12,6 @@ import (
"github.com/charmbracelet/crush/internal/app"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/db"
- "github.com/charmbracelet/crush/internal/format"
"github.com/charmbracelet/crush/internal/llm/agent"
"github.com/charmbracelet/crush/internal/log"
"github.com/charmbracelet/crush/internal/tui"
@@ -52,14 +51,8 @@ to assist developers in writing, debugging, and understanding code directly from
debug, _ := cmd.Flags().GetBool("debug")
cwd, _ := cmd.Flags().GetString("cwd")
prompt, _ := cmd.Flags().GetString("prompt")
- outputFormat, _ := cmd.Flags().GetString("output-format")
quiet, _ := cmd.Flags().GetBool("quiet")
- // Validate format option
- if !format.IsValid(outputFormat) {
- return fmt.Errorf("invalid format option: %s\n%s", outputFormat, format.GetHelpText())
- }
-
if cwd != "" {
err := os.Chdir(cwd)
if err != nil {
@@ -79,9 +72,7 @@ to assist developers in writing, debugging, and understanding code directly from
return err
}
- // Create main context for the application
- ctx, cancel := context.WithCancel(context.Background())
- defer cancel()
+ ctx := cmd.Context()
// Connect DB, this will also run migrations
conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
@@ -109,7 +100,7 @@ to assist developers in writing, debugging, and understanding code directly from
// Non-interactive mode
if prompt != "" {
// Run non-interactive flow using the App method
- return app.RunNonInteractive(ctx, prompt, outputFormat, quiet)
+ return app.RunNonInteractive(ctx, prompt, quiet)
}
// Set up the TUI
@@ -152,6 +143,7 @@ func Execute() {
context.Background(),
rootCmd,
fang.WithVersion(version.Version),
+ fang.WithNotifySignal(os.Interrupt),
); err != nil {
os.Exit(1)
}
@@ -164,17 +156,8 @@ func init() {
rootCmd.Flags().BoolP("debug", "d", false, "Debug")
rootCmd.Flags().StringP("prompt", "p", "", "Prompt to run in non-interactive mode")
- // Add format flag with validation logic
- rootCmd.Flags().StringP("output-format", "f", format.Text.String(),
- "Output format for non-interactive mode (text, json)")
-
// Add quiet flag to hide spinner in non-interactive mode
rootCmd.Flags().BoolP("quiet", "q", false, "Hide spinner in non-interactive mode")
-
- // Register custom validation for the format flag
- rootCmd.RegisterFlagCompletionFunc("output-format", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
- return format.SupportedFormats, cobra.ShellCompDirectiveNoFileComp
- })
}
func maybePrependStdin(prompt string) (string, error) {
@@ -18,9 +18,9 @@ require (
github.com/charlievieth/fastwalk v1.0.11
github.com/charmbracelet/bubbles/v2 v2.0.0-beta.1.0.20250710161907-a4c42b579198
github.com/charmbracelet/bubbletea/v2 v2.0.0-beta.1
- github.com/charmbracelet/fang v0.1.0
+ github.com/charmbracelet/fang v0.3.1-0.20250711140230-d5ebb8c1d674
github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe
- github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.2.0.20250703152125-8e1c474f8a71
+ github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.3
github.com/charmbracelet/log/v2 v2.0.0-20250226163916-c379e29ff706
github.com/charmbracelet/x/ansi v0.9.3
github.com/charmbracelet/x/exp/charmtone v0.0.0-20250708181618-a60a724ba6c3
@@ -29,7 +29,7 @@ require (
github.com/fsnotify/fsnotify v1.8.0
github.com/google/uuid v1.6.0
github.com/joho/godotenv v1.5.1
- github.com/mark3labs/mcp-go v0.32.0
+ github.com/mark3labs/mcp-go v0.33.0
github.com/muesli/termenv v0.16.0
github.com/ncruces/go-sqlite3 v0.25.0
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
@@ -74,8 +74,8 @@ github.com/charmbracelet/bubbletea-internal/v2 v2.0.0-20250710185017-3c0ffd25e59
github.com/charmbracelet/bubbletea-internal/v2 v2.0.0-20250710185017-3c0ffd25e595/go.mod h1:+Tl7rePElw6OKt382t04zXwtPFoPXxAaJzNrYmtsLds=
github.com/charmbracelet/colorprofile v0.3.1 h1:k8dTHMd7fgw4bnFd7jXTLZrSU/CQrKnL3m+AxCzDz40=
github.com/charmbracelet/colorprofile v0.3.1/go.mod h1:/GkGusxNs8VB/RSOh3fu0TJmQ4ICMMPApIIVn0KszZ0=
-github.com/charmbracelet/fang v0.1.0 h1:SlZS2crf3/zQh7Mr4+W+7QR1k+L08rrPX5rm5z3d7Wg=
-github.com/charmbracelet/fang v0.1.0/go.mod h1:Zl/zeUQ8EtQuGyiV0ZKZlZPDowKRTzu8s/367EpN/fc=
+github.com/charmbracelet/fang v0.3.1-0.20250711140230-d5ebb8c1d674 h1:+Cz+VfxD5DO+JT1LlswXWhre0HYLj6l2HW8HVGfMuC0=
+github.com/charmbracelet/fang v0.3.1-0.20250711140230-d5ebb8c1d674/go.mod h1:9gCUAHmVx5BwSafeyNr3GI0GgvlB1WYjL21SkPp1jyU=
github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe h1:i6ce4CcAlPpTj2ER69m1DBeLZ3RRcHnKExuwhKa3GfY=
github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe/go.mod h1:p3Q+aN4eQKeM5jhrmXPMgPrlKbmc59rWSnMsSA3udhk=
github.com/charmbracelet/lipgloss-internal/v2 v2.0.0-20250710185058-03664cb9cecb h1:lswj7CYZVYbLn2OhYJsXOMRQQGdRIfyuSnh5FdVSMr0=
@@ -165,8 +165,8 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
-github.com/mark3labs/mcp-go v0.32.0 h1:fgwmbfL2gbd67obg57OfV2Dnrhs1HtSdlY/i5fn7MU8=
-github.com/mark3labs/mcp-go v0.32.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4=
+github.com/mark3labs/mcp-go v0.33.0 h1:naxhjnTIs/tyPZmWUZFuG0lDmdA6sUyYGGf3gsHvTCc=
+github.com/mark3labs/mcp-go v0.33.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
@@ -92,16 +92,26 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
}
// RunNonInteractive handles the execution flow when a prompt is provided via CLI flag.
-func (a *App) RunNonInteractive(ctx context.Context, prompt string, outputFormat string, quiet bool) error {
+func (a *App) RunNonInteractive(ctx context.Context, prompt string, quiet bool) error {
slog.Info("Running in non-interactive mode")
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
// Start spinner if not in quiet mode
var spinner *format.Spinner
if !quiet {
- spinner = format.NewSpinner(ctx, "Generating")
+ spinner = format.NewSpinner(ctx, cancel, "Generating")
spinner.Start()
- defer spinner.Stop()
}
+ // Helper function to stop spinner once
+ stopSpinner := func() {
+ if !quiet && spinner != nil {
+ spinner.Stop()
+ spinner = nil
+ }
+ }
+ defer stopSpinner()
const maxPromptLengthForTitle = 100
titlePrefix := "Non-interactive: "
@@ -128,35 +138,42 @@ func (a *App) RunNonInteractive(ctx context.Context, prompt string, outputFormat
return fmt.Errorf("failed to start agent processing stream: %w", err)
}
- result := <-done
+ messageEvents := a.Messages.Subscribe(ctx)
+ readBts := 0
- // Stop spinner before printing output
- if !quiet && spinner != nil {
- spinner.Stop()
- }
+ for {
+ select {
+ case result := <-done:
+ stopSpinner()
+
+ if result.Error != nil {
+ if errors.Is(result.Error, context.Canceled) || errors.Is(result.Error, agent.ErrRequestCancelled) {
+ slog.Info("Agent processing cancelled", "session_id", sess.ID)
+ return nil
+ }
+ return fmt.Errorf("agent processing failed: %w", result.Error)
+ }
+
+ part := result.Message.Content().String()[readBts:]
+ fmt.Println(part)
- if result.Error != nil {
- if errors.Is(result.Error, context.Canceled) || errors.Is(result.Error, agent.ErrRequestCancelled) {
- slog.Info("Agent processing cancelled", "session_id", sess.ID)
+ slog.Info("Non-interactive run completed", "session_id", sess.ID)
return nil
- }
- return fmt.Errorf("agent processing failed: %w", result.Error)
- }
- // Get the text content from the response
- content := "No content available"
- if result.Message.Content().String() != "" {
- content = result.Message.Content().String()
- }
+ case event := <-messageEvents:
+ msg := event.Payload
+ if msg.SessionID == sess.ID && msg.Role == message.Assistant && len(msg.Parts) > 0 {
+ stopSpinner()
+ part := msg.Content().String()[readBts:]
+ fmt.Print(part)
+ readBts += len(part)
+ }
- out, err := format.FormatOutput(content, outputFormat)
- if err != nil {
- return err
+ case <-ctx.Done():
+ stopSpinner()
+ return ctx.Err()
+ }
}
-
- fmt.Println(out)
- slog.Info("Non-interactive run completed", "session_id", sess.ID)
- return nil
}
func (app *App) UpdateAgentModel() error {
@@ -59,11 +59,8 @@ func (app *App) createAndStartLSPClient(ctx context.Context, name string, comman
// Create a child context that can be canceled when the app is shutting down
watchCtx, cancelFunc := context.WithCancel(ctx)
- // Create a context with the server name for better identification
- watchCtx = context.WithValue(watchCtx, "serverName", name)
-
// Create the workspace watcher
- workspaceWatcher := watcher.NewWorkspaceWatcher(lspClient)
+ workspaceWatcher := watcher.NewWorkspaceWatcher(name, lspClient)
// Store the cancel function to be called during cleanup
app.cancelFuncsMutex.Lock()
@@ -72,7 +72,7 @@ type ProviderConfig struct {
Disable bool `json:"disable,omitempty"`
// Extra headers to send with each request to the provider.
- ExtraHeaders map[string]string
+ ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
// Used to pass extra parameters to the provider.
ExtraParams map[string]string `json:"-"`
@@ -371,3 +371,34 @@ func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
c.Providers[providerID] = providerConfig
return nil
}
+
+func (c *Config) SetupAgents() {
+ agents := map[string]Agent{
+ "coder": {
+ ID: "coder",
+ Name: "Coder",
+ Description: "An agent that helps with executing coding tasks.",
+ Model: SelectedModelTypeLarge,
+ ContextPaths: c.Options.ContextPaths,
+ // All tools allowed
+ },
+ "task": {
+ ID: "task",
+ Name: "Task",
+ Description: "An agent that helps with searching for context and finding implementation details.",
+ Model: SelectedModelTypeLarge,
+ ContextPaths: c.Options.ContextPaths,
+ AllowedTools: []string{
+ "glob",
+ "grep",
+ "ls",
+ "sourcegraph",
+ "view",
+ },
+ // NO MCPs or LSPs by default
+ AllowedMCP: map[string][]string{},
+ AllowedLSP: []string{},
+ },
+ }
+ c.Agents = agents
+}
@@ -83,37 +83,7 @@ func Load(workingDir string, debug bool) (*Config, error) {
if err := cfg.configureSelectedModels(providers); err != nil {
return nil, fmt.Errorf("failed to configure selected models: %w", err)
}
-
- // TODO: remove the agents concept from the config
- agents := map[string]Agent{
- "coder": {
- ID: "coder",
- Name: "Coder",
- Description: "An agent that helps with executing coding tasks.",
- Model: SelectedModelTypeLarge,
- ContextPaths: cfg.Options.ContextPaths,
- // All tools allowed
- },
- "task": {
- ID: "task",
- Name: "Task",
- Description: "An agent that helps with searching for context and finding implementation details.",
- Model: SelectedModelTypeLarge,
- ContextPaths: cfg.Options.ContextPaths,
- AllowedTools: []string{
- "glob",
- "grep",
- "ls",
- "sourcegraph",
- "view",
- },
- // NO MCPs or LSPs by default
- AllowedMCP: map[string][]string{},
- AllowedLSP: []string{},
- },
- }
- cfg.Agents = agents
-
+ cfg.SetupAgents()
return cfg, nil
}
@@ -331,6 +301,7 @@ func (cfg *Config) defaultModelSelection(knownProviders []provider.Provider) (la
defaultSmallModel := cfg.GetModel(string(p.ID), p.DefaultSmallModelID)
if defaultSmallModel == nil {
err = fmt.Errorf("default small model %s not found for provider %s", p.DefaultSmallModelID, p.ID)
+ return
}
smallModel = SelectedModel{
Provider: string(p.ID),
@@ -387,8 +358,6 @@ func (cfg *Config) configureSelectedModels(knownProviders []provider.Provider) e
large.Provider = largeModelSelected.Provider
}
model := cfg.GetModel(large.Provider, large.Model)
- slog.Info("Configuring selected large model", "provider", large.Provider, "model", large.Model)
- slog.Info("Model configured", "model", model)
if model == nil {
large = defaultLarge
// override the model type to large
@@ -44,7 +44,7 @@ func (r *shellVariableResolver) ResolveValue(value string) (string, error) {
if strings.HasPrefix(value, "$(") && strings.HasSuffix(value, ")") {
command := strings.TrimSuffix(strings.TrimPrefix(value, "$("), ")")
- ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
stdout, _, err := r.shell.Exec(ctx, command)
@@ -1,91 +0,0 @@
-package format
-
-import (
- "encoding/json"
- "fmt"
- "strings"
-)
-
-// OutputFormat represents the output format type for non-interactive mode
-type OutputFormat string
-
-const (
- // Text format outputs the AI response as plain text.
- Text OutputFormat = "text"
-
- // JSON format outputs the AI response wrapped in a JSON object.
- JSON OutputFormat = "json"
-)
-
-// String returns the string representation of the OutputFormat
-func (f OutputFormat) String() string {
- return string(f)
-}
-
-// SupportedFormats is a list of all supported output formats as strings
-var SupportedFormats = []string{
- string(Text),
- string(JSON),
-}
-
-// Parse converts a string to an OutputFormat
-func Parse(s string) (OutputFormat, error) {
- s = strings.ToLower(strings.TrimSpace(s))
-
- switch s {
- case string(Text):
- return Text, nil
- case string(JSON):
- return JSON, nil
- default:
- return "", fmt.Errorf("invalid format: %s", s)
- }
-}
-
-// IsValid checks if the provided format string is supported
-func IsValid(s string) bool {
- _, err := Parse(s)
- return err == nil
-}
-
-// GetHelpText returns a formatted string describing all supported formats
-func GetHelpText() string {
- return fmt.Sprintf(`Supported output formats:
-- %s: Plain text output (default)
-- %s: Output wrapped in a JSON object`,
- Text, JSON)
-}
-
-// FormatOutput formats the AI response according to the specified format
-func FormatOutput(content string, formatStr string) (string, error) {
- format, err := Parse(formatStr)
- if err != nil {
- format = Text
- }
-
- switch format {
- case JSON:
- return formatAsJSON(content)
- case Text:
- fallthrough
- default:
- return content, nil
- }
-}
-
-// formatAsJSON wraps the content in a simple JSON object
-func formatAsJSON(content string) (string, error) {
- // Use the JSON package to properly escape the content
- response := struct {
- Response string `json:"response"`
- }{
- Response: content,
- }
-
- jsonBytes, err := json.MarshalIndent(response, "", " ")
- if err != nil {
- return "", fmt.Errorf("failed to marshal output into JSON: %w", err)
- }
-
- return string(jsonBytes), nil
-}
@@ -18,24 +18,48 @@ type Spinner struct {
prog *tea.Program
}
+type model struct {
+ cancel context.CancelFunc
+ anim anim.Anim
+}
+
+func (m model) Init() tea.Cmd { return m.anim.Init() }
+func (m model) View() string { return m.anim.View() }
+
+// Update implements tea.Model.
+func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
+ switch msg := msg.(type) {
+ case tea.KeyPressMsg:
+ switch msg.String() {
+ case "ctrl+c", "esc":
+ m.cancel()
+ return m, tea.Quit
+ }
+ }
+ mm, cmd := m.anim.Update(msg)
+ m.anim = mm.(anim.Anim)
+ return m, cmd
+}
+
// NewSpinner creates a new spinner with the given message
-func NewSpinner(ctx context.Context, message string) *Spinner {
+func NewSpinner(ctx context.Context, cancel context.CancelFunc, message string) *Spinner {
t := styles.CurrentTheme()
- model := anim.New(anim.Settings{
- Size: 10,
- Label: message,
- LabelColor: t.FgBase,
- GradColorA: t.Primary,
- GradColorB: t.Secondary,
- CycleColors: true,
- })
+ model := model{
+ anim: anim.New(anim.Settings{
+ Size: 10,
+ Label: message,
+ LabelColor: t.FgBase,
+ GradColorA: t.Primary,
+ GradColorB: t.Secondary,
+ CycleColors: true,
+ }),
+ cancel: cancel,
+ }
prog := tea.NewProgram(
model,
- tea.WithInput(nil),
tea.WithOutput(os.Stderr),
tea.WithContext(ctx),
- tea.WithoutCatchPanics(),
)
return &Spinner{
@@ -47,13 +71,13 @@ func NewSpinner(ctx context.Context, message string) *Spinner {
// Start begins the spinner animation
func (s *Spinner) Start() {
go func() {
+ defer close(s.done)
_, err := s.prog.Run()
// ensures line is cleared
fmt.Fprint(os.Stderr, ansi.EraseEntireLine)
- if err != nil && !errors.Is(err, context.Canceled) {
+ if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, tea.ErrInterrupted) {
fmt.Fprintf(os.Stderr, "Error running spinner: %v\n", err)
}
- close(s.done)
}()
}
@@ -10,7 +10,7 @@ import (
"github.com/charmbracelet/crush/internal/fur/provider"
)
-const defaultURL = "https://fur.charmcli.dev"
+const defaultURL = "https://fur.charm.sh"
// Client represents a client for the fur service.
type Client struct {
@@ -149,7 +149,7 @@ func NewAgent(
}
opts := []provider.ProviderClientOption{
provider.WithModel(agentCfg.Model),
- provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
+ provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
}
agentProvider, err := provider.NewProvider(*providerCfg, opts...)
if err != nil {
@@ -827,7 +827,7 @@ func (a *agent) UpdateModel() error {
opts := []provider.ProviderClientOption{
provider.WithModel(a.agentCfg.Model),
- provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
+ provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
}
newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
@@ -5,6 +5,9 @@ import (
"path/filepath"
"strings"
"sync"
+
+ "github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/env"
)
type PromptID string
@@ -21,7 +24,7 @@ func GetPrompt(promptID PromptID, provider string, contextPaths ...string) strin
basePrompt := ""
switch promptID {
case PromptCoder:
- basePrompt = CoderPrompt(provider)
+ basePrompt = CoderPrompt(provider, contextPaths...)
case PromptTitle:
basePrompt = TitlePrompt()
case PromptTask:
@@ -38,6 +41,32 @@ func getContextFromPaths(workingDir string, contextPaths []string) string {
return processContextPaths(workingDir, contextPaths)
}
+// expandPath expands ~ and environment variables in file paths
+func expandPath(path string) string {
+ // Handle tilde expansion
+ if strings.HasPrefix(path, "~/") {
+ homeDir, err := os.UserHomeDir()
+ if err == nil {
+ path = filepath.Join(homeDir, path[2:])
+ }
+ } else if path == "~" {
+ homeDir, err := os.UserHomeDir()
+ if err == nil {
+ path = homeDir
+ }
+ }
+
+ // Handle environment variable expansion using the same pattern as config
+ if strings.HasPrefix(path, "$") {
+ resolver := config.NewEnvironmentVariableResolver(env.New())
+ if expanded, err := resolver.ResolveValue(path); err == nil {
+ path = expanded
+ }
+ }
+
+ return path
+}
+
func processContextPaths(workDir string, paths []string) string {
var (
wg sync.WaitGroup
@@ -53,8 +82,23 @@ func processContextPaths(workDir string, paths []string) string {
go func(p string) {
defer wg.Done()
- if strings.HasSuffix(p, "/") {
- filepath.WalkDir(filepath.Join(workDir, p), func(path string, d os.DirEntry, err error) error {
+ // Expand ~ and environment variables before processing
+ p = expandPath(p)
+
+ // Use absolute path if provided, otherwise join with workDir
+ fullPath := p
+ if !filepath.IsAbs(p) {
+ fullPath = filepath.Join(workDir, p)
+ }
+
+ // Check if the path is a directory using os.Stat
+ info, err := os.Stat(fullPath)
+ if err != nil {
+ return // Skip if path doesn't exist or can't be accessed
+ }
+
+ if info.IsDir() {
+ filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
if err != nil {
return err
}
@@ -78,8 +122,7 @@ func processContextPaths(workDir string, paths []string) string {
return nil
})
} else {
- fullPath := filepath.Join(workDir, p)
-
+ // It's a file, process it directly
// Check if we've already processed this file (case-insensitive)
lowerPath := strings.ToLower(fullPath)
@@ -0,0 +1,113 @@
+package prompt
+
+import (
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+)
+
+func TestExpandPath(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ expected func() string
+ }{
+ {
+ name: "regular path unchanged",
+ input: "/absolute/path",
+ expected: func() string {
+ return "/absolute/path"
+ },
+ },
+ {
+ name: "tilde expansion",
+ input: "~/documents",
+ expected: func() string {
+ home, _ := os.UserHomeDir()
+ return filepath.Join(home, "documents")
+ },
+ },
+ {
+ name: "tilde only",
+ input: "~",
+ expected: func() string {
+ home, _ := os.UserHomeDir()
+ return home
+ },
+ },
+ {
+ name: "environment variable expansion",
+ input: "$HOME",
+ expected: func() string {
+ return os.Getenv("HOME")
+ },
+ },
+ {
+ name: "relative path unchanged",
+ input: "relative/path",
+ expected: func() string {
+ return "relative/path"
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := expandPath(tt.input)
+ expected := tt.expected()
+
+ // Skip test if environment variable is not set
+ if strings.HasPrefix(tt.input, "$") && expected == "" {
+ t.Skip("Environment variable not set")
+ }
+
+ if result != expected {
+ t.Errorf("expandPath(%q) = %q, want %q", tt.input, result, expected)
+ }
+ })
+ }
+}
+
+func TestProcessContextPaths(t *testing.T) {
+ // Create a temporary directory and file for testing
+ tmpDir := t.TempDir()
+ testFile := filepath.Join(tmpDir, "test.txt")
+ testContent := "test content"
+
+ err := os.WriteFile(testFile, []byte(testContent), 0o644)
+ if err != nil {
+ t.Fatalf("Failed to create test file: %v", err)
+ }
+
+ // Test with absolute path to file
+ result := processContextPaths("", []string{testFile})
+ expected := "# From:" + testFile + "\n" + testContent
+
+ if result != expected {
+ t.Errorf("processContextPaths with absolute path failed.\nGot: %q\nWant: %q", result, expected)
+ }
+
+ // Test with directory path (should process all files in directory)
+ result = processContextPaths("", []string{tmpDir})
+ if !strings.Contains(result, testContent) {
+ t.Errorf("processContextPaths with directory path failed to include file content")
+ }
+
+ // Test with tilde expansion (if we can create a file in home directory)
+ tmpDir = t.TempDir()
+ t.Setenv("HOME", tmpDir)
+ homeTestFile := filepath.Join(tmpDir, "crush_test_file.txt")
+ err = os.WriteFile(homeTestFile, []byte(testContent), 0o644)
+ if err == nil {
+ defer os.Remove(homeTestFile) // Clean up
+
+ tildeFile := "~/crush_test_file.txt"
+ result = processContextPaths("", []string{tildeFile})
+ expected = "# From:" + homeTestFile + "\n" + testContent
+
+ if result != expected {
+ t.Errorf("processContextPaths with tilde expansion failed.\nGot: %q\nWant: %q", result, expected)
+ }
+ }
+}
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
+ "log/slog"
"runtime"
"strings"
"time"
@@ -41,9 +42,74 @@ const (
)
var bannedCommands = []string{
- "alias", "curl", "curlie", "wget", "axel", "aria2c",
- "nc", "telnet", "lynx", "w3m", "links", "httpie", "xh",
- "http-prompt", "chrome", "firefox", "safari",
+ // Network/Download tools
+ "alias",
+ "aria2c",
+ "axel",
+ "chrome",
+ "curl",
+ "curlie",
+ "firefox",
+ "http-prompt",
+ "httpie",
+ "links",
+ "lynx",
+ "nc",
+ "safari",
+ "telnet",
+ "w3m",
+ "wget",
+ "xh",
+
+ // System administration
+ "doas",
+ "su",
+ "sudo",
+
+ // Package managers
+ "apk",
+ "apt",
+ "apt-cache",
+ "apt-get",
+ "dnf",
+ "dpkg",
+ "emerge",
+ "home-manager",
+ "makepkg",
+ "opkg",
+ "pacman",
+ "paru",
+ "pkg",
+ "pkg_add",
+ "pkg_delete",
+ "portage",
+ "rpm",
+ "yay",
+ "yum",
+ "zypper",
+
+ // System modification
+ "at",
+ "batch",
+ "chkconfig",
+ "crontab",
+ "fdisk",
+ "mkfs",
+ "mount",
+ "parted",
+ "service",
+ "systemctl",
+ "umount",
+
+ // Network configuration
+ "firewall-cmd",
+ "ifconfig",
+ "ip",
+ "iptables",
+ "netstat",
+ "pfctl",
+ "route",
+ "ufw",
}
// getSafeReadOnlyCommands returns platform-appropriate safe commands
@@ -244,7 +310,42 @@ Important:
- Never update git config`, bannedCommandsStr, MaxOutputLength)
}
+func blockFuncs() []shell.BlockFunc {
+ return []shell.BlockFunc{
+ shell.CommandsBlocker(bannedCommands),
+ shell.ArgumentsBlocker([][]string{
+ // System package managers
+ {"apk", "add"},
+ {"apt", "install"},
+ {"apt-get", "install"},
+ {"dnf", "install"},
+ {"emerge"},
+ {"pacman", "-S"},
+ {"pkg", "install"},
+ {"yum", "install"},
+ {"zypper", "install"},
+
+ // Language-specific package managers
+ {"brew", "install"},
+ {"cargo", "install"},
+ {"gem", "install"},
+ {"go", "install"},
+ {"npm", "install", "-g"},
+ {"npm", "install", "--global"},
+ {"pip", "install", "--user"},
+ {"pip3", "install", "--user"},
+ {"pnpm", "add", "-g"},
+ {"pnpm", "add", "--global"},
+ {"yarn", "global", "add"},
+ }),
+ }
+}
+
func NewBashTool(permission permission.Service, workingDir string) BaseTool {
+ // Set up command blocking on the persistent shell
+ persistentShell := shell.GetPersistentShell(workingDir)
+ persistentShell.SetBlockFuncs(blockFuncs())
+
return &bashTool{
permissions: permission,
workingDir: workingDir,
@@ -289,13 +390,6 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
return NewTextErrorResponse("missing command"), nil
}
- baseCmd := strings.Fields(params.Command)[0]
- for _, banned := range bannedCommands {
- if strings.EqualFold(baseCmd, banned) {
- return NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", baseCmd)), nil
- }
- }
-
isSafeReadOnly := false
cmdLower := strings.ToLower(params.Command)
@@ -349,7 +443,20 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
stdout = truncateOutput(stdout)
stderr = truncateOutput(stderr)
+ slog.Info("Bash command executed",
+ "command", params.Command,
+ "stdout", stdout,
+ "stderr", stderr,
+ "exit_code", exitCode,
+ "interrupted", interrupted,
+ "err", err,
+ )
+
errorMessage := stderr
+ if errorMessage == "" && err != nil {
+ errorMessage = err.Error()
+ }
+
if interrupted {
if errorMessage != "" {
errorMessage += "\n"
@@ -222,29 +222,32 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any
}
// Wait for response
- resp := <-ch
-
- if cfg.Options.DebugLSP {
- slog.Debug("Received response", "id", id)
- }
-
- if resp.Error != nil {
- return fmt.Errorf("request failed: %s (code: %d)", resp.Error.Message, resp.Error.Code)
- }
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case resp := <-ch:
+ if cfg.Options.DebugLSP {
+ slog.Debug("Received response", "id", id)
+ }
- if result != nil {
- // If result is a json.RawMessage, just copy the raw bytes
- if rawMsg, ok := result.(*json.RawMessage); ok {
- *rawMsg = resp.Result
- return nil
+ if resp.Error != nil {
+ return fmt.Errorf("request failed: %s (code: %d)", resp.Error.Message, resp.Error.Code)
}
- // Otherwise unmarshal into the provided type
- if err := json.Unmarshal(resp.Result, result); err != nil {
- return fmt.Errorf("failed to unmarshal result: %w", err)
+
+ if result != nil {
+ // If result is a json.RawMessage, just copy the raw bytes
+ if rawMsg, ok := result.(*json.RawMessage); ok {
+ *rawMsg = resp.Result
+ return nil
+ }
+ // Otherwise unmarshal into the provided type
+ if err := json.Unmarshal(resp.Result, result); err != nil {
+ return fmt.Errorf("failed to unmarshal result: %w", err)
+ }
}
- }
- return nil
+ return nil
+ }
}
// Notify sends a notification (a request without an ID that doesn't expect a response)
@@ -21,6 +21,7 @@ import (
// WorkspaceWatcher manages LSP file watching
type WorkspaceWatcher struct {
client *lsp.Client
+ name string
workspacePath string
debounceTime time.Duration
@@ -33,8 +34,9 @@ type WorkspaceWatcher struct {
}
// NewWorkspaceWatcher creates a new workspace watcher
-func NewWorkspaceWatcher(client *lsp.Client) *WorkspaceWatcher {
+func NewWorkspaceWatcher(name string, client *lsp.Client) *WorkspaceWatcher {
return &WorkspaceWatcher{
+ name: name,
client: client,
debounceTime: 300 * time.Millisecond,
debounceMap: make(map[string]*time.Timer),
@@ -95,7 +97,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
}
// Determine server type for specialized handling
- serverName := getServerNameFromContext(ctx)
+ serverName := w.name
slog.Debug("Server type detected", "serverName", serverName)
// Check if this server has sent file watchers
@@ -325,17 +327,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
cfg := config.Get()
w.workspacePath = workspacePath
- // Store the watcher in the context for later use
- ctx = context.WithValue(ctx, "workspaceWatcher", w)
-
- // If the server name isn't already in the context, try to detect it
- if _, ok := ctx.Value("serverName").(string); !ok {
- serverName := getServerNameFromContext(ctx)
- ctx = context.WithValue(ctx, "serverName", serverName)
- }
-
- serverName := getServerNameFromContext(ctx)
- slog.Debug("Starting workspace watcher", "workspacePath", workspacePath, "serverName", serverName)
+ slog.Debug("Starting workspace watcher", "workspacePath", workspacePath, "serverName", w.name)
// Register handler for file watcher registrations from the server
lsp.RegisterFileWatchHandler(func(id string, watchers []protocol.FileSystemWatcher) {
@@ -697,40 +689,6 @@ func (w *WorkspaceWatcher) notifyFileEvent(ctx context.Context, uri string, chan
return w.client.DidChangeWatchedFiles(ctx, params)
}
-// getServerNameFromContext extracts the server name from the context
-// This is a best-effort function that tries to identify which LSP server we're dealing with
-func getServerNameFromContext(ctx context.Context) string {
- // First check if the server name is directly stored in the context
- if serverName, ok := ctx.Value("serverName").(string); ok && serverName != "" {
- return strings.ToLower(serverName)
- }
-
- // Otherwise, try to extract server name from the client command path
- if w, ok := ctx.Value("workspaceWatcher").(*WorkspaceWatcher); ok && w != nil && w.client != nil && w.client.Cmd != nil {
- path := strings.ToLower(w.client.Cmd.Path)
-
- // Extract server name from path
- if strings.Contains(path, "typescript") || strings.Contains(path, "tsserver") || strings.Contains(path, "vtsls") {
- return "typescript"
- } else if strings.Contains(path, "gopls") {
- return "gopls"
- } else if strings.Contains(path, "rust-analyzer") {
- return "rust-analyzer"
- } else if strings.Contains(path, "pyright") || strings.Contains(path, "pylsp") || strings.Contains(path, "python") {
- return "python"
- } else if strings.Contains(path, "clangd") {
- return "clangd"
- } else if strings.Contains(path, "jdtls") || strings.Contains(path, "java") {
- return "java"
- }
-
- // Return the base name as fallback
- return filepath.Base(path)
- }
-
- return "unknown"
-}
-
// shouldPreloadFiles determines if we should preload files for a specific language server
// Some servers work better with preloaded files, others don't need it
func shouldPreloadFiles(serverName string) bool {
@@ -884,64 +842,63 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) {
}
// Check if this path should be watched according to server registrations
- if watched, _ := w.isPathWatched(path); watched {
- // Get server name for specialized handling
- serverName := getServerNameFromContext(ctx)
+ if watched, _ := w.isPathWatched(path); !watched {
+ return
+ }
- // Check if the file is a high-priority file that should be opened immediately
- // This helps with project initialization for certain language servers
- if isHighPriorityFile(path, serverName) {
- if cfg.Options.DebugLSP {
- slog.Debug("Opening high-priority file", "path", path, "serverName", serverName)
- }
- if err := w.client.OpenFile(ctx, path); err != nil && cfg.Options.DebugLSP {
- slog.Error("Error opening high-priority file", "path", path, "error", err)
- }
- return
- }
+ serverName := w.name
- // For non-high-priority files, we'll use different strategies based on server type
- if shouldPreloadFiles(serverName) {
- // For servers that benefit from preloading, open files but with limits
+ // Get server name for specialized handling
+ // Check if the file is a high-priority file that should be opened immediately
+ // This helps with project initialization for certain language servers
+ if isHighPriorityFile(path, serverName) {
+ if cfg.Options.DebugLSP {
+ slog.Debug("Opening high-priority file", "path", path, "serverName", serverName)
+ }
+ if err := w.client.OpenFile(ctx, path); err != nil && cfg.Options.DebugLSP {
+ slog.Error("Error opening high-priority file", "path", path, "error", err)
+ }
+ return
+ }
- // Check file size - for preloading we're more conservative
- if info.Size() > (1 * 1024 * 1024) { // 1MB limit for preloaded files
- if cfg.Options.DebugLSP {
- slog.Debug("Skipping large file for preloading", "path", path, "size", info.Size())
- }
- return
- }
+ // For non-high-priority files, we'll use different strategies based on server type
+ if !shouldPreloadFiles(serverName) {
+ return
+ }
+ // For servers that benefit from preloading, open files but with limits
- // Check file extension for common source files
- ext := strings.ToLower(filepath.Ext(path))
+ // Check file size - for preloading we're more conservative
+ if info.Size() > (1 * 1024 * 1024) { // 1MB limit for preloaded files
+ if cfg.Options.DebugLSP {
+ slog.Debug("Skipping large file for preloading", "path", path, "size", info.Size())
+ }
+ return
+ }
- // Only preload source files for the specific language
- shouldOpen := false
+ // Check file extension for common source files
+ ext := strings.ToLower(filepath.Ext(path))
- switch serverName {
- case "typescript", "typescript-language-server", "tsserver", "vtsls":
- shouldOpen = ext == ".ts" || ext == ".js" || ext == ".tsx" || ext == ".jsx"
- case "gopls":
- shouldOpen = ext == ".go"
- case "rust-analyzer":
- shouldOpen = ext == ".rs"
- case "python", "pyright", "pylsp":
- shouldOpen = ext == ".py"
- case "clangd":
- shouldOpen = ext == ".c" || ext == ".cpp" || ext == ".h" || ext == ".hpp"
- case "java", "jdtls":
- shouldOpen = ext == ".java"
- default:
- // For unknown servers, be conservative
- shouldOpen = false
- }
+ // Only preload source files for the specific language
+ var shouldOpen bool
+ switch serverName {
+ case "typescript", "typescript-language-server", "tsserver", "vtsls":
+ shouldOpen = ext == ".ts" || ext == ".js" || ext == ".tsx" || ext == ".jsx"
+ case "gopls":
+ shouldOpen = ext == ".go"
+ case "rust-analyzer":
+ shouldOpen = ext == ".rs"
+ case "python", "pyright", "pylsp":
+ shouldOpen = ext == ".py"
+ case "clangd":
+ shouldOpen = ext == ".c" || ext == ".cpp" || ext == ".h" || ext == ".hpp"
+ case "java", "jdtls":
+ shouldOpen = ext == ".java"
+ }
- if shouldOpen {
- // Don't need to check if it's already open - the client.OpenFile handles that
- if err := w.client.OpenFile(ctx, path); err != nil && cfg.Options.DebugLSP {
- slog.Error("Error opening file", "path", path, "error", err)
- }
- }
+ if shouldOpen {
+ // Don't need to check if it's already open - the client.OpenFile handles that
+ if err := w.client.OpenFile(ctx, path); err != nil && cfg.Options.DebugLSP {
+ slog.Error("Error opening file", "path", path, "error", err)
}
}
}
@@ -0,0 +1,123 @@
+package shell
+
+import (
+ "context"
+ "os"
+ "strings"
+ "testing"
+)
+
+func TestCommandBlocking(t *testing.T) {
+ tests := []struct {
+ name string
+ blockFuncs []BlockFunc
+ command string
+ shouldBlock bool
+ }{
+ {
+ name: "block simple command",
+ blockFuncs: []BlockFunc{
+ func(args []string) bool {
+ return len(args) > 0 && args[0] == "curl"
+ },
+ },
+ command: "curl https://example.com",
+ shouldBlock: true,
+ },
+ {
+ name: "allow non-blocked command",
+ blockFuncs: []BlockFunc{
+ func(args []string) bool {
+ return len(args) > 0 && args[0] == "curl"
+ },
+ },
+ command: "echo hello",
+ shouldBlock: false,
+ },
+ {
+ name: "block subcommand",
+ blockFuncs: []BlockFunc{
+ func(args []string) bool {
+ return len(args) >= 2 && args[0] == "brew" && args[1] == "install"
+ },
+ },
+ command: "brew install wget",
+ shouldBlock: true,
+ },
+ {
+ name: "allow different subcommand",
+ blockFuncs: []BlockFunc{
+ func(args []string) bool {
+ return len(args) >= 2 && args[0] == "brew" && args[1] == "install"
+ },
+ },
+ command: "brew list",
+ shouldBlock: false,
+ },
+ {
+ name: "block npm global install with -g",
+ blockFuncs: []BlockFunc{
+ ArgumentsBlocker([][]string{
+ {"npm", "install", "-g"},
+ {"npm", "install", "--global"},
+ }),
+ },
+ command: "npm install -g typescript",
+ shouldBlock: true,
+ },
+ {
+ name: "block npm global install with --global",
+ blockFuncs: []BlockFunc{
+ ArgumentsBlocker([][]string{
+ {"npm", "install", "-g"},
+ {"npm", "install", "--global"},
+ }),
+ },
+ command: "npm install --global typescript",
+ shouldBlock: true,
+ },
+ {
+ name: "allow npm local install",
+ blockFuncs: []BlockFunc{
+ ArgumentsBlocker([][]string{
+ {"npm", "install", "-g"},
+ {"npm", "install", "--global"},
+ }),
+ },
+ command: "npm install typescript",
+ shouldBlock: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Create a temporary directory for each test
+ tmpDir, err := os.MkdirTemp("", "shell-test-*")
+ if err != nil {
+ t.Fatalf("Failed to create temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ shell := NewShell(&Options{
+ WorkingDir: tmpDir,
+ BlockFuncs: tt.blockFuncs,
+ })
+
+ _, _, err = shell.Exec(context.Background(), tt.command)
+
+ if tt.shouldBlock {
+ if err == nil {
+ t.Errorf("Expected command to be blocked, but it was allowed")
+ } else if !strings.Contains(err.Error(), "not allowed for security reasons") {
+ t.Errorf("Expected security error, got: %v", err)
+ }
+ } else {
+ // For non-blocked commands, we might get other errors (like command not found)
+ // but we shouldn't get the security error
+ if err != nil && strings.Contains(err.Error(), "not allowed for security reasons") {
+ t.Errorf("Command was unexpectedly blocked: %v", err)
+ }
+ }
+ })
+ }
+}
@@ -44,12 +44,16 @@ type noopLogger struct{}
func (noopLogger) InfoPersist(msg string, keysAndValues ...interface{}) {}
+// BlockFunc is a function that determines if a command should be blocked
+type BlockFunc func(args []string) bool
+
// Shell provides cross-platform shell execution with optional state persistence
type Shell struct {
- env []string
- cwd string
- mu sync.Mutex
- logger Logger
+ env []string
+ cwd string
+ mu sync.Mutex
+ logger Logger
+ blockFuncs []BlockFunc
}
// Options for creating a new shell
@@ -57,6 +61,7 @@ type Options struct {
WorkingDir string
Env []string
Logger Logger
+ BlockFuncs []BlockFunc
}
// NewShell creates a new shell instance with the given options
@@ -81,9 +86,10 @@ func NewShell(opts *Options) *Shell {
}
return &Shell{
- cwd: cwd,
- env: env,
- logger: logger,
+ cwd: cwd,
+ env: env,
+ logger: logger,
+ blockFuncs: opts.BlockFuncs,
}
}
@@ -152,6 +158,13 @@ func (s *Shell) SetEnv(key, value string) {
s.env = append(s.env, keyPrefix+value)
}
+// SetBlockFuncs sets the command block functions for the shell
+func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.blockFuncs = blockFuncs
+}
+
// Windows-specific commands that should use native shell
var windowsNativeCommands = map[string]bool{
"dir": true,
@@ -203,6 +216,60 @@ func (s *Shell) determineShellType(command string) ShellType {
return ShellTypePOSIX
}
+// CommandsBlocker creates a BlockFunc that blocks exact command matches
+func CommandsBlocker(bannedCommands []string) BlockFunc {
+ bannedSet := make(map[string]bool)
+ for _, cmd := range bannedCommands {
+ bannedSet[cmd] = true
+ }
+
+ return func(args []string) bool {
+ if len(args) == 0 {
+ return false
+ }
+ return bannedSet[args[0]]
+ }
+}
+
+// ArgumentsBlocker creates a BlockFunc that blocks specific subcommands
+func ArgumentsBlocker(blockedSubCommands [][]string) BlockFunc {
+ return func(args []string) bool {
+ for _, blocked := range blockedSubCommands {
+ if len(args) >= len(blocked) {
+ match := true
+ for i, part := range blocked {
+ if args[i] != part {
+ match = false
+ break
+ }
+ }
+ if match {
+ return true
+ }
+ }
+ }
+ return false
+ }
+}
+
+func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
+ return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
+ return func(ctx context.Context, args []string) error {
+ if len(args) == 0 {
+ return next(ctx, args)
+ }
+
+ for _, blockFunc := range s.blockFuncs {
+ if blockFunc(args) {
+ return fmt.Errorf("command is not allowed for security reasons: %s", strings.Join(args, " "))
+ }
+ }
+
+ return next(ctx, args)
+ }
+ }
+}
+
// execWindows executes commands using native Windows shells (cmd.exe or PowerShell)
func (s *Shell) execWindows(ctx context.Context, command string, shell string) (string, string, error) {
var cmd *exec.Cmd
@@ -291,6 +358,7 @@ func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string,
interp.Interactive(false),
interp.Env(expand.ListEnviron(s.env...)),
interp.Dir(s.cwd),
+ interp.ExecHandlers(s.blockHandler()),
)
if err != nil {
return "", "", fmt.Errorf("could not run command: %w", err)
@@ -361,8 +361,9 @@ func (m *editorCmp) startCompletions() tea.Msg {
})
}
- x := m.textarea.Cursor().X + m.x + 1
- y := m.textarea.Cursor().Y + m.y + 1
+ cur := m.textarea.Cursor()
+ x := cur.X + m.x // adjust for padding
+ y := cur.Y + m.y + 1
return completions.OpenCompletionsMsg{
Completions: completionItems,
X: x,
@@ -207,6 +207,7 @@ func (br bashRenderer) Render(v *toolCallCmp) string {
}
cmd := strings.ReplaceAll(params.Command, "\n", " ")
+ cmd = strings.ReplaceAll(cmd, "\t", " ")
args := newParamBuilder().addMain(cmd).build()
return br.renderWithParams(v, "Bash", args, func() string {
@@ -578,8 +579,8 @@ func renderParamList(nested bool, paramsWidth int, params ...string) string {
return ""
}
mainParam := params[0]
- if paramsWidth-3 >= 0 && len(mainParam) > paramsWidth {
- mainParam = mainParam[:paramsWidth-3] + "…"
+ if paramsWidth >= 0 && lipgloss.Width(mainParam) > paramsWidth {
+ mainParam = ansi.Truncate(mainParam, paramsWidth, "…")
}
if len(params) == 1 {
@@ -649,7 +650,7 @@ func joinHeaderBody(header, body string) string {
return header
}
body = t.S().Base.PaddingLeft(2).Render(body)
- return lipgloss.JoinVertical(lipgloss.Left, header, "", body, "")
+ return lipgloss.JoinVertical(lipgloss.Left, header, "", body)
}
func renderPlainContent(v *toolCallCmp, content string) string {
@@ -30,11 +30,11 @@ func DefaultKeyMap() KeyMap {
key.WithHelp("↑", "previous item"),
),
Yes: key.NewBinding(
- key.WithKeys("y"),
+ key.WithKeys("y", "Y"),
key.WithHelp("y", "yes"),
),
No: key.NewBinding(
- key.WithKeys("n"),
+ key.WithKeys("n", "N"),
key.WithHelp("n", "no"),
),
Tab: key.NewBinding(
@@ -313,6 +313,7 @@ func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
return util.ReportError(err)
}
}
+ cfg.SetupAgents()
return nil
}
@@ -9,6 +9,8 @@ import (
"github.com/charmbracelet/lipgloss/v2"
)
+const maxCompletionsHeight = 10
+
type Completion struct {
Title string // The title of the completion item
Value any // The value of the completion item
@@ -43,7 +45,7 @@ type Completions interface {
type completionsCmp struct {
width int
height int // Height of the completions component`
- x int // X position for the completions popup\
+ x int // X position for the completions popup
y int // Y position for the completions popup
open bool // Indicates if the completions are open
keyMap KeyMap
@@ -150,18 +152,25 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if !c.open {
return c, nil // If completions are not open, do nothing
}
- cmd := c.list.Filter(msg.Query)
- c.height = max(min(10, len(c.list.Items())), 1)
- return c, tea.Batch(
- cmd,
- c.list.SetSize(c.width, c.height),
- )
+ var cmds []tea.Cmd
+ cmds = append(cmds, c.list.Filter(msg.Query))
+ itemsLen := len(c.list.Items())
+ c.height = max(min(maxCompletionsHeight, itemsLen), 1)
+ cmds = append(cmds, c.list.SetSize(c.width, c.height))
+ if itemsLen == 0 {
+ // Close completions if no items match the query
+ cmds = append(cmds, util.CmdHandler(CloseCompletionsMsg{}))
+ }
+ return c, tea.Batch(cmds...)
}
return c, nil
}
// View implements Completions.
func (c *completionsCmp) View() string {
+ if !c.open {
+ return ""
+ }
if len(c.list.Items()) == 0 {
return c.style().Render("No completions found")
}
@@ -109,9 +109,5 @@ func (k KeyMap) ShortHelp() []key.Binding {
key.WithKeys("shift+left", "shift+down", "shift+up", "shift+right"),
key.WithHelp("shift+←↓↑→", "scroll"),
),
- key.NewBinding(
- key.WithKeys("shift+h", "shift+j", "shift+k", "shift+l"),
- key.WithHelp("shift+hjkl", "scroll"),
- ),
}
}
@@ -365,7 +365,8 @@ func (dv *DiffView) renderUnified() string {
shouldWrite := func() bool { return printedLines >= 0 }
getContent := func(in string, ls LineStyle) (content string, leadingEllipsis bool) {
- content = strings.TrimSuffix(in, "\n")
+ content = strings.ReplaceAll(in, "\r\n", "\n")
+ content = strings.TrimSuffix(content, "\n")
content = dv.hightlightCode(content, ls.Code.GetBackground())
content = ansi.GraphemeWidth.Cut(content, dv.xOffset, len(content))
content = ansi.Truncate(content, dv.codeWidth, "…")
@@ -488,7 +489,8 @@ func (dv *DiffView) renderSplit() string {
shouldWrite := func() bool { return printedLines >= 0 }
getContent := func(in string, ls LineStyle) (content string, leadingEllipsis bool) {
- content = strings.TrimSuffix(in, "\n")
+ content = strings.ReplaceAll(in, "\r\n", "\n")
+ content = strings.TrimSuffix(content, "\n")
content = dv.hightlightCode(content, ls.Code.GetBackground())
content = ansi.GraphemeWidth.Cut(content, dv.xOffset, len(content))
content = ansi.Truncate(content, dv.codeWidth, "…")
@@ -2,7 +2,6 @@ package chat
import (
"context"
- "runtime"
"time"
"github.com/charmbracelet/bubbles/v2/help"
@@ -615,26 +614,12 @@ func (a *chatPage) Help() help.KeyMap {
fullList = append(fullList, []key.Binding{v})
}
case a.isOnboarding && a.splash.IsShowingAPIKey():
- var pasteKey key.Binding
- if runtime.GOOS != "darwin" {
- pasteKey = key.NewBinding(
- key.WithKeys("ctrl+v"),
- key.WithHelp("ctrl+v", "paste API key"),
- )
- } else {
- pasteKey = key.NewBinding(
- key.WithKeys("cmd+v"),
- key.WithHelp("cmd+v", "paste API key"),
- )
- }
shortList = append(shortList,
// Go back
key.NewBinding(
key.WithKeys("esc"),
key.WithHelp("esc", "back"),
),
- // Paste
- pasteKey,
// Quit
key.NewBinding(
key.WithKeys("ctrl+c"),
@@ -327,6 +327,9 @@ func (a *appModel) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd {
// If the commands dialog is already open, close it
return util.CmdHandler(dialogs.CloseDialogMsg{})
}
+ if a.dialog.HasDialogs() {
+ return nil // Don't open commands dialog if another dialog is active
+ }
return util.CmdHandler(dialogs.OpenDialogMsg{
Model: commands.NewCommandDialog(a.selectedSessionID),
})
@@ -1,6 +1,7 @@
package util
import (
+ "log/slog"
"time"
tea "github.com/charmbracelet/bubbletea/v2"
@@ -22,6 +23,7 @@ func CmdHandler(msg tea.Msg) tea.Cmd {
}
func ReportError(err error) tea.Cmd {
+ slog.Error("Error reported", "error", err)
return CmdHandler(InfoMsg{
Type: InfoTypeError,
Msg: err.Error(),
@@ -1,60 +0,0 @@
-github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.0-beta.1 h1:ODs3brnqQM99Tq1PffODpAViYv3Bf8zOg464MU7p5ew=
-github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.0-beta.1/go.mod h1:3Ug6Qzto9anB6mGlEdgYMDF5zHQ+wwhEaYR4s17PHMw=
-github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.0 h1:fb8kj/Dh4CSwgsOzHeZY4Xh68cFVbzXx+ONXGMY//4w=
-github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.0/go.mod h1:uReU2sSxZExRPBAg3qKzmAucSi51+SP1OhohieR821Q=
-github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0/go.mod h1:okt5dMMTOFjX/aovMlrjvvXoPMBVSPzk9185BT0+eZM=
-github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2/go.mod h1:yInRyqWXAuaPrgI7p70+lDDgh3mlBohis29jGMISnmc=
-github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI=
-github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
-github.com/keybase/dbus v0.0.0-20220506165403-5aa21ea2c23a/go.mod h1:YPNKjjE7Ubp9dTbnWvsP3HT+hYnY6TfXzubYTBeUxc8=
-github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
-github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
-github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
-github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
-github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow=
-github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
-github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
-github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
-github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
-github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
-github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
-github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
-github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
-github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
-github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
-github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
-github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
-golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
-golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
-golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
-golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
-golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
-golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
-golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
-golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
-golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
-golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
-golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
-golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
-golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY=
-golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
-golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o=
-golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
-golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
-golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
-golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0=
-golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
-golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk=
-golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
-golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
-golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
-golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
-golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
-golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
-gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
-gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
-gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
@@ -1,7 +1,7 @@
# Fang
<p>
- <img width="485" alt="Charm Fang" src="https://github.com/user-attachments/assets/3f34ea01-3750-4760-beb2-a1b700e110f5">
+ <img width="485" alt="Charm Fang" src="https://github.com/user-attachments/assets/3f34ea01-3750-4760-beb2-a1b700e110f5">
</p>
<p>
<a href="https://github.com/charmbracelet/fang/releases"><img src="https://img.shields.io/github/release/charmbracelet/fang.svg" alt="Latest Release"></a>
@@ -12,7 +12,7 @@
The CLI starter kit. A small, experimental library for batteries-included [Cobra][cobra] applications.
<p>
- <img width="865" alt="fang-02" src="https://github.com/user-attachments/assets/7f68ec3f-2b42-4188-a750-7e2808696132" />
+ <img width="859" alt="The Charm Fang mascot and title treatment" src="https://github.com/user-attachments/assets/5c35e1fa-9577-4f81-a879-3ddb4d4a43f0" />
</p>
## Features
@@ -45,6 +45,7 @@ To use it, invoke `fang.Execute` passing your root `*cobra.Command`:
package main
import (
+ "context"
"os"
"github.com/charmbracelet/fang"
@@ -56,7 +57,7 @@ func main() {
Use: "example",
Short: "A simple example program!",
}
- if err := fang.Execute(context.TODO(), cmd); err != nil {
+ if err := fang.Execute(context.Background(), cmd); err != nil {
os.Exit(1)
}
}
@@ -4,11 +4,14 @@ package fang
import (
"context"
"fmt"
+ "io"
"os"
+ "os/signal"
"runtime/debug"
"github.com/charmbracelet/colorprofile"
"github.com/charmbracelet/lipgloss/v2"
+ "github.com/charmbracelet/x/term"
mango "github.com/muesli/mango-cobra"
"github.com/muesli/roff"
"github.com/spf13/cobra"
@@ -16,12 +19,24 @@ import (
const shaLen = 7
+// ErrorHandler handles an error, printing them to the given [io.Writer].
+//
+// Note that this will only be used if the STDERR is a terminal, and should
+// be used for styling only.
+type ErrorHandler = func(w io.Writer, styles Styles, err error)
+
+// ColorSchemeFunc gets a [lipgloss.LightDarkFunc] and returns a [ColorScheme].
+type ColorSchemeFunc = func(lipgloss.LightDarkFunc) ColorScheme
+
type settings struct {
completions bool
manpages bool
+ skipVersion bool
version string
commit string
- theme *ColorScheme
+ colorscheme ColorSchemeFunc
+ errHandler ErrorHandler
+ signals []os.Signal
}
// Option changes fang settings.
@@ -41,10 +56,21 @@ func WithoutManpage() Option {
}
}
+// WithColorSchemeFunc sets a function that return colorscheme.
+func WithColorSchemeFunc(cs ColorSchemeFunc) Option {
+ return func(s *settings) {
+ s.colorscheme = cs
+ }
+}
+
// WithTheme sets the colorscheme.
+//
+// Deprecated: use [WithColorSchemeFunc] instead.
func WithTheme(theme ColorScheme) Option {
return func(s *settings) {
- s.theme = &theme
+ s.colorscheme = func(lipgloss.LightDarkFunc) ColorScheme {
+ return theme
+ }
}
}
@@ -55,6 +81,13 @@ func WithVersion(version string) Option {
}
}
+// WithoutVersion skips the `-v`/`--version` functionality.
+func WithoutVersion() Option {
+ return func(s *settings) {
+ s.skipVersion = true
+ }
+}
+
// WithCommit sets the commit SHA.
func WithCommit(commit string) Option {
return func(s *settings) {
@@ -62,30 +95,45 @@ func WithCommit(commit string) Option {
}
}
+// WithErrorHandler sets the error handler.
+func WithErrorHandler(handler ErrorHandler) Option {
+ return func(s *settings) {
+ s.errHandler = handler
+ }
+}
+
+// WithNotifySignal sets the signals that should interrupt the execution of the
+// program.
+func WithNotifySignal(signals ...os.Signal) Option {
+ return func(s *settings) {
+ s.signals = signals
+ }
+}
+
// Execute applies fang to the command and executes it.
func Execute(ctx context.Context, root *cobra.Command, options ...Option) error {
opts := settings{
manpages: true,
completions: true,
+ colorscheme: DefaultColorScheme,
+ errHandler: DefaultErrorHandler,
}
+
for _, option := range options {
option(&opts)
}
- if opts.theme == nil {
- isDark := lipgloss.HasDarkBackground(os.Stdin, os.Stderr)
- t := DefaultTheme(isDark)
- opts.theme = &t
+ helpFunc := func(c *cobra.Command, _ []string) {
+ w := colorprofile.NewWriter(c.OutOrStdout(), os.Environ())
+ helpFn(c, w, makeStyles(mustColorscheme(opts.colorscheme)))
}
- styles := makeStyles(*opts.theme)
-
- root.SetHelpFunc(func(c *cobra.Command, _ []string) {
- w := colorprofile.NewWriter(c.OutOrStdout(), os.Environ())
- helpFn(c, w, styles)
- })
root.SilenceUsage = true
root.SilenceErrors = true
+ if !opts.skipVersion {
+ root.Version = buildVersion(opts)
+ }
+ root.SetHelpFunc(helpFunc)
if opts.manpages {
root.AddCommand(&cobra.Command{
@@ -108,34 +156,49 @@ func Execute(ctx context.Context, root *cobra.Command, options ...Option) error
})
}
- if opts.completions {
- root.InitDefaultCompletionCmd()
- } else {
+ if !opts.completions {
root.CompletionOptions.DisableDefaultCmd = true
}
- if opts.version == "" {
- if info, ok := debug.ReadBuildInfo(); ok && info.Main.Sum != "" {
- opts.version = info.Main.Version
- opts.commit = getKey(info, "vcs.revision")
- } else {
- opts.version = "unknown (built from source)"
- }
- }
- if len(opts.commit) >= shaLen {
- opts.version += " (" + opts.commit[:shaLen] + ")"
+ if len(opts.signals) > 0 {
+ var cancel context.CancelFunc
+ ctx, cancel = signal.NotifyContext(ctx, opts.signals...)
+ defer cancel()
}
- root.Version = opts.version
-
if err := root.ExecuteContext(ctx); err != nil {
+ if w, ok := root.ErrOrStderr().(term.File); ok {
+ // if stderr is not a tty, simply print the error without any
+ // styling or going through an [ErrorHandler]:
+ if !term.IsTerminal(w.Fd()) {
+ _, _ = fmt.Fprintln(w, err.Error())
+ return err //nolint:wrapcheck
+ }
+ }
w := colorprofile.NewWriter(root.ErrOrStderr(), os.Environ())
- writeError(w, styles, err)
+ opts.errHandler(w, makeStyles(mustColorscheme(opts.colorscheme)), err)
return err //nolint:wrapcheck
}
return nil
}
+func buildVersion(opts settings) string {
+ commit := opts.commit
+ version := opts.version
+ if version == "" {
+ if info, ok := debug.ReadBuildInfo(); ok && info.Main.Sum != "" {
+ version = info.Main.Version
+ commit = getKey(info, "vcs.revision")
+ } else {
+ version = "unknown (built from source)"
+ }
+ }
+ if len(commit) >= shaLen {
+ version += " (" + commit[:shaLen] + ")"
+ }
+ return version
+}
+
func getKey(info *debug.BuildInfo, key string) string {
if info == nil {
return ""
@@ -3,7 +3,10 @@ package fang
import (
"cmp"
"fmt"
+ "io"
+ "iter"
"os"
+ "reflect"
"regexp"
"strconv"
"strings"
@@ -20,6 +23,7 @@ import (
const (
minSpace = 10
shortPad = 2
+ longPad = 4
)
var width = sync.OnceValue(func() int {
@@ -45,65 +49,95 @@ func helpFn(c *cobra.Command, w *colorprofile.Writer, styles Styles) {
blockWidth = max(blockWidth, lipgloss.Width(ex))
}
blockWidth = min(width()-padding, blockWidth+padding)
+ blockStyle := styles.Codeblock.Base.Width(blockWidth)
- styles.Codeblock.Base = styles.Codeblock.Base.Width(blockWidth)
+ // if the color profile is ascii or notty, or if the block has no
+ // background color set, remove the vertical padding.
+ if w.Profile <= colorprofile.Ascii || reflect.DeepEqual(blockStyle.GetBackground(), lipgloss.NoColor{}) {
+ blockStyle = blockStyle.PaddingTop(0).PaddingBottom(0)
+ }
_, _ = fmt.Fprintln(w, styles.Title.Render("usage"))
- _, _ = fmt.Fprintln(w, styles.Codeblock.Base.Render(usage))
+ _, _ = fmt.Fprintln(w, blockStyle.Render(usage))
if len(examples) > 0 {
- cw := styles.Codeblock.Base.GetWidth() - styles.Codeblock.Base.GetHorizontalPadding()
+ cw := blockStyle.GetWidth() - blockStyle.GetHorizontalPadding()
_, _ = fmt.Fprintln(w, styles.Title.Render("examples"))
for i, example := range examples {
if lipgloss.Width(example) > cw {
examples[i] = ansi.Truncate(example, cw, "…")
}
}
- _, _ = fmt.Fprintln(w, styles.Codeblock.Base.Render(strings.Join(examples, "\n")))
+ _, _ = fmt.Fprintln(w, blockStyle.Render(strings.Join(examples, "\n")))
}
+ groups, groupKeys := evalGroups(c)
cmds, cmdKeys := evalCmds(c, styles)
flags, flagKeys := evalFlags(c, styles)
space := calculateSpace(cmdKeys, flagKeys)
- leftPadding := 4
- if len(cmds) > 0 {
- _, _ = fmt.Fprintln(w, styles.Title.Render("commands"))
- for _, k := range cmdKeys {
- _, _ = fmt.Fprintln(w, lipgloss.JoinHorizontal(
- lipgloss.Left,
- lipgloss.NewStyle().PaddingLeft(leftPadding).Render(k),
- strings.Repeat(" ", space-lipgloss.Width(k)),
- cmds[k],
- ))
+ for _, groupID := range groupKeys {
+ group := cmds[groupID]
+ if len(group) == 0 {
+ continue
}
+ renderGroup(w, styles, space, groups[groupID], func(yield func(string, string) bool) {
+ for _, k := range cmdKeys {
+ cmds, ok := group[k]
+ if !ok {
+ continue
+ }
+ if !yield(k, cmds) {
+ return
+ }
+ }
+ })
}
if len(flags) > 0 {
- _, _ = fmt.Fprintln(w, styles.Title.Render("flags"))
- for _, k := range flagKeys {
- _, _ = fmt.Fprintln(w, lipgloss.JoinHorizontal(
- lipgloss.Left,
- lipgloss.NewStyle().PaddingLeft(leftPadding).Render(k),
- strings.Repeat(" ", space-lipgloss.Width(k)),
- flags[k],
- ))
- }
+ renderGroup(w, styles, space, "flags", func(yield func(string, string) bool) {
+ for _, k := range flagKeys {
+ if !yield(k, flags[k]) {
+ return
+ }
+ }
+ })
}
_, _ = fmt.Fprintln(w)
}
-func writeError(w *colorprofile.Writer, styles Styles, err error) {
+// DefaultErrorHandler is the default [ErrorHandler] implementation.
+func DefaultErrorHandler(w io.Writer, styles Styles, err error) {
_, _ = fmt.Fprintln(w, styles.ErrorHeader.String())
_, _ = fmt.Fprintln(w, styles.ErrorText.Render(err.Error()+"."))
_, _ = fmt.Fprintln(w)
- _, _ = fmt.Fprintln(w, lipgloss.JoinHorizontal(
- lipgloss.Left,
- styles.ErrorText.UnsetWidth().Render("Try"),
- styles.Program.Flag.Render("--help"),
- styles.ErrorText.UnsetWidth().UnsetMargins().UnsetTransform().PaddingLeft(1).Render("for usage."),
- ))
- _, _ = fmt.Fprintln(w)
+ if isUsageError(err) {
+ _, _ = fmt.Fprintln(w, lipgloss.JoinHorizontal(
+ lipgloss.Left,
+ styles.ErrorText.UnsetWidth().Render("Try"),
+ styles.Program.Flag.Render(" --help "),
+ styles.ErrorText.UnsetWidth().UnsetMargins().UnsetTransform().Render("for usage."),
+ ))
+ _, _ = fmt.Fprintln(w)
+ }
+}
+
+// XXX: this is a hack to detect usage errors.
+// See: https://github.com/spf13/cobra/pull/2266
+func isUsageError(err error) bool {
+ s := err.Error()
+ for _, prefix := range []string{
+ "flag needs an argument:",
+ "unknown flag:",
+ "unknown shorthand flag:",
+ "unknown command",
+ "invalid argument",
+ } {
+ if strings.HasPrefix(s, prefix) {
+ return true
+ }
+ }
+ return false
}
func writeLongShort(w *colorprofile.Writer, styles Styles, longShort string) {
@@ -118,8 +152,10 @@ var otherArgsRe = regexp.MustCompile(`(\[.*\])`)
// styleUsage stylized styleUsage line for a given command.
func styleUsage(c *cobra.Command, styles Program, complete bool) string {
- // XXX: maybe use c.UseLine() here?
u := c.Use
+ if complete {
+ u = c.UseLine()
+ }
hasArgs := strings.Contains(u, "[args]")
hasFlags := strings.Contains(u, "[flags]") || strings.Contains(u, "[--flags]") || c.HasFlags() || c.HasPersistentFlags() || c.HasAvailableFlags()
hasCommands := strings.Contains(u, "[command]") || c.HasAvailableSubCommands()
@@ -139,34 +175,38 @@ func styleUsage(c *cobra.Command, styles Program, complete bool) string {
u = strings.TrimSpace(u)
- useLine := []string{
- styles.Name.Render(u),
- }
- if !complete {
- useLine[0] = styles.Command.Render(u)
+ useLine := []string{}
+ if complete {
+ parts := strings.Fields(u)
+ useLine = append(useLine, styles.Name.Render(parts[0]))
+ if len(parts) > 1 {
+ useLine = append(useLine, styles.Command.Render(" "+strings.Join(parts[1:], " ")))
+ }
+ } else {
+ useLine = append(useLine, styles.Command.Render(u))
}
if hasCommands {
useLine = append(
useLine,
- styles.DimmedArgument.Render("[command]"),
+ styles.DimmedArgument.Render(" [command]"),
)
}
if hasArgs {
useLine = append(
useLine,
- styles.DimmedArgument.Render("[args]"),
+ styles.DimmedArgument.Render(" [args]"),
)
}
for _, arg := range otherArgs {
useLine = append(
useLine,
- styles.DimmedArgument.Render(arg),
+ styles.DimmedArgument.Render(" "+arg),
)
}
if hasFlags {
useLine = append(
useLine,
- styles.DimmedArgument.Render("[--flags]"),
+ styles.DimmedArgument.Render(" [--flags]"),
)
}
return lipgloss.JoinHorizontal(lipgloss.Left, useLine...)
@@ -180,19 +220,21 @@ func styleExamples(c *cobra.Command, styles Styles) []string {
}
usage := []string{}
examples := strings.Split(c.Example, "\n")
+ var indent bool
for i, line := range examples {
line = strings.TrimSpace(line)
if (i == 0 || i == len(examples)-1) && line == "" {
continue
}
- s := styleExample(c, line, styles.Codeblock)
+ s := styleExample(c, line, indent, styles.Codeblock)
usage = append(usage, s)
+ indent = len(line) > 1 && (line[len(line)-1] == '\\' || line[len(line)-1] == '|')
}
return usage
}
-func styleExample(c *cobra.Command, line string, styles Codeblock) string {
+func styleExample(c *cobra.Command, line string, indent bool, styles Codeblock) string {
if strings.HasPrefix(line, "# ") {
return lipgloss.JoinHorizontal(
lipgloss.Left,
@@ -200,66 +242,110 @@ func styleExample(c *cobra.Command, line string, styles Codeblock) string {
)
}
- args := strings.Fields(line)
- var nextIsFlag bool
var isQuotedString bool
+ var foundProgramName bool
+ var isRedirecting bool
+ programName := c.Root().Name()
+ args := strings.Fields(line)
+ var cleanArgs []string
for i, arg := range args {
- if i == 0 {
- args[i] = styles.Program.Name.Render(arg)
- continue
+ isQuoteStart := arg[0] == '"' || arg[0] == '\''
+ isQuoteEnd := arg[len(arg)-1] == '"' || arg[len(arg)-1] == '\''
+ isFlag := arg[0] == '-'
+
+ switch i {
+ case 0:
+ args[i] = ""
+ if indent {
+ args[i] = styles.Program.DimmedArgument.Render(" ")
+ indent = false
+ }
+ default:
+ args[i] = styles.Program.DimmedArgument.Render(" ")
}
- quoteStart := arg[0] == '"'
- quoteEnd := arg[len(arg)-1] == '"'
- flagStart := arg[0] == '-'
- if i == 1 && !quoteStart && !flagStart {
- args[i] = styles.Program.Command.Render(arg)
+ if isRedirecting {
+ args[i] += styles.Program.DimmedArgument.Render(arg)
+ isRedirecting = false
continue
}
- if quoteStart {
- isQuotedString = true
- }
- if isQuotedString {
- args[i] = styles.Program.QuotedString.Render(arg)
- if quoteEnd {
- isQuotedString = false
+
+ switch arg {
+ case "\\":
+ if i == len(args)-1 {
+ args[i] += styles.Program.DimmedArgument.Render(arg)
+ continue
}
+ case "|", "||", "-", "&", "&&":
+ args[i] += styles.Program.DimmedArgument.Render(arg)
continue
}
- if nextIsFlag {
- args[i] = styles.Program.Flag.Render(arg)
+
+ if isRedirect(arg) {
+ args[i] += styles.Program.DimmedArgument.Render(arg)
+ isRedirecting = true
continue
}
- var dashes string
- if strings.HasPrefix(arg, "-") {
- dashes = "-"
+
+ if !foundProgramName { //nolint:nestif
+ if isQuotedString {
+ args[i] += styles.Program.QuotedString.Render(arg)
+ isQuotedString = !isQuoteEnd
+ continue
+ }
+ if left, right, ok := strings.Cut(arg, "="); ok {
+ args[i] += styles.Program.Flag.Render(left + "=")
+ if right[0] == '"' {
+ isQuotedString = true
+ args[i] += styles.Program.QuotedString.Render(right)
+ continue
+ }
+ args[i] += styles.Program.Argument.Render(right)
+ continue
+ }
+
+ if arg == programName {
+ args[i] += styles.Program.Name.Render(arg)
+ foundProgramName = true
+ continue
+ }
}
- if strings.HasPrefix(arg, "--") {
- dashes = "--"
+
+ if !isQuoteStart && !isQuotedString && !isFlag {
+ cleanArgs = append(cleanArgs, arg)
+ }
+
+ if !isQuoteStart && !isFlag && isSubCommand(c, cleanArgs, arg) {
+ args[i] += styles.Program.Command.Render(arg)
+ continue
+ }
+ isQuotedString = isQuotedString || isQuoteStart
+ if isQuotedString {
+ args[i] += styles.Program.QuotedString.Render(arg)
+ isQuotedString = !isQuoteEnd
+ continue
}
// handle a flag
- if dashes != "" {
+ if isFlag {
name, value, ok := strings.Cut(arg, "=")
- name = strings.TrimPrefix(name, dashes)
// it is --flag=value
if ok {
- args[i] = lipgloss.JoinHorizontal(
+ args[i] += lipgloss.JoinHorizontal(
lipgloss.Left,
- styles.Program.Flag.Render(dashes+name+"="),
- styles.Program.Argument.UnsetPadding().Render(value),
+ styles.Program.Flag.Render(name+"="),
+ styles.Program.Argument.Render(value),
)
continue
}
// it is either --bool-flag or --flag value
- args[i] = lipgloss.JoinHorizontal(
+ args[i] += lipgloss.JoinHorizontal(
lipgloss.Left,
- styles.Program.Flag.Render(dashes+name),
+ styles.Program.Flag.Render(name),
)
- // if the flag is not a bool flag, next arg continues current flag
- nextIsFlag = !isFlagBool(c, name)
continue
}
- args[i] = styles.Program.Argument.Render(arg)
+
+ args[i] += styles.Program.Argument.Render(arg)
}
return lipgloss.JoinHorizontal(
@@ -284,8 +370,7 @@ func evalFlags(c *cobra.Command, styles Styles) (map[string]string, []string) {
} else {
parts = append(
parts,
- styles.Program.Flag.Render("-"+f.Shorthand),
- styles.Program.Flag.Render("--"+f.Name),
+ styles.Program.Flag.Render("-"+f.Shorthand+" --"+f.Name),
)
}
key := lipgloss.JoinHorizontal(lipgloss.Left, parts...)
@@ -303,22 +388,50 @@ func evalFlags(c *cobra.Command, styles Styles) (map[string]string, []string) {
return flags, keys
}
-func evalCmds(c *cobra.Command, styles Styles) (map[string]string, []string) {
+// result is map[groupID]map[styled cmd name]styled cmd help, and the keys in
+// the order they are defined.
+func evalCmds(c *cobra.Command, styles Styles) (map[string](map[string]string), []string) {
padStyle := lipgloss.NewStyle().PaddingLeft(0) //nolint:mnd
keys := []string{}
- cmds := map[string]string{}
+ cmds := map[string]map[string]string{}
for _, sc := range c.Commands() {
if sc.Hidden {
continue
}
+ if _, ok := cmds[sc.GroupID]; !ok {
+ cmds[sc.GroupID] = map[string]string{}
+ }
key := padStyle.Render(styleUsage(sc, styles.Program, false))
help := styles.FlagDescription.Render(sc.Short)
- cmds[key] = help
+ cmds[sc.GroupID][key] = help
keys = append(keys, key)
}
return cmds, keys
}
+func evalGroups(c *cobra.Command) (map[string]string, []string) {
+ // make sure the default group is the first
+ ids := []string{""}
+ groups := map[string]string{"": "commands"}
+ for _, g := range c.Groups() {
+ groups[g.ID] = g.Title
+ ids = append(ids, g.ID)
+ }
+ return groups, ids
+}
+
+func renderGroup(w io.Writer, styles Styles, space int, name string, items iter.Seq2[string, string]) {
+ _, _ = fmt.Fprintln(w, styles.Title.Render(name))
+ for key, help := range items {
+ _, _ = fmt.Fprintln(w, lipgloss.JoinHorizontal(
+ lipgloss.Left,
+ lipgloss.NewStyle().PaddingLeft(longPad).Render(key),
+ strings.Repeat(" ", space-lipgloss.Width(key)),
+ help,
+ ))
+ }
+}
+
func calculateSpace(k1, k2 []string) int {
const spaceBetween = 2
space := minSpace
@@ -328,13 +441,18 @@ func calculateSpace(k1, k2 []string) int {
return space
}
-func isFlagBool(c *cobra.Command, name string) bool {
- flag := c.Flags().Lookup(name)
- if flag == nil && len(name) == 1 {
- flag = c.Flags().ShorthandLookup(name)
- }
- if flag == nil {
- return false
+func isSubCommand(c *cobra.Command, args []string, word string) bool {
+ cmd, _, _ := c.Root().Traverse(args)
+ return cmd != nil && cmd.Name() == word
+}
+
+var redirectPrefixes = []string{">", "<", "&>", "2>", "1>", ">>", "2>>"}
+
+func isRedirect(s string) bool {
+ for _, p := range redirectPrefixes {
+ if strings.HasPrefix(s, p) {
+ return true
+ }
}
- return flag.Value.Type() == "bool"
+ return false
}
@@ -2,10 +2,12 @@ package fang
import (
"image/color"
+ "os"
"strings"
"github.com/charmbracelet/lipgloss/v2"
"github.com/charmbracelet/x/exp/charmtone"
+ "github.com/charmbracelet/x/term"
"golang.org/x/text/cases"
"golang.org/x/text/language"
)
@@ -31,8 +33,14 @@ type ColorScheme struct {
}
// DefaultTheme is the default colorscheme.
+//
+// Deprecated: use [DefaultColorScheme] instead.
func DefaultTheme(isDark bool) ColorScheme {
- c := lipgloss.LightDark(isDark)
+ return DefaultColorScheme(lipgloss.LightDark(isDark))
+}
+
+// DefaultColorScheme is the default colorscheme.
+func DefaultColorScheme(c lipgloss.LightDarkFunc) ColorScheme {
return ColorScheme{
Base: c(charmtone.Charcoal, charmtone.Ash),
Title: charmtone.Charple,
@@ -45,7 +53,7 @@ func DefaultTheme(isDark bool) ColorScheme {
Argument: c(charmtone.Charcoal, charmtone.Ash),
Description: c(charmtone.Charcoal, charmtone.Ash), // flag and command descriptions
FlagDefault: c(charmtone.Smoke, charmtone.Squid), // flag default values in descriptions
- QuotedString: c(charmtone.Charcoal, charmtone.Ash),
+ QuotedString: c(charmtone.Coral, charmtone.Salmon),
ErrorHeader: [2]color.Color{
charmtone.Butter,
charmtone.Cherry,
@@ -53,6 +61,26 @@ func DefaultTheme(isDark bool) ColorScheme {
}
}
+// AnsiColorScheme is a ANSI colorscheme.
+func AnsiColorScheme(c lipgloss.LightDarkFunc) ColorScheme {
+ base := c(lipgloss.Black, lipgloss.White)
+ return ColorScheme{
+ Base: base,
+ Title: lipgloss.Blue,
+ Description: base,
+ Comment: c(lipgloss.BrightWhite, lipgloss.BrightBlack),
+ Flag: lipgloss.Magenta,
+ FlagDefault: lipgloss.BrightMagenta,
+ Command: lipgloss.Cyan,
+ QuotedString: lipgloss.Green,
+ Argument: base,
+ Help: base,
+ Dash: base,
+ ErrorHeader: [2]color.Color{lipgloss.Black, lipgloss.Red},
+ ErrorDetails: lipgloss.Red,
+ }
+}
+
// Styles represents all the styles used.
type Styles struct {
Text lipgloss.Style
@@ -84,6 +112,14 @@ type Program struct {
QuotedString lipgloss.Style
}
+func mustColorscheme(cs func(lipgloss.LightDarkFunc) ColorScheme) ColorScheme {
+ var isDark bool
+ if term.IsTerminal(os.Stdout.Fd()) {
+ isDark = lipgloss.HasDarkBackground(os.Stdin, os.Stdout)
+ }
+ return cs(lipgloss.LightDark(isDark))
+}
+
func makeStyles(cs ColorScheme) Styles {
//nolint:mnd
return Styles{
@@ -98,8 +134,7 @@ func makeStyles(cs ColorScheme) Styles {
Foreground(cs.Description).
Transform(titleFirstWord),
FlagDefault: lipgloss.NewStyle().
- Foreground(cs.FlagDefault).
- PaddingLeft(1),
+ Foreground(cs.FlagDefault),
Codeblock: Codeblock{
Base: lipgloss.NewStyle().
Background(cs.Codeblock).
@@ -116,23 +151,18 @@ func makeStyles(cs ColorScheme) Styles {
Background(cs.Codeblock).
Foreground(cs.Program),
Flag: lipgloss.NewStyle().
- PaddingLeft(1).
Background(cs.Codeblock).
Foreground(cs.Flag),
Argument: lipgloss.NewStyle().
- PaddingLeft(1).
Background(cs.Codeblock).
Foreground(cs.Argument),
DimmedArgument: lipgloss.NewStyle().
- PaddingLeft(1).
Background(cs.Codeblock).
Foreground(cs.DimmedArgument),
Command: lipgloss.NewStyle().
- PaddingLeft(1).
Background(cs.Codeblock).
Foreground(cs.Command),
QuotedString: lipgloss.NewStyle().
- PaddingLeft(1).
Background(cs.Codeblock).
Foreground(cs.QuotedString),
},
@@ -141,18 +171,14 @@ func makeStyles(cs ColorScheme) Styles {
Name: lipgloss.NewStyle().
Foreground(cs.Program),
Argument: lipgloss.NewStyle().
- PaddingLeft(1).
Foreground(cs.Argument),
DimmedArgument: lipgloss.NewStyle().
- PaddingLeft(1).
Foreground(cs.DimmedArgument),
Flag: lipgloss.NewStyle().
- PaddingLeft(1).
Foreground(cs.Flag),
Command: lipgloss.NewStyle().
Foreground(cs.Command),
QuotedString: lipgloss.NewStyle().
- PaddingLeft(1).
Foreground(cs.QuotedString),
},
Span: lipgloss.NewStyle().
@@ -22,6 +22,7 @@ type Client struct {
requestID atomic.Int64
clientCapabilities mcp.ClientCapabilities
serverCapabilities mcp.ServerCapabilities
+ samplingHandler SamplingHandler
}
type ClientOption func(*Client)
@@ -33,6 +34,21 @@ func WithClientCapabilities(capabilities mcp.ClientCapabilities) ClientOption {
}
}
+// WithSamplingHandler sets the sampling handler for the client.
+// When set, the client will declare sampling capability during initialization.
+func WithSamplingHandler(handler SamplingHandler) ClientOption {
+ return func(c *Client) {
+ c.samplingHandler = handler
+ }
+}
+
+// WithSession assumes a MCP Session has already been initialized
+func WithSession() ClientOption {
+ return func(c *Client) {
+ c.initialized = true
+ }
+}
+
// NewClient creates a new MCP client with the given transport.
// Usage:
//
@@ -71,6 +87,12 @@ func (c *Client) Start(ctx context.Context) error {
handler(notification)
}
})
+
+ // Set up request handler for bidirectional communication (e.g., sampling)
+ if bidirectional, ok := c.transport.(transport.BidirectionalInterface); ok {
+ bidirectional.SetRequestHandler(c.handleIncomingRequest)
+ }
+
return nil
}
@@ -127,6 +149,12 @@ func (c *Client) Initialize(
ctx context.Context,
request mcp.InitializeRequest,
) (*mcp.InitializeResult, error) {
+ // Merge client capabilities with sampling capability if handler is configured
+ capabilities := request.Params.Capabilities
+ if c.samplingHandler != nil {
+ capabilities.Sampling = &struct{}{}
+ }
+
// Ensure we send a params object with all required fields
params := struct {
ProtocolVersion string `json:"protocolVersion"`
@@ -135,7 +163,7 @@ func (c *Client) Initialize(
}{
ProtocolVersion: request.Params.ProtocolVersion,
ClientInfo: request.Params.ClientInfo,
- Capabilities: request.Params.Capabilities, // Will be empty struct if not set
+ Capabilities: capabilities,
}
response, err := c.sendRequest(ctx, "initialize", params)
@@ -398,6 +426,64 @@ func (c *Client) Complete(
return &result, nil
}
+// handleIncomingRequest processes incoming requests from the server.
+// This is the main entry point for server-to-client requests like sampling.
+func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
+ switch request.Method {
+ case string(mcp.MethodSamplingCreateMessage):
+ return c.handleSamplingRequestTransport(ctx, request)
+ default:
+ return nil, fmt.Errorf("unsupported request method: %s", request.Method)
+ }
+}
+
+// handleSamplingRequestTransport handles sampling requests at the transport level.
+func (c *Client) handleSamplingRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
+ if c.samplingHandler == nil {
+ return nil, fmt.Errorf("no sampling handler configured")
+ }
+
+ // Parse the request parameters
+ var params mcp.CreateMessageParams
+ if request.Params != nil {
+ paramsBytes, err := json.Marshal(request.Params)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal params: %w", err)
+ }
+ if err := json.Unmarshal(paramsBytes, ¶ms); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal params: %w", err)
+ }
+ }
+
+ // Create the MCP request
+ mcpRequest := mcp.CreateMessageRequest{
+ Request: mcp.Request{
+ Method: string(mcp.MethodSamplingCreateMessage),
+ },
+ CreateMessageParams: params,
+ }
+
+ // Call the sampling handler
+ result, err := c.samplingHandler.CreateMessage(ctx, mcpRequest)
+ if err != nil {
+ return nil, err
+ }
+
+ // Marshal the result
+ resultBytes, err := json.Marshal(result)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal result: %w", err)
+ }
+
+ // Create the transport response
+ response := &transport.JSONRPCResponse{
+ JSONRPC: mcp.JSONRPC_VERSION,
+ ID: request.ID,
+ Result: json.RawMessage(resultBytes),
+ }
+
+ return response, nil
+}
func listByPage[T any](
ctx context.Context,
client *Client,
@@ -432,3 +518,17 @@ func (c *Client) GetServerCapabilities() mcp.ServerCapabilities {
func (c *Client) GetClientCapabilities() mcp.ClientCapabilities {
return c.clientCapabilities
}
+
+// GetSessionId returns the session ID of the transport.
+// If the transport does not support sessions, it returns an empty string.
+func (c *Client) GetSessionId() string {
+ if c.transport == nil {
+ return ""
+ }
+ return c.transport.GetSessionId()
+}
+
+// IsInitialized returns true if the client has been initialized.
+func (c *Client) IsInitialized() bool {
+ return c.initialized
+}
@@ -13,5 +13,10 @@ func NewStreamableHttpClient(baseURL string, options ...transport.StreamableHTTP
if err != nil {
return nil, fmt.Errorf("failed to create SSE transport: %w", err)
}
- return NewClient(trans), nil
+ clientOptions := make([]ClientOption, 0)
+ sessionID := trans.GetSessionId()
+ if sessionID != "" {
+ clientOptions = append(clientOptions, WithSession())
+ }
+ return NewClient(trans, clientOptions...), nil
}
@@ -0,0 +1,20 @@
+package client
+
+import (
+ "context"
+
+ "github.com/mark3labs/mcp-go/mcp"
+)
+
+// SamplingHandler defines the interface for handling sampling requests from servers.
+// Clients can implement this interface to provide LLM sampling capabilities to servers.
+type SamplingHandler interface {
+ // CreateMessage handles a sampling request from the server and returns the generated message.
+ // The implementation should:
+ // 1. Validate the request parameters
+ // 2. Optionally prompt the user for approval (human-in-the-loop)
+ // 3. Select an appropriate model based on preferences
+ // 4. Generate the response using the selected model
+ // 5. Return the result with model information and stop reason
+ CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error)
+}
@@ -19,10 +19,26 @@ func NewStdioMCPClient(
env []string,
args ...string,
) (*Client, error) {
+ return NewStdioMCPClientWithOptions(command, env, args)
+}
+
+// NewStdioMCPClientWithOptions creates a new stdio-based MCP client that communicates with a subprocess.
+// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication.
+// Optional configuration functions can be provided to customize the transport before it starts,
+// such as setting a custom command function.
+//
+// NOTICE: NewStdioMCPClientWithOptions automatically starts the underlying transport.
+// Don't call the Start method manually.
+// This is for backward compatibility.
+func NewStdioMCPClientWithOptions(
+ command string,
+ env []string,
+ args []string,
+ opts ...transport.StdioOption,
+) (*Client, error) {
+ stdioTransport := transport.NewStdioWithOptions(command, env, args, opts...)
- stdioTransport := transport.NewStdio(command, env, args...)
- err := stdioTransport.Start(context.Background())
- if err != nil {
+ if err := stdioTransport.Start(context.Background()); err != nil {
return nil, fmt.Errorf("failed to start stdio transport: %w", err)
}
@@ -68,3 +68,7 @@ func (c *InProcessTransport) SetNotificationHandler(handler func(notification mc
func (*InProcessTransport) Close() error {
return nil
}
+
+func (c *InProcessTransport) GetSessionId() string {
+ return ""
+}
@@ -29,6 +29,22 @@ type Interface interface {
// Close the connection.
Close() error
+
+ // GetSessionId returns the session ID of the transport.
+ GetSessionId() string
+}
+
+// RequestHandler defines a function that handles incoming requests from the server.
+type RequestHandler func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error)
+
+// BidirectionalInterface extends Interface to support incoming requests from the server.
+// This is used for features like sampling where the server can send requests to the client.
+type BidirectionalInterface interface {
+ Interface
+
+ // SetRequestHandler sets the handler for incoming requests from the server.
+ // The handler should process the request and return a response.
+ SetRequestHandler(handler RequestHandler)
}
type JSONRPCRequest struct {
@@ -41,10 +57,10 @@ type JSONRPCRequest struct {
type JSONRPCResponse struct {
JSONRPC string `json:"jsonrpc"`
ID mcp.RequestId `json:"id"`
- Result json.RawMessage `json:"result"`
+ Result json.RawMessage `json:"result,omitempty"`
Error *struct {
Code int `json:"code"`
Message string `json:"message"`
Data json.RawMessage `json:"data"`
- } `json:"error"`
+ } `json:"error,omitempty"`
}
@@ -428,6 +428,12 @@ func (c *SSE) Close() error {
return nil
}
+// GetSessionId returns the session ID of the transport.
+// Since SSE does not maintain a session ID, it returns an empty string.
+func (c *SSE) GetSessionId() string {
+ return ""
+}
+
// SendNotification sends a JSON-RPC notification to the server without expecting a response.
func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
if c.endpoint == nil {
@@ -23,6 +23,7 @@ type Stdio struct {
env []string
cmd *exec.Cmd
+ cmdFunc CommandFunc
stdin io.WriteCloser
stdout *bufio.Reader
stderr io.ReadCloser
@@ -31,6 +32,28 @@ type Stdio struct {
done chan struct{}
onNotification func(mcp.JSONRPCNotification)
notifyMu sync.RWMutex
+ onRequest RequestHandler
+ requestMu sync.RWMutex
+ ctx context.Context
+ ctxMu sync.RWMutex
+}
+
+// StdioOption defines a function that configures a Stdio transport instance.
+// Options can be used to customize the behavior of the transport before it starts,
+// such as setting a custom command function.
+type StdioOption func(*Stdio)
+
+// CommandFunc is a factory function that returns a custom exec.Cmd used to launch the MCP subprocess.
+// It can be used to apply sandboxing, custom environment control, working directories, etc.
+type CommandFunc func(ctx context.Context, command string, env []string, args []string) (*exec.Cmd, error)
+
+// WithCommandFunc sets a custom command factory function for the stdio transport.
+// The CommandFunc is responsible for constructing the exec.Cmd used to launch the subprocess,
+// allowing control over attributes like environment, working directory, and system-level sandboxing.
+func WithCommandFunc(f CommandFunc) StdioOption {
+ return func(s *Stdio) {
+ s.cmdFunc = f
+ }
}
// NewIO returns a new stdio-based transport using existing input, output, and
@@ -44,6 +67,7 @@ func NewIO(input io.Reader, output io.WriteCloser, logging io.ReadCloser) *Stdio
responses: make(map[string]chan *JSONRPCResponse),
done: make(chan struct{}),
+ ctx: context.Background(),
}
}
@@ -55,20 +79,43 @@ func NewStdio(
env []string,
args ...string,
) *Stdio {
+ return NewStdioWithOptions(command, env, args)
+}
- client := &Stdio{
+// NewStdioWithOptions creates a new stdio transport to communicate with a subprocess.
+// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication.
+// Returns an error if the subprocess cannot be started or the pipes cannot be created.
+// Optional configuration functions can be provided to customize the transport before it starts,
+// such as setting a custom command factory.
+func NewStdioWithOptions(
+ command string,
+ env []string,
+ args []string,
+ opts ...StdioOption,
+) *Stdio {
+ s := &Stdio{
command: command,
args: args,
env: env,
responses: make(map[string]chan *JSONRPCResponse),
done: make(chan struct{}),
+ ctx: context.Background(),
+ }
+
+ for _, opt := range opts {
+ opt(s)
}
- return client
+ return s
}
func (c *Stdio) Start(ctx context.Context) error {
+ // Store the context for use in request handling
+ c.ctxMu.Lock()
+ c.ctx = ctx
+ c.ctxMu.Unlock()
+
if err := c.spawnCommand(ctx); err != nil {
return err
}
@@ -83,18 +130,25 @@ func (c *Stdio) Start(ctx context.Context) error {
return nil
}
-// spawnCommand spawns a new process running c.command.
+// spawnCommand spawns a new process running the configured command, args, and env.
+// If an (optional) cmdFunc custom command factory function was configured, it will be used to construct the subprocess;
+// otherwise, the default behavior uses exec.CommandContext with the merged environment.
+// Initializes stdin, stdout, and stderr pipes for JSON-RPC communication.
func (c *Stdio) spawnCommand(ctx context.Context) error {
if c.command == "" {
return nil
}
- cmd := exec.CommandContext(ctx, c.command, c.args...)
-
- mergedEnv := os.Environ()
- mergedEnv = append(mergedEnv, c.env...)
+ var cmd *exec.Cmd
+ var err error
- cmd.Env = mergedEnv
+ // Standard behavior if no command func present.
+ if c.cmdFunc == nil {
+ cmd = exec.CommandContext(ctx, c.command, c.args...)
+ cmd.Env = append(os.Environ(), c.env...)
+ } else if cmd, err = c.cmdFunc(ctx, c.command, c.env, c.args); err != nil {
+ return err
+ }
stdin, err := cmd.StdinPipe()
if err != nil {
@@ -148,6 +202,12 @@ func (c *Stdio) Close() error {
return nil
}
+// GetSessionId returns the session ID of the transport.
+// Since stdio does not maintain a session ID, it returns an empty string.
+func (c *Stdio) GetSessionId() string {
+ return ""
+}
+
// SetNotificationHandler sets the handler function to be called when a notification is received.
// Only one handler can be set at a time; setting a new one replaces the previous handler.
func (c *Stdio) SetNotificationHandler(
@@ -158,6 +218,14 @@ func (c *Stdio) SetNotificationHandler(
c.onNotification = handler
}
+// SetRequestHandler sets the handler function to be called when a request is received from the server.
+// This enables bidirectional communication for features like sampling.
+func (c *Stdio) SetRequestHandler(handler RequestHandler) {
+ c.requestMu.Lock()
+ defer c.requestMu.Unlock()
+ c.onRequest = handler
+}
+
// readResponses continuously reads and processes responses from the server's stdout.
// It handles both responses to requests and notifications, routing them appropriately.
// Runs until the done channel is closed or an error occurs reading from stdout.
@@ -175,13 +243,18 @@ func (c *Stdio) readResponses() {
return
}
- var baseMessage JSONRPCResponse
+ // First try to parse as a generic message to check for ID field
+ var baseMessage struct {
+ JSONRPC string `json:"jsonrpc"`
+ ID *mcp.RequestId `json:"id,omitempty"`
+ Method string `json:"method,omitempty"`
+ }
if err := json.Unmarshal([]byte(line), &baseMessage); err != nil {
continue
}
- // Handle notification
- if baseMessage.ID.IsNil() {
+ // If it has a method but no ID, it's a notification
+ if baseMessage.Method != "" && baseMessage.ID == nil {
var notification mcp.JSONRPCNotification
if err := json.Unmarshal([]byte(line), ¬ification); err != nil {
continue
@@ -194,15 +267,30 @@ func (c *Stdio) readResponses() {
continue
}
+ // If it has a method and an ID, it's an incoming request
+ if baseMessage.Method != "" && baseMessage.ID != nil {
+ var request JSONRPCRequest
+ if err := json.Unmarshal([]byte(line), &request); err == nil {
+ c.handleIncomingRequest(request)
+ continue
+ }
+ }
+
+ // Otherwise, it's a response to our request
+ var response JSONRPCResponse
+ if err := json.Unmarshal([]byte(line), &response); err != nil {
+ continue
+ }
+
// Create string key for map lookup
- idKey := baseMessage.ID.String()
+ idKey := response.ID.String()
c.mu.RLock()
ch, exists := c.responses[idKey]
c.mu.RUnlock()
if exists {
- ch <- &baseMessage
+ ch <- &response
c.mu.Lock()
delete(c.responses, idKey)
c.mu.Unlock()
@@ -281,6 +369,96 @@ func (c *Stdio) SendNotification(
return nil
}
+// handleIncomingRequest processes incoming requests from the server.
+// It calls the registered request handler and sends the response back to the server.
+func (c *Stdio) handleIncomingRequest(request JSONRPCRequest) {
+ c.requestMu.RLock()
+ handler := c.onRequest
+ c.requestMu.RUnlock()
+
+ if handler == nil {
+ // Send error response if no handler is configured
+ errorResponse := JSONRPCResponse{
+ JSONRPC: mcp.JSONRPC_VERSION,
+ ID: request.ID,
+ Error: &struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Data json.RawMessage `json:"data"`
+ }{
+ Code: mcp.METHOD_NOT_FOUND,
+ Message: "No request handler configured",
+ },
+ }
+ c.sendResponse(errorResponse)
+ return
+ }
+
+ // Handle the request in a goroutine to avoid blocking
+ go func() {
+ c.ctxMu.RLock()
+ ctx := c.ctx
+ c.ctxMu.RUnlock()
+
+ // Check if context is already cancelled before processing
+ select {
+ case <-ctx.Done():
+ errorResponse := JSONRPCResponse{
+ JSONRPC: mcp.JSONRPC_VERSION,
+ ID: request.ID,
+ Error: &struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Data json.RawMessage `json:"data"`
+ }{
+ Code: mcp.INTERNAL_ERROR,
+ Message: ctx.Err().Error(),
+ },
+ }
+ c.sendResponse(errorResponse)
+ return
+ default:
+ }
+
+ response, err := handler(ctx, request)
+
+ if err != nil {
+ errorResponse := JSONRPCResponse{
+ JSONRPC: mcp.JSONRPC_VERSION,
+ ID: request.ID,
+ Error: &struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Data json.RawMessage `json:"data"`
+ }{
+ Code: mcp.INTERNAL_ERROR,
+ Message: err.Error(),
+ },
+ }
+ c.sendResponse(errorResponse)
+ return
+ }
+
+ if response != nil {
+ c.sendResponse(*response)
+ }
+ }()
+}
+
+// sendResponse sends a response back to the server.
+func (c *Stdio) sendResponse(response JSONRPCResponse) {
+ responseBytes, err := json.Marshal(response)
+ if err != nil {
+ fmt.Printf("Error marshaling response: %v\n", err)
+ return
+ }
+ responseBytes = append(responseBytes, '\n')
+
+ if _, err := c.stdin.Write(responseBytes); err != nil {
+ fmt.Printf("Error writing response: %v\n", err)
+ }
+}
+
// Stderr returns a reader for the stderr output of the subprocess.
// This can be used to capture error messages or logs from the subprocess.
func (c *Stdio) Stderr() io.Reader {
@@ -17,10 +17,24 @@ import (
"time"
"github.com/mark3labs/mcp-go/mcp"
+ "github.com/mark3labs/mcp-go/util"
)
type StreamableHTTPCOption func(*StreamableHTTP)
+// WithContinuousListening enables receiving server-to-client notifications when no request is in flight.
+// In particular, if you want to receive global notifications from the server (like ToolListChangedNotification),
+// you should enable this option.
+//
+// It will establish a standalone long-live GET HTTP connection to the server.
+// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server
+// NOTICE: Even enabled, the server may not support this feature.
+func WithContinuousListening() StreamableHTTPCOption {
+ return func(sc *StreamableHTTP) {
+ sc.getListeningEnabled = true
+ }
+}
+
// WithHTTPClient sets a custom HTTP client on the StreamableHTTP transport.
func WithHTTPBasicClient(client *http.Client) StreamableHTTPCOption {
return func(sc *StreamableHTTP) {
@@ -54,6 +68,19 @@ func WithHTTPOAuth(config OAuthConfig) StreamableHTTPCOption {
}
}
+func WithLogger(logger util.Logger) StreamableHTTPCOption {
+ return func(sc *StreamableHTTP) {
+ sc.logger = logger
+ }
+}
+
+// WithSession creates a client with a pre-configured session
+func WithSession(sessionID string) StreamableHTTPCOption {
+ return func(sc *StreamableHTTP) {
+ sc.sessionID.Store(sessionID)
+ }
+}
+
// StreamableHTTP implements Streamable HTTP transport.
//
// It transmits JSON-RPC messages over individual HTTP requests. One message per request.
@@ -64,19 +91,22 @@ func WithHTTPOAuth(config OAuthConfig) StreamableHTTPCOption {
//
// The current implementation does not support the following features:
// - batching
-// - continuously listening for server notifications when no request is in flight
-// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server)
// - resuming stream
// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery)
// - server -> client request
type StreamableHTTP struct {
- serverURL *url.URL
- httpClient *http.Client
- headers map[string]string
- headerFunc HTTPHeaderFunc
+ serverURL *url.URL
+ httpClient *http.Client
+ headers map[string]string
+ headerFunc HTTPHeaderFunc
+ logger util.Logger
+ getListeningEnabled bool
sessionID atomic.Value // string
+ initialized chan struct{}
+ initializedOnce sync.Once
+
notificationHandler func(mcp.JSONRPCNotification)
notifyMu sync.RWMutex
@@ -95,15 +125,19 @@ func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*Str
}
smc := &StreamableHTTP{
- serverURL: parsedURL,
- httpClient: &http.Client{},
- headers: make(map[string]string),
- closed: make(chan struct{}),
+ serverURL: parsedURL,
+ httpClient: &http.Client{},
+ headers: make(map[string]string),
+ closed: make(chan struct{}),
+ logger: util.DefaultLogger(),
+ initialized: make(chan struct{}),
}
smc.sessionID.Store("") // set initial value to simplify later usage
for _, opt := range options {
- opt(smc)
+ if opt != nil {
+ opt(smc)
+ }
}
// If OAuth is configured, set the base URL for metadata discovery
@@ -118,7 +152,20 @@ func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*Str
// Start initiates the HTTP connection to the server.
func (c *StreamableHTTP) Start(ctx context.Context) error {
- // For Streamable HTTP, we don't need to establish a persistent connection
+ // For Streamable HTTP, we don't need to establish a persistent connection by default
+ if c.getListeningEnabled {
+ go func() {
+ select {
+ case <-c.initialized:
+ ctx, cancel := c.contextAwareOfClientClose(ctx)
+ defer cancel()
+ c.listenForever(ctx)
+ case <-c.closed:
+ return
+ }
+ }()
+ }
+
return nil
}
@@ -142,13 +189,13 @@ func (c *StreamableHTTP) Close() error {
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.serverURL.String(), nil)
if err != nil {
- fmt.Printf("failed to create close request\n: %v", err)
+ c.logger.Errorf("failed to create close request: %v", err)
return
}
req.Header.Set(headerKeySessionID, sessionId)
res, err := c.httpClient.Do(req)
if err != nil {
- fmt.Printf("failed to send close request\n: %v", err)
+ c.logger.Errorf("failed to send close request: %v", err)
return
}
res.Body.Close()
@@ -185,77 +232,29 @@ func (c *StreamableHTTP) SendRequest(
request JSONRPCRequest,
) (*JSONRPCResponse, error) {
- // Create a combined context that could be canceled when the client is closed
- newCtx, cancel := context.WithCancel(ctx)
- defer cancel()
- go func() {
- select {
- case <-c.closed:
- cancel()
- case <-newCtx.Done():
- // The original context was canceled, no need to do anything
- }
- }()
- ctx = newCtx
-
// Marshal request
requestBody, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
- // Create HTTP request
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody))
- if err != nil {
- return nil, fmt.Errorf("failed to create request: %w", err)
- }
-
- // Set headers
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Accept", "application/json, text/event-stream")
- sessionID := c.sessionID.Load()
- if sessionID != "" {
- req.Header.Set(headerKeySessionID, sessionID.(string))
- }
- for k, v := range c.headers {
- req.Header.Set(k, v)
- }
-
- // Add OAuth authorization if configured
- if c.oauthHandler != nil {
- authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
- if err != nil {
- // If we get an authorization error, return a specific error that can be handled by the client
- if err.Error() == "no valid token available, authorization required" {
- return nil, &OAuthAuthorizationRequiredError{
- Handler: c.oauthHandler,
- }
- }
- return nil, fmt.Errorf("failed to get authorization header: %w", err)
- }
- req.Header.Set("Authorization", authHeader)
- }
-
- if c.headerFunc != nil {
- for k, v := range c.headerFunc(ctx) {
- req.Header.Set(k, v)
- }
- }
+ ctx, cancel := c.contextAwareOfClientClose(ctx)
+ defer cancel()
- // Send request
- resp, err := c.httpClient.Do(req)
+ resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream")
if err != nil {
- return nil, fmt.Errorf("failed to send request: %w", err)
+ if errors.Is(err, ErrSessionTerminated) && request.Method == string(mcp.MethodInitialize) {
+ // If the request is initialize, should not return a SessionTerminated error
+ // It should be a genuine endpoint-routing issue.
+ // ( Fall through to return StatusCode checking. )
+ } else {
+ return nil, fmt.Errorf("failed to send request: %w", err)
+ }
}
defer resp.Body.Close()
// Check if we got an error response
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
- // handle session closed
- if resp.StatusCode == http.StatusNotFound {
- c.sessionID.CompareAndSwap(sessionID, "")
- return nil, fmt.Errorf("session terminated (404). need to re-initialize")
- }
// Handle OAuth unauthorized error
if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
@@ -279,6 +278,10 @@ func (c *StreamableHTTP) SendRequest(
if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" {
c.sessionID.Store(sessionID)
}
+
+ c.initializedOnce.Do(func() {
+ close(c.initialized)
+ })
}
// Handle different response types
@@ -300,16 +303,77 @@ func (c *StreamableHTTP) SendRequest(
case "text/event-stream":
// Server is using SSE for streaming responses
- return c.handleSSEResponse(ctx, resp.Body)
+ return c.handleSSEResponse(ctx, resp.Body, false)
default:
return nil, fmt.Errorf("unexpected content type: %s", resp.Header.Get("Content-Type"))
}
}
+func (c *StreamableHTTP) sendHTTP(
+ ctx context.Context,
+ method string,
+ body io.Reader,
+ acceptType string,
+) (resp *http.Response, err error) {
+
+ // Create HTTP request
+ req, err := http.NewRequestWithContext(ctx, method, c.serverURL.String(), body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create request: %w", err)
+ }
+
+ // Set headers
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", acceptType)
+ sessionID := c.sessionID.Load().(string)
+ if sessionID != "" {
+ req.Header.Set(headerKeySessionID, sessionID)
+ }
+ for k, v := range c.headers {
+ req.Header.Set(k, v)
+ }
+
+ // Add OAuth authorization if configured
+ if c.oauthHandler != nil {
+ authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
+ if err != nil {
+ // If we get an authorization error, return a specific error that can be handled by the client
+ if err.Error() == "no valid token available, authorization required" {
+ return nil, &OAuthAuthorizationRequiredError{
+ Handler: c.oauthHandler,
+ }
+ }
+ return nil, fmt.Errorf("failed to get authorization header: %w", err)
+ }
+ req.Header.Set("Authorization", authHeader)
+ }
+
+ if c.headerFunc != nil {
+ for k, v := range c.headerFunc(ctx) {
+ req.Header.Set(k, v)
+ }
+ }
+
+ // Send request
+ resp, err = c.httpClient.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("failed to send request: %w", err)
+ }
+
+ // universal handling for session terminated
+ if resp.StatusCode == http.StatusNotFound {
+ c.sessionID.CompareAndSwap(sessionID, "")
+ return nil, ErrSessionTerminated
+ }
+
+ return resp, nil
+}
+
// handleSSEResponse processes an SSE stream for a specific request.
// It returns the final result for the request once received, or an error.
-func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser) (*JSONRPCResponse, error) {
+// If ignoreResponse is true, it won't return when a response messge is received. This is for continuous listening.
+func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser, ignoreResponse bool) (*JSONRPCResponse, error) {
// Create a channel for this specific request
responseChan := make(chan *JSONRPCResponse, 1)
@@ -328,7 +392,7 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
var message JSONRPCResponse
if err := json.Unmarshal([]byte(data), &message); err != nil {
- fmt.Printf("failed to unmarshal message: %v\n", err)
+ c.logger.Errorf("failed to unmarshal message: %v", err)
return
}
@@ -336,7 +400,7 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
if message.ID.IsNil() {
var notification mcp.JSONRPCNotification
if err := json.Unmarshal([]byte(data), ¬ification); err != nil {
- fmt.Printf("failed to unmarshal notification: %v\n", err)
+ c.logger.Errorf("failed to unmarshal notification: %v", err)
return
}
c.notifyMu.RLock()
@@ -347,7 +411,9 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
return
}
- responseChan <- &message
+ if !ignoreResponse {
+ responseChan <- &message
+ }
})
}()
@@ -393,7 +459,7 @@ func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, hand
case <-ctx.Done():
return
default:
- fmt.Printf("SSE stream error: %v\n", err)
+ c.logger.Errorf("SSE stream error: %v", err)
return
}
}
@@ -432,44 +498,10 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp.
}
// Create HTTP request
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody))
- if err != nil {
- return fmt.Errorf("failed to create request: %w", err)
- }
-
- // Set headers
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Accept", "application/json, text/event-stream")
- if sessionID := c.sessionID.Load(); sessionID != "" {
- req.Header.Set(headerKeySessionID, sessionID.(string))
- }
- for k, v := range c.headers {
- req.Header.Set(k, v)
- }
-
- // Add OAuth authorization if configured
- if c.oauthHandler != nil {
- authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
- if err != nil {
- // If we get an authorization error, return a specific error that can be handled by the client
- if errors.Is(err, ErrOAuthAuthorizationRequired) {
- return &OAuthAuthorizationRequiredError{
- Handler: c.oauthHandler,
- }
- }
- return fmt.Errorf("failed to get authorization header: %w", err)
- }
- req.Header.Set("Authorization", authHeader)
- }
-
- if c.headerFunc != nil {
- for k, v := range c.headerFunc(ctx) {
- req.Header.Set(k, v)
- }
- }
+ ctx, cancel := c.contextAwareOfClientClose(ctx)
+ defer cancel()
- // Send request
- resp, err := c.httpClient.Do(req)
+ resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream")
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
}
@@ -513,3 +545,84 @@ func (c *StreamableHTTP) GetOAuthHandler() *OAuthHandler {
func (c *StreamableHTTP) IsOAuthEnabled() bool {
return c.oauthHandler != nil
}
+
+func (c *StreamableHTTP) listenForever(ctx context.Context) {
+ c.logger.Infof("listening to server forever")
+ for {
+ err := c.createGETConnectionToServer(ctx)
+ if errors.Is(err, ErrGetMethodNotAllowed) {
+ // server does not support listening
+ c.logger.Errorf("server does not support listening")
+ return
+ }
+
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ }
+
+ if err != nil {
+ c.logger.Errorf("failed to listen to server. retry in 1 second: %v", err)
+ }
+ time.Sleep(retryInterval)
+ }
+}
+
+var (
+ ErrSessionTerminated = fmt.Errorf("session terminated (404). need to re-initialize")
+ ErrGetMethodNotAllowed = fmt.Errorf("GET method not allowed")
+
+ retryInterval = 1 * time.Second // a variable is convenient for testing
+)
+
+func (c *StreamableHTTP) createGETConnectionToServer(ctx context.Context) error {
+
+ resp, err := c.sendHTTP(ctx, http.MethodGet, nil, "text/event-stream")
+ if err != nil {
+ return fmt.Errorf("failed to send request: %w", err)
+ }
+ defer resp.Body.Close()
+
+ // Check if we got an error response
+ if resp.StatusCode == http.StatusMethodNotAllowed {
+ return ErrGetMethodNotAllowed
+ }
+
+ if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
+ body, _ := io.ReadAll(resp.Body)
+ return fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body)
+ }
+
+ // handle SSE response
+ contentType := resp.Header.Get("Content-Type")
+ if contentType != "text/event-stream" {
+ return fmt.Errorf("unexpected content type: %s", contentType)
+ }
+
+ // When ignoreResponse is true, the function will never return expect context is done.
+ // NOTICE: Due to the ambiguity of the specification, other SDKs may use the GET connection to transfer the response
+ // messages. To be more compatible, we should handle this response, however, as the transport layer is message-based,
+ // currently, there is no convenient way to handle this response.
+ // So we ignore the response here. It's not a bug, but may be not compatible with other SDKs.
+ _, err = c.handleSSEResponse(ctx, resp.Body, true)
+ if err != nil {
+ return fmt.Errorf("failed to handle SSE response: %w", err)
+ }
+
+ return nil
+}
+
+func (c *StreamableHTTP) contextAwareOfClientClose(ctx context.Context) (context.Context, context.CancelFunc) {
+ newCtx, cancel := context.WithCancel(ctx)
+ go func() {
+ select {
+ case <-c.closed:
+ cancel()
+ case <-newCtx.Done():
+ // The original context was canceled
+ cancel()
+ }
+ }()
+ return newCtx, cancel
+}
@@ -945,7 +945,20 @@ func PropertyNames(schema map[string]any) PropertyOption {
}
}
-// Items defines the schema for array items
+// Items defines the schema for array items.
+// Accepts any schema definition for maximum flexibility.
+//
+// Example:
+//
+// Items(map[string]any{
+// "type": "object",
+// "properties": map[string]any{
+// "name": map[string]any{"type": "string"},
+// "age": map[string]any{"type": "number"},
+// },
+// })
+//
+// For simple types, use ItemsString(), ItemsNumber(), ItemsBoolean() instead.
func Items(schema any) PropertyOption {
return func(schemaMap map[string]any) {
schemaMap["items"] = schema
@@ -972,3 +985,94 @@ func UniqueItems(unique bool) PropertyOption {
schema["uniqueItems"] = unique
}
}
+
+// WithStringItems configures an array's items to be of type string.
+//
+// Supported options: Description(), DefaultString(), Enum(), MaxLength(), MinLength(), Pattern()
+// Note: Options like Required() are not valid for item schemas and will be ignored.
+//
+// Examples:
+//
+// mcp.WithArray("tags", mcp.WithStringItems())
+// mcp.WithArray("colors", mcp.WithStringItems(mcp.Enum("red", "green", "blue")))
+// mcp.WithArray("names", mcp.WithStringItems(mcp.MinLength(1), mcp.MaxLength(50)))
+//
+// Limitations: Only supports simple string arrays. Use Items() for complex objects.
+func WithStringItems(opts ...PropertyOption) PropertyOption {
+ return func(schema map[string]any) {
+ itemSchema := map[string]any{
+ "type": "string",
+ }
+
+ for _, opt := range opts {
+ opt(itemSchema)
+ }
+
+ schema["items"] = itemSchema
+ }
+}
+
+// WithStringEnumItems configures an array's items to be of type string with a specified enum.
+// Example:
+//
+// mcp.WithArray("priority", mcp.WithStringEnumItems([]string{"low", "medium", "high"}))
+//
+// Limitations: Only supports string enums. Use WithStringItems(Enum(...)) for more flexibility.
+func WithStringEnumItems(values []string) PropertyOption {
+ return func(schema map[string]any) {
+ schema["items"] = map[string]any{
+ "type": "string",
+ "enum": values,
+ }
+ }
+}
+
+// WithNumberItems configures an array's items to be of type number.
+//
+// Supported options: Description(), DefaultNumber(), Min(), Max(), MultipleOf()
+// Note: Options like Required() are not valid for item schemas and will be ignored.
+//
+// Examples:
+//
+// mcp.WithArray("scores", mcp.WithNumberItems(mcp.Min(0), mcp.Max(100)))
+// mcp.WithArray("prices", mcp.WithNumberItems(mcp.Min(0)))
+//
+// Limitations: Only supports simple number arrays. Use Items() for complex objects.
+func WithNumberItems(opts ...PropertyOption) PropertyOption {
+ return func(schema map[string]any) {
+ itemSchema := map[string]any{
+ "type": "number",
+ }
+
+ for _, opt := range opts {
+ opt(itemSchema)
+ }
+
+ schema["items"] = itemSchema
+ }
+}
+
+// WithBooleanItems configures an array's items to be of type boolean.
+//
+// Supported options: Description(), DefaultBool()
+// Note: Options like Required() are not valid for item schemas and will be ignored.
+//
+// Examples:
+//
+// mcp.WithArray("flags", mcp.WithBooleanItems())
+// mcp.WithArray("permissions", mcp.WithBooleanItems(mcp.Description("User permissions")))
+//
+// Limitations: Only supports simple boolean arrays. Use Items() for complex objects.
+func WithBooleanItems(opts ...PropertyOption) PropertyOption {
+ return func(schema map[string]any) {
+ itemSchema := map[string]any{
+ "type": "boolean",
+ }
+
+ for _, opt := range opts {
+ opt(itemSchema)
+ }
+
+ schema["items"] = itemSchema
+ }
+}
@@ -763,6 +763,11 @@ const (
/* Sampling */
+const (
+ // MethodSamplingCreateMessage allows servers to request LLM completions from clients
+ MethodSamplingCreateMessage MCPMethod = "sampling/createMessage"
+)
+
// CreateMessageRequest is a request from the server to sample an LLM via the
// client. The client has full discretion over which model to select. The client
// should also inform the user before beginning sampling, to allow them to inspect
@@ -865,6 +870,22 @@ type AudioContent struct {
func (AudioContent) isContent() {}
+// ResourceLink represents a link to a resource that the client can access.
+type ResourceLink struct {
+ Annotated
+ Type string `json:"type"` // Must be "resource_link"
+ // The URI of the resource.
+ URI string `json:"uri"`
+ // The name of the resource.
+ Name string `json:"name"`
+ // The description of the resource.
+ Description string `json:"description"`
+ // The MIME type of the resource.
+ MIMEType string `json:"mimeType"`
+}
+
+func (ResourceLink) isContent() {}
+
// EmbeddedResource represents the contents of a resource, embedded into a prompt or tool call result.
//
// It is up to the client how best to render embedded resources for the
@@ -222,6 +222,17 @@ func NewAudioContent(data, mimeType string) AudioContent {
}
}
+// Helper function to create a new ResourceLink
+func NewResourceLink(uri, name, description, mimeType string) ResourceLink {
+ return ResourceLink{
+ Type: "resource_link",
+ URI: uri,
+ Name: name,
+ Description: description,
+ MIMEType: mimeType,
+ }
+}
+
// Helper function to create a new EmbeddedResource
func NewEmbeddedResource(resource ResourceContents) EmbeddedResource {
return EmbeddedResource{
@@ -476,6 +487,16 @@ func ParseContent(contentMap map[string]any) (Content, error) {
}
return NewAudioContent(data, mimeType), nil
+ case "resource_link":
+ uri := ExtractString(contentMap, "uri")
+ name := ExtractString(contentMap, "name")
+ description := ExtractString(contentMap, "description")
+ mimeType := ExtractString(contentMap, "mimeType")
+ if uri == "" || name == "" {
+ return nil, fmt.Errorf("resource_link uri or name is missing")
+ }
+ return NewResourceLink(uri, name, description, mimeType), nil
+
case "resource":
resourceMap := ExtractMap(contentMap, "resource")
if resourceMap == nil {
@@ -0,0 +1,37 @@
+package server
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/mark3labs/mcp-go/mcp"
+)
+
+// EnableSampling enables sampling capabilities for the server.
+// This allows the server to send sampling requests to clients that support it.
+func (s *MCPServer) EnableSampling() {
+ s.capabilitiesMu.Lock()
+ defer s.capabilitiesMu.Unlock()
+}
+
+// RequestSampling sends a sampling request to the client.
+// The client must have declared sampling capability during initialization.
+func (s *MCPServer) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
+ session := ClientSessionFromContext(ctx)
+ if session == nil {
+ return nil, fmt.Errorf("no active session")
+ }
+
+ // Check if the session supports sampling requests
+ if samplingSession, ok := session.(SessionWithSampling); ok {
+ return samplingSession.RequestSampling(ctx, request)
+ }
+
+ return nil, fmt.Errorf("session does not support sampling")
+}
+
+// SessionWithSampling extends ClientSession to support sampling requests.
+type SessionWithSampling interface {
+ ClientSession
+ RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error)
+}
@@ -9,6 +9,7 @@ import (
"log"
"os"
"os/signal"
+ "sync"
"sync/atomic"
"syscall"
@@ -51,10 +52,21 @@ func WithStdioContextFunc(fn StdioContextFunc) StdioOption {
// stdioSession is a static client session, since stdio has only one client.
type stdioSession struct {
- notifications chan mcp.JSONRPCNotification
- initialized atomic.Bool
- loggingLevel atomic.Value
- clientInfo atomic.Value // stores session-specific client info
+ notifications chan mcp.JSONRPCNotification
+ initialized atomic.Bool
+ loggingLevel atomic.Value
+ clientInfo atomic.Value // stores session-specific client info
+ writer io.Writer // for sending requests to client
+ requestID atomic.Int64 // for generating unique request IDs
+ mu sync.RWMutex // protects writer
+ pendingRequests map[int64]chan *samplingResponse // for tracking pending sampling requests
+ pendingMu sync.RWMutex // protects pendingRequests
+}
+
+// samplingResponse represents a response to a sampling request
+type samplingResponse struct {
+ result *mcp.CreateMessageResult
+ err error
}
func (s *stdioSession) SessionID() string {
@@ -100,14 +112,86 @@ func (s *stdioSession) GetLogLevel() mcp.LoggingLevel {
return level.(mcp.LoggingLevel)
}
+// RequestSampling sends a sampling request to the client and waits for the response.
+func (s *stdioSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
+ s.mu.RLock()
+ writer := s.writer
+ s.mu.RUnlock()
+
+ if writer == nil {
+ return nil, fmt.Errorf("no writer available for sending requests")
+ }
+
+ // Generate a unique request ID
+ id := s.requestID.Add(1)
+
+ // Create a response channel for this request
+ responseChan := make(chan *samplingResponse, 1)
+ s.pendingMu.Lock()
+ s.pendingRequests[id] = responseChan
+ s.pendingMu.Unlock()
+
+ // Cleanup function to remove the pending request
+ cleanup := func() {
+ s.pendingMu.Lock()
+ delete(s.pendingRequests, id)
+ s.pendingMu.Unlock()
+ }
+ defer cleanup()
+
+ // Create the JSON-RPC request
+ jsonRPCRequest := struct {
+ JSONRPC string `json:"jsonrpc"`
+ ID int64 `json:"id"`
+ Method string `json:"method"`
+ Params mcp.CreateMessageParams `json:"params"`
+ }{
+ JSONRPC: mcp.JSONRPC_VERSION,
+ ID: id,
+ Method: string(mcp.MethodSamplingCreateMessage),
+ Params: request.CreateMessageParams,
+ }
+
+ // Marshal and send the request
+ requestBytes, err := json.Marshal(jsonRPCRequest)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal sampling request: %w", err)
+ }
+ requestBytes = append(requestBytes, '\n')
+
+ if _, err := writer.Write(requestBytes); err != nil {
+ return nil, fmt.Errorf("failed to write sampling request: %w", err)
+ }
+
+ // Wait for the response or context cancellation
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case response := <-responseChan:
+ if response.err != nil {
+ return nil, response.err
+ }
+ return response.result, nil
+ }
+}
+
+// SetWriter sets the writer for sending requests to the client.
+func (s *stdioSession) SetWriter(writer io.Writer) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.writer = writer
+}
+
var (
_ ClientSession = (*stdioSession)(nil)
_ SessionWithLogging = (*stdioSession)(nil)
_ SessionWithClientInfo = (*stdioSession)(nil)
+ _ SessionWithSampling = (*stdioSession)(nil)
)
var stdioSessionInstance = stdioSession{
- notifications: make(chan mcp.JSONRPCNotification, 100),
+ notifications: make(chan mcp.JSONRPCNotification, 100),
+ pendingRequests: make(map[int64]chan *samplingResponse),
}
// NewStdioServer creates a new stdio server wrapper around an MCPServer.
@@ -224,6 +308,9 @@ func (s *StdioServer) Listen(
defer s.server.UnregisterSession(ctx, stdioSessionInstance.SessionID())
ctx = s.server.WithContext(ctx, &stdioSessionInstance)
+ // Set the writer for sending requests to the client
+ stdioSessionInstance.SetWriter(stdout)
+
// Add in any custom context.
if s.contextFunc != nil {
ctx = s.contextFunc(ctx)
@@ -256,7 +343,29 @@ func (s *StdioServer) processMessage(
return s.writeResponse(response, writer)
}
- // Handle the message using the wrapped server
+ // Check if this is a response to a sampling request
+ if s.handleSamplingResponse(rawMessage) {
+ return nil
+ }
+
+ // Check if this is a tool call that might need sampling (and thus should be processed concurrently)
+ var baseMessage struct {
+ Method string `json:"method"`
+ }
+ if json.Unmarshal(rawMessage, &baseMessage) == nil && baseMessage.Method == "tools/call" {
+ // Process tool calls concurrently to avoid blocking on sampling requests
+ go func() {
+ response := s.server.HandleMessage(ctx, rawMessage)
+ if response != nil {
+ if err := s.writeResponse(response, writer); err != nil {
+ s.errLogger.Printf("Error writing tool response: %v", err)
+ }
+ }
+ }()
+ return nil
+ }
+
+ // Handle other messages synchronously
response := s.server.HandleMessage(ctx, rawMessage)
// Only write response if there is one (not for notifications)
@@ -269,6 +378,65 @@ func (s *StdioServer) processMessage(
return nil
}
+// handleSamplingResponse checks if the message is a response to a sampling request
+// and routes it to the appropriate pending request channel.
+func (s *StdioServer) handleSamplingResponse(rawMessage json.RawMessage) bool {
+ return stdioSessionInstance.handleSamplingResponse(rawMessage)
+}
+
+// handleSamplingResponse handles incoming sampling responses for this session
+func (s *stdioSession) handleSamplingResponse(rawMessage json.RawMessage) bool {
+ // Try to parse as a JSON-RPC response
+ var response struct {
+ JSONRPC string `json:"jsonrpc"`
+ ID json.Number `json:"id"`
+ Result json.RawMessage `json:"result,omitempty"`
+ Error *struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ } `json:"error,omitempty"`
+ }
+
+ if err := json.Unmarshal(rawMessage, &response); err != nil {
+ return false
+ }
+ // Parse the ID as int64
+ idInt64, err := response.ID.Int64()
+ if err != nil || (response.Result == nil && response.Error == nil) {
+ return false
+ }
+
+ // Look for a pending request with this ID
+ s.pendingMu.RLock()
+ responseChan, exists := s.pendingRequests[idInt64]
+ s.pendingMu.RUnlock()
+
+ if !exists {
+ return false
+ } // Parse and send the response
+ samplingResp := &samplingResponse{}
+
+ if response.Error != nil {
+ samplingResp.err = fmt.Errorf("sampling request failed: %s", response.Error.Message)
+ } else {
+ var result mcp.CreateMessageResult
+ if err := json.Unmarshal(response.Result, &result); err != nil {
+ samplingResp.err = fmt.Errorf("failed to unmarshal sampling response: %w", err)
+ } else {
+ samplingResp.result = &result
+ }
+ }
+
+ // Send the response (non-blocking)
+ select {
+ case responseChan <- samplingResp:
+ default:
+ // Channel is full or closed, ignore
+ }
+
+ return true
+}
+
// writeResponse marshals and writes a JSON-RPC response message followed by a newline.
// Returns an error if marshaling or writing fails.
func (s *StdioServer) writeResponse(
@@ -40,7 +40,9 @@ func WithEndpointPath(endpointPath string) StreamableHTTPOption {
// to StatelessSessionIdManager.
func WithStateLess(stateLess bool) StreamableHTTPOption {
return func(s *StreamableHTTPServer) {
- s.sessionIdManager = &StatelessSessionIdManager{}
+ if stateLess {
+ s.sessionIdManager = &StatelessSessionIdManager{}
+ }
}
}
@@ -374,7 +376,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
- w.WriteHeader(http.StatusAccepted)
+ w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
if !ok {
@@ -260,8 +260,8 @@ github.com/charmbracelet/bubbletea/v2
# github.com/charmbracelet/colorprofile v0.3.1
## explicit; go 1.23.0
github.com/charmbracelet/colorprofile
-# github.com/charmbracelet/fang v0.1.0
-## explicit; go 1.23.0
+# github.com/charmbracelet/fang v0.3.1-0.20250711140230-d5ebb8c1d674
+## explicit; go 1.24.0
github.com/charmbracelet/fang
# github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe
## explicit; go 1.23.0
@@ -269,7 +269,7 @@ github.com/charmbracelet/glamour/v2
github.com/charmbracelet/glamour/v2/ansi
github.com/charmbracelet/glamour/v2/internal/autolink
github.com/charmbracelet/glamour/v2/styles
-# github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.2.0.20250703152125-8e1c474f8a71 => github.com/charmbracelet/lipgloss-internal/v2 v2.0.0-20250710185058-03664cb9cecb
+# github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.3 => github.com/charmbracelet/lipgloss-internal/v2 v2.0.0-20250710185058-03664cb9cecb
## explicit; go 1.24.2
github.com/charmbracelet/lipgloss/v2
github.com/charmbracelet/lipgloss/v2/table
@@ -403,7 +403,7 @@ github.com/kylelemons/godebug/pretty
# github.com/lucasb-eyer/go-colorful v1.2.0
## explicit; go 1.12
github.com/lucasb-eyer/go-colorful
-# github.com/mark3labs/mcp-go v0.32.0
+# github.com/mark3labs/mcp-go v0.33.0
## explicit; go 1.23
github.com/mark3labs/mcp-go/client
github.com/mark3labs/mcp-go/client/transport