Detailed changes
@@ -13,12 +13,12 @@ require (
github.com/JohannesKaufmann/html-to-markdown v1.6.0
github.com/MakeNowJust/heredoc v1.0.0
github.com/PuerkitoBio/goquery v1.11.0
- github.com/alecthomas/chroma/v2 v2.21.1
+ github.com/alecthomas/chroma/v2 v2.22.0
github.com/atotto/clipboard v0.1.4
github.com/aymanbagabas/go-udiff v0.3.1
- github.com/bmatcuk/doublestar/v4 v4.9.1
+ github.com/bmatcuk/doublestar/v4 v4.9.2
github.com/charlievieth/fastwalk v1.0.14
- github.com/charmbracelet/catwalk v0.12.2
+ github.com/charmbracelet/catwalk v0.13.0
github.com/charmbracelet/colorprofile v0.4.1
github.com/charmbracelet/fang v0.4.4
github.com/charmbracelet/ultraviolet v0.0.0-20251212194010-b927aa605560
@@ -58,10 +58,10 @@ require (
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
github.com/zeebo/xxh3 v1.0.2
- golang.org/x/mod v0.31.0
+ golang.org/x/mod v0.32.0
golang.org/x/net v0.48.0
golang.org/x/sync v0.19.0
- golang.org/x/text v0.32.0
+ golang.org/x/text v0.33.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v3 v3.0.1
modernc.org/sqlite v1.43.0
@@ -39,8 +39,8 @@ github.com/RealAlexandreAI/json-repair v0.0.14 h1:4kTqotVonDVTio5n2yweRUELVcNe2x
github.com/RealAlexandreAI/json-repair v0.0.14/go.mod h1:GKJi5borR78O8c7HCVbgqjhoiVibZ6hJldxbc6dGrAI=
github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0=
github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
-github.com/alecthomas/chroma/v2 v2.21.1 h1:FaSDrp6N+3pphkNKU6HPCiYLgm8dbe5UXIXcoBhZSWA=
-github.com/alecthomas/chroma/v2 v2.21.1/go.mod h1:NqVhfBR0lte5Ouh3DcthuUCTUpDC9cxBOfyMbMQPs3o=
+github.com/alecthomas/chroma/v2 v2.22.0 h1:PqEhf+ezz5F5owoDeOUKFzW+W3ZJDShNCaHg4sZuItI=
+github.com/alecthomas/chroma/v2 v2.22.0/go.mod h1:NqVhfBR0lte5Ouh3DcthuUCTUpDC9cxBOfyMbMQPs3o=
github.com/alecthomas/repr v0.5.2 h1:SU73FTI9D1P5UNtvseffFSGmdNci/O6RsqzeXJtP0Qs=
github.com/alecthomas/repr v0.5.2/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
github.com/andybalholm/cascadia v1.3.2/go.mod h1:7gtRlve5FxPPgIgX36uWBX58OdBsSS6lUvCFb+h7KvU=
@@ -86,16 +86,16 @@ github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuP
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
-github.com/bmatcuk/doublestar/v4 v4.9.1 h1:X8jg9rRZmJd4yRy7ZeNDRnM+T3ZfHv15JiBJ/avrEXE=
-github.com/bmatcuk/doublestar/v4 v4.9.1/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc=
+github.com/bmatcuk/doublestar/v4 v4.9.2 h1:b0mc6WyRSYLjzofB2v/0cuDUZ+MqoGyH3r0dVij35GI=
+github.com/bmatcuk/doublestar/v4 v4.9.2/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc=
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
github.com/charlievieth/fastwalk v1.0.14 h1:3Eh5uaFGwHZd8EGwTjJnSpBkfwfsak9h6ICgnWlhAyg=
github.com/charlievieth/fastwalk v1.0.14/go.mod h1:diVcUreiU1aQ4/Wu3NbxxH4/KYdKpLDojrQ1Bb2KgNY=
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20251024181547-21d6f3d9a904 h1:rwLdEpG9wE6kL69KkEKDiWprO8pQOZHZXeod6+9K+mw=
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20251024181547-21d6f3d9a904/go.mod h1:8TIYxZxsuCqqeJ0lga/b91tBwrbjoHDC66Sq5t8N2R4=
-github.com/charmbracelet/catwalk v0.12.2 h1:zq9b+7kiumof/Dzvqi/oHnwMBgSN/M2Yt82vlIAiKMU=
-github.com/charmbracelet/catwalk v0.12.2/go.mod h1:qg+Yl9oaZTkTvRscqbxfttzOFQ4v0pOT5XwC7b5O0NQ=
+github.com/charmbracelet/catwalk v0.13.0 h1:L+chddP+PJvX3Vl+hqlWW5HAwBErlkL/friQXih1JQI=
+github.com/charmbracelet/catwalk v0.13.0/go.mod h1:qg+Yl9oaZTkTvRscqbxfttzOFQ4v0pOT5XwC7b5O0NQ=
github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk=
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
github.com/charmbracelet/fang v0.4.4 h1:G4qKxF6or/eTPgmAolwPuRNyuci3hTUGGX1rj1YkHJY=
@@ -396,8 +396,8 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
-golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
-golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
+golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
+golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
@@ -462,8 +462,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
-golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
-golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
+golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
+golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@@ -472,8 +472,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
-golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
-golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
+golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
+golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
@@ -68,6 +68,7 @@ type SessionAgent interface {
Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
SetModels(large Model, small Model)
SetTools(tools []fantasy.AgentTool)
+ SetSystemPrompt(systemPrompt string)
Cancel(sessionID string)
CancelAll()
IsSessionBusy(sessionID string) bool
@@ -444,7 +445,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
Content: content,
IsError: true,
}
- _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
+ _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
Role: message.Tool,
Parts: []message.ContentPart{
toolResult,
@@ -561,15 +562,7 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan
return err
}
- summaryPromptText := "Provide a detailed summary of our conversation above."
- if len(currentSession.Todos) > 0 {
- summaryPromptText += "\n\n## Current Todo List\n\n"
- for _, t := range currentSession.Todos {
- summaryPromptText += fmt.Sprintf("- [%s] %s\n", t.Status, t.Content)
- }
- summaryPromptText += "\nInclude these tasks and their statuses in your summary. "
- summaryPromptText += "Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks."
- }
+ summaryPromptText := buildSummaryPrompt(currentSession.Todos)
resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
Prompt: summaryPromptText,
@@ -883,14 +876,17 @@ func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session,
}
func (a *sessionAgent) Cancel(sessionID string) {
- // Cancel regular requests.
- if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
+ // Cancel regular requests. Don't use Take() here - we need the entry to
+ // remain in activeRequests so IsBusy() returns true until the goroutine
+ // fully completes (including error handling that may access the DB).
+ // The defer in processRequest will clean up the entry.
+ if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
slog.Info("Request cancellation initiated", "session_id", sessionID)
cancel()
}
// Also check for summarize requests.
- if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
+ if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
slog.Info("Summarize cancellation initiated", "session_id", sessionID)
cancel()
}
@@ -972,6 +968,10 @@ func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
a.tools = tools
}
+func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
+ a.systemPrompt = systemPrompt
+}
+
func (a *sessionAgent) Model() Model {
return a.largeModel
}
@@ -1101,3 +1101,18 @@ func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Mes
return convertedMessages
}
+
+// buildSummaryPrompt constructs the prompt text for session summarization.
+func buildSummaryPrompt(todos []session.Todo) string {
+ var sb strings.Builder
+ sb.WriteString("Provide a detailed summary of our conversation above.")
+ if len(todos) > 0 {
+ sb.WriteString("\n\n## Current Todo List\n\n")
+ for _, t := range todos {
+ fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
+ }
+ sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
+ sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
+ }
+ return sb.String()
+}
@@ -1,6 +1,7 @@
package agent
import (
+ "fmt"
"os"
"path/filepath"
"runtime"
@@ -11,6 +12,7 @@ import (
"charm.land/x/vcr"
"github.com/charmbracelet/crush/internal/agent/tools"
"github.com/charmbracelet/crush/internal/message"
+ "github.com/charmbracelet/crush/internal/session"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -619,3 +621,37 @@ func TestCoderAgent(t *testing.T) {
})
}
}
+
+func makeTestTodos(n int) []session.Todo {
+ todos := make([]session.Todo, n)
+ for i := range n {
+ todos[i] = session.Todo{
+ Status: session.TodoStatusPending,
+ Content: fmt.Sprintf("Task %d: Implement feature with some description that makes it realistic", i),
+ }
+ }
+ return todos
+}
+
+func BenchmarkBuildSummaryPrompt(b *testing.B) {
+ cases := []struct {
+ name string
+ numTodos int
+ }{
+ {"0todos", 0},
+ {"5todos", 5},
+ {"10todos", 10},
+ {"50todos", 50},
+ }
+
+ for _, tc := range cases {
+ todos := makeTestTodos(tc.numTodos)
+
+ b.Run(tc.name, func(b *testing.B) {
+ b.ReportAllocs()
+ for range b.N {
+ _ = buildSummaryPrompt(todos)
+ }
+ })
+ }
+}
@@ -79,7 +79,7 @@ func (c *coordinator) agenticFetchTool(_ context.Context, client *http.Client) (
description = "Search the web and analyze results"
}
- p := c.permissions.Request(
+ p, err := c.permissions.Request(ctx,
permission.CreatePermissionRequest{
SessionID: validationResult.SessionID,
Path: c.cfg.WorkingDir(),
@@ -90,7 +90,9 @@ func (c *coordinator) agenticFetchTool(_ context.Context, client *http.Client) (
Params: tools.AgenticFetchPermissionsParams(params),
},
)
-
+ if err != nil {
+ return fantasy.ToolResponse{}, err
+ }
if !p {
return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
}
@@ -322,17 +322,12 @@ func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, age
return nil, err
}
- systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), *c.cfg)
- if err != nil {
- return nil, err
- }
-
largeProviderCfg, _ := c.cfg.Providers.Get(large.ModelCfg.Provider)
result := NewSessionAgent(SessionAgentOptions{
large,
small,
largeProviderCfg.SystemPromptPrefix,
- systemPrompt,
+ "",
isSubAgent,
c.cfg.Options.DisableAutoSummarize,
c.permissions.SkipRequests(),
@@ -340,6 +335,16 @@ func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, age
c.messages,
nil,
})
+
+ c.readyWg.Go(func() error {
+ systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), *c.cfg)
+ if err != nil {
+ return err
+ }
+ result.SetSystemPrompt(systemPrompt)
+ return nil
+ })
+
c.readyWg.Go(func() error {
tools, err := c.buildTools(ctx, agent)
if err != nil {
@@ -215,7 +215,7 @@ func NewBashTool(permissions permission.Service, workingDir string, attribution
return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for executing shell command")
}
if !isSafeReadOnly {
- p := permissions.Request(
+ p, err := permissions.Request(ctx,
permission.CreatePermissionRequest{
SessionID: sessionID,
Path: execWorkingDir,
@@ -226,6 +226,9 @@ func NewBashTool(permissions permission.Service, workingDir string, attribution
Params: BashPermissionsParams(params),
},
)
+ if err != nil {
+ return fantasy.ToolResponse{}, err
+ }
if !p {
return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
}
@@ -70,7 +70,7 @@ func NewDownloadTool(permissions permission.Service, workingDir string, client *
return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for downloading files")
}
- p := permissions.Request(
+ p, err := permissions.Request(ctx,
permission.CreatePermissionRequest{
SessionID: sessionID,
Path: filePath,
@@ -80,7 +80,9 @@ func NewDownloadTool(permissions permission.Service, workingDir string, client *
Params: DownloadPermissionsParams(params),
},
)
-
+ if err != nil {
+ return fantasy.ToolResponse{}, err
+ }
if !p {
return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
}
@@ -122,7 +122,7 @@ func createNewFile(edit editContext, filePath, content string, call fantasy.Tool
content,
strings.TrimPrefix(filePath, edit.workingDir),
)
- p := edit.permissions.Request(
+ p, err := edit.permissions.Request(edit.ctx,
permission.CreatePermissionRequest{
SessionID: sessionID,
Path: fsext.PathOrPrefix(filePath, edit.workingDir),
@@ -137,6 +137,9 @@ func createNewFile(edit editContext, filePath, content string, call fantasy.Tool
},
},
)
+ if err != nil {
+ return fantasy.ToolResponse{}, err
+ }
if !p {
return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
}
@@ -243,7 +246,7 @@ func deleteContent(edit editContext, filePath, oldString string, replaceAll bool
strings.TrimPrefix(filePath, edit.workingDir),
)
- p := edit.permissions.Request(
+ p, err := edit.permissions.Request(edit.ctx,
permission.CreatePermissionRequest{
SessionID: sessionID,
Path: fsext.PathOrPrefix(filePath, edit.workingDir),
@@ -258,6 +261,9 @@ func deleteContent(edit editContext, filePath, oldString string, replaceAll bool
},
},
)
+ if err != nil {
+ return fantasy.ToolResponse{}, err
+ }
if !p {
return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
}
@@ -378,7 +384,7 @@ func replaceContent(edit editContext, filePath, oldString, newString string, rep
strings.TrimPrefix(filePath, edit.workingDir),
)
- p := edit.permissions.Request(
+ p, err := edit.permissions.Request(edit.ctx,
permission.CreatePermissionRequest{
SessionID: sessionID,
Path: fsext.PathOrPrefix(filePath, edit.workingDir),
@@ -393,6 +399,9 @@ func replaceContent(edit editContext, filePath, oldString, newString string, rep
},
},
)
+ if err != nil {
+ return fantasy.ToolResponse{}, err
+ }
if !p {
return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
}
@@ -55,7 +55,7 @@ func NewFetchTool(permissions permission.Service, workingDir string, client *htt
return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
}
- p := permissions.Request(
+ p, err := permissions.Request(ctx,
permission.CreatePermissionRequest{
SessionID: sessionID,
Path: workingDir,
@@ -66,7 +66,9 @@ func NewFetchTool(permissions permission.Service, workingDir string, client *htt
Params: FetchPermissionsParams(params),
},
)
-
+ if err != nil {
+ return fantasy.ToolResponse{}, err
+ }
if !p {
return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
}
@@ -79,7 +79,7 @@ func NewLsTool(permissions permission.Service, workingDir string, lsConfig confi
return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for accessing directories outside working directory")
}
- granted := permissions.Request(
+ granted, err := permissions.Request(ctx,
permission.CreatePermissionRequest{
SessionID: sessionID,
Path: absSearchPath,
@@ -90,7 +90,9 @@ func NewLsTool(permissions permission.Service, workingDir string, lsConfig confi
Params: LSPermissionsParams(params),
},
)
-
+ if err != nil {
+ return fantasy.ToolResponse{}, err
+ }
if !granted {
return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
}
@@ -89,7 +89,7 @@ func (m *Tool) Run(ctx context.Context, params fantasy.ToolCall) (fantasy.ToolRe
return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
}
permissionDescription := fmt.Sprintf("execute %s with the following parameters:", m.Info().Name)
- p := m.permissions.Request(
+ p, err := m.permissions.Request(ctx,
permission.CreatePermissionRequest{
SessionID: sessionID,
ToolCallID: params.ID,
@@ -100,6 +100,9 @@ func (m *Tool) Run(ctx context.Context, params fantasy.ToolCall) (fantasy.ToolRe
Params: params.Input,
},
)
+ if err != nil {
+ return fantasy.ToolResponse{}, err
+ }
if !p {
return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
}
@@ -107,29 +107,28 @@ func GetState(name string) (ClientInfo, bool) {
// Close closes all MCP clients. This should be called during application shutdown.
func Close() error {
- var errs []error
var wg sync.WaitGroup
- for name, session := range sessions.Seq2() {
- wg.Go(func() {
- done := make(chan bool, 1)
- go func() {
+ done := make(chan struct{}, 1)
+ go func() {
+ for name, session := range sessions.Seq2() {
+ wg.Go(func() {
if err := session.Close(); err != nil &&
!errors.Is(err, io.EOF) &&
!errors.Is(err, context.Canceled) &&
err.Error() != "signal: killed" {
- errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
+ slog.Warn("Failed to shutdown MCP client", "name", name, "error", err)
}
- done <- true
- }()
- select {
- case <-done:
- case <-time.After(time.Millisecond * 250):
- }
- })
+ })
+ }
+ wg.Wait()
+ done <- struct{}{}
+ }()
+ select {
+ case <-done:
+ case <-time.After(5 * time.Second):
}
- wg.Wait()
broker.Shutdown()
- return errors.Join(errs...)
+ return nil
}
// Initialize initializes MCP clients based on the provided configuration.
@@ -173,7 +173,7 @@ func processMultiEditWithCreation(edit editContext, params MultiEditParams, call
} else {
description = fmt.Sprintf("Create file %s with %d edits", params.FilePath, editsApplied)
}
- p := edit.permissions.Request(permission.CreatePermissionRequest{
+ p, err := edit.permissions.Request(edit.ctx, permission.CreatePermissionRequest{
SessionID: sessionID,
Path: fsext.PathOrPrefix(params.FilePath, edit.workingDir),
ToolCallID: call.ID,
@@ -186,12 +186,15 @@ func processMultiEditWithCreation(edit editContext, params MultiEditParams, call
NewContent: currentContent,
},
})
+ if err != nil {
+ return fantasy.ToolResponse{}, err
+ }
if !p {
return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
}
// Write the file
- err := os.WriteFile(params.FilePath, []byte(currentContent), 0o644)
+ err = os.WriteFile(params.FilePath, []byte(currentContent), 0o644)
if err != nil {
return fantasy.ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
}
@@ -314,7 +317,7 @@ func processMultiEditExistingFile(edit editContext, params MultiEditParams, call
} else {
description = fmt.Sprintf("Apply %d edits to file %s", editsApplied, params.FilePath)
}
- p := edit.permissions.Request(permission.CreatePermissionRequest{
+ p, err := edit.permissions.Request(edit.ctx, permission.CreatePermissionRequest{
SessionID: sessionID,
Path: fsext.PathOrPrefix(params.FilePath, edit.workingDir),
ToolCallID: call.ID,
@@ -327,6 +330,9 @@ func processMultiEditExistingFile(edit editContext, params MultiEditParams, call
NewContent: currentContent,
},
})
+ if err != nil {
+ return fantasy.ToolResponse{}, err
+ }
if !p {
return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
}
@@ -19,8 +19,8 @@ type mockPermissionService struct {
*pubsub.Broker[permission.PermissionRequest]
}
-func (m *mockPermissionService) Request(req permission.CreatePermissionRequest) bool {
- return true
+func (m *mockPermissionService) Request(ctx context.Context, req permission.CreatePermissionRequest) (bool, error) {
+ return true, nil
}
func (m *mockPermissionService) Grant(req permission.PermissionRequest) {}
@@ -88,7 +88,7 @@ func NewViewTool(lspClients *csync.Map[string, *lsp.Client], permissions permiss
return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for accessing files outside working directory")
}
- granted := permissions.Request(
+ granted, err := permissions.Request(ctx,
permission.CreatePermissionRequest{
SessionID: sessionID,
Path: absFilePath,
@@ -99,7 +99,9 @@ func NewViewTool(lspClients *csync.Map[string, *lsp.Client], permissions permiss
Params: ViewPermissionsParams(params),
},
)
-
+ if err != nil {
+ return fantasy.ToolResponse{}, err
+ }
if !granted {
return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
}
@@ -111,7 +111,7 @@ func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permis
strings.TrimPrefix(filePath, workingDir),
)
- p := permissions.Request(
+ p, err := permissions.Request(ctx,
permission.CreatePermissionRequest{
SessionID: sessionID,
Path: fsext.PathOrPrefix(filePath, workingDir),
@@ -126,6 +126,9 @@ func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permis
},
},
)
+ if err != nil {
+ return fantasy.ToolResponse{}, err
+ }
if !p {
return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
}
@@ -392,25 +392,31 @@ func (app *App) Subscribe(program *tea.Program) {
func (app *App) Shutdown() {
start := time.Now()
defer func() { slog.Info("Shutdown took " + time.Since(start).String()) }()
- var wg sync.WaitGroup
+
+ // First, cancel all agents and wait for them to finish. This must complete
+ // before closing the DB so agents can finish writing their state.
if app.AgentCoordinator != nil {
- wg.Go(func() {
- app.AgentCoordinator.CancelAll()
- })
+ app.AgentCoordinator.CancelAll()
}
+ // Now run remaining cleanup tasks in parallel.
+ var wg sync.WaitGroup
+
// Kill all background shells.
wg.Go(func() {
shell.GetBackgroundShellManager().KillAll()
})
// Shutdown all LSP clients.
+ shutdownCtx, cancel := context.WithTimeout(app.globalCtx, 5*time.Second)
+ defer cancel()
for name, client := range app.LSPClients.Seq2() {
wg.Go(func() {
- shutdownCtx, cancel := context.WithTimeout(app.globalCtx, 5*time.Second)
- defer cancel()
- if err := client.Close(shutdownCtx); err != nil {
- slog.Error("Failed to shutdown LSP client", "name", name, "error", err)
+ if err := client.Close(shutdownCtx); err != nil &&
+ !errors.Is(err, io.EOF) &&
+ !errors.Is(err, context.Canceled) &&
+ err.Error() != "signal: killed" {
+ slog.Warn("Failed to shutdown LSP client", "name", name, "error", err)
}
})
}
@@ -1,8 +1,6 @@
package config
import (
- "io"
- "strings"
"testing"
"github.com/stretchr/testify/require"
@@ -83,7 +81,7 @@ func TestAttributionMigration(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
- cfg, err := loadFromReaders([]io.Reader{strings.NewReader(tt.configJSON)})
+ cfg, err := loadFromBytes([][]byte{[]byte(tt.configJSON)})
require.NoError(t, err)
cfg.setDefaults(t.TempDir(), "")
@@ -253,6 +253,7 @@ type Options struct {
DataDirectory string `json:"data_directory,omitempty" jsonschema:"description=Directory for storing application data (relative to working directory),default=.crush,example=.crush"` // Relative to the cwd
DisabledTools []string `json:"disabled_tools,omitempty" jsonschema:"description=List of built-in tools to disable and hide from the agent,example=bash,example=sourcegraph"`
DisableProviderAutoUpdate bool `json:"disable_provider_auto_update,omitempty" jsonschema:"description=Disable providers auto-update,default=false"`
+ DisableDefaultProviders bool `json:"disable_default_providers,omitempty" jsonschema:"description=Ignore all default/embedded providers. When enabled, providers must be fully specified in the config file with base_url, models, and api_key - no merging with defaults occurs,default=false"`
Attribution *Attribution `json:"attribution,omitempty" jsonschema:"description=Attribution settings for generated content"`
DisableMetrics bool `json:"disable_metrics,omitempty" jsonschema:"description=Disable sending metrics,default=false"`
InitializeAs string `json:"initialize_as,omitempty" jsonschema:"description=Name of the context file to create/update during project initialization,default=AGENTS.md,example=AGENTS.md,example=CRUSH.md,example=CLAUDE.md,example=docs/LLMs.md"`
@@ -811,21 +812,21 @@ func (c *ProviderConfig) TestConnection(resolver VariableResolver) error {
for k, v := range c.ExtraHeaders {
req.Header.Set(k, v)
}
- b, err := client.Do(req)
+ resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
}
+ defer resp.Body.Close()
if c.ID == string(catwalk.InferenceProviderZAI) {
- if b.StatusCode == http.StatusUnauthorized {
- // for z.ai just check if the http response is not 401
- return fmt.Errorf("failed to connect to provider %s: %s", c.ID, b.Status)
+ if resp.StatusCode == http.StatusUnauthorized {
+ // For z.ai just check if the http response is not 401.
+ return fmt.Errorf("failed to connect to provider %s: %s", c.ID, resp.Status)
}
} else {
- if b.StatusCode != http.StatusOK {
- return fmt.Errorf("failed to connect to provider %s: %s", c.ID, b.Status)
+ if resp.StatusCode != http.StatusOK {
+ return fmt.Errorf("failed to connect to provider %s: %s", c.ID, resp.Status)
}
}
- _ = b.Body.Close()
return nil
}
@@ -5,7 +5,6 @@ import (
"context"
"encoding/json"
"fmt"
- "io"
"log/slog"
"maps"
"os"
@@ -25,25 +24,11 @@ import (
"github.com/charmbracelet/crush/internal/home"
"github.com/charmbracelet/crush/internal/log"
powernapConfig "github.com/charmbracelet/x/powernap/pkg/config"
+ "github.com/qjebbs/go-jsons"
)
const defaultCatwalkURL = "https://catwalk.charm.sh"
-// LoadReader config via io.Reader.
-func LoadReader(fd io.Reader) (*Config, error) {
- data, err := io.ReadAll(fd)
- if err != nil {
- return nil, err
- }
-
- var config Config
- err = json.Unmarshal(data, &config)
- if err != nil {
- return nil, err
- }
- return &config, err
-}
-
// Load loads the configuration from the default paths.
func Load(workingDir, dataDir string, debug bool) (*Config, error) {
configPaths := lookupConfigs(workingDir)
@@ -137,6 +122,14 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
restore := PushPopCrushEnv()
defer restore()
+ // When disable_default_providers is enabled, skip all default/embedded
+ // providers entirely. Users must fully specify any providers they want.
+ // We skip to the custom provider validation loop which handles all
+ // user-configured providers uniformly.
+ if c.Options.DisableDefaultProviders {
+ knownProviders = nil
+ }
+
for _, p := range knownProviders {
knownProviderNames[string(p.ID)] = true
config, configExists := c.Providers.Get(string(p.ID))
@@ -377,6 +370,10 @@ func (c *Config) setDefaults(workingDir, dataDir string) {
c.Options.DisableProviderAutoUpdate, _ = strconv.ParseBool(str)
}
+ if str, ok := os.LookupEnv("CRUSH_DISABLE_DEFAULT_PROVIDERS"); ok {
+ c.Options.DisableDefaultProviders, _ = strconv.ParseBool(str)
+ }
+
if c.Options.Attribution == nil {
c.Options.Attribution = &Attribution{
TrailerStyle: TrailerStyleAssistedBy,
@@ -632,35 +629,39 @@ func lookupConfigs(cwd string) []string {
}
func loadFromConfigPaths(configPaths []string) (*Config, error) {
- var configs []io.Reader
+ var configs [][]byte
for _, path := range configPaths {
- fd, err := os.Open(path)
+ data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
continue
}
return nil, fmt.Errorf("failed to open config file %s: %w", path, err)
}
- defer fd.Close()
-
- configs = append(configs, fd)
+ if len(data) == 0 {
+ continue
+ }
+ configs = append(configs, data)
}
- return loadFromReaders(configs)
+ return loadFromBytes(configs)
}
-func loadFromReaders(readers []io.Reader) (*Config, error) {
- if len(readers) == 0 {
+func loadFromBytes(configs [][]byte) (*Config, error) {
+ if len(configs) == 0 {
return &Config{}, nil
}
- merged, err := Merge(readers)
+ data, err := jsons.Merge(configs)
if err != nil {
- return nil, fmt.Errorf("failed to merge configuration readers: %w", err)
+ return nil, err
}
-
- return LoadReader(merged)
+ var config Config
+ if err := json.Unmarshal(data, &config); err != nil {
+ return nil, err
+ }
+ return &config, nil
}
func hasVertexCredentials(env env.Env) bool {
@@ -0,0 +1,103 @@
+package config
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+func BenchmarkLoadFromConfigPaths(b *testing.B) {
+ // Create temp config files with realistic content.
+ tmpDir := b.TempDir()
+
+ globalConfig := filepath.Join(tmpDir, "global.json")
+ localConfig := filepath.Join(tmpDir, "local.json")
+
+ globalContent := []byte(`{
+ "providers": {
+ "openai": {
+ "api_key": "$OPENAI_API_KEY",
+ "base_url": "https://api.openai.com/v1"
+ },
+ "anthropic": {
+ "api_key": "$ANTHROPIC_API_KEY",
+ "base_url": "https://api.anthropic.com"
+ }
+ },
+ "options": {
+ "tui": {
+ "theme": "dark"
+ }
+ }
+ }`)
+
+ localContent := []byte(`{
+ "providers": {
+ "openai": {
+ "api_key": "sk-override-key"
+ }
+ },
+ "options": {
+ "context_paths": ["README.md", "AGENTS.md"]
+ }
+ }`)
+
+ if err := os.WriteFile(globalConfig, globalContent, 0o644); err != nil {
+ b.Fatal(err)
+ }
+ if err := os.WriteFile(localConfig, localContent, 0o644); err != nil {
+ b.Fatal(err)
+ }
+
+ configPaths := []string{globalConfig, localConfig}
+
+ b.ReportAllocs()
+ for b.Loop() {
+ _, err := loadFromConfigPaths(configPaths)
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func BenchmarkLoadFromConfigPaths_MissingFiles(b *testing.B) {
+ // Test with mix of existing and non-existing paths.
+ tmpDir := b.TempDir()
+
+ existingConfig := filepath.Join(tmpDir, "exists.json")
+ content := []byte(`{"options": {"tui": {"theme": "dark"}}}`)
+ if err := os.WriteFile(existingConfig, content, 0o644); err != nil {
+ b.Fatal(err)
+ }
+
+ configPaths := []string{
+ filepath.Join(tmpDir, "nonexistent1.json"),
+ existingConfig,
+ filepath.Join(tmpDir, "nonexistent2.json"),
+ }
+
+ b.ReportAllocs()
+ for b.Loop() {
+ _, err := loadFromConfigPaths(configPaths)
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func BenchmarkLoadFromConfigPaths_Empty(b *testing.B) {
+ // Test with no config files.
+ tmpDir := b.TempDir()
+ configPaths := []string{
+ filepath.Join(tmpDir, "nonexistent1.json"),
+ filepath.Join(tmpDir, "nonexistent2.json"),
+ }
+
+ b.ReportAllocs()
+ for b.Loop() {
+ _, err := loadFromConfigPaths(configPaths)
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+}
@@ -5,7 +5,6 @@ import (
"log/slog"
"os"
"path/filepath"
- "strings"
"testing"
"github.com/charmbracelet/catwalk/pkg/catwalk"
@@ -22,12 +21,12 @@ func TestMain(m *testing.M) {
os.Exit(exitVal)
}
-func TestConfig_LoadFromReaders(t *testing.T) {
- data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
- data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
- data3 := strings.NewReader(`{"providers": {"openai": {}}}`)
+func TestConfig_LoadFromBytes(t *testing.T) {
+ data1 := []byte(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
+ data2 := []byte(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
+ data3 := []byte(`{"providers": {"openai": {}}}`)
- loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3})
+ loadedConfig, err := loadFromBytes([][]byte{data1, data2, data3})
require.NoError(t, err)
require.NotNil(t, loadedConfig)
@@ -1095,6 +1094,217 @@ func TestConfig_defaultModelSelection(t *testing.T) {
})
}
+func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) {
+ t.Run("when enabled, ignores all default providers and requires full specification", func(t *testing.T) {
+ knownProviders := []catwalk.Provider{
+ {
+ ID: "openai",
+ APIKey: "$OPENAI_API_KEY",
+ APIEndpoint: "https://api.openai.com/v1",
+ Models: []catwalk.Model{{
+ ID: "gpt-4",
+ }},
+ },
+ }
+
+ // User references openai but doesn't fully specify it (no base_url, no
+ // models). This should be rejected because disable_default_providers
+ // treats all providers as custom.
+ cfg := &Config{
+ Options: &Options{
+ DisableDefaultProviders: true,
+ },
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
+ "openai": {
+ APIKey: "$OPENAI_API_KEY",
+ },
+ }),
+ }
+ cfg.setDefaults("/tmp", "")
+
+ env := env.NewFromMap(map[string]string{
+ "OPENAI_API_KEY": "test-key",
+ })
+ resolver := NewEnvironmentVariableResolver(env)
+ err := cfg.configureProviders(env, resolver, knownProviders)
+ require.NoError(t, err)
+
+ // openai should NOT be present because it lacks base_url and models.
+ require.Equal(t, 0, cfg.Providers.Len())
+ _, exists := cfg.Providers.Get("openai")
+ require.False(t, exists, "openai should not be present without full specification")
+ })
+
+ t.Run("when enabled, fully specified providers work", func(t *testing.T) {
+ knownProviders := []catwalk.Provider{
+ {
+ ID: "openai",
+ APIKey: "$OPENAI_API_KEY",
+ APIEndpoint: "https://api.openai.com/v1",
+ Models: []catwalk.Model{{
+ ID: "gpt-4",
+ }},
+ },
+ }
+
+ // User fully specifies their provider.
+ cfg := &Config{
+ Options: &Options{
+ DisableDefaultProviders: true,
+ },
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
+ "my-llm": {
+ APIKey: "$MY_API_KEY",
+ BaseURL: "https://my-llm.example.com/v1",
+ Models: []catwalk.Model{{
+ ID: "my-model",
+ }},
+ },
+ }),
+ }
+ cfg.setDefaults("/tmp", "")
+
+ env := env.NewFromMap(map[string]string{
+ "MY_API_KEY": "test-key",
+ "OPENAI_API_KEY": "test-key",
+ })
+ resolver := NewEnvironmentVariableResolver(env)
+ err := cfg.configureProviders(env, resolver, knownProviders)
+ require.NoError(t, err)
+
+ // Only fully specified provider should be present.
+ require.Equal(t, 1, cfg.Providers.Len())
+ provider, exists := cfg.Providers.Get("my-llm")
+ require.True(t, exists, "my-llm should be present")
+ require.Equal(t, "https://my-llm.example.com/v1", provider.BaseURL)
+ require.Len(t, provider.Models, 1)
+
+ // Default openai should NOT be present.
+ _, exists = cfg.Providers.Get("openai")
+ require.False(t, exists, "openai should not be present")
+ })
+
+ t.Run("when disabled, includes all known providers with valid credentials", func(t *testing.T) {
+ knownProviders := []catwalk.Provider{
+ {
+ ID: "openai",
+ APIKey: "$OPENAI_API_KEY",
+ APIEndpoint: "https://api.openai.com/v1",
+ Models: []catwalk.Model{{
+ ID: "gpt-4",
+ }},
+ },
+ {
+ ID: "anthropic",
+ APIKey: "$ANTHROPIC_API_KEY",
+ APIEndpoint: "https://api.anthropic.com/v1",
+ Models: []catwalk.Model{{
+ ID: "claude-3",
+ }},
+ },
+ }
+
+ // User only configures openai, both API keys are available, but option
+ // is disabled.
+ cfg := &Config{
+ Options: &Options{
+ DisableDefaultProviders: false,
+ },
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
+ "openai": {
+ APIKey: "$OPENAI_API_KEY",
+ },
+ }),
+ }
+ cfg.setDefaults("/tmp", "")
+
+ env := env.NewFromMap(map[string]string{
+ "OPENAI_API_KEY": "test-key",
+ "ANTHROPIC_API_KEY": "test-key",
+ })
+ resolver := NewEnvironmentVariableResolver(env)
+ err := cfg.configureProviders(env, resolver, knownProviders)
+ require.NoError(t, err)
+
+ // Both providers should be present.
+ require.Equal(t, 2, cfg.Providers.Len())
+ _, exists := cfg.Providers.Get("openai")
+ require.True(t, exists, "openai should be present")
+ _, exists = cfg.Providers.Get("anthropic")
+ require.True(t, exists, "anthropic should be present")
+ })
+
+ t.Run("when enabled, provider missing models is rejected", func(t *testing.T) {
+ cfg := &Config{
+ Options: &Options{
+ DisableDefaultProviders: true,
+ },
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
+ "my-llm": {
+ APIKey: "test-key",
+ BaseURL: "https://my-llm.example.com/v1",
+ Models: []catwalk.Model{}, // No models.
+ },
+ }),
+ }
+ cfg.setDefaults("/tmp", "")
+
+ env := env.NewFromMap(map[string]string{})
+ resolver := NewEnvironmentVariableResolver(env)
+ err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
+ require.NoError(t, err)
+
+ // Provider should be rejected for missing models.
+ require.Equal(t, 0, cfg.Providers.Len())
+ })
+
+ t.Run("when enabled, provider missing base_url is rejected", func(t *testing.T) {
+ cfg := &Config{
+ Options: &Options{
+ DisableDefaultProviders: true,
+ },
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
+ "my-llm": {
+ APIKey: "test-key",
+ Models: []catwalk.Model{{ID: "model"}},
+ // No BaseURL.
+ },
+ }),
+ }
+ cfg.setDefaults("/tmp", "")
+
+ env := env.NewFromMap(map[string]string{})
+ resolver := NewEnvironmentVariableResolver(env)
+ err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
+ require.NoError(t, err)
+
+ // Provider should be rejected for missing base_url.
+ require.Equal(t, 0, cfg.Providers.Len())
+ })
+}
+
+func TestConfig_setDefaultsDisableDefaultProvidersEnvVar(t *testing.T) {
+ t.Run("sets option from environment variable", func(t *testing.T) {
+ t.Setenv("CRUSH_DISABLE_DEFAULT_PROVIDERS", "true")
+
+ cfg := &Config{}
+ cfg.setDefaults("/tmp", "")
+
+ require.True(t, cfg.Options.DisableDefaultProviders)
+ })
+
+ t.Run("does not override when env var is not set", func(t *testing.T) {
+ cfg := &Config{
+ Options: &Options{
+ DisableDefaultProviders: true,
+ },
+ }
+ cfg.setDefaults("/tmp", "")
+
+ require.True(t, cfg.Options.DisableDefaultProviders)
+ })
+}
+
func TestConfig_configureSelectedModels(t *testing.T) {
t.Run("should override defaults", func(t *testing.T) {
knownProviders := []catwalk.Provider{
@@ -1,16 +0,0 @@
-package config
-
-import (
- "bytes"
- "io"
-
- "github.com/qjebbs/go-jsons"
-)
-
-func Merge(data []io.Reader) (io.Reader, error) {
- got, err := jsons.Merge(data)
- if err != nil {
- return nil, err
- }
- return bytes.NewReader(got), nil
-}
@@ -1,27 +0,0 @@
-package config
-
-import (
- "io"
- "strings"
- "testing"
-)
-
-func TestMerge(t *testing.T) {
- data1 := strings.NewReader(`{"foo": "bar"}`)
- data2 := strings.NewReader(`{"baz": "qux"}`)
-
- merged, err := Merge([]io.Reader{data1, data2})
- if err != nil {
- t.Fatalf("expected no error, got %v", err)
- }
-
- expected := `{"foo":"bar","baz":"qux"}`
- got, err := io.ReadAll(merged)
- if err != nil {
- t.Fatalf("expected no error reading merged data, got %v", err)
- }
-
- if string(got) != expected {
- t.Errorf("expected %s, got %s", expected, string(got))
- }
-}
@@ -33,8 +33,8 @@ func NewLazyMap[K comparable, V any](load func() map[K]V) *Map[K, V] {
m := &Map[K, V]{}
m.mu.Lock()
go func() {
+ defer m.mu.Unlock()
m.inner = load()
- m.mu.Unlock()
}()
return m
}
@@ -153,10 +153,6 @@ func (c *Client) Initialize(ctx context.Context, workspaceDir string) (*protocol
// Close closes the LSP client.
func (c *Client) Close(ctx context.Context) error {
- // Try to close all open files first
- ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
- defer cancel()
-
c.CloseAllFiles(ctx)
// Shutdown and exit the client
@@ -437,23 +437,27 @@ func (m *Message) AddBinary(mimeType string, data []byte) {
}
func PromptWithTextAttachments(prompt string, attachments []Attachment) string {
+ var sb strings.Builder
+ sb.WriteString(prompt)
addedAttachments := false
for _, content := range attachments {
if !content.IsText() {
continue
}
if !addedAttachments {
- prompt += "\n<system_info>The files below have been attached by the user, consider them in your response</system_info>\n"
+ sb.WriteString("\n<system_info>The files below have been attached by the user, consider them in your response</system_info>\n")
addedAttachments = true
}
- tag := `<file>\n`
if content.FilePath != "" {
- tag = fmt.Sprintf("<file path='%s'>\n", content.FilePath)
+ fmt.Fprintf(&sb, "<file path='%s'>\n", content.FilePath)
+ } else {
+ sb.WriteString("<file>\n")
}
- prompt += tag
- prompt += "\n" + string(content.Content) + "\n</file>\n"
+ sb.WriteString("\n")
+ sb.Write(content.Content)
+ sb.WriteString("\n</file>\n")
}
- return prompt
+ return sb.String()
}
func (m *Message) ToAIMessage() []fantasy.Message {
@@ -0,0 +1,45 @@
+package message
+
+import (
+ "fmt"
+ "strings"
+ "testing"
+)
+
+func makeTestAttachments(n int, contentSize int) []Attachment {
+ attachments := make([]Attachment, n)
+ content := []byte(strings.Repeat("x", contentSize))
+ for i := range n {
+ attachments[i] = Attachment{
+ FilePath: fmt.Sprintf("/path/to/file%d.txt", i),
+ MimeType: "text/plain",
+ Content: content,
+ }
+ }
+ return attachments
+}
+
+func BenchmarkPromptWithTextAttachments(b *testing.B) {
+ cases := []struct {
+ name string
+ numFiles int
+ contentSize int
+ }{
+ {"1file_100bytes", 1, 100},
+ {"5files_1KB", 5, 1024},
+ {"10files_10KB", 10, 10 * 1024},
+ {"20files_50KB", 20, 50 * 1024},
+ }
+
+ for _, tc := range cases {
+ attachments := makeTestAttachments(tc.numFiles, tc.contentSize)
+ prompt := "Process these files"
+
+ b.Run(tc.name, func(b *testing.B) {
+ b.ReportAllocs()
+ for range b.N {
+ _ = PromptWithTextAttachments(prompt, attachments)
+ }
+ })
+ }
+}
@@ -47,7 +47,7 @@ type Service interface {
GrantPersistent(permission PermissionRequest)
Grant(permission PermissionRequest)
Deny(permission PermissionRequest)
- Request(opts CreatePermissionRequest) bool
+ Request(ctx context.Context, opts CreatePermissionRequest) (bool, error)
AutoApproveSession(sessionID string)
SetSkipRequests(skip bool)
SkipRequests() bool
@@ -122,9 +122,9 @@ func (s *permissionService) Deny(permission PermissionRequest) {
}
}
-func (s *permissionService) Request(opts CreatePermissionRequest) bool {
+func (s *permissionService) Request(ctx context.Context, opts CreatePermissionRequest) (bool, error) {
if s.skip {
- return true
+ return true, nil
}
// tell the UI that a permission was requested
@@ -137,7 +137,7 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool {
// Check if the tool/action combination is in the allowlist
commandKey := opts.ToolName + ":" + opts.Action
if slices.Contains(s.allowedTools, commandKey) || slices.Contains(s.allowedTools, opts.ToolName) {
- return true
+ return true, nil
}
s.autoApproveSessionsMu.RLock()
@@ -145,7 +145,7 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool {
s.autoApproveSessionsMu.RUnlock()
if autoApprove {
- return true
+ return true, nil
}
fileInfo, err := os.Stat(opts.Path)
@@ -176,7 +176,7 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool {
for _, p := range s.sessionPermissions {
if p.ToolName == permission.ToolName && p.Action == permission.Action && p.SessionID == permission.SessionID && p.Path == permission.Path {
s.sessionPermissionsMu.RUnlock()
- return true
+ return true, nil
}
}
s.sessionPermissionsMu.RUnlock()
@@ -185,7 +185,7 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool {
for _, p := range s.sessionPermissions {
if p.ToolName == permission.ToolName && p.Action == permission.Action && p.SessionID == permission.SessionID && p.Path == permission.Path {
s.sessionPermissionsMu.RUnlock()
- return true
+ return true, nil
}
}
s.sessionPermissionsMu.RUnlock()
@@ -199,7 +199,12 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool {
// Publish the request
s.Publish(pubsub.CreatedEvent, permission)
- return <-respCh
+ select {
+ case <-ctx.Done():
+ return false, ctx.Err()
+ case granted := <-respCh:
+ return granted, nil
+ }
}
func (s *permissionService) AutoApproveSession(sessionID string) {
@@ -5,6 +5,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
func TestPermissionService_AllowedCommands(t *testing.T) {
@@ -81,14 +82,16 @@ func TestPermissionService_AllowedCommands(t *testing.T) {
func TestPermissionService_SkipMode(t *testing.T) {
service := NewPermissionService("/tmp", true, []string{})
- result := service.Request(CreatePermissionRequest{
+ result, err := service.Request(t.Context(), CreatePermissionRequest{
SessionID: "test-session",
ToolName: "bash",
Action: "execute",
Description: "test command",
Path: "/tmp",
})
-
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ }
if !result {
t.Error("expected permission to be granted in skip mode")
}
@@ -115,7 +118,7 @@ func TestPermissionService_SequentialProperties(t *testing.T) {
go func() {
defer wg.Done()
- result1 = service.Request(req1)
+ result1, _ = service.Request(t.Context(), req1)
}()
var permissionReq PermissionRequest
@@ -136,7 +139,8 @@ func TestPermissionService_SequentialProperties(t *testing.T) {
Params: map[string]string{"file": "test.txt"},
Path: "/tmp/test.txt",
}
- result2 := service.Request(req2)
+ result2, err := service.Request(t.Context(), req2)
+ require.NoError(t, err)
assert.True(t, result2, "Second request should be auto-approved")
})
t.Run("Sequential requests with temporary grants", func(t *testing.T) {
@@ -156,7 +160,7 @@ func TestPermissionService_SequentialProperties(t *testing.T) {
var wg sync.WaitGroup
wg.Go(func() {
- result1 = service.Request(req)
+ result1, _ = service.Request(t.Context(), req)
})
var permissionReq PermissionRequest
@@ -170,7 +174,7 @@ func TestPermissionService_SequentialProperties(t *testing.T) {
var result2 bool
wg.Go(func() {
- result2 = service.Request(req)
+ result2, _ = service.Request(t.Context(), req)
})
event = <-events
@@ -215,7 +219,8 @@ func TestPermissionService_SequentialProperties(t *testing.T) {
wg.Add(1)
go func(index int, request CreatePermissionRequest) {
defer wg.Done()
- results = append(results, service.Request(request))
+ result, _ := service.Request(t.Context(), request)
+ results = append(results, result)
}(i, req)
}
@@ -241,7 +246,8 @@ func TestPermissionService_SequentialProperties(t *testing.T) {
assert.Equal(t, 2, grantedCount, "Should have 2 granted and 1 denied")
secondReq := requests[1]
secondReq.Description = "Repeat of second request"
- result := service.Request(secondReq)
+ result, err := service.Request(t.Context(), secondReq)
+ require.NoError(t, err)
assert.True(t, result, "Repeated request should be auto-approved due to persistent permission")
})
}
@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"fmt"
+ "slices"
"sync"
"sync/atomic"
"time"
@@ -163,15 +164,26 @@ func (m *BackgroundShellManager) Cleanup() int {
// KillAll terminates all background shells.
func (m *BackgroundShellManager) KillAll() {
- shells := make([]*BackgroundShell, 0, m.shells.Len())
- for shell := range m.shells.Seq() {
- shells = append(shells, shell)
- }
+ shells := slices.Collect(m.shells.Seq())
m.shells.Reset(map[string]*BackgroundShell{})
+ done := make(chan struct{}, 1)
+ go func() {
+ var wg sync.WaitGroup
+ for _, shell := range shells {
+ wg.Go(func() {
+ shell.cancel()
+ <-shell.done
+ })
+ }
+ wg.Wait()
+ done <- struct{}{}
+ }()
- for _, shell := range shells {
- shell.cancel()
- <-shell.done
+ select {
+ case <-done:
+ return
+ case <-time.After(time.Second * 5):
+ return
}
}
@@ -247,12 +247,14 @@ func (s *Shell) newInterp(stdout, stderr io.Writer) (*interp.Runner, error) {
)
}
-// updateShellFromRunner updates the shell from the interpreter after execution
+// updateShellFromRunner updates the shell from the interpreter after execution.
func (s *Shell) updateShellFromRunner(runner *interp.Runner) {
s.cwd = runner.Dir
- s.env = nil
+ s.env = s.env[:0]
for name, vr := range runner.Vars {
- s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
+ if vr.Exported {
+ s.env = append(s.env, name+"="+vr.Str)
+ }
}
}
@@ -147,7 +147,7 @@ func Discover(paths []string) []*Skill {
slog.Warn("Skill validation failed", "path", path, "error", err)
return nil
}
- slog.Info("Successfully loaded skill", "name", skill.Name, "path", path)
+ slog.Debug("Successfully loaded skill", "name", skill.Name, "path", path)
mu.Lock()
skills = append(skills, skill)
mu.Unlock()
@@ -421,6 +421,11 @@
"description": "Disable providers auto-update",
"default": false
},
+ "disable_default_providers": {
+ "type": "boolean",
+ "description": "Ignore all default/embedded providers. When enabled",
+ "default": false
+ },
"attribution": {
"$ref": "#/$defs/Attribution",
"description": "Attribution settings for generated content"