From 237dbde92cc514b2b247887c8fd29b0c9cb774aa Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Thu, 10 Jul 2025 15:25:52 -0400 Subject: [PATCH 01/30] fix(tui): editor: position completions correctly and account for padding --- internal/tui/components/chat/editor/editor.go | 5 +++-- internal/tui/components/completions/completions.go | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/internal/tui/components/chat/editor/editor.go b/internal/tui/components/chat/editor/editor.go index 2185715c813dbdcb288bddde0fe70d63046cf731..67ba67f5e6c40f16a89f7bc4fe1b6932c9989754 100644 --- a/internal/tui/components/chat/editor/editor.go +++ b/internal/tui/components/chat/editor/editor.go @@ -361,8 +361,9 @@ func (m *editorCmp) startCompletions() tea.Msg { }) } - x := m.textarea.Cursor().X + m.x + 1 - y := m.textarea.Cursor().Y + m.y + 1 + cur := m.textarea.Cursor() + x := cur.X + m.x // adjust for padding + y := cur.Y + m.y + 1 return completions.OpenCompletionsMsg{ Completions: completionItems, X: x, diff --git a/internal/tui/components/completions/completions.go b/internal/tui/components/completions/completions.go index 29ea86365e9f1532eab3aa1a61214ef74b7f4a05..6409f4bc96f59349046891d872994393550c65ba 100644 --- a/internal/tui/components/completions/completions.go +++ b/internal/tui/components/completions/completions.go @@ -43,7 +43,7 @@ type Completions interface { type completionsCmp struct { width int height int // Height of the completions component` - x int // X position for the completions popup\ + x int // X position for the completions popup y int // Y position for the completions popup open bool // Indicates if the completions are open keyMap KeyMap From 28be21fd629cfc221127f75bde92de32d09fe85c Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Thu, 10 Jul 2025 23:42:08 +0200 Subject: [PATCH 02/30] chore: implement correct banned commands --- internal/llm/tools/bash.go | 236 +++++++++++++++++++++++++++++++- internal/llm/tools/bash_test.go | 165 ++++++++++++++++++++++ 2 files changed, 395 insertions(+), 6 deletions(-) diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index 0a10568a39315f6c4077385b8ca83f6b3e52691c..ea736e81bc05a67d1b42bb5927f537b06f4ada5f 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -10,6 +10,7 @@ import ( "github.com/charmbracelet/crush/internal/permission" "github.com/charmbracelet/crush/internal/shell" + "mvdan.cc/sh/v3/syntax" ) type BashParams struct { @@ -40,10 +41,235 @@ const ( BashNoOutput = "no output" ) +func containsBannedCommand(node syntax.Node) bool { + if node == nil { + return false + } + + switch n := node.(type) { + case *syntax.CallExpr: + if len(n.Args) > 0 { + cmdName := getWordValue(n.Args[0]) + for _, banned := range bannedCommands { + if strings.EqualFold(cmdName, banned) { + return true + } + } + } + for _, arg := range n.Args { + if containsBannedCommand(arg) { + return true + } + } + case *syntax.Word: + if checkWordForBannedCommands(n) { + return true + } + for _, part := range n.Parts { + if containsBannedCommand(part) { + return true + } + } + case *syntax.CmdSubst: + for _, stmt := range n.Stmts { + if containsBannedCommand(stmt) { + return true + } + } + case *syntax.Subshell: + for _, stmt := range n.Stmts { + if containsBannedCommand(stmt) { + return true + } + } + case *syntax.Stmt: + if containsBannedCommand(n.Cmd) { + return true + } + for _, redir := range n.Redirs { + if containsBannedCommand(redir) { + return true + } + } + case *syntax.BinaryCmd: + return containsBannedCommand(n.X) || containsBannedCommand(n.Y) + case *syntax.Block: + for _, stmt := range n.Stmts { + if containsBannedCommand(stmt) { + return true + } + } + case *syntax.IfClause: + for _, stmt := range n.Cond { + if containsBannedCommand(stmt) { + return true + } + } + for _, stmt := range n.Then { + if containsBannedCommand(stmt) { + return true + } + } + if n.Else != nil && containsBannedCommand(n.Else) { + return true + } + case *syntax.WhileClause: + for _, stmt := range n.Cond { + if containsBannedCommand(stmt) { + return true + } + } + for _, stmt := range n.Do { + if containsBannedCommand(stmt) { + return true + } + } + case *syntax.ForClause: + for _, stmt := range n.Do { + if containsBannedCommand(stmt) { + return true + } + } + if containsBannedCommand(n.Loop) { + return true + } + case *syntax.CaseClause: + for _, item := range n.Items { + for _, stmt := range item.Stmts { + if containsBannedCommand(stmt) { + return true + } + } + } + case *syntax.FuncDecl: + return containsBannedCommand(n.Body) + case *syntax.ArithmExp: + return containsBannedCommand(n.X) + case *syntax.Redirect: + return containsBannedCommand(n.Word) + } + return false +} + +func checkWordForBannedCommands(word *syntax.Word) bool { + if word == nil { + return false + } + + for _, part := range word.Parts { + switch p := part.(type) { + case *syntax.SglQuoted: + if checkQuotedStringForBannedCommands(p.Value) { + return true + } + case *syntax.DblQuoted: + var content strings.Builder + for _, qpart := range p.Parts { + if lit, ok := qpart.(*syntax.Lit); ok { + content.WriteString(lit.Value) + } + } + if checkQuotedStringForBannedCommands(content.String()) { + return true + } + } + } + return false +} + +func checkQuotedStringForBannedCommands(content string) bool { + parser := syntax.NewParser() + file, err := parser.Parse(strings.NewReader(content), "") + if err != nil { + return false + } + + if len(file.Stmts) == 0 { + return false + } + + // Simple heuristic: if it looks like prose rather than commands, don't flag it + if len(file.Stmts) == 1 { + stmt := file.Stmts[0] + if callExpr, ok := stmt.Cmd.(*syntax.CallExpr); ok { + if len(callExpr.Args) > 2 { + allText := true + for i, arg := range callExpr.Args { + if i == 0 { + continue + } + argStr := getWordValue(arg) + if strings.HasPrefix(argStr, "-") { + allText = false + break + } + } + if allText { + return false + } + } + } + } + + for _, stmt := range file.Stmts { + if containsBannedCommand(stmt) { + return true + } + } + return false +} + +func getWordValue(word *syntax.Word) string { + if word == nil || len(word.Parts) == 0 { + return "" + } + + var result strings.Builder + for _, part := range word.Parts { + switch p := part.(type) { + case *syntax.Lit: + result.WriteString(p.Value) + case *syntax.SglQuoted: + result.WriteString(p.Value) + case *syntax.DblQuoted: + for _, qpart := range p.Parts { + if lit, ok := qpart.(*syntax.Lit); ok { + result.WriteString(lit.Value) + } + } + } + } + return result.String() +} + +func validateCommand(command string) error { + parser := syntax.NewParser() + file, err := parser.Parse(strings.NewReader(command), "") + if err != nil { + parts := strings.Fields(command) + if len(parts) > 0 { + baseCmd := parts[0] + for _, banned := range bannedCommands { + if strings.EqualFold(baseCmd, banned) { + return fmt.Errorf("command '%s' is not allowed", baseCmd) + } + } + } + return nil + } + + for _, stmt := range file.Stmts { + if containsBannedCommand(stmt) { + return fmt.Errorf("command contains banned operations") + } + } + return nil +} + var bannedCommands = []string{ "alias", "curl", "curlie", "wget", "axel", "aria2c", "nc", "telnet", "lynx", "w3m", "links", "httpie", "xh", - "http-prompt", "chrome", "firefox", "safari", + "http-prompt", "chrome", "firefox", "safari", "sudo", } // getSafeReadOnlyCommands returns platform-appropriate safe commands @@ -289,11 +515,9 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return NewTextErrorResponse("missing command"), nil } - baseCmd := strings.Fields(params.Command)[0] - for _, banned := range bannedCommands { - if strings.EqualFold(baseCmd, banned) { - return NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", baseCmd)), nil - } + + if err := validateCommand(params.Command); err != nil { + return NewTextErrorResponse(err.Error()), nil } isSafeReadOnly := false diff --git a/internal/llm/tools/bash_test.go b/internal/llm/tools/bash_test.go index a810002749408af2bb89cb958b5999dc2da3bcb3..768a3738dd2fe162838f32ee498ac14ed2ba9eee 100644 --- a/internal/llm/tools/bash_test.go +++ b/internal/llm/tools/bash_test.go @@ -94,3 +94,168 @@ func TestPlatformSpecificSafeCommands(t *testing.T) { } } } + +func TestValidateCommand(t *testing.T) { + tests := []struct { + name string + command string + shouldError bool + }{ + // Commands that should be blocked + { + name: "direct sudo", + command: "sudo ls", + shouldError: true, + }, + { + name: "sudo in script", + command: "bash -c 'sudo ls'", + shouldError: true, + }, + { + name: "sudo in command substitution", + command: "$(sudo whoami)", + shouldError: true, + }, + { + name: "sudo in echo command substitution", + command: "echo $(sudo id)", + shouldError: true, + }, + { + name: "sudo in command chain", + command: "ls && sudo rm file", + shouldError: true, + }, + { + name: "sudo in if statement", + command: "if true; then sudo ls; fi", + shouldError: true, + }, + { + name: "sudo in for loop", + command: "for i in 1; do sudo echo $i; done", + shouldError: true, + }, + { + name: "direct curl", + command: "curl http://example.com", + shouldError: true, + }, + { + name: "curl in script", + command: "bash -c 'curl malicious.com'", + shouldError: true, + }, + { + name: "wget command", + command: "wget http://example.com", + shouldError: true, + }, + { + name: "nc command", + command: "nc -l 8080", + shouldError: true, + }, + // Commands that should be allowed + { + name: "simple ls", + command: "ls -la", + shouldError: false, + }, + { + name: "echo command", + command: "echo hello", + shouldError: false, + }, + { + name: "git status", + command: "git status", + shouldError: false, + }, + { + name: "go build", + command: "go build", + shouldError: false, + }, + { + name: "sudo as literal text", + command: "echo 'sudo is just text here'", + shouldError: false, + }, + { + name: "complex allowed command", + command: "find . -name '*.go' | head -10", + shouldError: false, + }, + { + name: "command with environment variables", + command: "FOO=bar go test", + shouldError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateCommand(tt.command) + if tt.shouldError && err == nil { + t.Errorf("Expected error for command %q, but got none", tt.command) + } + if !tt.shouldError && err != nil { + t.Errorf("Expected no error for command %q, but got: %v", tt.command, err) + } + }) + } +} + +func TestContainsBannedCommand(t *testing.T) { + // Test the helper functions directly with some edge cases + tests := []struct { + name string + command string + shouldError bool + }{ + { + name: "nested command substitution", + command: "echo $(echo $(sudo id))", + shouldError: true, + }, + { + name: "subshell with banned command", + command: "(sudo ls)", + shouldError: true, + }, + { + name: "case statement with banned command", + command: "case $1 in start) sudo systemctl start service ;; esac", + shouldError: true, + }, + { + name: "while loop with banned command", + command: "while true; do sudo echo test; done", + shouldError: true, + }, + { + name: "function with banned command", + command: "function test() { sudo ls; }", + shouldError: true, + }, + { + name: "complex valid command", + command: "if [ -f file ]; then echo exists; else echo missing; fi", + shouldError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateCommand(tt.command) + if tt.shouldError && err == nil { + t.Errorf("Expected error for command %q, but got none", tt.command) + } + if !tt.shouldError && err != nil { + t.Errorf("Expected no error for command %q, but got: %v", tt.command, err) + } + }) + } +} From de3b81532ae3c21551ca84cd3004acc54291739c Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 11 Jul 2025 00:00:44 +0200 Subject: [PATCH 03/30] Revert "chore: implement correct banned commands" This reverts commit 28be21fd629cfc221127f75bde92de32d09fe85c. --- internal/llm/tools/bash.go | 236 +------------------------------- internal/llm/tools/bash_test.go | 165 ---------------------- 2 files changed, 6 insertions(+), 395 deletions(-) diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index ea736e81bc05a67d1b42bb5927f537b06f4ada5f..0a10568a39315f6c4077385b8ca83f6b3e52691c 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -10,7 +10,6 @@ import ( "github.com/charmbracelet/crush/internal/permission" "github.com/charmbracelet/crush/internal/shell" - "mvdan.cc/sh/v3/syntax" ) type BashParams struct { @@ -41,235 +40,10 @@ const ( BashNoOutput = "no output" ) -func containsBannedCommand(node syntax.Node) bool { - if node == nil { - return false - } - - switch n := node.(type) { - case *syntax.CallExpr: - if len(n.Args) > 0 { - cmdName := getWordValue(n.Args[0]) - for _, banned := range bannedCommands { - if strings.EqualFold(cmdName, banned) { - return true - } - } - } - for _, arg := range n.Args { - if containsBannedCommand(arg) { - return true - } - } - case *syntax.Word: - if checkWordForBannedCommands(n) { - return true - } - for _, part := range n.Parts { - if containsBannedCommand(part) { - return true - } - } - case *syntax.CmdSubst: - for _, stmt := range n.Stmts { - if containsBannedCommand(stmt) { - return true - } - } - case *syntax.Subshell: - for _, stmt := range n.Stmts { - if containsBannedCommand(stmt) { - return true - } - } - case *syntax.Stmt: - if containsBannedCommand(n.Cmd) { - return true - } - for _, redir := range n.Redirs { - if containsBannedCommand(redir) { - return true - } - } - case *syntax.BinaryCmd: - return containsBannedCommand(n.X) || containsBannedCommand(n.Y) - case *syntax.Block: - for _, stmt := range n.Stmts { - if containsBannedCommand(stmt) { - return true - } - } - case *syntax.IfClause: - for _, stmt := range n.Cond { - if containsBannedCommand(stmt) { - return true - } - } - for _, stmt := range n.Then { - if containsBannedCommand(stmt) { - return true - } - } - if n.Else != nil && containsBannedCommand(n.Else) { - return true - } - case *syntax.WhileClause: - for _, stmt := range n.Cond { - if containsBannedCommand(stmt) { - return true - } - } - for _, stmt := range n.Do { - if containsBannedCommand(stmt) { - return true - } - } - case *syntax.ForClause: - for _, stmt := range n.Do { - if containsBannedCommand(stmt) { - return true - } - } - if containsBannedCommand(n.Loop) { - return true - } - case *syntax.CaseClause: - for _, item := range n.Items { - for _, stmt := range item.Stmts { - if containsBannedCommand(stmt) { - return true - } - } - } - case *syntax.FuncDecl: - return containsBannedCommand(n.Body) - case *syntax.ArithmExp: - return containsBannedCommand(n.X) - case *syntax.Redirect: - return containsBannedCommand(n.Word) - } - return false -} - -func checkWordForBannedCommands(word *syntax.Word) bool { - if word == nil { - return false - } - - for _, part := range word.Parts { - switch p := part.(type) { - case *syntax.SglQuoted: - if checkQuotedStringForBannedCommands(p.Value) { - return true - } - case *syntax.DblQuoted: - var content strings.Builder - for _, qpart := range p.Parts { - if lit, ok := qpart.(*syntax.Lit); ok { - content.WriteString(lit.Value) - } - } - if checkQuotedStringForBannedCommands(content.String()) { - return true - } - } - } - return false -} - -func checkQuotedStringForBannedCommands(content string) bool { - parser := syntax.NewParser() - file, err := parser.Parse(strings.NewReader(content), "") - if err != nil { - return false - } - - if len(file.Stmts) == 0 { - return false - } - - // Simple heuristic: if it looks like prose rather than commands, don't flag it - if len(file.Stmts) == 1 { - stmt := file.Stmts[0] - if callExpr, ok := stmt.Cmd.(*syntax.CallExpr); ok { - if len(callExpr.Args) > 2 { - allText := true - for i, arg := range callExpr.Args { - if i == 0 { - continue - } - argStr := getWordValue(arg) - if strings.HasPrefix(argStr, "-") { - allText = false - break - } - } - if allText { - return false - } - } - } - } - - for _, stmt := range file.Stmts { - if containsBannedCommand(stmt) { - return true - } - } - return false -} - -func getWordValue(word *syntax.Word) string { - if word == nil || len(word.Parts) == 0 { - return "" - } - - var result strings.Builder - for _, part := range word.Parts { - switch p := part.(type) { - case *syntax.Lit: - result.WriteString(p.Value) - case *syntax.SglQuoted: - result.WriteString(p.Value) - case *syntax.DblQuoted: - for _, qpart := range p.Parts { - if lit, ok := qpart.(*syntax.Lit); ok { - result.WriteString(lit.Value) - } - } - } - } - return result.String() -} - -func validateCommand(command string) error { - parser := syntax.NewParser() - file, err := parser.Parse(strings.NewReader(command), "") - if err != nil { - parts := strings.Fields(command) - if len(parts) > 0 { - baseCmd := parts[0] - for _, banned := range bannedCommands { - if strings.EqualFold(baseCmd, banned) { - return fmt.Errorf("command '%s' is not allowed", baseCmd) - } - } - } - return nil - } - - for _, stmt := range file.Stmts { - if containsBannedCommand(stmt) { - return fmt.Errorf("command contains banned operations") - } - } - return nil -} - var bannedCommands = []string{ "alias", "curl", "curlie", "wget", "axel", "aria2c", "nc", "telnet", "lynx", "w3m", "links", "httpie", "xh", - "http-prompt", "chrome", "firefox", "safari", "sudo", + "http-prompt", "chrome", "firefox", "safari", } // getSafeReadOnlyCommands returns platform-appropriate safe commands @@ -515,9 +289,11 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return NewTextErrorResponse("missing command"), nil } - - if err := validateCommand(params.Command); err != nil { - return NewTextErrorResponse(err.Error()), nil + baseCmd := strings.Fields(params.Command)[0] + for _, banned := range bannedCommands { + if strings.EqualFold(baseCmd, banned) { + return NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", baseCmd)), nil + } } isSafeReadOnly := false diff --git a/internal/llm/tools/bash_test.go b/internal/llm/tools/bash_test.go index 768a3738dd2fe162838f32ee498ac14ed2ba9eee..a810002749408af2bb89cb958b5999dc2da3bcb3 100644 --- a/internal/llm/tools/bash_test.go +++ b/internal/llm/tools/bash_test.go @@ -94,168 +94,3 @@ func TestPlatformSpecificSafeCommands(t *testing.T) { } } } - -func TestValidateCommand(t *testing.T) { - tests := []struct { - name string - command string - shouldError bool - }{ - // Commands that should be blocked - { - name: "direct sudo", - command: "sudo ls", - shouldError: true, - }, - { - name: "sudo in script", - command: "bash -c 'sudo ls'", - shouldError: true, - }, - { - name: "sudo in command substitution", - command: "$(sudo whoami)", - shouldError: true, - }, - { - name: "sudo in echo command substitution", - command: "echo $(sudo id)", - shouldError: true, - }, - { - name: "sudo in command chain", - command: "ls && sudo rm file", - shouldError: true, - }, - { - name: "sudo in if statement", - command: "if true; then sudo ls; fi", - shouldError: true, - }, - { - name: "sudo in for loop", - command: "for i in 1; do sudo echo $i; done", - shouldError: true, - }, - { - name: "direct curl", - command: "curl http://example.com", - shouldError: true, - }, - { - name: "curl in script", - command: "bash -c 'curl malicious.com'", - shouldError: true, - }, - { - name: "wget command", - command: "wget http://example.com", - shouldError: true, - }, - { - name: "nc command", - command: "nc -l 8080", - shouldError: true, - }, - // Commands that should be allowed - { - name: "simple ls", - command: "ls -la", - shouldError: false, - }, - { - name: "echo command", - command: "echo hello", - shouldError: false, - }, - { - name: "git status", - command: "git status", - shouldError: false, - }, - { - name: "go build", - command: "go build", - shouldError: false, - }, - { - name: "sudo as literal text", - command: "echo 'sudo is just text here'", - shouldError: false, - }, - { - name: "complex allowed command", - command: "find . -name '*.go' | head -10", - shouldError: false, - }, - { - name: "command with environment variables", - command: "FOO=bar go test", - shouldError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := validateCommand(tt.command) - if tt.shouldError && err == nil { - t.Errorf("Expected error for command %q, but got none", tt.command) - } - if !tt.shouldError && err != nil { - t.Errorf("Expected no error for command %q, but got: %v", tt.command, err) - } - }) - } -} - -func TestContainsBannedCommand(t *testing.T) { - // Test the helper functions directly with some edge cases - tests := []struct { - name string - command string - shouldError bool - }{ - { - name: "nested command substitution", - command: "echo $(echo $(sudo id))", - shouldError: true, - }, - { - name: "subshell with banned command", - command: "(sudo ls)", - shouldError: true, - }, - { - name: "case statement with banned command", - command: "case $1 in start) sudo systemctl start service ;; esac", - shouldError: true, - }, - { - name: "while loop with banned command", - command: "while true; do sudo echo test; done", - shouldError: true, - }, - { - name: "function with banned command", - command: "function test() { sudo ls; }", - shouldError: true, - }, - { - name: "complex valid command", - command: "if [ -f file ]; then echo exists; else echo missing; fi", - shouldError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := validateCommand(tt.command) - if tt.shouldError && err == nil { - t.Errorf("Expected error for command %q, but got none", tt.command) - } - if !tt.shouldError && err != nil { - t.Errorf("Expected no error for command %q, but got: %v", tt.command, err) - } - }) - } -} From f99f50427a16e36701abf1e965dd1c605c42557f Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 11 Jul 2025 00:21:19 +0200 Subject: [PATCH 04/30] refactor: improve command blocking system and fix test isolation - Simplify command blocking logic by using utility functions instead of complex closures - Add sudo to banned commands list - Move command blocking from bash tool to shell layer for better separation of concerns - Add comprehensive tests for command blocking functionality - Fix test isolation by using temporary directories to prevent npm package files from polluting source tree - Remove redundant command validation logic from bash tool --- internal/llm/tools/bash.go | 24 ++++-- internal/shell/command_block_test.go | 123 +++++++++++++++++++++++++++ internal/shell/shell.go | 81 ++++++++++++++++-- 3 files changed, 213 insertions(+), 15 deletions(-) create mode 100644 internal/shell/command_block_test.go diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index 0a10568a39315f6c4077385b8ca83f6b3e52691c..2ae1c2956c46fef6f2cc1f0a8a20114bbb8785c1 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -43,7 +43,7 @@ const ( var bannedCommands = []string{ "alias", "curl", "curlie", "wget", "axel", "aria2c", "nc", "telnet", "lynx", "w3m", "links", "httpie", "xh", - "http-prompt", "chrome", "firefox", "safari", + "http-prompt", "chrome", "firefox", "safari", "sudo", } // getSafeReadOnlyCommands returns platform-appropriate safe commands @@ -244,7 +244,22 @@ Important: - Never update git config`, bannedCommandsStr, MaxOutputLength) } +func createCommandBlockFuncs() []shell.CommandBlockFunc { + return []shell.CommandBlockFunc{ + shell.CreateSimpleCommandBlocker(bannedCommands), + shell.CreateSubCommandBlocker([][]string{ + {"brew", "install"}, + {"npm", "install", "-g"}, + {"npm", "install", "--global"}, + }), + } +} + func NewBashTool(permission permission.Service, workingDir string) BaseTool { + // Set up command blocking on the persistent shell + persistentShell := shell.GetPersistentShell(workingDir) + persistentShell.SetBlockFuncs(createCommandBlockFuncs()) + return &bashTool{ permissions: permission, workingDir: workingDir, @@ -289,13 +304,6 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return NewTextErrorResponse("missing command"), nil } - baseCmd := strings.Fields(params.Command)[0] - for _, banned := range bannedCommands { - if strings.EqualFold(baseCmd, banned) { - return NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", baseCmd)), nil - } - } - isSafeReadOnly := false cmdLower := strings.ToLower(params.Command) diff --git a/internal/shell/command_block_test.go b/internal/shell/command_block_test.go new file mode 100644 index 0000000000000000000000000000000000000000..85971748f882cf79fba3ea86d2682ce6ce4f252d --- /dev/null +++ b/internal/shell/command_block_test.go @@ -0,0 +1,123 @@ +package shell + +import ( + "context" + "os" + "strings" + "testing" +) + +func TestCommandBlocking(t *testing.T) { + tests := []struct { + name string + blockFuncs []CommandBlockFunc + command string + shouldBlock bool + }{ + { + name: "block simple command", + blockFuncs: []CommandBlockFunc{ + func(args []string) bool { + return len(args) > 0 && args[0] == "curl" + }, + }, + command: "curl https://example.com", + shouldBlock: true, + }, + { + name: "allow non-blocked command", + blockFuncs: []CommandBlockFunc{ + func(args []string) bool { + return len(args) > 0 && args[0] == "curl" + }, + }, + command: "echo hello", + shouldBlock: false, + }, + { + name: "block subcommand", + blockFuncs: []CommandBlockFunc{ + func(args []string) bool { + return len(args) >= 2 && args[0] == "brew" && args[1] == "install" + }, + }, + command: "brew install wget", + shouldBlock: true, + }, + { + name: "allow different subcommand", + blockFuncs: []CommandBlockFunc{ + func(args []string) bool { + return len(args) >= 2 && args[0] == "brew" && args[1] == "install" + }, + }, + command: "brew list", + shouldBlock: false, + }, + { + name: "block npm global install with -g", + blockFuncs: []CommandBlockFunc{ + CreateSubCommandBlocker([][]string{ + {"npm", "install", "-g"}, + {"npm", "install", "--global"}, + }), + }, + command: "npm install -g typescript", + shouldBlock: true, + }, + { + name: "block npm global install with --global", + blockFuncs: []CommandBlockFunc{ + CreateSubCommandBlocker([][]string{ + {"npm", "install", "-g"}, + {"npm", "install", "--global"}, + }), + }, + command: "npm install --global typescript", + shouldBlock: true, + }, + { + name: "allow npm local install", + blockFuncs: []CommandBlockFunc{ + CreateSubCommandBlocker([][]string{ + {"npm", "install", "-g"}, + {"npm", "install", "--global"}, + }), + }, + command: "npm install typescript", + shouldBlock: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a temporary directory for each test + tmpDir, err := os.MkdirTemp("", "shell-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + shell := NewShell(&Options{ + WorkingDir: tmpDir, + BlockFuncs: tt.blockFuncs, + }) + + _, _, err = shell.Exec(context.Background(), tt.command) + + if tt.shouldBlock { + if err == nil { + t.Errorf("Expected command to be blocked, but it was allowed") + } else if !strings.Contains(err.Error(), "not allowed for security reasons") { + t.Errorf("Expected security error, got: %v", err) + } + } else { + // For non-blocked commands, we might get other errors (like command not found) + // but we shouldn't get the security error + if err != nil && strings.Contains(err.Error(), "not allowed for security reasons") { + t.Errorf("Command was unexpectedly blocked: %v", err) + } + } + }) + } +} \ No newline at end of file diff --git a/internal/shell/shell.go b/internal/shell/shell.go index 0467c9072c5111e4b4ea9a5439519e4edf76af46..815be0907a3fd05f996a24f84f751ce5d776b833 100644 --- a/internal/shell/shell.go +++ b/internal/shell/shell.go @@ -44,12 +44,16 @@ type noopLogger struct{} func (noopLogger) InfoPersist(msg string, keysAndValues ...interface{}) {} +// CommandBlockFunc is a function that determines if a command should be blocked +type CommandBlockFunc func(args []string) bool + // Shell provides cross-platform shell execution with optional state persistence type Shell struct { - env []string - cwd string - mu sync.Mutex - logger Logger + env []string + cwd string + mu sync.Mutex + logger Logger + blockFuncs []CommandBlockFunc } // Options for creating a new shell @@ -57,6 +61,7 @@ type Options struct { WorkingDir string Env []string Logger Logger + BlockFuncs []CommandBlockFunc } // NewShell creates a new shell instance with the given options @@ -81,9 +86,10 @@ func NewShell(opts *Options) *Shell { } return &Shell{ - cwd: cwd, - env: env, - logger: logger, + cwd: cwd, + env: env, + logger: logger, + blockFuncs: opts.BlockFuncs, } } @@ -152,6 +158,13 @@ func (s *Shell) SetEnv(key, value string) { s.env = append(s.env, keyPrefix+value) } +// SetBlockFuncs sets the command block functions for the shell +func (s *Shell) SetBlockFuncs(blockFuncs []CommandBlockFunc) { + s.mu.Lock() + defer s.mu.Unlock() + s.blockFuncs = blockFuncs +} + // Windows-specific commands that should use native shell var windowsNativeCommands = map[string]bool{ "dir": true, @@ -203,6 +216,59 @@ func (s *Shell) determineShellType(command string) ShellType { return ShellTypePOSIX } +// CreateSimpleCommandBlocker creates a CommandBlockFunc that blocks exact command matches +func CreateSimpleCommandBlocker(bannedCommands []string) CommandBlockFunc { + bannedSet := make(map[string]bool) + for _, cmd := range bannedCommands { + bannedSet[cmd] = true + } + + return func(args []string) bool { + if len(args) == 0 { + return false + } + return bannedSet[args[0]] + } +} + +// CreateSubCommandBlocker creates a CommandBlockFunc that blocks specific subcommands +func CreateSubCommandBlocker(blockedSubCommands [][]string) CommandBlockFunc { + return func(args []string) bool { + for _, blocked := range blockedSubCommands { + if len(args) >= len(blocked) { + match := true + for i, part := range blocked { + if args[i] != part { + match = false + break + } + } + if match { + return true + } + } + } + return false + } +} +func (s *Shell) createCommandBlockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { + return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { + return func(ctx context.Context, args []string) error { + if len(args) == 0 { + return next(ctx, args) + } + + for _, blockFunc := range s.blockFuncs { + if blockFunc(args) { + return fmt.Errorf("command '%s' is not allowed for security reasons", strings.Join(args, " ")) + } + } + + return next(ctx, args) + } + } +} + // execWindows executes commands using native Windows shells (cmd.exe or PowerShell) func (s *Shell) execWindows(ctx context.Context, command string, shell string) (string, string, error) { var cmd *exec.Cmd @@ -291,6 +357,7 @@ func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, interp.Interactive(false), interp.Env(expand.ListEnviron(s.env...)), interp.Dir(s.cwd), + interp.ExecHandlers(s.createCommandBlockHandler()), ) if err != nil { return "", "", fmt.Errorf("could not run command: %w", err) From 931029e1579b57948322bdeda0a7c013202b90c8 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 11 Jul 2025 00:28:42 +0200 Subject: [PATCH 05/30] chore: fix errors not showing in bash tool --- internal/llm/tools/bash.go | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index 2ae1c2956c46fef6f2cc1f0a8a20114bbb8785c1..feac86e01cc7769d886f90a926ad20b87741702d 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "log/slog" "runtime" "strings" "time" @@ -259,7 +260,7 @@ func NewBashTool(permission permission.Service, workingDir string) BaseTool { // Set up command blocking on the persistent shell persistentShell := shell.GetPersistentShell(workingDir) persistentShell.SetBlockFuncs(createCommandBlockFuncs()) - + return &bashTool{ permissions: permission, workingDir: workingDir, @@ -357,7 +358,20 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) stdout = truncateOutput(stdout) stderr = truncateOutput(stderr) + slog.Info("Bash command executed", + "command", params.Command, + "stdout", stdout, + "stderr", stderr, + "exit_code", exitCode, + "interrupted", interrupted, + "err", err, + ) + errorMessage := stderr + if errorMessage == "" && err != nil { + errorMessage = err.Error() + } + if interrupted { if errorMessage != "" { errorMessage += "\n" From f8d8ce33e15d33eed1d638ec341a8b76afbe38d8 Mon Sep 17 00:00:00 2001 From: Tai Groot Date: Thu, 10 Jul 2025 18:15:38 -0700 Subject: [PATCH 06/30] allow for custom contextFiles outside of workingDir path --- internal/llm/agent/agent.go | 4 +- internal/llm/prompt/prompt.go | 48 ++++++++++++- internal/llm/prompt/prompt_test.go | 108 +++++++++++++++++++++++++++++ 3 files changed, 155 insertions(+), 5 deletions(-) create mode 100644 internal/llm/prompt/prompt_test.go diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 313b83c0448d8a668e2390368c6797c82dd22452..fbb5b4fd8c6390ff0dfad0e072af35342355ba41 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -149,7 +149,7 @@ func NewAgent( } opts := []provider.ProviderClientOption{ provider.WithModel(agentCfg.Model), - provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)), + provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)), } agentProvider, err := provider.NewProvider(*providerCfg, opts...) if err != nil { @@ -827,7 +827,7 @@ func (a *agent) UpdateModel() error { opts := []provider.ProviderClientOption{ provider.WithModel(a.agentCfg.Model), - provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)), + provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)), } newProvider, err := provider.NewProvider(*currentProviderCfg, opts...) diff --git a/internal/llm/prompt/prompt.go b/internal/llm/prompt/prompt.go index 835279b4f4c0e08e46aaad271b7cb7f2a59b467f..7f1f58d6f7dcb163a7a9c64bf0fac8f3e63455b3 100644 --- a/internal/llm/prompt/prompt.go +++ b/internal/llm/prompt/prompt.go @@ -5,6 +5,9 @@ import ( "path/filepath" "strings" "sync" + + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/env" ) type PromptID string @@ -21,7 +24,7 @@ func GetPrompt(promptID PromptID, provider string, contextPaths ...string) strin basePrompt := "" switch promptID { case PromptCoder: - basePrompt = CoderPrompt(provider) + basePrompt = CoderPrompt(provider, contextPaths...) case PromptTitle: basePrompt = TitlePrompt() case PromptTask: @@ -38,6 +41,32 @@ func getContextFromPaths(workingDir string, contextPaths []string) string { return processContextPaths(workingDir, contextPaths) } +// expandPath expands ~ and environment variables in file paths +func expandPath(path string) string { + // Handle tilde expansion + if strings.HasPrefix(path, "~/") { + homeDir, err := os.UserHomeDir() + if err == nil { + path = filepath.Join(homeDir, path[2:]) + } + } else if path == "~" { + homeDir, err := os.UserHomeDir() + if err == nil { + path = homeDir + } + } + + // Handle environment variable expansion using the same pattern as config + if strings.HasPrefix(path, "$") { + resolver := config.NewEnvironmentVariableResolver(env.New()) + if expanded, err := resolver.ResolveValue(path); err == nil { + path = expanded + } + } + + return path +} + func processContextPaths(workDir string, paths []string) string { var ( wg sync.WaitGroup @@ -53,8 +82,16 @@ func processContextPaths(workDir string, paths []string) string { go func(p string) { defer wg.Done() + // Expand ~ and environment variables before processing + p = expandPath(p) + if strings.HasSuffix(p, "/") { - filepath.WalkDir(filepath.Join(workDir, p), func(path string, d os.DirEntry, err error) error { + // Use absolute path if provided, otherwise join with workDir + dirPath := p + if !filepath.IsAbs(p) { + dirPath = filepath.Join(workDir, p) + } + filepath.WalkDir(dirPath, func(path string, d os.DirEntry, err error) error { if err != nil { return err } @@ -78,7 +115,12 @@ func processContextPaths(workDir string, paths []string) string { return nil }) } else { - fullPath := filepath.Join(workDir, p) + // Expand ~ and environment variables before processing + // Use absolute path if provided, otherwise join with workDir + fullPath := p + if !filepath.IsAbs(p) { + fullPath = filepath.Join(workDir, p) + } // Check if we've already processed this file (case-insensitive) lowerPath := strings.ToLower(fullPath) diff --git a/internal/llm/prompt/prompt_test.go b/internal/llm/prompt/prompt_test.go new file mode 100644 index 0000000000000000000000000000000000000000..77fe86a827749e0f7f0ef285e100c043b908bdea --- /dev/null +++ b/internal/llm/prompt/prompt_test.go @@ -0,0 +1,108 @@ +package prompt + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestExpandPath(t *testing.T) { + tests := []struct { + name string + input string + expected func() string + }{ + { + name: "regular path unchanged", + input: "/absolute/path", + expected: func() string { + return "/absolute/path" + }, + }, + { + name: "tilde expansion", + input: "~/documents", + expected: func() string { + home, _ := os.UserHomeDir() + return filepath.Join(home, "documents") + }, + }, + { + name: "tilde only", + input: "~", + expected: func() string { + home, _ := os.UserHomeDir() + return home + }, + }, + { + name: "environment variable expansion", + input: "$HOME", + expected: func() string { + return os.Getenv("HOME") + }, + }, + { + name: "relative path unchanged", + input: "relative/path", + expected: func() string { + return "relative/path" + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := expandPath(tt.input) + expected := tt.expected() + + // Skip test if environment variable is not set + if strings.HasPrefix(tt.input, "$") && expected == "" { + t.Skip("Environment variable not set") + } + + if result != expected { + t.Errorf("expandPath(%q) = %q, want %q", tt.input, result, expected) + } + }) + } +} + +func TestProcessContextPaths(t *testing.T) { + // Create a temporary directory and file for testing + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + testContent := "test content" + + err := os.WriteFile(testFile, []byte(testContent), 0o644) + if err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + + // Test with absolute path + result := processContextPaths("", []string{testFile}) + expected := "# From:" + testFile + "\n" + testContent + + if result != expected { + t.Errorf("processContextPaths with absolute path failed.\nGot: %q\nWant: %q", result, expected) + } + + // Test with tilde expansion (if we can create a file in home directory) + home, err := os.UserHomeDir() + if err == nil { + homeTestFile := filepath.Join(home, "crush_test_file.txt") + err = os.WriteFile(homeTestFile, []byte(testContent), 0o644) + if err == nil { + defer os.Remove(homeTestFile) // Clean up + + tildeFile := "~/crush_test_file.txt" + result = processContextPaths("", []string{tildeFile}) + expected = "# From:" + homeTestFile + "\n" + testContent + + if result != expected { + t.Errorf("processContextPaths with tilde expansion failed.\nGot: %q\nWant: %q", result, expected) + } + } + } +} From b7935b4ef8a834da5ec4514e96aeab65a66f537f Mon Sep 17 00:00:00 2001 From: Tai Groot Date: Thu, 10 Jul 2025 18:21:37 -0700 Subject: [PATCH 07/30] Update internal/llm/prompt/prompt_test.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- internal/llm/prompt/prompt_test.go | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/internal/llm/prompt/prompt_test.go b/internal/llm/prompt/prompt_test.go index 77fe86a827749e0f7f0ef285e100c043b908bdea..2087ca149a372209e8cd8c8cdb56aaf8cbc4d68e 100644 --- a/internal/llm/prompt/prompt_test.go +++ b/internal/llm/prompt/prompt_test.go @@ -89,20 +89,19 @@ func TestProcessContextPaths(t *testing.T) { } // Test with tilde expansion (if we can create a file in home directory) - home, err := os.UserHomeDir() + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + homeTestFile := filepath.Join(tmpDir, "crush_test_file.txt") + err := os.WriteFile(homeTestFile, []byte(testContent), 0o644) if err == nil { - homeTestFile := filepath.Join(home, "crush_test_file.txt") - err = os.WriteFile(homeTestFile, []byte(testContent), 0o644) - if err == nil { - defer os.Remove(homeTestFile) // Clean up + defer os.Remove(homeTestFile) // Clean up - tildeFile := "~/crush_test_file.txt" - result = processContextPaths("", []string{tildeFile}) - expected = "# From:" + homeTestFile + "\n" + testContent + tildeFile := "~/crush_test_file.txt" + result = processContextPaths("", []string{tildeFile}) + expected = "# From:" + homeTestFile + "\n" + testContent - if result != expected { - t.Errorf("processContextPaths with tilde expansion failed.\nGot: %q\nWant: %q", result, expected) - } + if result != expected { + t.Errorf("processContextPaths with tilde expansion failed.\nGot: %q\nWant: %q", result, expected) } } } From f6d6ffdc01bd72c1682a5f79cafbd17be89f040e Mon Sep 17 00:00:00 2001 From: Tai Groot Date: Thu, 10 Jul 2025 18:24:18 -0700 Subject: [PATCH 08/30] fixup suggestions from copilot --- internal/llm/prompt/prompt.go | 29 +++++++++++++++-------------- internal/llm/prompt/prompt_test.go | 8 +++++++- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/internal/llm/prompt/prompt.go b/internal/llm/prompt/prompt.go index 7f1f58d6f7dcb163a7a9c64bf0fac8f3e63455b3..4a2661bb9f663d9f93cf0371ac5d71dd513392c7 100644 --- a/internal/llm/prompt/prompt.go +++ b/internal/llm/prompt/prompt.go @@ -85,13 +85,20 @@ func processContextPaths(workDir string, paths []string) string { // Expand ~ and environment variables before processing p = expandPath(p) - if strings.HasSuffix(p, "/") { - // Use absolute path if provided, otherwise join with workDir - dirPath := p - if !filepath.IsAbs(p) { - dirPath = filepath.Join(workDir, p) - } - filepath.WalkDir(dirPath, func(path string, d os.DirEntry, err error) error { + // Use absolute path if provided, otherwise join with workDir + fullPath := p + if !filepath.IsAbs(p) { + fullPath = filepath.Join(workDir, p) + } + + // Check if the path is a directory using os.Stat + info, err := os.Stat(fullPath) + if err != nil { + return // Skip if path doesn't exist or can't be accessed + } + + if info.IsDir() { + filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error { if err != nil { return err } @@ -115,13 +122,7 @@ func processContextPaths(workDir string, paths []string) string { return nil }) } else { - // Expand ~ and environment variables before processing - // Use absolute path if provided, otherwise join with workDir - fullPath := p - if !filepath.IsAbs(p) { - fullPath = filepath.Join(workDir, p) - } - + // It's a file, process it directly // Check if we've already processed this file (case-insensitive) lowerPath := strings.ToLower(fullPath) diff --git a/internal/llm/prompt/prompt_test.go b/internal/llm/prompt/prompt_test.go index 2087ca149a372209e8cd8c8cdb56aaf8cbc4d68e..3f87c435a18daf251bbe6745ff73085b195cf718 100644 --- a/internal/llm/prompt/prompt_test.go +++ b/internal/llm/prompt/prompt_test.go @@ -80,7 +80,7 @@ func TestProcessContextPaths(t *testing.T) { t.Fatalf("Failed to create test file: %v", err) } - // Test with absolute path + // Test with absolute path to file result := processContextPaths("", []string{testFile}) expected := "# From:" + testFile + "\n" + testContent @@ -88,6 +88,12 @@ func TestProcessContextPaths(t *testing.T) { t.Errorf("processContextPaths with absolute path failed.\nGot: %q\nWant: %q", result, expected) } + // Test with directory path (should process all files in directory) + result = processContextPaths("", []string{tmpDir}) + if !strings.Contains(result, testContent) { + t.Errorf("processContextPaths with directory path failed to include file content") + } + // Test with tilde expansion (if we can create a file in home directory) tmpDir := t.TempDir() t.Setenv("HOME", tmpDir) From 055b30f04e9121a60b1628169d4bb4fd6db0d0c8 Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Thu, 10 Jul 2025 17:13:11 -0400 Subject: [PATCH 09/30] fix(tui): completions: close when no items match query This simply closes the completions component when there are no items matching the query. --- .../tui/components/completions/completions.go | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/internal/tui/components/completions/completions.go b/internal/tui/components/completions/completions.go index 6409f4bc96f59349046891d872994393550c65ba..e9c410d37a91ac245b7766cb7b1469dd4b421eda 100644 --- a/internal/tui/components/completions/completions.go +++ b/internal/tui/components/completions/completions.go @@ -150,18 +150,25 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if !c.open { return c, nil // If completions are not open, do nothing } - cmd := c.list.Filter(msg.Query) - c.height = max(min(10, len(c.list.Items())), 1) - return c, tea.Batch( - cmd, - c.list.SetSize(c.width, c.height), - ) + var cmds []tea.Cmd + cmds = append(cmds, c.list.Filter(msg.Query)) + itemsLen := len(c.list.Items()) + c.height = max(min(10, itemsLen), 1) + cmds = append(cmds, c.list.SetSize(c.width, c.height)) + if itemsLen == 0 { + // Close completions if no items match the query + cmds = append(cmds, util.CmdHandler(CloseCompletionsMsg{})) + } + return c, tea.Batch(cmds...) } return c, nil } // View implements Completions. func (c *completionsCmp) View() string { + if !c.open { + return "" + } if len(c.list.Items()) == 0 { return c.style().Render("No completions found") } From b6a97d222f29c2a383ee02e63636cbe7ec3f13a9 Mon Sep 17 00:00:00 2001 From: Tai Groot Date: Thu, 10 Jul 2025 18:28:40 -0700 Subject: [PATCH 10/30] fixup test --- internal/llm/prompt/prompt_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/llm/prompt/prompt_test.go b/internal/llm/prompt/prompt_test.go index 3f87c435a18daf251bbe6745ff73085b195cf718..ce7fa0fb35cfdf021b886a96a828202001588a7f 100644 --- a/internal/llm/prompt/prompt_test.go +++ b/internal/llm/prompt/prompt_test.go @@ -95,10 +95,10 @@ func TestProcessContextPaths(t *testing.T) { } // Test with tilde expansion (if we can create a file in home directory) - tmpDir := t.TempDir() + tmpDir = t.TempDir() t.Setenv("HOME", tmpDir) homeTestFile := filepath.Join(tmpDir, "crush_test_file.txt") - err := os.WriteFile(homeTestFile, []byte(testContent), 0o644) + err = os.WriteFile(homeTestFile, []byte(testContent), 0o644) if err == nil { defer os.Remove(homeTestFile) // Clean up From 88077377e176c4940249502e4680d12c4e5c5aa3 Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Thu, 10 Jul 2025 21:33:59 -0400 Subject: [PATCH 11/30] chore(tui/completions): pull out magic number --- internal/tui/components/completions/completions.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/internal/tui/components/completions/completions.go b/internal/tui/components/completions/completions.go index e9c410d37a91ac245b7766cb7b1469dd4b421eda..5a6bcfe92e23f38c3f40c84770a0dcc9893e59d5 100644 --- a/internal/tui/components/completions/completions.go +++ b/internal/tui/components/completions/completions.go @@ -9,6 +9,8 @@ import ( "github.com/charmbracelet/lipgloss/v2" ) +const maxCompletionsHeight = 10 + type Completion struct { Title string // The title of the completion item Value any // The value of the completion item @@ -153,7 +155,7 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmds []tea.Cmd cmds = append(cmds, c.list.Filter(msg.Query)) itemsLen := len(c.list.Items()) - c.height = max(min(10, itemsLen), 1) + c.height = max(min(maxCompletionsHeight, itemsLen), 1) cmds = append(cmds, c.list.SetSize(c.width, c.height)) if itemsLen == 0 { // Close completions if no items match the query From 8bd9d06af2ed7e617e80fdbb26589fd2ff1b5151 Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Thu, 10 Jul 2025 21:40:08 -0400 Subject: [PATCH 12/30] fix(logs): typo --- internal/config/load.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/config/load.go b/internal/config/load.go index cc9191fcda5ebfb875fefbac899b21c3597ef0e2..9f2b5e55f1ccc0a687d46083b67e81d6e5fa212a 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -388,7 +388,7 @@ func (cfg *Config) configureSelectedModels(knownProviders []provider.Provider) e } model := cfg.GetModel(large.Provider, large.Model) slog.Info("Configuring selected large model", "provider", large.Provider, "model", large.Model) - slog.Info("MOdel configured", "model", model) + slog.Info("Model configured", "model", model) if model == nil { large = defaultLarge // override the model type to large From 3cbafd39aa6f106b33f49f46e2da36cab5f98d33 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 11 Jul 2025 13:59:20 +0200 Subject: [PATCH 13/30] chore: fix provider --- internal/config/config.go | 2 +- internal/config/resolve.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index 5c978106bc49f7b5956ea1d1d6e4d994f53eae58..ae8bcfdc35562e680527e99cdc74fd591e849874 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -72,7 +72,7 @@ type ProviderConfig struct { Disable bool `json:"disable,omitempty"` // Extra headers to send with each request to the provider. - ExtraHeaders map[string]string + ExtraHeaders map[string]string `json:"extra_headers,omitempty"` // Used to pass extra parameters to the provider. ExtraParams map[string]string `json:"-"` diff --git a/internal/config/resolve.go b/internal/config/resolve.go index 9c9116661814fe7abee91e2821829442bc65080d..3c97a6456cf7fe5968311746d62b2772b21d6aaa 100644 --- a/internal/config/resolve.go +++ b/internal/config/resolve.go @@ -44,7 +44,7 @@ func (r *shellVariableResolver) ResolveValue(value string) (string, error) { if strings.HasPrefix(value, "$(") && strings.HasSuffix(value, ")") { command := strings.TrimSuffix(strings.TrimPrefix(value, "$("), ")") - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() stdout, _, err := r.shell.Exec(ctx, command) From f6a79e41310f8ed94bfa8903848144b483f8323a Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Fri, 11 Jul 2025 11:02:00 -0300 Subject: [PATCH 14/30] feat: stream content in non-interactive mode (#133) --- cmd/root.go | 18 +------- internal/app/app.go | 64 ++++++++++++++++----------- internal/format/format.go | 91 --------------------------------------- 3 files changed, 40 insertions(+), 133 deletions(-) delete mode 100644 internal/format/format.go diff --git a/cmd/root.go b/cmd/root.go index e27bc46adcf38ae4b36cfba8d0f518690091242f..3a8f4fba0fe759a42ef1e7647223b2b3b11fbc65 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -12,7 +12,6 @@ import ( "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/db" - "github.com/charmbracelet/crush/internal/format" "github.com/charmbracelet/crush/internal/llm/agent" "github.com/charmbracelet/crush/internal/log" "github.com/charmbracelet/crush/internal/tui" @@ -52,14 +51,8 @@ to assist developers in writing, debugging, and understanding code directly from debug, _ := cmd.Flags().GetBool("debug") cwd, _ := cmd.Flags().GetString("cwd") prompt, _ := cmd.Flags().GetString("prompt") - outputFormat, _ := cmd.Flags().GetString("output-format") quiet, _ := cmd.Flags().GetBool("quiet") - // Validate format option - if !format.IsValid(outputFormat) { - return fmt.Errorf("invalid format option: %s\n%s", outputFormat, format.GetHelpText()) - } - if cwd != "" { err := os.Chdir(cwd) if err != nil { @@ -109,7 +102,7 @@ to assist developers in writing, debugging, and understanding code directly from // Non-interactive mode if prompt != "" { // Run non-interactive flow using the App method - return app.RunNonInteractive(ctx, prompt, outputFormat, quiet) + return app.RunNonInteractive(ctx, prompt, quiet) } // Set up the TUI @@ -164,17 +157,8 @@ func init() { rootCmd.Flags().BoolP("debug", "d", false, "Debug") rootCmd.Flags().StringP("prompt", "p", "", "Prompt to run in non-interactive mode") - // Add format flag with validation logic - rootCmd.Flags().StringP("output-format", "f", format.Text.String(), - "Output format for non-interactive mode (text, json)") - // Add quiet flag to hide spinner in non-interactive mode rootCmd.Flags().BoolP("quiet", "q", false, "Hide spinner in non-interactive mode") - - // Register custom validation for the format flag - rootCmd.RegisterFlagCompletionFunc("output-format", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - return format.SupportedFormats, cobra.ShellCompDirectiveNoFileComp - }) } func maybePrependStdin(prompt string) (string, error) { diff --git a/internal/app/app.go b/internal/app/app.go index 099b092089c4a4e4e0ddcc9ccf79c36ca66acdce..9d0e6f176b14df0b15fd90f4b3651cdefafd6826 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -92,7 +92,7 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { } // RunNonInteractive handles the execution flow when a prompt is provided via CLI flag. -func (a *App) RunNonInteractive(ctx context.Context, prompt string, outputFormat string, quiet bool) error { +func (a *App) RunNonInteractive(ctx context.Context, prompt string, quiet bool) error { slog.Info("Running in non-interactive mode") // Start spinner if not in quiet mode @@ -100,8 +100,15 @@ func (a *App) RunNonInteractive(ctx context.Context, prompt string, outputFormat if !quiet { spinner = format.NewSpinner(ctx, "Generating") spinner.Start() - defer spinner.Stop() } + // Helper function to stop spinner once + stopSpinner := func() { + if !quiet && spinner != nil { + spinner.Stop() + spinner = nil + } + } + defer stopSpinner() const maxPromptLengthForTitle = 100 titlePrefix := "Non-interactive: " @@ -128,35 +135,42 @@ func (a *App) RunNonInteractive(ctx context.Context, prompt string, outputFormat return fmt.Errorf("failed to start agent processing stream: %w", err) } - result := <-done + messageEvents := a.Messages.Subscribe(ctx) + readBts := 0 - // Stop spinner before printing output - if !quiet && spinner != nil { - spinner.Stop() - } + for { + select { + case result := <-done: + stopSpinner() + + if result.Error != nil { + if errors.Is(result.Error, context.Canceled) || errors.Is(result.Error, agent.ErrRequestCancelled) { + slog.Info("Agent processing cancelled", "session_id", sess.ID) + return nil + } + return fmt.Errorf("agent processing failed: %w", result.Error) + } + + part := result.Message.Content().String()[readBts:] + fmt.Println(part) - if result.Error != nil { - if errors.Is(result.Error, context.Canceled) || errors.Is(result.Error, agent.ErrRequestCancelled) { - slog.Info("Agent processing cancelled", "session_id", sess.ID) + slog.Info("Non-interactive run completed", "session_id", sess.ID) return nil - } - return fmt.Errorf("agent processing failed: %w", result.Error) - } - // Get the text content from the response - content := "No content available" - if result.Message.Content().String() != "" { - content = result.Message.Content().String() - } + case event := <-messageEvents: + msg := event.Payload + if msg.SessionID == sess.ID && msg.Role == message.Assistant && len(msg.Parts) > 0 { + stopSpinner() + part := msg.Content().String()[readBts:] + fmt.Print(part) + readBts += len(part) + } - out, err := format.FormatOutput(content, outputFormat) - if err != nil { - return err + case <-ctx.Done(): + stopSpinner() + return ctx.Err() + } } - - fmt.Println(out) - slog.Info("Non-interactive run completed", "session_id", sess.ID) - return nil } func (app *App) UpdateAgentModel() error { diff --git a/internal/format/format.go b/internal/format/format.go deleted file mode 100644 index 9f5a98910cafa41b924ff516da54ab751eb7f058..0000000000000000000000000000000000000000 --- a/internal/format/format.go +++ /dev/null @@ -1,91 +0,0 @@ -package format - -import ( - "encoding/json" - "fmt" - "strings" -) - -// OutputFormat represents the output format type for non-interactive mode -type OutputFormat string - -const ( - // Text format outputs the AI response as plain text. - Text OutputFormat = "text" - - // JSON format outputs the AI response wrapped in a JSON object. - JSON OutputFormat = "json" -) - -// String returns the string representation of the OutputFormat -func (f OutputFormat) String() string { - return string(f) -} - -// SupportedFormats is a list of all supported output formats as strings -var SupportedFormats = []string{ - string(Text), - string(JSON), -} - -// Parse converts a string to an OutputFormat -func Parse(s string) (OutputFormat, error) { - s = strings.ToLower(strings.TrimSpace(s)) - - switch s { - case string(Text): - return Text, nil - case string(JSON): - return JSON, nil - default: - return "", fmt.Errorf("invalid format: %s", s) - } -} - -// IsValid checks if the provided format string is supported -func IsValid(s string) bool { - _, err := Parse(s) - return err == nil -} - -// GetHelpText returns a formatted string describing all supported formats -func GetHelpText() string { - return fmt.Sprintf(`Supported output formats: -- %s: Plain text output (default) -- %s: Output wrapped in a JSON object`, - Text, JSON) -} - -// FormatOutput formats the AI response according to the specified format -func FormatOutput(content string, formatStr string) (string, error) { - format, err := Parse(formatStr) - if err != nil { - format = Text - } - - switch format { - case JSON: - return formatAsJSON(content) - case Text: - fallthrough - default: - return content, nil - } -} - -// formatAsJSON wraps the content in a simple JSON object -func formatAsJSON(content string) (string, error) { - // Use the JSON package to properly escape the content - response := struct { - Response string `json:"response"` - }{ - Response: content, - } - - jsonBytes, err := json.MarshalIndent(response, "", " ") - if err != nil { - return "", fmt.Errorf("failed to marshal output into JSON: %w", err) - } - - return string(jsonBytes), nil -} From 9a0f9a0982117b55599bcf8003c903bbe9dea679 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 11 Jul 2025 17:43:28 +0200 Subject: [PATCH 15/30] chore: fix bash commands with tabs --- internal/tui/components/chat/messages/renderer.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/internal/tui/components/chat/messages/renderer.go b/internal/tui/components/chat/messages/renderer.go index cad86659e04c6eb77e957e2fef4885000214a953..88f59d7ebc81cecea0bb8ef314de73d720ad2938 100644 --- a/internal/tui/components/chat/messages/renderer.go +++ b/internal/tui/components/chat/messages/renderer.go @@ -207,6 +207,7 @@ func (br bashRenderer) Render(v *toolCallCmp) string { } cmd := strings.ReplaceAll(params.Command, "\n", " ") + cmd = strings.ReplaceAll(cmd, "\t", " ") args := newParamBuilder().addMain(cmd).build() return br.renderWithParams(v, "Bash", args, func() string { @@ -578,8 +579,8 @@ func renderParamList(nested bool, paramsWidth int, params ...string) string { return "" } mainParam := params[0] - if paramsWidth-3 >= 0 && len(mainParam) > paramsWidth { - mainParam = mainParam[:paramsWidth-3] + "…" + if paramsWidth >= 0 && lipgloss.Width(mainParam) > paramsWidth { + mainParam = ansi.Truncate(mainParam, paramsWidth, "…") } if len(params) == 1 { From 2fdbcac0290e1aa977de0139c335210ce35b0faf Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 11 Jul 2025 17:55:01 +0200 Subject: [PATCH 16/30] chore: better naming --- internal/llm/tools/bash.go | 101 ++++++++++++++++++++++++--- internal/shell/command_block_test.go | 24 +++---- internal/shell/shell.go | 25 +++---- 3 files changed, 118 insertions(+), 32 deletions(-) diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index feac86e01cc7769d886f90a926ad20b87741702d..6d7a9a32b3829da02021be80e6e41e28888efd83 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -42,9 +42,74 @@ const ( ) var bannedCommands = []string{ - "alias", "curl", "curlie", "wget", "axel", "aria2c", - "nc", "telnet", "lynx", "w3m", "links", "httpie", "xh", - "http-prompt", "chrome", "firefox", "safari", "sudo", + // Network/Download tools + "alias", + "aria2c", + "axel", + "chrome", + "curl", + "curlie", + "firefox", + "http-prompt", + "httpie", + "links", + "lynx", + "nc", + "safari", + "telnet", + "w3m", + "wget", + "xh", + + // System administration + "doas", + "su", + "sudo", + + // Package managers + "apk", + "apt", + "apt-cache", + "apt-get", + "dnf", + "dpkg", + "emerge", + "home-manager", + "makepkg", + "opkg", + "pacman", + "paru", + "pkg", + "pkg_add", + "pkg_delete", + "portage", + "rpm", + "yay", + "yum", + "zypper", + + // System modification + "at", + "batch", + "chkconfig", + "crontab", + "fdisk", + "mkfs", + "mount", + "parted", + "service", + "systemctl", + "umount", + + // Network configuration + "firewall-cmd", + "ifconfig", + "ip", + "iptables", + "netstat", + "pfctl", + "route", + "ufw", } // getSafeReadOnlyCommands returns platform-appropriate safe commands @@ -245,13 +310,33 @@ Important: - Never update git config`, bannedCommandsStr, MaxOutputLength) } -func createCommandBlockFuncs() []shell.CommandBlockFunc { - return []shell.CommandBlockFunc{ - shell.CreateSimpleCommandBlocker(bannedCommands), - shell.CreateSubCommandBlocker([][]string{ +func blockFuncs() []shell.BlockFunc { + return []shell.BlockFunc{ + shell.CommandsBlocker(bannedCommands), + shell.ArgumentsBlocker([][]string{ + // System package managers + {"apk", "add"}, + {"apt", "install"}, + {"apt-get", "install"}, + {"dnf", "install"}, + {"emerge"}, + {"pacman", "-S"}, + {"pkg", "install"}, + {"yum", "install"}, + {"zypper", "install"}, + + // Language-specific package managers {"brew", "install"}, + {"cargo", "install"}, + {"gem", "install"}, + {"go", "install"}, {"npm", "install", "-g"}, {"npm", "install", "--global"}, + {"pip", "install", "--user"}, + {"pip3", "install", "--user"}, + {"pnpm", "add", "-g"}, + {"pnpm", "add", "--global"}, + {"yarn", "global", "add"}, }), } } @@ -259,7 +344,7 @@ func createCommandBlockFuncs() []shell.CommandBlockFunc { func NewBashTool(permission permission.Service, workingDir string) BaseTool { // Set up command blocking on the persistent shell persistentShell := shell.GetPersistentShell(workingDir) - persistentShell.SetBlockFuncs(createCommandBlockFuncs()) + persistentShell.SetBlockFuncs(blockFuncs()) return &bashTool{ permissions: permission, diff --git a/internal/shell/command_block_test.go b/internal/shell/command_block_test.go index 85971748f882cf79fba3ea86d2682ce6ce4f252d..fd7c46bcd98e54f44abbe982e834f3cbb04cbfa4 100644 --- a/internal/shell/command_block_test.go +++ b/internal/shell/command_block_test.go @@ -10,13 +10,13 @@ import ( func TestCommandBlocking(t *testing.T) { tests := []struct { name string - blockFuncs []CommandBlockFunc + blockFuncs []BlockFunc command string shouldBlock bool }{ { name: "block simple command", - blockFuncs: []CommandBlockFunc{ + blockFuncs: []BlockFunc{ func(args []string) bool { return len(args) > 0 && args[0] == "curl" }, @@ -26,7 +26,7 @@ func TestCommandBlocking(t *testing.T) { }, { name: "allow non-blocked command", - blockFuncs: []CommandBlockFunc{ + blockFuncs: []BlockFunc{ func(args []string) bool { return len(args) > 0 && args[0] == "curl" }, @@ -36,7 +36,7 @@ func TestCommandBlocking(t *testing.T) { }, { name: "block subcommand", - blockFuncs: []CommandBlockFunc{ + blockFuncs: []BlockFunc{ func(args []string) bool { return len(args) >= 2 && args[0] == "brew" && args[1] == "install" }, @@ -46,7 +46,7 @@ func TestCommandBlocking(t *testing.T) { }, { name: "allow different subcommand", - blockFuncs: []CommandBlockFunc{ + blockFuncs: []BlockFunc{ func(args []string) bool { return len(args) >= 2 && args[0] == "brew" && args[1] == "install" }, @@ -56,8 +56,8 @@ func TestCommandBlocking(t *testing.T) { }, { name: "block npm global install with -g", - blockFuncs: []CommandBlockFunc{ - CreateSubCommandBlocker([][]string{ + blockFuncs: []BlockFunc{ + ArgumentsBlocker([][]string{ {"npm", "install", "-g"}, {"npm", "install", "--global"}, }), @@ -67,8 +67,8 @@ func TestCommandBlocking(t *testing.T) { }, { name: "block npm global install with --global", - blockFuncs: []CommandBlockFunc{ - CreateSubCommandBlocker([][]string{ + blockFuncs: []BlockFunc{ + ArgumentsBlocker([][]string{ {"npm", "install", "-g"}, {"npm", "install", "--global"}, }), @@ -78,8 +78,8 @@ func TestCommandBlocking(t *testing.T) { }, { name: "allow npm local install", - blockFuncs: []CommandBlockFunc{ - CreateSubCommandBlocker([][]string{ + blockFuncs: []BlockFunc{ + ArgumentsBlocker([][]string{ {"npm", "install", "-g"}, {"npm", "install", "--global"}, }), @@ -120,4 +120,4 @@ func TestCommandBlocking(t *testing.T) { } }) } -} \ No newline at end of file +} diff --git a/internal/shell/shell.go b/internal/shell/shell.go index 815be0907a3fd05f996a24f84f751ce5d776b833..097af74d0172264efbe899711ff57137abc6ee30 100644 --- a/internal/shell/shell.go +++ b/internal/shell/shell.go @@ -44,8 +44,8 @@ type noopLogger struct{} func (noopLogger) InfoPersist(msg string, keysAndValues ...interface{}) {} -// CommandBlockFunc is a function that determines if a command should be blocked -type CommandBlockFunc func(args []string) bool +// BlockFunc is a function that determines if a command should be blocked +type BlockFunc func(args []string) bool // Shell provides cross-platform shell execution with optional state persistence type Shell struct { @@ -53,7 +53,7 @@ type Shell struct { cwd string mu sync.Mutex logger Logger - blockFuncs []CommandBlockFunc + blockFuncs []BlockFunc } // Options for creating a new shell @@ -61,7 +61,7 @@ type Options struct { WorkingDir string Env []string Logger Logger - BlockFuncs []CommandBlockFunc + BlockFuncs []BlockFunc } // NewShell creates a new shell instance with the given options @@ -159,7 +159,7 @@ func (s *Shell) SetEnv(key, value string) { } // SetBlockFuncs sets the command block functions for the shell -func (s *Shell) SetBlockFuncs(blockFuncs []CommandBlockFunc) { +func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) { s.mu.Lock() defer s.mu.Unlock() s.blockFuncs = blockFuncs @@ -216,13 +216,13 @@ func (s *Shell) determineShellType(command string) ShellType { return ShellTypePOSIX } -// CreateSimpleCommandBlocker creates a CommandBlockFunc that blocks exact command matches -func CreateSimpleCommandBlocker(bannedCommands []string) CommandBlockFunc { +// CommandsBlocker creates a BlockFunc that blocks exact command matches +func CommandsBlocker(bannedCommands []string) BlockFunc { bannedSet := make(map[string]bool) for _, cmd := range bannedCommands { bannedSet[cmd] = true } - + return func(args []string) bool { if len(args) == 0 { return false @@ -231,8 +231,8 @@ func CreateSimpleCommandBlocker(bannedCommands []string) CommandBlockFunc { } } -// CreateSubCommandBlocker creates a CommandBlockFunc that blocks specific subcommands -func CreateSubCommandBlocker(blockedSubCommands [][]string) CommandBlockFunc { +// ArgumentsBlocker creates a BlockFunc that blocks specific subcommands +func ArgumentsBlocker(blockedSubCommands [][]string) BlockFunc { return func(args []string) bool { for _, blocked := range blockedSubCommands { if len(args) >= len(blocked) { @@ -251,7 +251,8 @@ func CreateSubCommandBlocker(blockedSubCommands [][]string) CommandBlockFunc { return false } } -func (s *Shell) createCommandBlockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { + +func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { return func(ctx context.Context, args []string) error { if len(args) == 0 { @@ -357,7 +358,7 @@ func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, interp.Interactive(false), interp.Env(expand.ListEnviron(s.env...)), interp.Dir(s.cwd), - interp.ExecHandlers(s.createCommandBlockHandler()), + interp.ExecHandlers(s.blockHandler()), ) if err != nil { return "", "", fmt.Errorf("could not run command: %w", err) From 0a5274840166e83a8d2e3fbbb8fcec2a6f324e26 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 11 Jul 2025 18:02:42 +0200 Subject: [PATCH 17/30] chore: small change --- internal/shell/shell.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/shell/shell.go b/internal/shell/shell.go index 097af74d0172264efbe899711ff57137abc6ee30..b655c5dbecf5b69c7ad102c53108733515138771 100644 --- a/internal/shell/shell.go +++ b/internal/shell/shell.go @@ -261,7 +261,7 @@ func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHand for _, blockFunc := range s.blockFuncs { if blockFunc(args) { - return fmt.Errorf("command '%s' is not allowed for security reasons", strings.Join(args, " ")) + return fmt.Errorf("command is not allowed for security reasons: %s", strings.Join(args, " ")) } } From 36d4f98c98a30ce816c3c06321e10b6a062c39ff Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 11 Jul 2025 18:10:44 +0200 Subject: [PATCH 18/30] chore: fix extra space --- internal/tui/components/chat/messages/renderer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/tui/components/chat/messages/renderer.go b/internal/tui/components/chat/messages/renderer.go index 88f59d7ebc81cecea0bb8ef314de73d720ad2938..87eb2c8476655fe7d11fc8c787e73b32d4584de4 100644 --- a/internal/tui/components/chat/messages/renderer.go +++ b/internal/tui/components/chat/messages/renderer.go @@ -650,7 +650,7 @@ func joinHeaderBody(header, body string) string { return header } body = t.S().Base.PaddingLeft(2).Render(body) - return lipgloss.JoinVertical(lipgloss.Left, header, "", body, "") + return lipgloss.JoinVertical(lipgloss.Left, header, "", body) } func renderPlainContent(v *toolCallCmp, content string) string { From dfb5080ff6e0fac83daf3084add497170327abad Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 11 Jul 2025 18:12:17 +0200 Subject: [PATCH 19/30] chore: remove extra help --- internal/tui/components/dialogs/permissions/keys.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/internal/tui/components/dialogs/permissions/keys.go b/internal/tui/components/dialogs/permissions/keys.go index 052c5222bc1ff7d7de1eb7e8f8a1378a7c79c1bc..9edc368d275d90d670eeb8f03346184d3edea800 100644 --- a/internal/tui/components/dialogs/permissions/keys.go +++ b/internal/tui/components/dialogs/permissions/keys.go @@ -109,9 +109,5 @@ func (k KeyMap) ShortHelp() []key.Binding { key.WithKeys("shift+left", "shift+down", "shift+up", "shift+right"), key.WithHelp("shift+←↓↑→", "scroll"), ), - key.NewBinding( - key.WithKeys("shift+h", "shift+j", "shift+k", "shift+l"), - key.WithHelp("shift+hjkl", "scroll"), - ), } } From 232860cb8128d640272dcf278d3454038224581d Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 11 Jul 2025 18:17:26 +0200 Subject: [PATCH 20/30] chore: disable commands when other dialogs are open --- internal/tui/tui.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 633766a1d80bf8b0056e8d856b71df04613e1101..365db72299865897feb94879f837baa93bff5e43 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -327,6 +327,9 @@ func (a *appModel) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd { // If the commands dialog is already open, close it return util.CmdHandler(dialogs.CloseDialogMsg{}) } + if a.dialog.HasDialogs() { + return nil // Don't open commands dialog if another dialog is active + } return util.CmdHandler(dialogs.OpenDialogMsg{ Model: commands.NewCommandDialog(a.selectedSessionID), }) From 531b3fd44fc1ee8458f16313eb92fc090cb62f0a Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 11 Jul 2025 18:23:51 +0200 Subject: [PATCH 21/30] chore: fix missing keys --- internal/tui/components/chat/splash/keys.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/tui/components/chat/splash/keys.go b/internal/tui/components/chat/splash/keys.go index 9cf2e3124daa87b0fc62c2ea404fb1c6c86ec649..675c608a94af4aa72b701376f3983506166ac7d7 100644 --- a/internal/tui/components/chat/splash/keys.go +++ b/internal/tui/components/chat/splash/keys.go @@ -30,11 +30,11 @@ func DefaultKeyMap() KeyMap { key.WithHelp("↑", "previous item"), ), Yes: key.NewBinding( - key.WithKeys("y"), + key.WithKeys("y", "Y"), key.WithHelp("y", "yes"), ), No: key.NewBinding( - key.WithKeys("n"), + key.WithKeys("n", "N"), key.WithHelp("n", "no"), ), Tab: key.NewBinding( From 3e820ececc845e719ade074b7c4b4d495be8b20c Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Fri, 11 Jul 2025 15:29:10 -0300 Subject: [PATCH 22/30] chore(deps): update mcp-go (#155) * chore(deps): update mcp-go Signed-off-by: Carlos Alexandro Becker * fix: vendoring Signed-off-by: Carlos Alexandro Becker --------- Signed-off-by: Carlos Alexandro Becker --- go.mod | 2 +- go.sum | 4 +- .../mark3labs/mcp-go/client/client.go | 102 +++++- .../mark3labs/mcp-go/client/http.go | 7 +- .../mark3labs/mcp-go/client/sampling.go | 20 + .../mark3labs/mcp-go/client/stdio.go | 22 +- .../mcp-go/client/transport/inprocess.go | 4 + .../mcp-go/client/transport/interface.go | 20 +- .../mark3labs/mcp-go/client/transport/sse.go | 6 + .../mcp-go/client/transport/stdio.go | 204 ++++++++++- .../client/transport/streamable_http.go | 343 ++++++++++++------ .../github.com/mark3labs/mcp-go/mcp/tools.go | 106 +++++- .../github.com/mark3labs/mcp-go/mcp/types.go | 21 ++ .../github.com/mark3labs/mcp-go/mcp/utils.go | 21 ++ .../mark3labs/mcp-go/server/sampling.go | 37 ++ .../mark3labs/mcp-go/server/stdio.go | 180 ++++++++- .../mcp-go/server/streamable_http.go | 6 +- vendor/modules.txt | 2 +- 18 files changed, 959 insertions(+), 148 deletions(-) create mode 100644 vendor/github.com/mark3labs/mcp-go/client/sampling.go create mode 100644 vendor/github.com/mark3labs/mcp-go/server/sampling.go diff --git a/go.mod b/go.mod index 35907121af5791acc5cfc5f3aa07f10df9eba763..2a9d6d5dfbaa827a5c8a57cadbe716dd956e1401 100644 --- a/go.mod +++ b/go.mod @@ -29,7 +29,7 @@ require ( github.com/fsnotify/fsnotify v1.8.0 github.com/google/uuid v1.6.0 github.com/joho/godotenv v1.5.1 - github.com/mark3labs/mcp-go v0.32.0 + github.com/mark3labs/mcp-go v0.33.0 github.com/muesli/termenv v0.16.0 github.com/ncruces/go-sqlite3 v0.25.0 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 diff --git a/go.sum b/go.sum index 50e30a46d4a47cb210add9c3fe61f0c9fb8e6c26..1d40961a3dce4180d9a06d17e3843f8d8709567b 100644 --- a/go.sum +++ b/go.sum @@ -165,8 +165,8 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0 github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= -github.com/mark3labs/mcp-go v0.32.0 h1:fgwmbfL2gbd67obg57OfV2Dnrhs1HtSdlY/i5fn7MU8= -github.com/mark3labs/mcp-go v0.32.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= +github.com/mark3labs/mcp-go v0.33.0 h1:naxhjnTIs/tyPZmWUZFuG0lDmdA6sUyYGGf3gsHvTCc= +github.com/mark3labs/mcp-go v0.33.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= diff --git a/vendor/github.com/mark3labs/mcp-go/client/client.go b/vendor/github.com/mark3labs/mcp-go/client/client.go index dd0e31a013595ccbb900a10fe413e02d1ed9d0ad..e2c466586050cf69e2015e83056fdaf6eda949f6 100644 --- a/vendor/github.com/mark3labs/mcp-go/client/client.go +++ b/vendor/github.com/mark3labs/mcp-go/client/client.go @@ -22,6 +22,7 @@ type Client struct { requestID atomic.Int64 clientCapabilities mcp.ClientCapabilities serverCapabilities mcp.ServerCapabilities + samplingHandler SamplingHandler } type ClientOption func(*Client) @@ -33,6 +34,21 @@ func WithClientCapabilities(capabilities mcp.ClientCapabilities) ClientOption { } } +// WithSamplingHandler sets the sampling handler for the client. +// When set, the client will declare sampling capability during initialization. +func WithSamplingHandler(handler SamplingHandler) ClientOption { + return func(c *Client) { + c.samplingHandler = handler + } +} + +// WithSession assumes a MCP Session has already been initialized +func WithSession() ClientOption { + return func(c *Client) { + c.initialized = true + } +} + // NewClient creates a new MCP client with the given transport. // Usage: // @@ -71,6 +87,12 @@ func (c *Client) Start(ctx context.Context) error { handler(notification) } }) + + // Set up request handler for bidirectional communication (e.g., sampling) + if bidirectional, ok := c.transport.(transport.BidirectionalInterface); ok { + bidirectional.SetRequestHandler(c.handleIncomingRequest) + } + return nil } @@ -127,6 +149,12 @@ func (c *Client) Initialize( ctx context.Context, request mcp.InitializeRequest, ) (*mcp.InitializeResult, error) { + // Merge client capabilities with sampling capability if handler is configured + capabilities := request.Params.Capabilities + if c.samplingHandler != nil { + capabilities.Sampling = &struct{}{} + } + // Ensure we send a params object with all required fields params := struct { ProtocolVersion string `json:"protocolVersion"` @@ -135,7 +163,7 @@ func (c *Client) Initialize( }{ ProtocolVersion: request.Params.ProtocolVersion, ClientInfo: request.Params.ClientInfo, - Capabilities: request.Params.Capabilities, // Will be empty struct if not set + Capabilities: capabilities, } response, err := c.sendRequest(ctx, "initialize", params) @@ -398,6 +426,64 @@ func (c *Client) Complete( return &result, nil } +// handleIncomingRequest processes incoming requests from the server. +// This is the main entry point for server-to-client requests like sampling. +func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { + switch request.Method { + case string(mcp.MethodSamplingCreateMessage): + return c.handleSamplingRequestTransport(ctx, request) + default: + return nil, fmt.Errorf("unsupported request method: %s", request.Method) + } +} + +// handleSamplingRequestTransport handles sampling requests at the transport level. +func (c *Client) handleSamplingRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { + if c.samplingHandler == nil { + return nil, fmt.Errorf("no sampling handler configured") + } + + // Parse the request parameters + var params mcp.CreateMessageParams + if request.Params != nil { + paramsBytes, err := json.Marshal(request.Params) + if err != nil { + return nil, fmt.Errorf("failed to marshal params: %w", err) + } + if err := json.Unmarshal(paramsBytes, ¶ms); err != nil { + return nil, fmt.Errorf("failed to unmarshal params: %w", err) + } + } + + // Create the MCP request + mcpRequest := mcp.CreateMessageRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodSamplingCreateMessage), + }, + CreateMessageParams: params, + } + + // Call the sampling handler + result, err := c.samplingHandler.CreateMessage(ctx, mcpRequest) + if err != nil { + return nil, err + } + + // Marshal the result + resultBytes, err := json.Marshal(result) + if err != nil { + return nil, fmt.Errorf("failed to marshal result: %w", err) + } + + // Create the transport response + response := &transport.JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: request.ID, + Result: json.RawMessage(resultBytes), + } + + return response, nil +} func listByPage[T any]( ctx context.Context, client *Client, @@ -432,3 +518,17 @@ func (c *Client) GetServerCapabilities() mcp.ServerCapabilities { func (c *Client) GetClientCapabilities() mcp.ClientCapabilities { return c.clientCapabilities } + +// GetSessionId returns the session ID of the transport. +// If the transport does not support sessions, it returns an empty string. +func (c *Client) GetSessionId() string { + if c.transport == nil { + return "" + } + return c.transport.GetSessionId() +} + +// IsInitialized returns true if the client has been initialized. +func (c *Client) IsInitialized() bool { + return c.initialized +} diff --git a/vendor/github.com/mark3labs/mcp-go/client/http.go b/vendor/github.com/mark3labs/mcp-go/client/http.go index cb3be35d64cfc731efe2cef0c268a018c53a9538..d001a1e63d08e42d7457adbf5c497d93d029e203 100644 --- a/vendor/github.com/mark3labs/mcp-go/client/http.go +++ b/vendor/github.com/mark3labs/mcp-go/client/http.go @@ -13,5 +13,10 @@ func NewStreamableHttpClient(baseURL string, options ...transport.StreamableHTTP if err != nil { return nil, fmt.Errorf("failed to create SSE transport: %w", err) } - return NewClient(trans), nil + clientOptions := make([]ClientOption, 0) + sessionID := trans.GetSessionId() + if sessionID != "" { + clientOptions = append(clientOptions, WithSession()) + } + return NewClient(trans, clientOptions...), nil } diff --git a/vendor/github.com/mark3labs/mcp-go/client/sampling.go b/vendor/github.com/mark3labs/mcp-go/client/sampling.go new file mode 100644 index 0000000000000000000000000000000000000000..245e2c1f7f305ddb75658a345eddcaba5e2898e3 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/client/sampling.go @@ -0,0 +1,20 @@ +package client + +import ( + "context" + + "github.com/mark3labs/mcp-go/mcp" +) + +// SamplingHandler defines the interface for handling sampling requests from servers. +// Clients can implement this interface to provide LLM sampling capabilities to servers. +type SamplingHandler interface { + // CreateMessage handles a sampling request from the server and returns the generated message. + // The implementation should: + // 1. Validate the request parameters + // 2. Optionally prompt the user for approval (human-in-the-loop) + // 3. Select an appropriate model based on preferences + // 4. Generate the response using the selected model + // 5. Return the result with model information and stop reason + CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) +} diff --git a/vendor/github.com/mark3labs/mcp-go/client/stdio.go b/vendor/github.com/mark3labs/mcp-go/client/stdio.go index 100c08a7cc0529ca30ca1386f74f6ea4f9be4654..199ec14c381b57c12d691495179cf0c45029d29e 100644 --- a/vendor/github.com/mark3labs/mcp-go/client/stdio.go +++ b/vendor/github.com/mark3labs/mcp-go/client/stdio.go @@ -19,10 +19,26 @@ func NewStdioMCPClient( env []string, args ...string, ) (*Client, error) { + return NewStdioMCPClientWithOptions(command, env, args) +} + +// NewStdioMCPClientWithOptions creates a new stdio-based MCP client that communicates with a subprocess. +// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication. +// Optional configuration functions can be provided to customize the transport before it starts, +// such as setting a custom command function. +// +// NOTICE: NewStdioMCPClientWithOptions automatically starts the underlying transport. +// Don't call the Start method manually. +// This is for backward compatibility. +func NewStdioMCPClientWithOptions( + command string, + env []string, + args []string, + opts ...transport.StdioOption, +) (*Client, error) { + stdioTransport := transport.NewStdioWithOptions(command, env, args, opts...) - stdioTransport := transport.NewStdio(command, env, args...) - err := stdioTransport.Start(context.Background()) - if err != nil { + if err := stdioTransport.Start(context.Background()); err != nil { return nil, fmt.Errorf("failed to start stdio transport: %w", err) } diff --git a/vendor/github.com/mark3labs/mcp-go/client/transport/inprocess.go b/vendor/github.com/mark3labs/mcp-go/client/transport/inprocess.go index 90fc2fae1f05ebf635b46d0fc415e0260348d3a0..0e2393f0731bcf361d5544da99304d1f07e08706 100644 --- a/vendor/github.com/mark3labs/mcp-go/client/transport/inprocess.go +++ b/vendor/github.com/mark3labs/mcp-go/client/transport/inprocess.go @@ -68,3 +68,7 @@ func (c *InProcessTransport) SetNotificationHandler(handler func(notification mc func (*InProcessTransport) Close() error { return nil } + +func (c *InProcessTransport) GetSessionId() string { + return "" +} diff --git a/vendor/github.com/mark3labs/mcp-go/client/transport/interface.go b/vendor/github.com/mark3labs/mcp-go/client/transport/interface.go index c83c7c65a3a8b0c7a301564144516242919fe2a5..5f8ed6180b6404a1a0f4085c5557aeb1789e3485 100644 --- a/vendor/github.com/mark3labs/mcp-go/client/transport/interface.go +++ b/vendor/github.com/mark3labs/mcp-go/client/transport/interface.go @@ -29,6 +29,22 @@ type Interface interface { // Close the connection. Close() error + + // GetSessionId returns the session ID of the transport. + GetSessionId() string +} + +// RequestHandler defines a function that handles incoming requests from the server. +type RequestHandler func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) + +// BidirectionalInterface extends Interface to support incoming requests from the server. +// This is used for features like sampling where the server can send requests to the client. +type BidirectionalInterface interface { + Interface + + // SetRequestHandler sets the handler for incoming requests from the server. + // The handler should process the request and return a response. + SetRequestHandler(handler RequestHandler) } type JSONRPCRequest struct { @@ -41,10 +57,10 @@ type JSONRPCRequest struct { type JSONRPCResponse struct { JSONRPC string `json:"jsonrpc"` ID mcp.RequestId `json:"id"` - Result json.RawMessage `json:"result"` + Result json.RawMessage `json:"result,omitempty"` Error *struct { Code int `json:"code"` Message string `json:"message"` Data json.RawMessage `json:"data"` - } `json:"error"` + } `json:"error,omitempty"` } diff --git a/vendor/github.com/mark3labs/mcp-go/client/transport/sse.go b/vendor/github.com/mark3labs/mcp-go/client/transport/sse.go index b22ff62d40124b765b633d2b1700c7407d92d041..ffe3247f0ecd87a4e9c68df72b726a8cc44a7736 100644 --- a/vendor/github.com/mark3labs/mcp-go/client/transport/sse.go +++ b/vendor/github.com/mark3labs/mcp-go/client/transport/sse.go @@ -428,6 +428,12 @@ func (c *SSE) Close() error { return nil } +// GetSessionId returns the session ID of the transport. +// Since SSE does not maintain a session ID, it returns an empty string. +func (c *SSE) GetSessionId() string { + return "" +} + // SendNotification sends a JSON-RPC notification to the server without expecting a response. func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error { if c.endpoint == nil { diff --git a/vendor/github.com/mark3labs/mcp-go/client/transport/stdio.go b/vendor/github.com/mark3labs/mcp-go/client/transport/stdio.go index c300c405f7e3880f0b94e1e09e3ee5ca7def732a..c36dc2d37737d71d9028ba11485932c23bb09f9f 100644 --- a/vendor/github.com/mark3labs/mcp-go/client/transport/stdio.go +++ b/vendor/github.com/mark3labs/mcp-go/client/transport/stdio.go @@ -23,6 +23,7 @@ type Stdio struct { env []string cmd *exec.Cmd + cmdFunc CommandFunc stdin io.WriteCloser stdout *bufio.Reader stderr io.ReadCloser @@ -31,6 +32,28 @@ type Stdio struct { done chan struct{} onNotification func(mcp.JSONRPCNotification) notifyMu sync.RWMutex + onRequest RequestHandler + requestMu sync.RWMutex + ctx context.Context + ctxMu sync.RWMutex +} + +// StdioOption defines a function that configures a Stdio transport instance. +// Options can be used to customize the behavior of the transport before it starts, +// such as setting a custom command function. +type StdioOption func(*Stdio) + +// CommandFunc is a factory function that returns a custom exec.Cmd used to launch the MCP subprocess. +// It can be used to apply sandboxing, custom environment control, working directories, etc. +type CommandFunc func(ctx context.Context, command string, env []string, args []string) (*exec.Cmd, error) + +// WithCommandFunc sets a custom command factory function for the stdio transport. +// The CommandFunc is responsible for constructing the exec.Cmd used to launch the subprocess, +// allowing control over attributes like environment, working directory, and system-level sandboxing. +func WithCommandFunc(f CommandFunc) StdioOption { + return func(s *Stdio) { + s.cmdFunc = f + } } // NewIO returns a new stdio-based transport using existing input, output, and @@ -44,6 +67,7 @@ func NewIO(input io.Reader, output io.WriteCloser, logging io.ReadCloser) *Stdio responses: make(map[string]chan *JSONRPCResponse), done: make(chan struct{}), + ctx: context.Background(), } } @@ -55,20 +79,43 @@ func NewStdio( env []string, args ...string, ) *Stdio { + return NewStdioWithOptions(command, env, args) +} - client := &Stdio{ +// NewStdioWithOptions creates a new stdio transport to communicate with a subprocess. +// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication. +// Returns an error if the subprocess cannot be started or the pipes cannot be created. +// Optional configuration functions can be provided to customize the transport before it starts, +// such as setting a custom command factory. +func NewStdioWithOptions( + command string, + env []string, + args []string, + opts ...StdioOption, +) *Stdio { + s := &Stdio{ command: command, args: args, env: env, responses: make(map[string]chan *JSONRPCResponse), done: make(chan struct{}), + ctx: context.Background(), + } + + for _, opt := range opts { + opt(s) } - return client + return s } func (c *Stdio) Start(ctx context.Context) error { + // Store the context for use in request handling + c.ctxMu.Lock() + c.ctx = ctx + c.ctxMu.Unlock() + if err := c.spawnCommand(ctx); err != nil { return err } @@ -83,18 +130,25 @@ func (c *Stdio) Start(ctx context.Context) error { return nil } -// spawnCommand spawns a new process running c.command. +// spawnCommand spawns a new process running the configured command, args, and env. +// If an (optional) cmdFunc custom command factory function was configured, it will be used to construct the subprocess; +// otherwise, the default behavior uses exec.CommandContext with the merged environment. +// Initializes stdin, stdout, and stderr pipes for JSON-RPC communication. func (c *Stdio) spawnCommand(ctx context.Context) error { if c.command == "" { return nil } - cmd := exec.CommandContext(ctx, c.command, c.args...) - - mergedEnv := os.Environ() - mergedEnv = append(mergedEnv, c.env...) + var cmd *exec.Cmd + var err error - cmd.Env = mergedEnv + // Standard behavior if no command func present. + if c.cmdFunc == nil { + cmd = exec.CommandContext(ctx, c.command, c.args...) + cmd.Env = append(os.Environ(), c.env...) + } else if cmd, err = c.cmdFunc(ctx, c.command, c.env, c.args); err != nil { + return err + } stdin, err := cmd.StdinPipe() if err != nil { @@ -148,6 +202,12 @@ func (c *Stdio) Close() error { return nil } +// GetSessionId returns the session ID of the transport. +// Since stdio does not maintain a session ID, it returns an empty string. +func (c *Stdio) GetSessionId() string { + return "" +} + // SetNotificationHandler sets the handler function to be called when a notification is received. // Only one handler can be set at a time; setting a new one replaces the previous handler. func (c *Stdio) SetNotificationHandler( @@ -158,6 +218,14 @@ func (c *Stdio) SetNotificationHandler( c.onNotification = handler } +// SetRequestHandler sets the handler function to be called when a request is received from the server. +// This enables bidirectional communication for features like sampling. +func (c *Stdio) SetRequestHandler(handler RequestHandler) { + c.requestMu.Lock() + defer c.requestMu.Unlock() + c.onRequest = handler +} + // readResponses continuously reads and processes responses from the server's stdout. // It handles both responses to requests and notifications, routing them appropriately. // Runs until the done channel is closed or an error occurs reading from stdout. @@ -175,13 +243,18 @@ func (c *Stdio) readResponses() { return } - var baseMessage JSONRPCResponse + // First try to parse as a generic message to check for ID field + var baseMessage struct { + JSONRPC string `json:"jsonrpc"` + ID *mcp.RequestId `json:"id,omitempty"` + Method string `json:"method,omitempty"` + } if err := json.Unmarshal([]byte(line), &baseMessage); err != nil { continue } - // Handle notification - if baseMessage.ID.IsNil() { + // If it has a method but no ID, it's a notification + if baseMessage.Method != "" && baseMessage.ID == nil { var notification mcp.JSONRPCNotification if err := json.Unmarshal([]byte(line), ¬ification); err != nil { continue @@ -194,15 +267,30 @@ func (c *Stdio) readResponses() { continue } + // If it has a method and an ID, it's an incoming request + if baseMessage.Method != "" && baseMessage.ID != nil { + var request JSONRPCRequest + if err := json.Unmarshal([]byte(line), &request); err == nil { + c.handleIncomingRequest(request) + continue + } + } + + // Otherwise, it's a response to our request + var response JSONRPCResponse + if err := json.Unmarshal([]byte(line), &response); err != nil { + continue + } + // Create string key for map lookup - idKey := baseMessage.ID.String() + idKey := response.ID.String() c.mu.RLock() ch, exists := c.responses[idKey] c.mu.RUnlock() if exists { - ch <- &baseMessage + ch <- &response c.mu.Lock() delete(c.responses, idKey) c.mu.Unlock() @@ -281,6 +369,96 @@ func (c *Stdio) SendNotification( return nil } +// handleIncomingRequest processes incoming requests from the server. +// It calls the registered request handler and sends the response back to the server. +func (c *Stdio) handleIncomingRequest(request JSONRPCRequest) { + c.requestMu.RLock() + handler := c.onRequest + c.requestMu.RUnlock() + + if handler == nil { + // Send error response if no handler is configured + errorResponse := JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: request.ID, + Error: &struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` + }{ + Code: mcp.METHOD_NOT_FOUND, + Message: "No request handler configured", + }, + } + c.sendResponse(errorResponse) + return + } + + // Handle the request in a goroutine to avoid blocking + go func() { + c.ctxMu.RLock() + ctx := c.ctx + c.ctxMu.RUnlock() + + // Check if context is already cancelled before processing + select { + case <-ctx.Done(): + errorResponse := JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: request.ID, + Error: &struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` + }{ + Code: mcp.INTERNAL_ERROR, + Message: ctx.Err().Error(), + }, + } + c.sendResponse(errorResponse) + return + default: + } + + response, err := handler(ctx, request) + + if err != nil { + errorResponse := JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: request.ID, + Error: &struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` + }{ + Code: mcp.INTERNAL_ERROR, + Message: err.Error(), + }, + } + c.sendResponse(errorResponse) + return + } + + if response != nil { + c.sendResponse(*response) + } + }() +} + +// sendResponse sends a response back to the server. +func (c *Stdio) sendResponse(response JSONRPCResponse) { + responseBytes, err := json.Marshal(response) + if err != nil { + fmt.Printf("Error marshaling response: %v\n", err) + return + } + responseBytes = append(responseBytes, '\n') + + if _, err := c.stdin.Write(responseBytes); err != nil { + fmt.Printf("Error writing response: %v\n", err) + } +} + // Stderr returns a reader for the stderr output of the subprocess. // This can be used to capture error messages or logs from the subprocess. func (c *Stdio) Stderr() io.Reader { diff --git a/vendor/github.com/mark3labs/mcp-go/client/transport/streamable_http.go b/vendor/github.com/mark3labs/mcp-go/client/transport/streamable_http.go index 50bde9c288d39e32e00bc5691cbaf75addd740f5..e358751b3344c3783be539cc5daa3e09ffa81020 100644 --- a/vendor/github.com/mark3labs/mcp-go/client/transport/streamable_http.go +++ b/vendor/github.com/mark3labs/mcp-go/client/transport/streamable_http.go @@ -17,10 +17,24 @@ import ( "time" "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/util" ) type StreamableHTTPCOption func(*StreamableHTTP) +// WithContinuousListening enables receiving server-to-client notifications when no request is in flight. +// In particular, if you want to receive global notifications from the server (like ToolListChangedNotification), +// you should enable this option. +// +// It will establish a standalone long-live GET HTTP connection to the server. +// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server +// NOTICE: Even enabled, the server may not support this feature. +func WithContinuousListening() StreamableHTTPCOption { + return func(sc *StreamableHTTP) { + sc.getListeningEnabled = true + } +} + // WithHTTPClient sets a custom HTTP client on the StreamableHTTP transport. func WithHTTPBasicClient(client *http.Client) StreamableHTTPCOption { return func(sc *StreamableHTTP) { @@ -54,6 +68,19 @@ func WithHTTPOAuth(config OAuthConfig) StreamableHTTPCOption { } } +func WithLogger(logger util.Logger) StreamableHTTPCOption { + return func(sc *StreamableHTTP) { + sc.logger = logger + } +} + +// WithSession creates a client with a pre-configured session +func WithSession(sessionID string) StreamableHTTPCOption { + return func(sc *StreamableHTTP) { + sc.sessionID.Store(sessionID) + } +} + // StreamableHTTP implements Streamable HTTP transport. // // It transmits JSON-RPC messages over individual HTTP requests. One message per request. @@ -64,19 +91,22 @@ func WithHTTPOAuth(config OAuthConfig) StreamableHTTPCOption { // // The current implementation does not support the following features: // - batching -// - continuously listening for server notifications when no request is in flight -// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server) // - resuming stream // (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery) // - server -> client request type StreamableHTTP struct { - serverURL *url.URL - httpClient *http.Client - headers map[string]string - headerFunc HTTPHeaderFunc + serverURL *url.URL + httpClient *http.Client + headers map[string]string + headerFunc HTTPHeaderFunc + logger util.Logger + getListeningEnabled bool sessionID atomic.Value // string + initialized chan struct{} + initializedOnce sync.Once + notificationHandler func(mcp.JSONRPCNotification) notifyMu sync.RWMutex @@ -95,15 +125,19 @@ func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*Str } smc := &StreamableHTTP{ - serverURL: parsedURL, - httpClient: &http.Client{}, - headers: make(map[string]string), - closed: make(chan struct{}), + serverURL: parsedURL, + httpClient: &http.Client{}, + headers: make(map[string]string), + closed: make(chan struct{}), + logger: util.DefaultLogger(), + initialized: make(chan struct{}), } smc.sessionID.Store("") // set initial value to simplify later usage for _, opt := range options { - opt(smc) + if opt != nil { + opt(smc) + } } // If OAuth is configured, set the base URL for metadata discovery @@ -118,7 +152,20 @@ func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*Str // Start initiates the HTTP connection to the server. func (c *StreamableHTTP) Start(ctx context.Context) error { - // For Streamable HTTP, we don't need to establish a persistent connection + // For Streamable HTTP, we don't need to establish a persistent connection by default + if c.getListeningEnabled { + go func() { + select { + case <-c.initialized: + ctx, cancel := c.contextAwareOfClientClose(ctx) + defer cancel() + c.listenForever(ctx) + case <-c.closed: + return + } + }() + } + return nil } @@ -142,13 +189,13 @@ func (c *StreamableHTTP) Close() error { defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.serverURL.String(), nil) if err != nil { - fmt.Printf("failed to create close request\n: %v", err) + c.logger.Errorf("failed to create close request: %v", err) return } req.Header.Set(headerKeySessionID, sessionId) res, err := c.httpClient.Do(req) if err != nil { - fmt.Printf("failed to send close request\n: %v", err) + c.logger.Errorf("failed to send close request: %v", err) return } res.Body.Close() @@ -185,77 +232,29 @@ func (c *StreamableHTTP) SendRequest( request JSONRPCRequest, ) (*JSONRPCResponse, error) { - // Create a combined context that could be canceled when the client is closed - newCtx, cancel := context.WithCancel(ctx) - defer cancel() - go func() { - select { - case <-c.closed: - cancel() - case <-newCtx.Done(): - // The original context was canceled, no need to do anything - } - }() - ctx = newCtx - // Marshal request requestBody, err := json.Marshal(request) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) } - // Create HTTP request - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody)) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - // Set headers - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json, text/event-stream") - sessionID := c.sessionID.Load() - if sessionID != "" { - req.Header.Set(headerKeySessionID, sessionID.(string)) - } - for k, v := range c.headers { - req.Header.Set(k, v) - } - - // Add OAuth authorization if configured - if c.oauthHandler != nil { - authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx) - if err != nil { - // If we get an authorization error, return a specific error that can be handled by the client - if err.Error() == "no valid token available, authorization required" { - return nil, &OAuthAuthorizationRequiredError{ - Handler: c.oauthHandler, - } - } - return nil, fmt.Errorf("failed to get authorization header: %w", err) - } - req.Header.Set("Authorization", authHeader) - } - - if c.headerFunc != nil { - for k, v := range c.headerFunc(ctx) { - req.Header.Set(k, v) - } - } + ctx, cancel := c.contextAwareOfClientClose(ctx) + defer cancel() - // Send request - resp, err := c.httpClient.Do(req) + resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream") if err != nil { - return nil, fmt.Errorf("failed to send request: %w", err) + if errors.Is(err, ErrSessionTerminated) && request.Method == string(mcp.MethodInitialize) { + // If the request is initialize, should not return a SessionTerminated error + // It should be a genuine endpoint-routing issue. + // ( Fall through to return StatusCode checking. ) + } else { + return nil, fmt.Errorf("failed to send request: %w", err) + } } defer resp.Body.Close() // Check if we got an error response if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { - // handle session closed - if resp.StatusCode == http.StatusNotFound { - c.sessionID.CompareAndSwap(sessionID, "") - return nil, fmt.Errorf("session terminated (404). need to re-initialize") - } // Handle OAuth unauthorized error if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil { @@ -279,6 +278,10 @@ func (c *StreamableHTTP) SendRequest( if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" { c.sessionID.Store(sessionID) } + + c.initializedOnce.Do(func() { + close(c.initialized) + }) } // Handle different response types @@ -300,16 +303,77 @@ func (c *StreamableHTTP) SendRequest( case "text/event-stream": // Server is using SSE for streaming responses - return c.handleSSEResponse(ctx, resp.Body) + return c.handleSSEResponse(ctx, resp.Body, false) default: return nil, fmt.Errorf("unexpected content type: %s", resp.Header.Get("Content-Type")) } } +func (c *StreamableHTTP) sendHTTP( + ctx context.Context, + method string, + body io.Reader, + acceptType string, +) (resp *http.Response, err error) { + + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, method, c.serverURL.String(), body) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", acceptType) + sessionID := c.sessionID.Load().(string) + if sessionID != "" { + req.Header.Set(headerKeySessionID, sessionID) + } + for k, v := range c.headers { + req.Header.Set(k, v) + } + + // Add OAuth authorization if configured + if c.oauthHandler != nil { + authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx) + if err != nil { + // If we get an authorization error, return a specific error that can be handled by the client + if err.Error() == "no valid token available, authorization required" { + return nil, &OAuthAuthorizationRequiredError{ + Handler: c.oauthHandler, + } + } + return nil, fmt.Errorf("failed to get authorization header: %w", err) + } + req.Header.Set("Authorization", authHeader) + } + + if c.headerFunc != nil { + for k, v := range c.headerFunc(ctx) { + req.Header.Set(k, v) + } + } + + // Send request + resp, err = c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + + // universal handling for session terminated + if resp.StatusCode == http.StatusNotFound { + c.sessionID.CompareAndSwap(sessionID, "") + return nil, ErrSessionTerminated + } + + return resp, nil +} + // handleSSEResponse processes an SSE stream for a specific request. // It returns the final result for the request once received, or an error. -func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser) (*JSONRPCResponse, error) { +// If ignoreResponse is true, it won't return when a response messge is received. This is for continuous listening. +func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser, ignoreResponse bool) (*JSONRPCResponse, error) { // Create a channel for this specific request responseChan := make(chan *JSONRPCResponse, 1) @@ -328,7 +392,7 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl var message JSONRPCResponse if err := json.Unmarshal([]byte(data), &message); err != nil { - fmt.Printf("failed to unmarshal message: %v\n", err) + c.logger.Errorf("failed to unmarshal message: %v", err) return } @@ -336,7 +400,7 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl if message.ID.IsNil() { var notification mcp.JSONRPCNotification if err := json.Unmarshal([]byte(data), ¬ification); err != nil { - fmt.Printf("failed to unmarshal notification: %v\n", err) + c.logger.Errorf("failed to unmarshal notification: %v", err) return } c.notifyMu.RLock() @@ -347,7 +411,9 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl return } - responseChan <- &message + if !ignoreResponse { + responseChan <- &message + } }) }() @@ -393,7 +459,7 @@ func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, hand case <-ctx.Done(): return default: - fmt.Printf("SSE stream error: %v\n", err) + c.logger.Errorf("SSE stream error: %v", err) return } } @@ -432,44 +498,10 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp. } // Create HTTP request - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody)) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - - // Set headers - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json, text/event-stream") - if sessionID := c.sessionID.Load(); sessionID != "" { - req.Header.Set(headerKeySessionID, sessionID.(string)) - } - for k, v := range c.headers { - req.Header.Set(k, v) - } - - // Add OAuth authorization if configured - if c.oauthHandler != nil { - authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx) - if err != nil { - // If we get an authorization error, return a specific error that can be handled by the client - if errors.Is(err, ErrOAuthAuthorizationRequired) { - return &OAuthAuthorizationRequiredError{ - Handler: c.oauthHandler, - } - } - return fmt.Errorf("failed to get authorization header: %w", err) - } - req.Header.Set("Authorization", authHeader) - } - - if c.headerFunc != nil { - for k, v := range c.headerFunc(ctx) { - req.Header.Set(k, v) - } - } + ctx, cancel := c.contextAwareOfClientClose(ctx) + defer cancel() - // Send request - resp, err := c.httpClient.Do(req) + resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream") if err != nil { return fmt.Errorf("failed to send request: %w", err) } @@ -513,3 +545,84 @@ func (c *StreamableHTTP) GetOAuthHandler() *OAuthHandler { func (c *StreamableHTTP) IsOAuthEnabled() bool { return c.oauthHandler != nil } + +func (c *StreamableHTTP) listenForever(ctx context.Context) { + c.logger.Infof("listening to server forever") + for { + err := c.createGETConnectionToServer(ctx) + if errors.Is(err, ErrGetMethodNotAllowed) { + // server does not support listening + c.logger.Errorf("server does not support listening") + return + } + + select { + case <-ctx.Done(): + return + default: + } + + if err != nil { + c.logger.Errorf("failed to listen to server. retry in 1 second: %v", err) + } + time.Sleep(retryInterval) + } +} + +var ( + ErrSessionTerminated = fmt.Errorf("session terminated (404). need to re-initialize") + ErrGetMethodNotAllowed = fmt.Errorf("GET method not allowed") + + retryInterval = 1 * time.Second // a variable is convenient for testing +) + +func (c *StreamableHTTP) createGETConnectionToServer(ctx context.Context) error { + + resp, err := c.sendHTTP(ctx, http.MethodGet, nil, "text/event-stream") + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + // Check if we got an error response + if resp.StatusCode == http.StatusMethodNotAllowed { + return ErrGetMethodNotAllowed + } + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body) + } + + // handle SSE response + contentType := resp.Header.Get("Content-Type") + if contentType != "text/event-stream" { + return fmt.Errorf("unexpected content type: %s", contentType) + } + + // When ignoreResponse is true, the function will never return expect context is done. + // NOTICE: Due to the ambiguity of the specification, other SDKs may use the GET connection to transfer the response + // messages. To be more compatible, we should handle this response, however, as the transport layer is message-based, + // currently, there is no convenient way to handle this response. + // So we ignore the response here. It's not a bug, but may be not compatible with other SDKs. + _, err = c.handleSSEResponse(ctx, resp.Body, true) + if err != nil { + return fmt.Errorf("failed to handle SSE response: %w", err) + } + + return nil +} + +func (c *StreamableHTTP) contextAwareOfClientClose(ctx context.Context) (context.Context, context.CancelFunc) { + newCtx, cancel := context.WithCancel(ctx) + go func() { + select { + case <-c.closed: + cancel() + case <-newCtx.Done(): + // The original context was canceled + cancel() + } + }() + return newCtx, cancel +} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/tools.go b/vendor/github.com/mark3labs/mcp-go/mcp/tools.go index 5f3524b0212d12a89b14fd1e2f4b2e6ba4dbd806..3e3931b09c9aedfce1f6e58a80be180e107b3116 100644 --- a/vendor/github.com/mark3labs/mcp-go/mcp/tools.go +++ b/vendor/github.com/mark3labs/mcp-go/mcp/tools.go @@ -945,7 +945,20 @@ func PropertyNames(schema map[string]any) PropertyOption { } } -// Items defines the schema for array items +// Items defines the schema for array items. +// Accepts any schema definition for maximum flexibility. +// +// Example: +// +// Items(map[string]any{ +// "type": "object", +// "properties": map[string]any{ +// "name": map[string]any{"type": "string"}, +// "age": map[string]any{"type": "number"}, +// }, +// }) +// +// For simple types, use ItemsString(), ItemsNumber(), ItemsBoolean() instead. func Items(schema any) PropertyOption { return func(schemaMap map[string]any) { schemaMap["items"] = schema @@ -972,3 +985,94 @@ func UniqueItems(unique bool) PropertyOption { schema["uniqueItems"] = unique } } + +// WithStringItems configures an array's items to be of type string. +// +// Supported options: Description(), DefaultString(), Enum(), MaxLength(), MinLength(), Pattern() +// Note: Options like Required() are not valid for item schemas and will be ignored. +// +// Examples: +// +// mcp.WithArray("tags", mcp.WithStringItems()) +// mcp.WithArray("colors", mcp.WithStringItems(mcp.Enum("red", "green", "blue"))) +// mcp.WithArray("names", mcp.WithStringItems(mcp.MinLength(1), mcp.MaxLength(50))) +// +// Limitations: Only supports simple string arrays. Use Items() for complex objects. +func WithStringItems(opts ...PropertyOption) PropertyOption { + return func(schema map[string]any) { + itemSchema := map[string]any{ + "type": "string", + } + + for _, opt := range opts { + opt(itemSchema) + } + + schema["items"] = itemSchema + } +} + +// WithStringEnumItems configures an array's items to be of type string with a specified enum. +// Example: +// +// mcp.WithArray("priority", mcp.WithStringEnumItems([]string{"low", "medium", "high"})) +// +// Limitations: Only supports string enums. Use WithStringItems(Enum(...)) for more flexibility. +func WithStringEnumItems(values []string) PropertyOption { + return func(schema map[string]any) { + schema["items"] = map[string]any{ + "type": "string", + "enum": values, + } + } +} + +// WithNumberItems configures an array's items to be of type number. +// +// Supported options: Description(), DefaultNumber(), Min(), Max(), MultipleOf() +// Note: Options like Required() are not valid for item schemas and will be ignored. +// +// Examples: +// +// mcp.WithArray("scores", mcp.WithNumberItems(mcp.Min(0), mcp.Max(100))) +// mcp.WithArray("prices", mcp.WithNumberItems(mcp.Min(0))) +// +// Limitations: Only supports simple number arrays. Use Items() for complex objects. +func WithNumberItems(opts ...PropertyOption) PropertyOption { + return func(schema map[string]any) { + itemSchema := map[string]any{ + "type": "number", + } + + for _, opt := range opts { + opt(itemSchema) + } + + schema["items"] = itemSchema + } +} + +// WithBooleanItems configures an array's items to be of type boolean. +// +// Supported options: Description(), DefaultBool() +// Note: Options like Required() are not valid for item schemas and will be ignored. +// +// Examples: +// +// mcp.WithArray("flags", mcp.WithBooleanItems()) +// mcp.WithArray("permissions", mcp.WithBooleanItems(mcp.Description("User permissions"))) +// +// Limitations: Only supports simple boolean arrays. Use Items() for complex objects. +func WithBooleanItems(opts ...PropertyOption) PropertyOption { + return func(schema map[string]any) { + itemSchema := map[string]any{ + "type": "boolean", + } + + for _, opt := range opts { + opt(itemSchema) + } + + schema["items"] = itemSchema + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/types.go b/vendor/github.com/mark3labs/mcp-go/mcp/types.go index 0091d2e42d380253ee03d0d1b5cde8597775be8f..241b55ce9b549941d764a2ca5b4ba11e551d301d 100644 --- a/vendor/github.com/mark3labs/mcp-go/mcp/types.go +++ b/vendor/github.com/mark3labs/mcp-go/mcp/types.go @@ -763,6 +763,11 @@ const ( /* Sampling */ +const ( + // MethodSamplingCreateMessage allows servers to request LLM completions from clients + MethodSamplingCreateMessage MCPMethod = "sampling/createMessage" +) + // CreateMessageRequest is a request from the server to sample an LLM via the // client. The client has full discretion over which model to select. The client // should also inform the user before beginning sampling, to allow them to inspect @@ -865,6 +870,22 @@ type AudioContent struct { func (AudioContent) isContent() {} +// ResourceLink represents a link to a resource that the client can access. +type ResourceLink struct { + Annotated + Type string `json:"type"` // Must be "resource_link" + // The URI of the resource. + URI string `json:"uri"` + // The name of the resource. + Name string `json:"name"` + // The description of the resource. + Description string `json:"description"` + // The MIME type of the resource. + MIMEType string `json:"mimeType"` +} + +func (ResourceLink) isContent() {} + // EmbeddedResource represents the contents of a resource, embedded into a prompt or tool call result. // // It is up to the client how best to render embedded resources for the diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/utils.go b/vendor/github.com/mark3labs/mcp-go/mcp/utils.go index 55bef7a997e2a406f111b2fb399812ca1941ab96..3e652efd7e842d24bc6ab13fa119d21f272a8ba7 100644 --- a/vendor/github.com/mark3labs/mcp-go/mcp/utils.go +++ b/vendor/github.com/mark3labs/mcp-go/mcp/utils.go @@ -222,6 +222,17 @@ func NewAudioContent(data, mimeType string) AudioContent { } } +// Helper function to create a new ResourceLink +func NewResourceLink(uri, name, description, mimeType string) ResourceLink { + return ResourceLink{ + Type: "resource_link", + URI: uri, + Name: name, + Description: description, + MIMEType: mimeType, + } +} + // Helper function to create a new EmbeddedResource func NewEmbeddedResource(resource ResourceContents) EmbeddedResource { return EmbeddedResource{ @@ -476,6 +487,16 @@ func ParseContent(contentMap map[string]any) (Content, error) { } return NewAudioContent(data, mimeType), nil + case "resource_link": + uri := ExtractString(contentMap, "uri") + name := ExtractString(contentMap, "name") + description := ExtractString(contentMap, "description") + mimeType := ExtractString(contentMap, "mimeType") + if uri == "" || name == "" { + return nil, fmt.Errorf("resource_link uri or name is missing") + } + return NewResourceLink(uri, name, description, mimeType), nil + case "resource": resourceMap := ExtractMap(contentMap, "resource") if resourceMap == nil { diff --git a/vendor/github.com/mark3labs/mcp-go/server/sampling.go b/vendor/github.com/mark3labs/mcp-go/server/sampling.go new file mode 100644 index 0000000000000000000000000000000000000000..b633b24d07ebfeeedd9b49468d7aadb411c87b4c --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/sampling.go @@ -0,0 +1,37 @@ +package server + +import ( + "context" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" +) + +// EnableSampling enables sampling capabilities for the server. +// This allows the server to send sampling requests to clients that support it. +func (s *MCPServer) EnableSampling() { + s.capabilitiesMu.Lock() + defer s.capabilitiesMu.Unlock() +} + +// RequestSampling sends a sampling request to the client. +// The client must have declared sampling capability during initialization. +func (s *MCPServer) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + session := ClientSessionFromContext(ctx) + if session == nil { + return nil, fmt.Errorf("no active session") + } + + // Check if the session supports sampling requests + if samplingSession, ok := session.(SessionWithSampling); ok { + return samplingSession.RequestSampling(ctx, request) + } + + return nil, fmt.Errorf("session does not support sampling") +} + +// SessionWithSampling extends ClientSession to support sampling requests. +type SessionWithSampling interface { + ClientSession + RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/stdio.go b/vendor/github.com/mark3labs/mcp-go/server/stdio.go index 746a7d96f6c3635ec05c6bc2d7b92820824a8e20..33ac9bb8854527db09ea31a0be4d109521fa0c37 100644 --- a/vendor/github.com/mark3labs/mcp-go/server/stdio.go +++ b/vendor/github.com/mark3labs/mcp-go/server/stdio.go @@ -9,6 +9,7 @@ import ( "log" "os" "os/signal" + "sync" "sync/atomic" "syscall" @@ -51,10 +52,21 @@ func WithStdioContextFunc(fn StdioContextFunc) StdioOption { // stdioSession is a static client session, since stdio has only one client. type stdioSession struct { - notifications chan mcp.JSONRPCNotification - initialized atomic.Bool - loggingLevel atomic.Value - clientInfo atomic.Value // stores session-specific client info + notifications chan mcp.JSONRPCNotification + initialized atomic.Bool + loggingLevel atomic.Value + clientInfo atomic.Value // stores session-specific client info + writer io.Writer // for sending requests to client + requestID atomic.Int64 // for generating unique request IDs + mu sync.RWMutex // protects writer + pendingRequests map[int64]chan *samplingResponse // for tracking pending sampling requests + pendingMu sync.RWMutex // protects pendingRequests +} + +// samplingResponse represents a response to a sampling request +type samplingResponse struct { + result *mcp.CreateMessageResult + err error } func (s *stdioSession) SessionID() string { @@ -100,14 +112,86 @@ func (s *stdioSession) GetLogLevel() mcp.LoggingLevel { return level.(mcp.LoggingLevel) } +// RequestSampling sends a sampling request to the client and waits for the response. +func (s *stdioSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + s.mu.RLock() + writer := s.writer + s.mu.RUnlock() + + if writer == nil { + return nil, fmt.Errorf("no writer available for sending requests") + } + + // Generate a unique request ID + id := s.requestID.Add(1) + + // Create a response channel for this request + responseChan := make(chan *samplingResponse, 1) + s.pendingMu.Lock() + s.pendingRequests[id] = responseChan + s.pendingMu.Unlock() + + // Cleanup function to remove the pending request + cleanup := func() { + s.pendingMu.Lock() + delete(s.pendingRequests, id) + s.pendingMu.Unlock() + } + defer cleanup() + + // Create the JSON-RPC request + jsonRPCRequest := struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params mcp.CreateMessageParams `json:"params"` + }{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: id, + Method: string(mcp.MethodSamplingCreateMessage), + Params: request.CreateMessageParams, + } + + // Marshal and send the request + requestBytes, err := json.Marshal(jsonRPCRequest) + if err != nil { + return nil, fmt.Errorf("failed to marshal sampling request: %w", err) + } + requestBytes = append(requestBytes, '\n') + + if _, err := writer.Write(requestBytes); err != nil { + return nil, fmt.Errorf("failed to write sampling request: %w", err) + } + + // Wait for the response or context cancellation + select { + case <-ctx.Done(): + return nil, ctx.Err() + case response := <-responseChan: + if response.err != nil { + return nil, response.err + } + return response.result, nil + } +} + +// SetWriter sets the writer for sending requests to the client. +func (s *stdioSession) SetWriter(writer io.Writer) { + s.mu.Lock() + defer s.mu.Unlock() + s.writer = writer +} + var ( _ ClientSession = (*stdioSession)(nil) _ SessionWithLogging = (*stdioSession)(nil) _ SessionWithClientInfo = (*stdioSession)(nil) + _ SessionWithSampling = (*stdioSession)(nil) ) var stdioSessionInstance = stdioSession{ - notifications: make(chan mcp.JSONRPCNotification, 100), + notifications: make(chan mcp.JSONRPCNotification, 100), + pendingRequests: make(map[int64]chan *samplingResponse), } // NewStdioServer creates a new stdio server wrapper around an MCPServer. @@ -224,6 +308,9 @@ func (s *StdioServer) Listen( defer s.server.UnregisterSession(ctx, stdioSessionInstance.SessionID()) ctx = s.server.WithContext(ctx, &stdioSessionInstance) + // Set the writer for sending requests to the client + stdioSessionInstance.SetWriter(stdout) + // Add in any custom context. if s.contextFunc != nil { ctx = s.contextFunc(ctx) @@ -256,7 +343,29 @@ func (s *StdioServer) processMessage( return s.writeResponse(response, writer) } - // Handle the message using the wrapped server + // Check if this is a response to a sampling request + if s.handleSamplingResponse(rawMessage) { + return nil + } + + // Check if this is a tool call that might need sampling (and thus should be processed concurrently) + var baseMessage struct { + Method string `json:"method"` + } + if json.Unmarshal(rawMessage, &baseMessage) == nil && baseMessage.Method == "tools/call" { + // Process tool calls concurrently to avoid blocking on sampling requests + go func() { + response := s.server.HandleMessage(ctx, rawMessage) + if response != nil { + if err := s.writeResponse(response, writer); err != nil { + s.errLogger.Printf("Error writing tool response: %v", err) + } + } + }() + return nil + } + + // Handle other messages synchronously response := s.server.HandleMessage(ctx, rawMessage) // Only write response if there is one (not for notifications) @@ -269,6 +378,65 @@ func (s *StdioServer) processMessage( return nil } +// handleSamplingResponse checks if the message is a response to a sampling request +// and routes it to the appropriate pending request channel. +func (s *StdioServer) handleSamplingResponse(rawMessage json.RawMessage) bool { + return stdioSessionInstance.handleSamplingResponse(rawMessage) +} + +// handleSamplingResponse handles incoming sampling responses for this session +func (s *stdioSession) handleSamplingResponse(rawMessage json.RawMessage) bool { + // Try to parse as a JSON-RPC response + var response struct { + JSONRPC string `json:"jsonrpc"` + ID json.Number `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error *struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error,omitempty"` + } + + if err := json.Unmarshal(rawMessage, &response); err != nil { + return false + } + // Parse the ID as int64 + idInt64, err := response.ID.Int64() + if err != nil || (response.Result == nil && response.Error == nil) { + return false + } + + // Look for a pending request with this ID + s.pendingMu.RLock() + responseChan, exists := s.pendingRequests[idInt64] + s.pendingMu.RUnlock() + + if !exists { + return false + } // Parse and send the response + samplingResp := &samplingResponse{} + + if response.Error != nil { + samplingResp.err = fmt.Errorf("sampling request failed: %s", response.Error.Message) + } else { + var result mcp.CreateMessageResult + if err := json.Unmarshal(response.Result, &result); err != nil { + samplingResp.err = fmt.Errorf("failed to unmarshal sampling response: %w", err) + } else { + samplingResp.result = &result + } + } + + // Send the response (non-blocking) + select { + case responseChan <- samplingResp: + default: + // Channel is full or closed, ignore + } + + return true +} + // writeResponse marshals and writes a JSON-RPC response message followed by a newline. // Returns an error if marshaling or writing fails. func (s *StdioServer) writeResponse( diff --git a/vendor/github.com/mark3labs/mcp-go/server/streamable_http.go b/vendor/github.com/mark3labs/mcp-go/server/streamable_http.go index e9a011fb1c31b46771e3baeffe666ef9a71ef1a1..1312c9753a5ddc2d37a2d3c9f6266cc80d517e2e 100644 --- a/vendor/github.com/mark3labs/mcp-go/server/streamable_http.go +++ b/vendor/github.com/mark3labs/mcp-go/server/streamable_http.go @@ -40,7 +40,9 @@ func WithEndpointPath(endpointPath string) StreamableHTTPOption { // to StatelessSessionIdManager. func WithStateLess(stateLess bool) StreamableHTTPOption { return func(s *StreamableHTTPServer) { - s.sessionIdManager = &StatelessSessionIdManager{} + if stateLess { + s.sessionIdManager = &StatelessSessionIdManager{} + } } } @@ -374,7 +376,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") - w.WriteHeader(http.StatusAccepted) + w.WriteHeader(http.StatusOK) flusher, ok := w.(http.Flusher) if !ok { diff --git a/vendor/modules.txt b/vendor/modules.txt index 33d95285eebb41a1038aa2d95233bbcc96a87151..8cbc2b93ffb7ce6c044bca3f157defbf2db3d00c 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -403,7 +403,7 @@ github.com/kylelemons/godebug/pretty # github.com/lucasb-eyer/go-colorful v1.2.0 ## explicit; go 1.12 github.com/lucasb-eyer/go-colorful -# github.com/mark3labs/mcp-go v0.32.0 +# github.com/mark3labs/mcp-go v0.33.0 ## explicit; go 1.23 github.com/mark3labs/mcp-go/client github.com/mark3labs/mcp-go/client/transport From fb21333164037023f626bbfb60d6e9d07210e494 Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Fri, 11 Jul 2025 16:09:06 -0300 Subject: [PATCH 23/30] refactor: remove weird context value usage (#153) * refactor: remove weird context value usage * fix: improvements * fix: more cleanup * fix: diff --- internal/app/lsp.go | 5 +- internal/lsp/watcher/watcher.go | 153 ++++++++++++-------------------- 2 files changed, 56 insertions(+), 102 deletions(-) diff --git a/internal/app/lsp.go b/internal/app/lsp.go index ba98d4b3a074c2e9abcef87eb3030a21be669eab..33506016690645dd714c682ddd2e65e992d2d1f9 100644 --- a/internal/app/lsp.go +++ b/internal/app/lsp.go @@ -59,11 +59,8 @@ func (app *App) createAndStartLSPClient(ctx context.Context, name string, comman // Create a child context that can be canceled when the app is shutting down watchCtx, cancelFunc := context.WithCancel(ctx) - // Create a context with the server name for better identification - watchCtx = context.WithValue(watchCtx, "serverName", name) - // Create the workspace watcher - workspaceWatcher := watcher.NewWorkspaceWatcher(lspClient) + workspaceWatcher := watcher.NewWorkspaceWatcher(name, lspClient) // Store the cancel function to be called during cleanup app.cancelFuncsMutex.Lock() diff --git a/internal/lsp/watcher/watcher.go b/internal/lsp/watcher/watcher.go index a6d27f057e06ea7026a6eed0308979991a44fb9d..5bd016eebe413a17acca29ef628612825d40b923 100644 --- a/internal/lsp/watcher/watcher.go +++ b/internal/lsp/watcher/watcher.go @@ -21,6 +21,7 @@ import ( // WorkspaceWatcher manages LSP file watching type WorkspaceWatcher struct { client *lsp.Client + name string workspacePath string debounceTime time.Duration @@ -33,8 +34,9 @@ type WorkspaceWatcher struct { } // NewWorkspaceWatcher creates a new workspace watcher -func NewWorkspaceWatcher(client *lsp.Client) *WorkspaceWatcher { +func NewWorkspaceWatcher(name string, client *lsp.Client) *WorkspaceWatcher { return &WorkspaceWatcher{ + name: name, client: client, debounceTime: 300 * time.Millisecond, debounceMap: make(map[string]*time.Timer), @@ -95,7 +97,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc } // Determine server type for specialized handling - serverName := getServerNameFromContext(ctx) + serverName := w.name slog.Debug("Server type detected", "serverName", serverName) // Check if this server has sent file watchers @@ -325,17 +327,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str cfg := config.Get() w.workspacePath = workspacePath - // Store the watcher in the context for later use - ctx = context.WithValue(ctx, "workspaceWatcher", w) - - // If the server name isn't already in the context, try to detect it - if _, ok := ctx.Value("serverName").(string); !ok { - serverName := getServerNameFromContext(ctx) - ctx = context.WithValue(ctx, "serverName", serverName) - } - - serverName := getServerNameFromContext(ctx) - slog.Debug("Starting workspace watcher", "workspacePath", workspacePath, "serverName", serverName) + slog.Debug("Starting workspace watcher", "workspacePath", workspacePath, "serverName", w.name) // Register handler for file watcher registrations from the server lsp.RegisterFileWatchHandler(func(id string, watchers []protocol.FileSystemWatcher) { @@ -697,40 +689,6 @@ func (w *WorkspaceWatcher) notifyFileEvent(ctx context.Context, uri string, chan return w.client.DidChangeWatchedFiles(ctx, params) } -// getServerNameFromContext extracts the server name from the context -// This is a best-effort function that tries to identify which LSP server we're dealing with -func getServerNameFromContext(ctx context.Context) string { - // First check if the server name is directly stored in the context - if serverName, ok := ctx.Value("serverName").(string); ok && serverName != "" { - return strings.ToLower(serverName) - } - - // Otherwise, try to extract server name from the client command path - if w, ok := ctx.Value("workspaceWatcher").(*WorkspaceWatcher); ok && w != nil && w.client != nil && w.client.Cmd != nil { - path := strings.ToLower(w.client.Cmd.Path) - - // Extract server name from path - if strings.Contains(path, "typescript") || strings.Contains(path, "tsserver") || strings.Contains(path, "vtsls") { - return "typescript" - } else if strings.Contains(path, "gopls") { - return "gopls" - } else if strings.Contains(path, "rust-analyzer") { - return "rust-analyzer" - } else if strings.Contains(path, "pyright") || strings.Contains(path, "pylsp") || strings.Contains(path, "python") { - return "python" - } else if strings.Contains(path, "clangd") { - return "clangd" - } else if strings.Contains(path, "jdtls") || strings.Contains(path, "java") { - return "java" - } - - // Return the base name as fallback - return filepath.Base(path) - } - - return "unknown" -} - // shouldPreloadFiles determines if we should preload files for a specific language server // Some servers work better with preloaded files, others don't need it func shouldPreloadFiles(serverName string) bool { @@ -884,64 +842,63 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) { } // Check if this path should be watched according to server registrations - if watched, _ := w.isPathWatched(path); watched { - // Get server name for specialized handling - serverName := getServerNameFromContext(ctx) + if watched, _ := w.isPathWatched(path); !watched { + return + } - // Check if the file is a high-priority file that should be opened immediately - // This helps with project initialization for certain language servers - if isHighPriorityFile(path, serverName) { - if cfg.Options.DebugLSP { - slog.Debug("Opening high-priority file", "path", path, "serverName", serverName) - } - if err := w.client.OpenFile(ctx, path); err != nil && cfg.Options.DebugLSP { - slog.Error("Error opening high-priority file", "path", path, "error", err) - } - return - } + serverName := w.name - // For non-high-priority files, we'll use different strategies based on server type - if shouldPreloadFiles(serverName) { - // For servers that benefit from preloading, open files but with limits + // Get server name for specialized handling + // Check if the file is a high-priority file that should be opened immediately + // This helps with project initialization for certain language servers + if isHighPriorityFile(path, serverName) { + if cfg.Options.DebugLSP { + slog.Debug("Opening high-priority file", "path", path, "serverName", serverName) + } + if err := w.client.OpenFile(ctx, path); err != nil && cfg.Options.DebugLSP { + slog.Error("Error opening high-priority file", "path", path, "error", err) + } + return + } - // Check file size - for preloading we're more conservative - if info.Size() > (1 * 1024 * 1024) { // 1MB limit for preloaded files - if cfg.Options.DebugLSP { - slog.Debug("Skipping large file for preloading", "path", path, "size", info.Size()) - } - return - } + // For non-high-priority files, we'll use different strategies based on server type + if !shouldPreloadFiles(serverName) { + return + } + // For servers that benefit from preloading, open files but with limits - // Check file extension for common source files - ext := strings.ToLower(filepath.Ext(path)) + // Check file size - for preloading we're more conservative + if info.Size() > (1 * 1024 * 1024) { // 1MB limit for preloaded files + if cfg.Options.DebugLSP { + slog.Debug("Skipping large file for preloading", "path", path, "size", info.Size()) + } + return + } - // Only preload source files for the specific language - shouldOpen := false + // Check file extension for common source files + ext := strings.ToLower(filepath.Ext(path)) - switch serverName { - case "typescript", "typescript-language-server", "tsserver", "vtsls": - shouldOpen = ext == ".ts" || ext == ".js" || ext == ".tsx" || ext == ".jsx" - case "gopls": - shouldOpen = ext == ".go" - case "rust-analyzer": - shouldOpen = ext == ".rs" - case "python", "pyright", "pylsp": - shouldOpen = ext == ".py" - case "clangd": - shouldOpen = ext == ".c" || ext == ".cpp" || ext == ".h" || ext == ".hpp" - case "java", "jdtls": - shouldOpen = ext == ".java" - default: - // For unknown servers, be conservative - shouldOpen = false - } + // Only preload source files for the specific language + var shouldOpen bool + switch serverName { + case "typescript", "typescript-language-server", "tsserver", "vtsls": + shouldOpen = ext == ".ts" || ext == ".js" || ext == ".tsx" || ext == ".jsx" + case "gopls": + shouldOpen = ext == ".go" + case "rust-analyzer": + shouldOpen = ext == ".rs" + case "python", "pyright", "pylsp": + shouldOpen = ext == ".py" + case "clangd": + shouldOpen = ext == ".c" || ext == ".cpp" || ext == ".h" || ext == ".hpp" + case "java", "jdtls": + shouldOpen = ext == ".java" + } - if shouldOpen { - // Don't need to check if it's already open - the client.OpenFile handles that - if err := w.client.OpenFile(ctx, path); err != nil && cfg.Options.DebugLSP { - slog.Error("Error opening file", "path", path, "error", err) - } - } + if shouldOpen { + // Don't need to check if it's already open - the client.OpenFile handles that + if err := w.client.OpenFile(ctx, path); err != nil && cfg.Options.DebugLSP { + slog.Error("Error opening file", "path", path, "error", err) } } } From 95f5e75458eb471ef7cba3c0b5a39000284e2823 Mon Sep 17 00:00:00 2001 From: Andrey Nering Date: Fri, 11 Jul 2025 17:25:33 -0300 Subject: [PATCH 24/30] ci: fix diffview tests on ci on windows --- internal/tui/exp/diffview/diffview.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/internal/tui/exp/diffview/diffview.go b/internal/tui/exp/diffview/diffview.go index bb51a7e505666e67fd9e914a135a0dd7632bb184..1cb56a678f51d0809c584edc1bedd73befc59966 100644 --- a/internal/tui/exp/diffview/diffview.go +++ b/internal/tui/exp/diffview/diffview.go @@ -365,7 +365,8 @@ func (dv *DiffView) renderUnified() string { shouldWrite := func() bool { return printedLines >= 0 } getContent := func(in string, ls LineStyle) (content string, leadingEllipsis bool) { - content = strings.TrimSuffix(in, "\n") + content = strings.ReplaceAll(in, "\r\n", "\n") + content = strings.TrimSuffix(content, "\n") content = dv.hightlightCode(content, ls.Code.GetBackground()) content = ansi.GraphemeWidth.Cut(content, dv.xOffset, len(content)) content = ansi.Truncate(content, dv.codeWidth, "…") @@ -488,7 +489,8 @@ func (dv *DiffView) renderSplit() string { shouldWrite := func() bool { return printedLines >= 0 } getContent := func(in string, ls LineStyle) (content string, leadingEllipsis bool) { - content = strings.TrimSuffix(in, "\n") + content = strings.ReplaceAll(in, "\r\n", "\n") + content = strings.TrimSuffix(content, "\n") content = dv.hightlightCode(content, ls.Code.GetBackground()) content = ansi.GraphemeWidth.Cut(content, dv.xOffset, len(content)) content = ansi.Truncate(content, dv.codeWidth, "…") From 2aa5f6151e3607bf6d221a7f308c214bce03671c Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Mon, 14 Jul 2025 10:16:36 -0300 Subject: [PATCH 25/30] fix: non-interactive and context cancellation when SIGINT (#143) * fix: non-interactive and context cancellation when SIGINT * fix: spinner interrupt * refactor: remove weird context value usage * fix: improvements * fix: spinner competing for signal handling * fix: vendoring Signed-off-by: Carlos Alexandro Becker * fix: ctrl+c in raw mode * fix: sigkill can't be handled --------- Signed-off-by: Carlos Alexandro Becker --- .gitignore | 1 + cmd/root.go | 6 +- go.mod | 4 +- go.sum | 4 +- internal/app/app.go | 5 +- internal/format/spinner.go | 50 ++- internal/lsp/transport.go | 41 +-- .../sdk/azidentity/go.work.sum | 60 ---- .../github.com/charmbracelet/fang/README.md | 7 +- vendor/github.com/charmbracelet/fang/fang.go | 119 +++++-- vendor/github.com/charmbracelet/fang/help.go | 298 ++++++++++++------ vendor/github.com/charmbracelet/fang/theme.go | 52 ++- vendor/modules.txt | 6 +- 13 files changed, 416 insertions(+), 237 deletions(-) delete mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azidentity/go.work.sum diff --git a/.gitignore b/.gitignore index b28e5a0c727163e8f3585522e680d1df2ad6e621..2f16f744432d89e0a72fd6ea8e359678a64b6d42 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,7 @@ # Go workspace file go.work +go.work.sum # IDE specific files .idea/ diff --git a/cmd/root.go b/cmd/root.go index 3a8f4fba0fe759a42ef1e7647223b2b3b11fbc65..c5231accd672c43a564e4d0a174fb762d06ee044 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -6,6 +6,7 @@ import ( "io" "log/slog" "os" + "syscall" "time" tea "github.com/charmbracelet/bubbletea/v2" @@ -72,9 +73,7 @@ to assist developers in writing, debugging, and understanding code directly from return err } - // Create main context for the application - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := cmd.Context() // Connect DB, this will also run migrations conn, err := db.Connect(ctx, cfg.Options.DataDirectory) @@ -145,6 +144,7 @@ func Execute() { context.Background(), rootCmd, fang.WithVersion(version.Version), + fang.WithNotifySignal(os.Interrupt, syscall.SIGTERM), ); err != nil { os.Exit(1) } diff --git a/go.mod b/go.mod index 2a9d6d5dfbaa827a5c8a57cadbe716dd956e1401..d510a774a03c27ceca623400257228763cc2e9a1 100644 --- a/go.mod +++ b/go.mod @@ -18,9 +18,9 @@ require ( github.com/charlievieth/fastwalk v1.0.11 github.com/charmbracelet/bubbles/v2 v2.0.0-beta.1.0.20250710161907-a4c42b579198 github.com/charmbracelet/bubbletea/v2 v2.0.0-beta.1 - github.com/charmbracelet/fang v0.1.0 + github.com/charmbracelet/fang v0.3.1-0.20250711140230-d5ebb8c1d674 github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe - github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.2.0.20250703152125-8e1c474f8a71 + github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.3 github.com/charmbracelet/log/v2 v2.0.0-20250226163916-c379e29ff706 github.com/charmbracelet/x/ansi v0.9.3 github.com/charmbracelet/x/exp/charmtone v0.0.0-20250708181618-a60a724ba6c3 diff --git a/go.sum b/go.sum index 1d40961a3dce4180d9a06d17e3843f8d8709567b..d7004401154b86ce0658162c06bfc610a0c77126 100644 --- a/go.sum +++ b/go.sum @@ -74,8 +74,8 @@ github.com/charmbracelet/bubbletea-internal/v2 v2.0.0-20250710185017-3c0ffd25e59 github.com/charmbracelet/bubbletea-internal/v2 v2.0.0-20250710185017-3c0ffd25e595/go.mod h1:+Tl7rePElw6OKt382t04zXwtPFoPXxAaJzNrYmtsLds= github.com/charmbracelet/colorprofile v0.3.1 h1:k8dTHMd7fgw4bnFd7jXTLZrSU/CQrKnL3m+AxCzDz40= github.com/charmbracelet/colorprofile v0.3.1/go.mod h1:/GkGusxNs8VB/RSOh3fu0TJmQ4ICMMPApIIVn0KszZ0= -github.com/charmbracelet/fang v0.1.0 h1:SlZS2crf3/zQh7Mr4+W+7QR1k+L08rrPX5rm5z3d7Wg= -github.com/charmbracelet/fang v0.1.0/go.mod h1:Zl/zeUQ8EtQuGyiV0ZKZlZPDowKRTzu8s/367EpN/fc= +github.com/charmbracelet/fang v0.3.1-0.20250711140230-d5ebb8c1d674 h1:+Cz+VfxD5DO+JT1LlswXWhre0HYLj6l2HW8HVGfMuC0= +github.com/charmbracelet/fang v0.3.1-0.20250711140230-d5ebb8c1d674/go.mod h1:9gCUAHmVx5BwSafeyNr3GI0GgvlB1WYjL21SkPp1jyU= github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe h1:i6ce4CcAlPpTj2ER69m1DBeLZ3RRcHnKExuwhKa3GfY= github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe/go.mod h1:p3Q+aN4eQKeM5jhrmXPMgPrlKbmc59rWSnMsSA3udhk= github.com/charmbracelet/lipgloss-internal/v2 v2.0.0-20250710185058-03664cb9cecb h1:lswj7CYZVYbLn2OhYJsXOMRQQGdRIfyuSnh5FdVSMr0= diff --git a/internal/app/app.go b/internal/app/app.go index 9d0e6f176b14df0b15fd90f4b3651cdefafd6826..c3dae3d88a2be7c4cd5491e089b97695b08a7a23 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -95,10 +95,13 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { func (a *App) RunNonInteractive(ctx context.Context, prompt string, quiet bool) error { slog.Info("Running in non-interactive mode") + ctx, cancel := context.WithCancel(ctx) + defer cancel() + // Start spinner if not in quiet mode var spinner *format.Spinner if !quiet { - spinner = format.NewSpinner(ctx, "Generating") + spinner = format.NewSpinner(ctx, cancel, "Generating") spinner.Start() } // Helper function to stop spinner once diff --git a/internal/format/spinner.go b/internal/format/spinner.go index 9377bd3b4c145fc6866ac1e0f4e63dff8ab51619..da64fb93ce262e04a0b5fb9da8c4aea8403d10d8 100644 --- a/internal/format/spinner.go +++ b/internal/format/spinner.go @@ -18,24 +18,48 @@ type Spinner struct { prog *tea.Program } +type model struct { + cancel context.CancelFunc + anim anim.Anim +} + +func (m model) Init() tea.Cmd { return m.anim.Init() } +func (m model) View() string { return m.anim.View() } + +// Update implements tea.Model. +func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.KeyPressMsg: + switch msg.String() { + case "ctrl+c", "esc": + m.cancel() + return m, tea.Quit + } + } + mm, cmd := m.anim.Update(msg) + m.anim = mm.(anim.Anim) + return m, cmd +} + // NewSpinner creates a new spinner with the given message -func NewSpinner(ctx context.Context, message string) *Spinner { +func NewSpinner(ctx context.Context, cancel context.CancelFunc, message string) *Spinner { t := styles.CurrentTheme() - model := anim.New(anim.Settings{ - Size: 10, - Label: message, - LabelColor: t.FgBase, - GradColorA: t.Primary, - GradColorB: t.Secondary, - CycleColors: true, - }) + model := model{ + anim: anim.New(anim.Settings{ + Size: 10, + Label: message, + LabelColor: t.FgBase, + GradColorA: t.Primary, + GradColorB: t.Secondary, + CycleColors: true, + }), + cancel: cancel, + } prog := tea.NewProgram( model, - tea.WithInput(nil), tea.WithOutput(os.Stderr), tea.WithContext(ctx), - tea.WithoutCatchPanics(), ) return &Spinner{ @@ -47,13 +71,13 @@ func NewSpinner(ctx context.Context, message string) *Spinner { // Start begins the spinner animation func (s *Spinner) Start() { go func() { + defer close(s.done) _, err := s.prog.Run() // ensures line is cleared fmt.Fprint(os.Stderr, ansi.EraseEntireLine) - if err != nil && !errors.Is(err, context.Canceled) { + if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, tea.ErrInterrupted) { fmt.Fprintf(os.Stderr, "Error running spinner: %v\n", err) } - close(s.done) }() } diff --git a/internal/lsp/transport.go b/internal/lsp/transport.go index 431a099fa1cda5e5035de7ce6c10ef3761e397ea..9a3dfd261fb68b1afdd17f614daab761f9294327 100644 --- a/internal/lsp/transport.go +++ b/internal/lsp/transport.go @@ -222,29 +222,32 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any } // Wait for response - resp := <-ch - - if cfg.Options.DebugLSP { - slog.Debug("Received response", "id", id) - } - - if resp.Error != nil { - return fmt.Errorf("request failed: %s (code: %d)", resp.Error.Message, resp.Error.Code) - } + select { + case <-ctx.Done(): + return ctx.Err() + case resp := <-ch: + if cfg.Options.DebugLSP { + slog.Debug("Received response", "id", id) + } - if result != nil { - // If result is a json.RawMessage, just copy the raw bytes - if rawMsg, ok := result.(*json.RawMessage); ok { - *rawMsg = resp.Result - return nil + if resp.Error != nil { + return fmt.Errorf("request failed: %s (code: %d)", resp.Error.Message, resp.Error.Code) } - // Otherwise unmarshal into the provided type - if err := json.Unmarshal(resp.Result, result); err != nil { - return fmt.Errorf("failed to unmarshal result: %w", err) + + if result != nil { + // If result is a json.RawMessage, just copy the raw bytes + if rawMsg, ok := result.(*json.RawMessage); ok { + *rawMsg = resp.Result + return nil + } + // Otherwise unmarshal into the provided type + if err := json.Unmarshal(resp.Result, result); err != nil { + return fmt.Errorf("failed to unmarshal result: %w", err) + } } - } - return nil + return nil + } } // Notify sends a notification (a request without an ID that doesn't expect a response) diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azidentity/go.work.sum b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azidentity/go.work.sum deleted file mode 100644 index c592f283b6bdb1cb2b13aa4b0769b94811a1cfe9..0000000000000000000000000000000000000000 --- a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azidentity/go.work.sum +++ /dev/null @@ -1,60 +0,0 @@ -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.0-beta.1 h1:ODs3brnqQM99Tq1PffODpAViYv3Bf8zOg464MU7p5ew= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.0-beta.1/go.mod h1:3Ug6Qzto9anB6mGlEdgYMDF5zHQ+wwhEaYR4s17PHMw= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.0 h1:fb8kj/Dh4CSwgsOzHeZY4Xh68cFVbzXx+ONXGMY//4w= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.0/go.mod h1:uReU2sSxZExRPBAg3qKzmAucSi51+SP1OhohieR821Q= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0/go.mod h1:okt5dMMTOFjX/aovMlrjvvXoPMBVSPzk9185BT0+eZM= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2/go.mod h1:yInRyqWXAuaPrgI7p70+lDDgh3mlBohis29jGMISnmc= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI= -github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/keybase/dbus v0.0.0-20220506165403-5aa21ea2c23a/go.mod h1:YPNKjjE7Ubp9dTbnWvsP3HT+hYnY6TfXzubYTBeUxc8= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= -github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= -github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= -github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= -golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= -golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= -golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= -golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= -golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= -golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= -golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= -golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o= -golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= -golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= -golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= -golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= -golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= -golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk= -golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= -golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= -golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/vendor/github.com/charmbracelet/fang/README.md b/vendor/github.com/charmbracelet/fang/README.md index 88a225cfd6e698d15dd29a9af0a5dca74b61ecf7..575b40ce13fa57eb0e41082943a3c21e05c82777 100644 --- a/vendor/github.com/charmbracelet/fang/README.md +++ b/vendor/github.com/charmbracelet/fang/README.md @@ -1,7 +1,7 @@ # Fang

- Charm Fang + Charm Fang

Latest Release @@ -12,7 +12,7 @@ The CLI starter kit. A small, experimental library for batteries-included [Cobra][cobra] applications.

- fang-02 + The Charm Fang mascot and title treatment

## Features @@ -45,6 +45,7 @@ To use it, invoke `fang.Execute` passing your root `*cobra.Command`: package main import ( + "context" "os" "github.com/charmbracelet/fang" @@ -56,7 +57,7 @@ func main() { Use: "example", Short: "A simple example program!", } - if err := fang.Execute(context.TODO(), cmd); err != nil { + if err := fang.Execute(context.Background(), cmd); err != nil { os.Exit(1) } } diff --git a/vendor/github.com/charmbracelet/fang/fang.go b/vendor/github.com/charmbracelet/fang/fang.go index c1f9bc06a5299c991bac569aa6868e3d08fcd37c..6a6ab99a63fc4debf404694473d23e0a576d2fab 100644 --- a/vendor/github.com/charmbracelet/fang/fang.go +++ b/vendor/github.com/charmbracelet/fang/fang.go @@ -4,11 +4,14 @@ package fang import ( "context" "fmt" + "io" "os" + "os/signal" "runtime/debug" "github.com/charmbracelet/colorprofile" "github.com/charmbracelet/lipgloss/v2" + "github.com/charmbracelet/x/term" mango "github.com/muesli/mango-cobra" "github.com/muesli/roff" "github.com/spf13/cobra" @@ -16,12 +19,24 @@ import ( const shaLen = 7 +// ErrorHandler handles an error, printing them to the given [io.Writer]. +// +// Note that this will only be used if the STDERR is a terminal, and should +// be used for styling only. +type ErrorHandler = func(w io.Writer, styles Styles, err error) + +// ColorSchemeFunc gets a [lipgloss.LightDarkFunc] and returns a [ColorScheme]. +type ColorSchemeFunc = func(lipgloss.LightDarkFunc) ColorScheme + type settings struct { completions bool manpages bool + skipVersion bool version string commit string - theme *ColorScheme + colorscheme ColorSchemeFunc + errHandler ErrorHandler + signals []os.Signal } // Option changes fang settings. @@ -41,10 +56,21 @@ func WithoutManpage() Option { } } +// WithColorSchemeFunc sets a function that return colorscheme. +func WithColorSchemeFunc(cs ColorSchemeFunc) Option { + return func(s *settings) { + s.colorscheme = cs + } +} + // WithTheme sets the colorscheme. +// +// Deprecated: use [WithColorSchemeFunc] instead. func WithTheme(theme ColorScheme) Option { return func(s *settings) { - s.theme = &theme + s.colorscheme = func(lipgloss.LightDarkFunc) ColorScheme { + return theme + } } } @@ -55,6 +81,13 @@ func WithVersion(version string) Option { } } +// WithoutVersion skips the `-v`/`--version` functionality. +func WithoutVersion() Option { + return func(s *settings) { + s.skipVersion = true + } +} + // WithCommit sets the commit SHA. func WithCommit(commit string) Option { return func(s *settings) { @@ -62,30 +95,45 @@ func WithCommit(commit string) Option { } } +// WithErrorHandler sets the error handler. +func WithErrorHandler(handler ErrorHandler) Option { + return func(s *settings) { + s.errHandler = handler + } +} + +// WithNotifySignal sets the signals that should interrupt the execution of the +// program. +func WithNotifySignal(signals ...os.Signal) Option { + return func(s *settings) { + s.signals = signals + } +} + // Execute applies fang to the command and executes it. func Execute(ctx context.Context, root *cobra.Command, options ...Option) error { opts := settings{ manpages: true, completions: true, + colorscheme: DefaultColorScheme, + errHandler: DefaultErrorHandler, } + for _, option := range options { option(&opts) } - if opts.theme == nil { - isDark := lipgloss.HasDarkBackground(os.Stdin, os.Stderr) - t := DefaultTheme(isDark) - opts.theme = &t + helpFunc := func(c *cobra.Command, _ []string) { + w := colorprofile.NewWriter(c.OutOrStdout(), os.Environ()) + helpFn(c, w, makeStyles(mustColorscheme(opts.colorscheme))) } - styles := makeStyles(*opts.theme) - - root.SetHelpFunc(func(c *cobra.Command, _ []string) { - w := colorprofile.NewWriter(c.OutOrStdout(), os.Environ()) - helpFn(c, w, styles) - }) root.SilenceUsage = true root.SilenceErrors = true + if !opts.skipVersion { + root.Version = buildVersion(opts) + } + root.SetHelpFunc(helpFunc) if opts.manpages { root.AddCommand(&cobra.Command{ @@ -108,34 +156,49 @@ func Execute(ctx context.Context, root *cobra.Command, options ...Option) error }) } - if opts.completions { - root.InitDefaultCompletionCmd() - } else { + if !opts.completions { root.CompletionOptions.DisableDefaultCmd = true } - if opts.version == "" { - if info, ok := debug.ReadBuildInfo(); ok && info.Main.Sum != "" { - opts.version = info.Main.Version - opts.commit = getKey(info, "vcs.revision") - } else { - opts.version = "unknown (built from source)" - } - } - if len(opts.commit) >= shaLen { - opts.version += " (" + opts.commit[:shaLen] + ")" + if len(opts.signals) > 0 { + var cancel context.CancelFunc + ctx, cancel = signal.NotifyContext(ctx, opts.signals...) + defer cancel() } - root.Version = opts.version - if err := root.ExecuteContext(ctx); err != nil { + if w, ok := root.ErrOrStderr().(term.File); ok { + // if stderr is not a tty, simply print the error without any + // styling or going through an [ErrorHandler]: + if !term.IsTerminal(w.Fd()) { + _, _ = fmt.Fprintln(w, err.Error()) + return err //nolint:wrapcheck + } + } w := colorprofile.NewWriter(root.ErrOrStderr(), os.Environ()) - writeError(w, styles, err) + opts.errHandler(w, makeStyles(mustColorscheme(opts.colorscheme)), err) return err //nolint:wrapcheck } return nil } +func buildVersion(opts settings) string { + commit := opts.commit + version := opts.version + if version == "" { + if info, ok := debug.ReadBuildInfo(); ok && info.Main.Sum != "" { + version = info.Main.Version + commit = getKey(info, "vcs.revision") + } else { + version = "unknown (built from source)" + } + } + if len(commit) >= shaLen { + version += " (" + commit[:shaLen] + ")" + } + return version +} + func getKey(info *debug.BuildInfo, key string) string { if info == nil { return "" diff --git a/vendor/github.com/charmbracelet/fang/help.go b/vendor/github.com/charmbracelet/fang/help.go index 340090eadf1f779c0e702b03440d7e7efb29b62b..ba2a6185844787e83753c51c3415d5ccc06e36ec 100644 --- a/vendor/github.com/charmbracelet/fang/help.go +++ b/vendor/github.com/charmbracelet/fang/help.go @@ -3,7 +3,10 @@ package fang import ( "cmp" "fmt" + "io" + "iter" "os" + "reflect" "regexp" "strconv" "strings" @@ -20,6 +23,7 @@ import ( const ( minSpace = 10 shortPad = 2 + longPad = 4 ) var width = sync.OnceValue(func() int { @@ -45,65 +49,95 @@ func helpFn(c *cobra.Command, w *colorprofile.Writer, styles Styles) { blockWidth = max(blockWidth, lipgloss.Width(ex)) } blockWidth = min(width()-padding, blockWidth+padding) + blockStyle := styles.Codeblock.Base.Width(blockWidth) - styles.Codeblock.Base = styles.Codeblock.Base.Width(blockWidth) + // if the color profile is ascii or notty, or if the block has no + // background color set, remove the vertical padding. + if w.Profile <= colorprofile.Ascii || reflect.DeepEqual(blockStyle.GetBackground(), lipgloss.NoColor{}) { + blockStyle = blockStyle.PaddingTop(0).PaddingBottom(0) + } _, _ = fmt.Fprintln(w, styles.Title.Render("usage")) - _, _ = fmt.Fprintln(w, styles.Codeblock.Base.Render(usage)) + _, _ = fmt.Fprintln(w, blockStyle.Render(usage)) if len(examples) > 0 { - cw := styles.Codeblock.Base.GetWidth() - styles.Codeblock.Base.GetHorizontalPadding() + cw := blockStyle.GetWidth() - blockStyle.GetHorizontalPadding() _, _ = fmt.Fprintln(w, styles.Title.Render("examples")) for i, example := range examples { if lipgloss.Width(example) > cw { examples[i] = ansi.Truncate(example, cw, "…") } } - _, _ = fmt.Fprintln(w, styles.Codeblock.Base.Render(strings.Join(examples, "\n"))) + _, _ = fmt.Fprintln(w, blockStyle.Render(strings.Join(examples, "\n"))) } + groups, groupKeys := evalGroups(c) cmds, cmdKeys := evalCmds(c, styles) flags, flagKeys := evalFlags(c, styles) space := calculateSpace(cmdKeys, flagKeys) - leftPadding := 4 - if len(cmds) > 0 { - _, _ = fmt.Fprintln(w, styles.Title.Render("commands")) - for _, k := range cmdKeys { - _, _ = fmt.Fprintln(w, lipgloss.JoinHorizontal( - lipgloss.Left, - lipgloss.NewStyle().PaddingLeft(leftPadding).Render(k), - strings.Repeat(" ", space-lipgloss.Width(k)), - cmds[k], - )) + for _, groupID := range groupKeys { + group := cmds[groupID] + if len(group) == 0 { + continue } + renderGroup(w, styles, space, groups[groupID], func(yield func(string, string) bool) { + for _, k := range cmdKeys { + cmds, ok := group[k] + if !ok { + continue + } + if !yield(k, cmds) { + return + } + } + }) } if len(flags) > 0 { - _, _ = fmt.Fprintln(w, styles.Title.Render("flags")) - for _, k := range flagKeys { - _, _ = fmt.Fprintln(w, lipgloss.JoinHorizontal( - lipgloss.Left, - lipgloss.NewStyle().PaddingLeft(leftPadding).Render(k), - strings.Repeat(" ", space-lipgloss.Width(k)), - flags[k], - )) - } + renderGroup(w, styles, space, "flags", func(yield func(string, string) bool) { + for _, k := range flagKeys { + if !yield(k, flags[k]) { + return + } + } + }) } _, _ = fmt.Fprintln(w) } -func writeError(w *colorprofile.Writer, styles Styles, err error) { +// DefaultErrorHandler is the default [ErrorHandler] implementation. +func DefaultErrorHandler(w io.Writer, styles Styles, err error) { _, _ = fmt.Fprintln(w, styles.ErrorHeader.String()) _, _ = fmt.Fprintln(w, styles.ErrorText.Render(err.Error()+".")) _, _ = fmt.Fprintln(w) - _, _ = fmt.Fprintln(w, lipgloss.JoinHorizontal( - lipgloss.Left, - styles.ErrorText.UnsetWidth().Render("Try"), - styles.Program.Flag.Render("--help"), - styles.ErrorText.UnsetWidth().UnsetMargins().UnsetTransform().PaddingLeft(1).Render("for usage."), - )) - _, _ = fmt.Fprintln(w) + if isUsageError(err) { + _, _ = fmt.Fprintln(w, lipgloss.JoinHorizontal( + lipgloss.Left, + styles.ErrorText.UnsetWidth().Render("Try"), + styles.Program.Flag.Render(" --help "), + styles.ErrorText.UnsetWidth().UnsetMargins().UnsetTransform().Render("for usage."), + )) + _, _ = fmt.Fprintln(w) + } +} + +// XXX: this is a hack to detect usage errors. +// See: https://github.com/spf13/cobra/pull/2266 +func isUsageError(err error) bool { + s := err.Error() + for _, prefix := range []string{ + "flag needs an argument:", + "unknown flag:", + "unknown shorthand flag:", + "unknown command", + "invalid argument", + } { + if strings.HasPrefix(s, prefix) { + return true + } + } + return false } func writeLongShort(w *colorprofile.Writer, styles Styles, longShort string) { @@ -118,8 +152,10 @@ var otherArgsRe = regexp.MustCompile(`(\[.*\])`) // styleUsage stylized styleUsage line for a given command. func styleUsage(c *cobra.Command, styles Program, complete bool) string { - // XXX: maybe use c.UseLine() here? u := c.Use + if complete { + u = c.UseLine() + } hasArgs := strings.Contains(u, "[args]") hasFlags := strings.Contains(u, "[flags]") || strings.Contains(u, "[--flags]") || c.HasFlags() || c.HasPersistentFlags() || c.HasAvailableFlags() hasCommands := strings.Contains(u, "[command]") || c.HasAvailableSubCommands() @@ -139,34 +175,38 @@ func styleUsage(c *cobra.Command, styles Program, complete bool) string { u = strings.TrimSpace(u) - useLine := []string{ - styles.Name.Render(u), - } - if !complete { - useLine[0] = styles.Command.Render(u) + useLine := []string{} + if complete { + parts := strings.Fields(u) + useLine = append(useLine, styles.Name.Render(parts[0])) + if len(parts) > 1 { + useLine = append(useLine, styles.Command.Render(" "+strings.Join(parts[1:], " "))) + } + } else { + useLine = append(useLine, styles.Command.Render(u)) } if hasCommands { useLine = append( useLine, - styles.DimmedArgument.Render("[command]"), + styles.DimmedArgument.Render(" [command]"), ) } if hasArgs { useLine = append( useLine, - styles.DimmedArgument.Render("[args]"), + styles.DimmedArgument.Render(" [args]"), ) } for _, arg := range otherArgs { useLine = append( useLine, - styles.DimmedArgument.Render(arg), + styles.DimmedArgument.Render(" "+arg), ) } if hasFlags { useLine = append( useLine, - styles.DimmedArgument.Render("[--flags]"), + styles.DimmedArgument.Render(" [--flags]"), ) } return lipgloss.JoinHorizontal(lipgloss.Left, useLine...) @@ -180,19 +220,21 @@ func styleExamples(c *cobra.Command, styles Styles) []string { } usage := []string{} examples := strings.Split(c.Example, "\n") + var indent bool for i, line := range examples { line = strings.TrimSpace(line) if (i == 0 || i == len(examples)-1) && line == "" { continue } - s := styleExample(c, line, styles.Codeblock) + s := styleExample(c, line, indent, styles.Codeblock) usage = append(usage, s) + indent = len(line) > 1 && (line[len(line)-1] == '\\' || line[len(line)-1] == '|') } return usage } -func styleExample(c *cobra.Command, line string, styles Codeblock) string { +func styleExample(c *cobra.Command, line string, indent bool, styles Codeblock) string { if strings.HasPrefix(line, "# ") { return lipgloss.JoinHorizontal( lipgloss.Left, @@ -200,66 +242,110 @@ func styleExample(c *cobra.Command, line string, styles Codeblock) string { ) } - args := strings.Fields(line) - var nextIsFlag bool var isQuotedString bool + var foundProgramName bool + var isRedirecting bool + programName := c.Root().Name() + args := strings.Fields(line) + var cleanArgs []string for i, arg := range args { - if i == 0 { - args[i] = styles.Program.Name.Render(arg) - continue + isQuoteStart := arg[0] == '"' || arg[0] == '\'' + isQuoteEnd := arg[len(arg)-1] == '"' || arg[len(arg)-1] == '\'' + isFlag := arg[0] == '-' + + switch i { + case 0: + args[i] = "" + if indent { + args[i] = styles.Program.DimmedArgument.Render(" ") + indent = false + } + default: + args[i] = styles.Program.DimmedArgument.Render(" ") } - quoteStart := arg[0] == '"' - quoteEnd := arg[len(arg)-1] == '"' - flagStart := arg[0] == '-' - if i == 1 && !quoteStart && !flagStart { - args[i] = styles.Program.Command.Render(arg) + if isRedirecting { + args[i] += styles.Program.DimmedArgument.Render(arg) + isRedirecting = false continue } - if quoteStart { - isQuotedString = true - } - if isQuotedString { - args[i] = styles.Program.QuotedString.Render(arg) - if quoteEnd { - isQuotedString = false + + switch arg { + case "\\": + if i == len(args)-1 { + args[i] += styles.Program.DimmedArgument.Render(arg) + continue } + case "|", "||", "-", "&", "&&": + args[i] += styles.Program.DimmedArgument.Render(arg) continue } - if nextIsFlag { - args[i] = styles.Program.Flag.Render(arg) + + if isRedirect(arg) { + args[i] += styles.Program.DimmedArgument.Render(arg) + isRedirecting = true continue } - var dashes string - if strings.HasPrefix(arg, "-") { - dashes = "-" + + if !foundProgramName { //nolint:nestif + if isQuotedString { + args[i] += styles.Program.QuotedString.Render(arg) + isQuotedString = !isQuoteEnd + continue + } + if left, right, ok := strings.Cut(arg, "="); ok { + args[i] += styles.Program.Flag.Render(left + "=") + if right[0] == '"' { + isQuotedString = true + args[i] += styles.Program.QuotedString.Render(right) + continue + } + args[i] += styles.Program.Argument.Render(right) + continue + } + + if arg == programName { + args[i] += styles.Program.Name.Render(arg) + foundProgramName = true + continue + } } - if strings.HasPrefix(arg, "--") { - dashes = "--" + + if !isQuoteStart && !isQuotedString && !isFlag { + cleanArgs = append(cleanArgs, arg) + } + + if !isQuoteStart && !isFlag && isSubCommand(c, cleanArgs, arg) { + args[i] += styles.Program.Command.Render(arg) + continue + } + isQuotedString = isQuotedString || isQuoteStart + if isQuotedString { + args[i] += styles.Program.QuotedString.Render(arg) + isQuotedString = !isQuoteEnd + continue } // handle a flag - if dashes != "" { + if isFlag { name, value, ok := strings.Cut(arg, "=") - name = strings.TrimPrefix(name, dashes) // it is --flag=value if ok { - args[i] = lipgloss.JoinHorizontal( + args[i] += lipgloss.JoinHorizontal( lipgloss.Left, - styles.Program.Flag.Render(dashes+name+"="), - styles.Program.Argument.UnsetPadding().Render(value), + styles.Program.Flag.Render(name+"="), + styles.Program.Argument.Render(value), ) continue } // it is either --bool-flag or --flag value - args[i] = lipgloss.JoinHorizontal( + args[i] += lipgloss.JoinHorizontal( lipgloss.Left, - styles.Program.Flag.Render(dashes+name), + styles.Program.Flag.Render(name), ) - // if the flag is not a bool flag, next arg continues current flag - nextIsFlag = !isFlagBool(c, name) continue } - args[i] = styles.Program.Argument.Render(arg) + + args[i] += styles.Program.Argument.Render(arg) } return lipgloss.JoinHorizontal( @@ -284,8 +370,7 @@ func evalFlags(c *cobra.Command, styles Styles) (map[string]string, []string) { } else { parts = append( parts, - styles.Program.Flag.Render("-"+f.Shorthand), - styles.Program.Flag.Render("--"+f.Name), + styles.Program.Flag.Render("-"+f.Shorthand+" --"+f.Name), ) } key := lipgloss.JoinHorizontal(lipgloss.Left, parts...) @@ -303,22 +388,50 @@ func evalFlags(c *cobra.Command, styles Styles) (map[string]string, []string) { return flags, keys } -func evalCmds(c *cobra.Command, styles Styles) (map[string]string, []string) { +// result is map[groupID]map[styled cmd name]styled cmd help, and the keys in +// the order they are defined. +func evalCmds(c *cobra.Command, styles Styles) (map[string](map[string]string), []string) { padStyle := lipgloss.NewStyle().PaddingLeft(0) //nolint:mnd keys := []string{} - cmds := map[string]string{} + cmds := map[string]map[string]string{} for _, sc := range c.Commands() { if sc.Hidden { continue } + if _, ok := cmds[sc.GroupID]; !ok { + cmds[sc.GroupID] = map[string]string{} + } key := padStyle.Render(styleUsage(sc, styles.Program, false)) help := styles.FlagDescription.Render(sc.Short) - cmds[key] = help + cmds[sc.GroupID][key] = help keys = append(keys, key) } return cmds, keys } +func evalGroups(c *cobra.Command) (map[string]string, []string) { + // make sure the default group is the first + ids := []string{""} + groups := map[string]string{"": "commands"} + for _, g := range c.Groups() { + groups[g.ID] = g.Title + ids = append(ids, g.ID) + } + return groups, ids +} + +func renderGroup(w io.Writer, styles Styles, space int, name string, items iter.Seq2[string, string]) { + _, _ = fmt.Fprintln(w, styles.Title.Render(name)) + for key, help := range items { + _, _ = fmt.Fprintln(w, lipgloss.JoinHorizontal( + lipgloss.Left, + lipgloss.NewStyle().PaddingLeft(longPad).Render(key), + strings.Repeat(" ", space-lipgloss.Width(key)), + help, + )) + } +} + func calculateSpace(k1, k2 []string) int { const spaceBetween = 2 space := minSpace @@ -328,13 +441,18 @@ func calculateSpace(k1, k2 []string) int { return space } -func isFlagBool(c *cobra.Command, name string) bool { - flag := c.Flags().Lookup(name) - if flag == nil && len(name) == 1 { - flag = c.Flags().ShorthandLookup(name) - } - if flag == nil { - return false +func isSubCommand(c *cobra.Command, args []string, word string) bool { + cmd, _, _ := c.Root().Traverse(args) + return cmd != nil && cmd.Name() == word +} + +var redirectPrefixes = []string{">", "<", "&>", "2>", "1>", ">>", "2>>"} + +func isRedirect(s string) bool { + for _, p := range redirectPrefixes { + if strings.HasPrefix(s, p) { + return true + } } - return flag.Value.Type() == "bool" + return false } diff --git a/vendor/github.com/charmbracelet/fang/theme.go b/vendor/github.com/charmbracelet/fang/theme.go index 8e3389f6e84b4cc66ed0369f2425c4cc7c27d1b4..12cc868089d475d397691e757f55614a4614e44d 100644 --- a/vendor/github.com/charmbracelet/fang/theme.go +++ b/vendor/github.com/charmbracelet/fang/theme.go @@ -2,10 +2,12 @@ package fang import ( "image/color" + "os" "strings" "github.com/charmbracelet/lipgloss/v2" "github.com/charmbracelet/x/exp/charmtone" + "github.com/charmbracelet/x/term" "golang.org/x/text/cases" "golang.org/x/text/language" ) @@ -31,8 +33,14 @@ type ColorScheme struct { } // DefaultTheme is the default colorscheme. +// +// Deprecated: use [DefaultColorScheme] instead. func DefaultTheme(isDark bool) ColorScheme { - c := lipgloss.LightDark(isDark) + return DefaultColorScheme(lipgloss.LightDark(isDark)) +} + +// DefaultColorScheme is the default colorscheme. +func DefaultColorScheme(c lipgloss.LightDarkFunc) ColorScheme { return ColorScheme{ Base: c(charmtone.Charcoal, charmtone.Ash), Title: charmtone.Charple, @@ -45,7 +53,7 @@ func DefaultTheme(isDark bool) ColorScheme { Argument: c(charmtone.Charcoal, charmtone.Ash), Description: c(charmtone.Charcoal, charmtone.Ash), // flag and command descriptions FlagDefault: c(charmtone.Smoke, charmtone.Squid), // flag default values in descriptions - QuotedString: c(charmtone.Charcoal, charmtone.Ash), + QuotedString: c(charmtone.Coral, charmtone.Salmon), ErrorHeader: [2]color.Color{ charmtone.Butter, charmtone.Cherry, @@ -53,6 +61,26 @@ func DefaultTheme(isDark bool) ColorScheme { } } +// AnsiColorScheme is a ANSI colorscheme. +func AnsiColorScheme(c lipgloss.LightDarkFunc) ColorScheme { + base := c(lipgloss.Black, lipgloss.White) + return ColorScheme{ + Base: base, + Title: lipgloss.Blue, + Description: base, + Comment: c(lipgloss.BrightWhite, lipgloss.BrightBlack), + Flag: lipgloss.Magenta, + FlagDefault: lipgloss.BrightMagenta, + Command: lipgloss.Cyan, + QuotedString: lipgloss.Green, + Argument: base, + Help: base, + Dash: base, + ErrorHeader: [2]color.Color{lipgloss.Black, lipgloss.Red}, + ErrorDetails: lipgloss.Red, + } +} + // Styles represents all the styles used. type Styles struct { Text lipgloss.Style @@ -84,6 +112,14 @@ type Program struct { QuotedString lipgloss.Style } +func mustColorscheme(cs func(lipgloss.LightDarkFunc) ColorScheme) ColorScheme { + var isDark bool + if term.IsTerminal(os.Stdout.Fd()) { + isDark = lipgloss.HasDarkBackground(os.Stdin, os.Stdout) + } + return cs(lipgloss.LightDark(isDark)) +} + func makeStyles(cs ColorScheme) Styles { //nolint:mnd return Styles{ @@ -98,8 +134,7 @@ func makeStyles(cs ColorScheme) Styles { Foreground(cs.Description). Transform(titleFirstWord), FlagDefault: lipgloss.NewStyle(). - Foreground(cs.FlagDefault). - PaddingLeft(1), + Foreground(cs.FlagDefault), Codeblock: Codeblock{ Base: lipgloss.NewStyle(). Background(cs.Codeblock). @@ -116,23 +151,18 @@ func makeStyles(cs ColorScheme) Styles { Background(cs.Codeblock). Foreground(cs.Program), Flag: lipgloss.NewStyle(). - PaddingLeft(1). Background(cs.Codeblock). Foreground(cs.Flag), Argument: lipgloss.NewStyle(). - PaddingLeft(1). Background(cs.Codeblock). Foreground(cs.Argument), DimmedArgument: lipgloss.NewStyle(). - PaddingLeft(1). Background(cs.Codeblock). Foreground(cs.DimmedArgument), Command: lipgloss.NewStyle(). - PaddingLeft(1). Background(cs.Codeblock). Foreground(cs.Command), QuotedString: lipgloss.NewStyle(). - PaddingLeft(1). Background(cs.Codeblock). Foreground(cs.QuotedString), }, @@ -141,18 +171,14 @@ func makeStyles(cs ColorScheme) Styles { Name: lipgloss.NewStyle(). Foreground(cs.Program), Argument: lipgloss.NewStyle(). - PaddingLeft(1). Foreground(cs.Argument), DimmedArgument: lipgloss.NewStyle(). - PaddingLeft(1). Foreground(cs.DimmedArgument), Flag: lipgloss.NewStyle(). - PaddingLeft(1). Foreground(cs.Flag), Command: lipgloss.NewStyle(). Foreground(cs.Command), QuotedString: lipgloss.NewStyle(). - PaddingLeft(1). Foreground(cs.QuotedString), }, Span: lipgloss.NewStyle(). diff --git a/vendor/modules.txt b/vendor/modules.txt index 8cbc2b93ffb7ce6c044bca3f157defbf2db3d00c..ebdc8318f987500b38cb989a7a0de6bea45caf5f 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -260,8 +260,8 @@ github.com/charmbracelet/bubbletea/v2 # github.com/charmbracelet/colorprofile v0.3.1 ## explicit; go 1.23.0 github.com/charmbracelet/colorprofile -# github.com/charmbracelet/fang v0.1.0 -## explicit; go 1.23.0 +# github.com/charmbracelet/fang v0.3.1-0.20250711140230-d5ebb8c1d674 +## explicit; go 1.24.0 github.com/charmbracelet/fang # github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe ## explicit; go 1.23.0 @@ -269,7 +269,7 @@ github.com/charmbracelet/glamour/v2 github.com/charmbracelet/glamour/v2/ansi github.com/charmbracelet/glamour/v2/internal/autolink github.com/charmbracelet/glamour/v2/styles -# github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.2.0.20250703152125-8e1c474f8a71 => github.com/charmbracelet/lipgloss-internal/v2 v2.0.0-20250710185058-03664cb9cecb +# github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.3 => github.com/charmbracelet/lipgloss-internal/v2 v2.0.0-20250710185058-03664cb9cecb ## explicit; go 1.24.2 github.com/charmbracelet/lipgloss/v2 github.com/charmbracelet/lipgloss/v2/table From 911dc9b40f3a9ca014ace15a5a42df2786622823 Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Mon, 14 Jul 2025 09:55:06 -0400 Subject: [PATCH 26/30] fix(tui): chat: remove paste key binding from onboarding The keybinding does not bind to any action in the app but rather assumes the terminal application's default paste behavior. We shouldn't assume that and instead leave it to the terminal emulator and user to handle pasting. On most macOS terminals the default paste keybinding is `cmd+v` and on most Windows and Linux terminals it is `ctrl+shift+v`. --- internal/tui/page/chat/chat.go | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/internal/tui/page/chat/chat.go b/internal/tui/page/chat/chat.go index 33267772e96662f14934a8417149259c7d22541a..5c4b7738580db046920ac7812c7a493c21e996ee 100644 --- a/internal/tui/page/chat/chat.go +++ b/internal/tui/page/chat/chat.go @@ -2,7 +2,6 @@ package chat import ( "context" - "runtime" "time" "github.com/charmbracelet/bubbles/v2/help" @@ -615,26 +614,12 @@ func (a *chatPage) Help() help.KeyMap { fullList = append(fullList, []key.Binding{v}) } case a.isOnboarding && a.splash.IsShowingAPIKey(): - var pasteKey key.Binding - if runtime.GOOS != "darwin" { - pasteKey = key.NewBinding( - key.WithKeys("ctrl+v"), - key.WithHelp("ctrl+v", "paste API key"), - ) - } else { - pasteKey = key.NewBinding( - key.WithKeys("cmd+v"), - key.WithHelp("cmd+v", "paste API key"), - ) - } shortList = append(shortList, // Go back key.NewBinding( key.WithKeys("esc"), key.WithHelp("esc", "back"), ), - // Paste - pasteKey, // Quit key.NewBinding( key.WithKeys("ctrl+c"), From 5953ff8872597a63549e431f6043f6f32a6defa1 Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Mon, 14 Jul 2025 11:31:23 -0300 Subject: [PATCH 27/30] fix: signal Signed-off-by: Carlos Alexandro Becker --- cmd/root.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index c5231accd672c43a564e4d0a174fb762d06ee044..9ae26b993dd1be7374907305ae4cc90036cb05d6 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -6,7 +6,6 @@ import ( "io" "log/slog" "os" - "syscall" "time" tea "github.com/charmbracelet/bubbletea/v2" @@ -144,7 +143,7 @@ func Execute() { context.Background(), rootCmd, fang.WithVersion(version.Version), - fang.WithNotifySignal(os.Interrupt, syscall.SIGTERM), + fang.WithNotifySignal(os.Interrupt), ); err != nil { os.Exit(1) } From 682aa7964fb4091670fedebed579fe68fc372a99 Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Mon, 14 Jul 2025 14:17:11 -0300 Subject: [PATCH 28/30] fix: use fur production (#177) Signed-off-by: Carlos Alexandro Becker --- internal/fur/client/client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/fur/client/client.go b/internal/fur/client/client.go index 5f0ddeaeee708d4b5475403ce1874591f7e9bb2c..d007c9aee18f77c8b03fe804726b4196e474d0b4 100644 --- a/internal/fur/client/client.go +++ b/internal/fur/client/client.go @@ -10,7 +10,7 @@ import ( "github.com/charmbracelet/crush/internal/fur/provider" ) -const defaultURL = "https://fur.charmcli.dev" +const defaultURL = "https://fur.charm.sh" // Client represents a client for the fur service. type Client struct { From f2b9ed007c1fed6446697f7ceb746aea9d40c1e3 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 14 Jul 2025 20:05:46 +0200 Subject: [PATCH 29/30] fix: agent init --- internal/config/config.go | 31 +++++++++++++++++ internal/config/load.go | 34 +------------------ internal/tui/components/chat/splash/splash.go | 1 + internal/tui/util/util.go | 2 ++ 4 files changed, 35 insertions(+), 33 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index ae8bcfdc35562e680527e99cdc74fd591e849874..5108a5cbee1684b92f779243b35aa3a50f162e60 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -371,3 +371,34 @@ func (c *Config) SetProviderAPIKey(providerID, apiKey string) error { c.Providers[providerID] = providerConfig return nil } + +func (c *Config) SetupAgents() { + agents := map[string]Agent{ + "coder": { + ID: "coder", + Name: "Coder", + Description: "An agent that helps with executing coding tasks.", + Model: SelectedModelTypeLarge, + ContextPaths: c.Options.ContextPaths, + // All tools allowed + }, + "task": { + ID: "task", + Name: "Task", + Description: "An agent that helps with searching for context and finding implementation details.", + Model: SelectedModelTypeLarge, + ContextPaths: c.Options.ContextPaths, + AllowedTools: []string{ + "glob", + "grep", + "ls", + "sourcegraph", + "view", + }, + // NO MCPs or LSPs by default + AllowedMCP: map[string][]string{}, + AllowedLSP: []string{}, + }, + } + c.Agents = agents +} diff --git a/internal/config/load.go b/internal/config/load.go index 9f2b5e55f1ccc0a687d46083b67e81d6e5fa212a..d80a1ebfccddc509e983a6bc17084931c2a7dec3 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -83,37 +83,7 @@ func Load(workingDir string, debug bool) (*Config, error) { if err := cfg.configureSelectedModels(providers); err != nil { return nil, fmt.Errorf("failed to configure selected models: %w", err) } - - // TODO: remove the agents concept from the config - agents := map[string]Agent{ - "coder": { - ID: "coder", - Name: "Coder", - Description: "An agent that helps with executing coding tasks.", - Model: SelectedModelTypeLarge, - ContextPaths: cfg.Options.ContextPaths, - // All tools allowed - }, - "task": { - ID: "task", - Name: "Task", - Description: "An agent that helps with searching for context and finding implementation details.", - Model: SelectedModelTypeLarge, - ContextPaths: cfg.Options.ContextPaths, - AllowedTools: []string{ - "glob", - "grep", - "ls", - "sourcegraph", - "view", - }, - // NO MCPs or LSPs by default - AllowedMCP: map[string][]string{}, - AllowedLSP: []string{}, - }, - } - cfg.Agents = agents - + cfg.SetupAgents() return cfg, nil } @@ -387,8 +357,6 @@ func (cfg *Config) configureSelectedModels(knownProviders []provider.Provider) e large.Provider = largeModelSelected.Provider } model := cfg.GetModel(large.Provider, large.Model) - slog.Info("Configuring selected large model", "provider", large.Provider, "model", large.Model) - slog.Info("Model configured", "model", model) if model == nil { large = defaultLarge // override the model type to large diff --git a/internal/tui/components/chat/splash/splash.go b/internal/tui/components/chat/splash/splash.go index 722aaea6f75c6ef0bef7e0a9ec2de319c6d71bfb..5b343e6c5538cc17b476e521e6f2bfaf6b3490cb 100644 --- a/internal/tui/components/chat/splash/splash.go +++ b/internal/tui/components/chat/splash/splash.go @@ -313,6 +313,7 @@ func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd { return util.ReportError(err) } } + cfg.SetupAgents() return nil } diff --git a/internal/tui/util/util.go b/internal/tui/util/util.go index d737acb3f06a155ab52cc7eed7d32a634d85d582..1f4ea30c49c8fb0517a5068d3b7f05970638743a 100644 --- a/internal/tui/util/util.go +++ b/internal/tui/util/util.go @@ -1,6 +1,7 @@ package util import ( + "log/slog" "time" tea "github.com/charmbracelet/bubbletea/v2" @@ -22,6 +23,7 @@ func CmdHandler(msg tea.Msg) tea.Cmd { } func ReportError(err error) tea.Cmd { + slog.Error("Error reported", "error", err) return CmdHandler(InfoMsg{ Type: InfoTypeError, Msg: err.Error(), From 0b9532082732424c7d92b8a2bd157b72b50d2fd0 Mon Sep 17 00:00:00 2001 From: Andrey Nering Date: Mon, 14 Jul 2025 17:59:37 -0300 Subject: [PATCH 30/30] fix: address panic when deciding which model to use --- internal/config/load.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/config/load.go b/internal/config/load.go index d80a1ebfccddc509e983a6bc17084931c2a7dec3..81cb4398e5b3a7a2147ab5388b37088788ea041b 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -301,6 +301,7 @@ func (cfg *Config) defaultModelSelection(knownProviders []provider.Provider) (la defaultSmallModel := cfg.GetModel(string(p.ID), p.DefaultSmallModelID) if defaultSmallModel == nil { err = fmt.Errorf("default small model %s not found for provider %s", p.DefaultSmallModelID, p.ID) + return } smallModel = SelectedModel{ Provider: string(p.ID),