diff --git a/.github/cla-signatures.json b/.github/cla-signatures.json
index cf21b7c02c3ecb20d01ac8250cee76e2727b81b2..5929987f916594da1109eee2082c154620edf660 100644
--- a/.github/cla-signatures.json
+++ b/.github/cla-signatures.json
@@ -1055,6 +1055,22 @@
"created_at": "2026-01-12T22:16:05Z",
"repoId": 987670088,
"pullRequestNo": 1841
+ },
+ {
+ "name": "kuxoapp",
+ "id": 254052994,
+ "comment_id": 3747622477,
+ "created_at": "2026-01-14T04:18:44Z",
+ "repoId": 987670088,
+ "pullRequestNo": 1864
+ },
+ {
+ "name": "mhpenta",
+ "id": 183146177,
+ "comment_id": 3749703014,
+ "created_at": "2026-01-14T14:02:04Z",
+ "repoId": 987670088,
+ "pullRequestNo": 1870
}
]
}
\ No newline at end of file
diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 3511f7fe0c4f487eb3fc9009795361ada8e2eff7..39b5923298e2f7fa8d5452327a6e8b2a08f0df97 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -1,11 +1,27 @@
name: build
on: [push, pull_request]
+permissions:
+ contents: read
+
+concurrency:
+ group: build-${{ github.event.pull_request.number || github.ref }}
+ cancel-in-progress: true
+
jobs:
build:
- uses: charmbracelet/meta/.github/workflows/build.yml@main
- with:
- go-version: ""
- go-version-file: ./go.mod
- secrets:
- gh_pat: "${{ secrets.PERSONAL_ACCESS_TOKEN }}"
+ strategy:
+ matrix:
+ os: [ubuntu-latest, macos-latest, windows-latest]
+ runs-on: ${{ matrix.os }}
+ steps:
+ - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
+ with:
+ persist-credentials: false
+ - uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # v6.1.0
+ with:
+ go-version-file: go.mod
+ - run: go mod tidy
+ - run: git diff --exit-code
+ - run: go build -race ./...
+ - run: go test -race -failfast ./...
diff --git a/.github/workflows/security.yml b/.github/workflows/security.yml
new file mode 100644
index 0000000000000000000000000000000000000000..3a90ea316c3d86f5b2f93224fd2b35eaa572e704
--- /dev/null
+++ b/.github/workflows/security.yml
@@ -0,0 +1,92 @@
+name: "security"
+
+on:
+ pull_request:
+ push:
+ branches: [main]
+ schedule:
+ - cron: "0 2 * * *"
+
+permissions:
+ contents: read
+
+concurrency:
+ group: security-${{ github.event.pull_request.number || github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ codeql:
+ runs-on: ubuntu-latest
+ strategy:
+ fail-fast: false
+ matrix:
+ language: ["go", "actions"]
+ permissions:
+ actions: read
+ contents: read
+ pull-requests: read
+ security-events: write
+ steps:
+ - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
+ with:
+ persist-credentials: false
+ - uses: github/codeql-action/init@cf1bb45a277cb3c205638b2cd5c984db1c46a412 # v4.31.7
+ with:
+ languages: ${{ matrix.language }}
+ - uses: github/codeql-action/autobuild@cf1bb45a277cb3c205638b2cd5c984db1c46a412 # v4.31.7
+ - uses: github/codeql-action/analyze@cf1bb45a277cb3c205638b2cd5c984db1c46a412 # v4.31.7
+
+ grype:
+ runs-on: ubuntu-latest
+ permissions:
+ security-events: write
+ actions: read
+ contents: read
+ steps:
+ - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
+ with:
+ persist-credentials: false
+ - uses: anchore/scan-action@40a61b52209e9d50e87917c5b901783d546b12d0 # v7.2.1
+ id: scan
+ with:
+ path: "."
+ fail-build: true
+ severity-cutoff: critical
+ - uses: github/codeql-action/upload-sarif@cf1bb45a277cb3c205638b2cd5c984db1c46a412 # v4.31.7
+ with:
+ sarif_file: ${{ steps.scan.outputs.sarif }}
+
+ govulncheck:
+ runs-on: ubuntu-latest
+ permissions:
+ security-events: write
+ contents: read
+ steps:
+ - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
+ with:
+ persist-credentials: false
+ - uses: actions/setup-go@7a3fe6cf4cb3a834922a1244abfce67bcef6a0c5 # v6.2.0
+ with:
+ go-version: 1.26.0-rc.1 # change to "stable" once Go 1.26 is released
+ - name: Install govulncheck
+ run: go install golang.org/x/vuln/cmd/govulncheck@latest
+ - name: Run govulncheck
+ run: |
+ govulncheck -C . -format sarif ./... > results.sarif
+ - uses: github/codeql-action/upload-sarif@cf1bb45a277cb3c205638b2cd5c984db1c46a412 # v4.31.7
+ with:
+ sarif_file: results.sarif
+
+ dependency-review:
+ runs-on: ubuntu-latest
+ if: github.event_name == 'pull_request'
+ permissions:
+ contents: read
+ steps:
+ - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
+ with:
+ persist-credentials: false
+ - uses: actions/dependency-review-action@3c4e3dcb1aa7874d2c16be7d79418e9b7efd6261 # v4.8.2
+ with:
+ fail-on-severity: critical
+ allow-licenses: BSD-2-Clause, BSD-3-Clause, MIT, Apache-2.0, MPL-2.0, ISC, LicenseRef-scancode-google-patent-license-golang
diff --git a/README.md b/README.md
index 929b77425f1b42452a4e38d8cfa540773dd54a79..cfcd765ee150d181c00cf649a5ba15055b6bdbae 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,7 @@
# Crush
- 
+ 
diff --git a/Taskfile.yaml b/Taskfile.yaml
index 68c805c599314cadde5c86fc37a0e3d1a6184f4e..0043f4f033e455a5800da2431848e620c37a0f5a 100644
--- a/Taskfile.yaml
+++ b/Taskfile.yaml
@@ -5,6 +5,8 @@ version: "3"
vars:
VERSION:
sh: git describe --long 2>/dev/null || echo ""
+ RACE:
+ sh: test -f race.log && echo "1" || echo ""
env:
CGO_ENABLED: 0
@@ -37,20 +39,20 @@ tasks:
vars:
LDFLAGS: '{{if .VERSION}}-ldflags="-X github.com/charmbracelet/crush/internal/version.Version={{.VERSION}}"{{end}}'
cmds:
- - go build {{.LDFLAGS}} .
+ - "go build {{if .RACE}}-race{{end}} {{.LDFLAGS}} ."
generates:
- crush
run:
desc: Run build
cmds:
- - go build -o crush .
- - ./crush {{.CLI_ARGS}}
+ - task: build
+ - "./crush {{.CLI_ARGS}} {{if .RACE}}2>race.log{{end}}"
test:
desc: Run tests
cmds:
- - go test ./... {{.CLI_ARGS}}
+ - go test -race -failfast ./... {{.CLI_ARGS}}
test:record:
desc: Run tests and record all VCR cassettes again
diff --git a/go.mod b/go.mod
index 67c5089c7dd01fc3e4bf66e60ac50bdbe95b7767..0b253c7052b9813f10b721d763ade79aba356624 100644
--- a/go.mod
+++ b/go.mod
@@ -15,10 +15,11 @@ require (
github.com/PuerkitoBio/goquery v1.11.0
github.com/alecthomas/chroma/v2 v2.22.0
github.com/atotto/clipboard v0.1.4
+ github.com/aymanbagabas/go-nativeclipboard v0.1.2
github.com/aymanbagabas/go-udiff v0.3.1
github.com/bmatcuk/doublestar/v4 v4.9.2
github.com/charlievieth/fastwalk v1.0.14
- github.com/charmbracelet/catwalk v0.13.0
+ github.com/charmbracelet/catwalk v0.14.1
github.com/charmbracelet/colorprofile v0.4.1
github.com/charmbracelet/fang v0.4.4
github.com/charmbracelet/ultraviolet v0.0.0-20251212194010-b927aa605560
@@ -30,7 +31,7 @@ require (
github.com/charmbracelet/x/exp/ordered v0.1.0
github.com/charmbracelet/x/exp/slice v0.0.0-20251201173703-9f73bfd934ff
github.com/charmbracelet/x/mosaic v0.0.0-20251215102626-e0db08df7383
- github.com/charmbracelet/x/powernap v0.0.0-20251015113943-25f979b54ad4
+ github.com/charmbracelet/x/powernap v0.0.0-20260113142046-c1fa3de7983b
github.com/charmbracelet/x/term v0.2.2
github.com/denisbrodbeck/machineid v1.0.1
github.com/disintegration/imageorient v0.0.0-20180920195336-8147d86e83ec
@@ -103,12 +104,13 @@ require (
github.com/charmbracelet/x/json v0.2.0 // indirect
github.com/charmbracelet/x/termios v0.1.1 // indirect
github.com/charmbracelet/x/windows v0.2.2 // indirect
- github.com/clipperhouse/displaywidth v0.6.1 // indirect
+ github.com/clipperhouse/displaywidth v0.6.2 // indirect
github.com/clipperhouse/stringish v0.1.1 // indirect
github.com/clipperhouse/uax29/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/disintegration/gift v1.1.2 // indirect
github.com/dlclark/regexp2 v1.11.5 // indirect
+ github.com/ebitengine/purego v0.10.0-alpha.3.0.20260102153238-200df6041cff // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e // indirect
diff --git a/go.sum b/go.sum
index b18aab1b57cd3caf93fee69f1c73b565432ae37b..6346b736a83815afe4daff8390a03b941539ad6d 100644
--- a/go.sum
+++ b/go.sum
@@ -78,6 +78,8 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.41.6 h1:5fFjR/ToSOzB2OQ/XqWpZBmNvmP/
github.com/aws/aws-sdk-go-v2/service/sts v1.41.6/go.mod h1:qgFDZQSD/Kys7nJnVqYlWKnh0SSdMjAi0uSwON4wgYQ=
github.com/aws/smithy-go v1.24.0 h1:LpilSUItNPFr1eY85RYgTIg5eIEPtvFbskaFcmmIUnk=
github.com/aws/smithy-go v1.24.0/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
+github.com/aymanbagabas/go-nativeclipboard v0.1.2 h1:Z2iVRWQ4IynMLWM6a+lWH2Nk5gPyEtPRMuBIyZ2dECM=
+github.com/aymanbagabas/go-nativeclipboard v0.1.2/go.mod h1:BVJhN7hs5DieCzUB2Atf4Yk9Y9kFe62E95+gOjpJq6Q=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/aymanbagabas/go-udiff v0.3.1 h1:LV+qyBQ2pqe0u42ZsUEtPiCaUoqgA9gYRDs3vj1nolY=
@@ -94,8 +96,8 @@ github.com/charlievieth/fastwalk v1.0.14 h1:3Eh5uaFGwHZd8EGwTjJnSpBkfwfsak9h6ICg
github.com/charlievieth/fastwalk v1.0.14/go.mod h1:diVcUreiU1aQ4/Wu3NbxxH4/KYdKpLDojrQ1Bb2KgNY=
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20251024181547-21d6f3d9a904 h1:rwLdEpG9wE6kL69KkEKDiWprO8pQOZHZXeod6+9K+mw=
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20251024181547-21d6f3d9a904/go.mod h1:8TIYxZxsuCqqeJ0lga/b91tBwrbjoHDC66Sq5t8N2R4=
-github.com/charmbracelet/catwalk v0.13.0 h1:L+chddP+PJvX3Vl+hqlWW5HAwBErlkL/friQXih1JQI=
-github.com/charmbracelet/catwalk v0.13.0/go.mod h1:qg+Yl9oaZTkTvRscqbxfttzOFQ4v0pOT5XwC7b5O0NQ=
+github.com/charmbracelet/catwalk v0.14.1 h1:n16H880MHW8PPgQeh0dorP77AJMxw5JcOUPuC3FFhaQ=
+github.com/charmbracelet/catwalk v0.14.1/go.mod h1:qg+Yl9oaZTkTvRscqbxfttzOFQ4v0pOT5XwC7b5O0NQ=
github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk=
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
github.com/charmbracelet/fang v0.4.4 h1:G4qKxF6or/eTPgmAolwPuRNyuci3hTUGGX1rj1YkHJY=
@@ -120,16 +122,16 @@ github.com/charmbracelet/x/json v0.2.0 h1:DqB+ZGx2h+Z+1s98HOuOyli+i97wsFQIxP2ZQA
github.com/charmbracelet/x/json v0.2.0/go.mod h1:opFIflx2YgXgi49xVUu8gEQ21teFAxyMwvOiZhIvWNM=
github.com/charmbracelet/x/mosaic v0.0.0-20251215102626-e0db08df7383 h1:YpTd2/abobMn/dCRM6Vo+G7JO/VS6RW0Ln3YkVJih8Y=
github.com/charmbracelet/x/mosaic v0.0.0-20251215102626-e0db08df7383/go.mod h1:r+fiJS0jb0Z5XKO+1mgKbwbPWzTy8e2dMjBMqa+XqsY=
-github.com/charmbracelet/x/powernap v0.0.0-20251015113943-25f979b54ad4 h1:i/XilBPYK4L1Yo/mc9FPx0SyJzIsN0y4sj1MWq9Sscc=
-github.com/charmbracelet/x/powernap v0.0.0-20251015113943-25f979b54ad4/go.mod h1:cmdl5zlP5mR8TF2Y68UKc7hdGUDiSJ2+4hk0h04Hsx4=
+github.com/charmbracelet/x/powernap v0.0.0-20260113142046-c1fa3de7983b h1:5ye9hzBKH623bMVz5auIuY6K21loCdxpRmFle2O9R/8=
+github.com/charmbracelet/x/powernap v0.0.0-20260113142046-c1fa3de7983b/go.mod h1:cmdl5zlP5mR8TF2Y68UKc7hdGUDiSJ2+4hk0h04Hsx4=
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
github.com/charmbracelet/x/termios v0.1.1 h1:o3Q2bT8eqzGnGPOYheoYS8eEleT5ZVNYNy8JawjaNZY=
github.com/charmbracelet/x/termios v0.1.1/go.mod h1:rB7fnv1TgOPOyyKRJ9o+AsTU/vK5WHJ2ivHeut/Pcwo=
github.com/charmbracelet/x/windows v0.2.2 h1:IofanmuvaxnKHuV04sC0eBy/smG6kIKrWG2/jYn2GuM=
github.com/charmbracelet/x/windows v0.2.2/go.mod h1:/8XtdKZzedat74NQFn0NGlGL4soHB0YQZrETF96h75k=
-github.com/clipperhouse/displaywidth v0.6.1 h1:/zMlAezfDzT2xy6acHBzwIfyu2ic0hgkT83UX5EY2gY=
-github.com/clipperhouse/displaywidth v0.6.1/go.mod h1:R+kHuzaYWFkTm7xoMmK1lFydbci4X2CicfbGstSGg0o=
+github.com/clipperhouse/displaywidth v0.6.2 h1:ZDpTkFfpHOKte4RG5O/BOyf3ysnvFswpyYrV7z2uAKo=
+github.com/clipperhouse/displaywidth v0.6.2/go.mod h1:R+kHuzaYWFkTm7xoMmK1lFydbci4X2CicfbGstSGg0o=
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
github.com/clipperhouse/uax29/v2 v2.3.0 h1:SNdx9DVUqMoBuBoW3iLOj4FQv3dN5mDtuqwuhIGpJy4=
@@ -154,6 +156,8 @@ github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZ
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
+github.com/ebitengine/purego v0.10.0-alpha.3.0.20260102153238-200df6041cff h1:vAcU1VsCRstZ9ty11yD/L0WDyT73S/gVfmuWvcWX5DA=
+github.com/ebitengine/purego v0.10.0-alpha.3.0.20260102153238-200df6041cff/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
github.com/envoyproxy/go-control-plane v0.13.4 h1:zEqyPVyku6IvWCFwux4x9RxkLOMUL+1vC9xUFv5l2/M=
github.com/envoyproxy/go-control-plane/envoy v1.32.4 h1:jb83lalDRZSpPWW2Z7Mck/8kXZ5CQAFYVjQcdVIr83A=
github.com/envoyproxy/go-control-plane/envoy v1.32.4/go.mod h1:Gzjc5k8JcJswLjAx1Zm+wSYE20UrLtt7JZMWiWQXQEw=
diff --git a/internal/agent/agent.go b/internal/agent/agent.go
index 8d2fa40fd427143bf988587ef7faa3a89c3e23b1..c0b9080bb640085c6fd0fdbde8db0fbfe7f476dd 100644
--- a/internal/agent/agent.go
+++ b/internal/agent/agent.go
@@ -87,12 +87,13 @@ type Model struct {
}
type sessionAgent struct {
- largeModel Model
- smallModel Model
- systemPromptPrefix string
- systemPrompt string
+ largeModel *csync.Value[Model]
+ smallModel *csync.Value[Model]
+ systemPromptPrefix *csync.Value[string]
+ systemPrompt *csync.Value[string]
+ tools *csync.Slice[fantasy.AgentTool]
+
isSubAgent bool
- tools []fantasy.AgentTool
sessions session.Service
messages message.Service
disableAutoSummarize bool
@@ -119,15 +120,15 @@ func NewSessionAgent(
opts SessionAgentOptions,
) SessionAgent {
return &sessionAgent{
- largeModel: opts.LargeModel,
- smallModel: opts.SmallModel,
- systemPromptPrefix: opts.SystemPromptPrefix,
- systemPrompt: opts.SystemPrompt,
+ largeModel: csync.NewValue(opts.LargeModel),
+ smallModel: csync.NewValue(opts.SmallModel),
+ systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
+ systemPrompt: csync.NewValue(opts.SystemPrompt),
isSubAgent: opts.IsSubAgent,
sessions: opts.Sessions,
messages: opts.Messages,
disableAutoSummarize: opts.DisableAutoSummarize,
- tools: opts.Tools,
+ tools: csync.NewSliceFrom(opts.Tools),
isYolo: opts.IsYolo,
messageQueue: csync.NewMap[string, []SessionAgentCall](),
activeRequests: csync.NewMap[string, context.CancelFunc](),
@@ -153,15 +154,21 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
return nil, nil
}
- if len(a.tools) > 0 {
+ // Copy mutable fields under lock to avoid races with SetTools/SetModels.
+ agentTools := a.tools.Copy()
+ largeModel := a.largeModel.Get()
+ systemPrompt := a.systemPrompt.Get()
+ promptPrefix := a.systemPromptPrefix.Get()
+
+ if len(agentTools) > 0 {
// Add Anthropic caching to the last tool.
- a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
+ agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
}
agent := fantasy.NewAgent(
- a.largeModel.Model,
- fantasy.WithSystemPrompt(a.systemPrompt),
- fantasy.WithTools(a.tools...),
+ largeModel.Model,
+ fantasy.WithSystemPrompt(systemPrompt),
+ fantasy.WithTools(agentTools...),
)
sessionLock := sync.Mutex{}
@@ -183,6 +190,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
a.generateTitle(titleCtx, call.SessionID, call.Prompt)
})
}
+ defer wg.Wait()
// Add the user message to the session.
_, err = a.createUserMessage(ctx, call)
@@ -233,7 +241,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
}
- prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
+ prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
lastSystemRoleInx := 0
systemMessageUpdated := false
@@ -251,7 +259,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
}
}
- if promptPrefix := a.promptPrefix(); promptPrefix != "" {
+ if promptPrefix != "" {
prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
}
@@ -259,15 +267,15 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
Role: message.Assistant,
Parts: []message.ContentPart{},
- Model: a.largeModel.ModelCfg.Model,
- Provider: a.largeModel.ModelCfg.Provider,
+ Model: largeModel.ModelCfg.Model,
+ Provider: largeModel.ModelCfg.Provider,
})
if err != nil {
return callContext, prepared, err
}
callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
- callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
- callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
+ callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
+ callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
currentAssistant = &assistantMsg
return callContext, prepared, err
},
@@ -361,7 +369,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
sessionLock.Unlock()
return getSessionErr
}
- a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
+ a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
_, sessionErr := a.sessions.Save(genCtx, updatedSession)
sessionLock.Unlock()
if sessionErr != nil {
@@ -371,7 +379,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
},
StopWhen: []fantasy.StopCondition{
func(_ []fantasy.StepResult) bool {
- cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
+ cw := int64(largeModel.CatwalkCfg.ContextWindow)
tokens := currentSession.CompletionTokens + currentSession.PromptTokens
remaining := cw - tokens
var threshold int64
@@ -473,7 +481,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
currentAssistant.AddFinish(
message.FinishReasonError,
"Copilot model not enabled",
- fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", a.largeModel.CatwalkCfg.Name, link),
+ fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
)
} else {
currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
@@ -491,7 +499,6 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
}
return nil, err
}
- wg.Wait()
if shouldSummarize {
a.activeRequests.Del(call.SessionID)
@@ -529,6 +536,10 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan
return ErrSessionBusy
}
+ // Copy mutable fields under lock to avoid races with SetModels.
+ largeModel := a.largeModel.Get()
+ systemPromptPrefix := a.systemPromptPrefix.Get()
+
currentSession, err := a.sessions.Get(ctx, sessionID)
if err != nil {
return fmt.Errorf("failed to get session: %w", err)
@@ -549,13 +560,13 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan
defer a.activeRequests.Del(sessionID)
defer cancel()
- agent := fantasy.NewAgent(a.largeModel.Model,
+ agent := fantasy.NewAgent(largeModel.Model,
fantasy.WithSystemPrompt(string(summaryPrompt)),
)
summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
Role: message.Assistant,
- Model: a.largeModel.Model.Model(),
- Provider: a.largeModel.Model.Provider(),
+ Model: largeModel.Model.Model(),
+ Provider: largeModel.Model.Provider(),
IsSummaryMessage: true,
})
if err != nil {
@@ -570,8 +581,8 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan
ProviderOptions: opts,
PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
prepared.Messages = options.Messages
- if a.systemPromptPrefix != "" {
- prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
+ if systemPromptPrefix != "" {
+ prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
}
return callContext, prepared, nil
},
@@ -622,7 +633,7 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan
}
}
- a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
+ a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
// Just in case, get just the last usage info.
usage := resp.Response.Usage
@@ -730,9 +741,13 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, user
return
}
+ smallModel := a.smallModel.Get()
+ largeModel := a.largeModel.Get()
+ systemPromptPrefix := a.systemPromptPrefix.Get()
+
var maxOutputTokens int64 = 40
- if a.smallModel.CatwalkCfg.CanReason {
- maxOutputTokens = a.smallModel.CatwalkCfg.DefaultMaxTokens
+ if smallModel.CatwalkCfg.CanReason {
+ maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
}
newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
@@ -746,9 +761,9 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, user
Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n \n\n", userPrompt),
PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
prepared.Messages = opts.Messages
- if a.systemPromptPrefix != "" {
+ if systemPromptPrefix != "" {
prepared.Messages = append([]fantasy.Message{
- fantasy.NewSystemMessage(a.systemPromptPrefix),
+ fantasy.NewSystemMessage(systemPromptPrefix),
}, prepared.Messages...)
}
return callCtx, prepared, nil
@@ -756,7 +771,7 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, user
}
// Use the small model to generate the title.
- model := &a.smallModel
+ model := smallModel
agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
resp, err := agent.Stream(ctx, streamCall)
if err == nil {
@@ -765,7 +780,7 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, user
} else {
// It didn't work. Let's try with the big model.
slog.Error("error generating title with small model; trying big model", "err", err)
- model = &a.largeModel
+ model = largeModel
agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
resp, err = agent.Stream(ctx, streamCall)
if err == nil {
@@ -960,24 +975,20 @@ func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
}
func (a *sessionAgent) SetModels(large Model, small Model) {
- a.largeModel = large
- a.smallModel = small
+ a.largeModel.Set(large)
+ a.smallModel.Set(small)
}
func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
- a.tools = tools
+ a.tools.SetSlice(tools)
}
func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
- a.systemPrompt = systemPrompt
+ a.systemPrompt.Set(systemPrompt)
}
func (a *sessionAgent) Model() Model {
- return a.largeModel
-}
-
-func (a *sessionAgent) promptPrefix() string {
- return a.systemPromptPrefix
+ return a.largeModel.Get()
}
// convertToToolResult converts a fantasy tool result to a message tool result.
@@ -1034,9 +1045,9 @@ func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) mes
//
// BEFORE: [tool result: image data]
// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
-func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
- providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
- a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
+func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
+ providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
+ largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
if providerSupportsMedia {
return messages
diff --git a/internal/agent/event.go b/internal/agent/event.go
index bf36ec84bf4270bd2e63ae0efae0440474288565..3f6c640f6a983c515034e0698676632d0cb57824 100644
--- a/internal/agent/event.go
+++ b/internal/agent/event.go
@@ -7,23 +7,23 @@ import (
"github.com/charmbracelet/crush/internal/event"
)
-func (a sessionAgent) eventPromptSent(sessionID string) {
+func (a *sessionAgent) eventPromptSent(sessionID string) {
event.PromptSent(
- a.eventCommon(sessionID, a.largeModel)...,
+ a.eventCommon(sessionID, a.largeModel.Get())...,
)
}
-func (a sessionAgent) eventPromptResponded(sessionID string, duration time.Duration) {
+func (a *sessionAgent) eventPromptResponded(sessionID string, duration time.Duration) {
event.PromptResponded(
append(
- a.eventCommon(sessionID, a.largeModel),
+ a.eventCommon(sessionID, a.largeModel.Get()),
"prompt duration pretty", duration.String(),
"prompt duration in seconds", int64(duration.Seconds()),
)...,
)
}
-func (a sessionAgent) eventTokensUsed(sessionID string, model Model, usage fantasy.Usage, cost float64) {
+func (a *sessionAgent) eventTokensUsed(sessionID string, model Model, usage fantasy.Usage, cost float64) {
event.TokensUsed(
append(
a.eventCommon(sessionID, model),
@@ -37,7 +37,7 @@ func (a sessionAgent) eventTokensUsed(sessionID string, model Model, usage fanta
)
}
-func (a sessionAgent) eventCommon(sessionID string, model Model) []any {
+func (a *sessionAgent) eventCommon(sessionID string, model Model) []any {
m := model.ModelCfg
return []any{
diff --git a/internal/commands/commands.go b/internal/commands/commands.go
index 169b789abd224b774592032c02a5156b91efb3a5..b3fd3915182fa293aefc1fe60ec54e5b369fa591 100644
--- a/internal/commands/commands.go
+++ b/internal/commands/commands.go
@@ -1,6 +1,7 @@
package commands
import (
+ "context"
"io/fs"
"os"
"path/filepath"
@@ -19,18 +20,22 @@ const (
projectCommandPrefix = "project:"
)
-// Argument represents a command argument with its name and required status.
+// Argument represents a command argument with its metadata.
type Argument struct {
- Name string
- Required bool
+ ID string
+ Title string
+ Description string
+ Required bool
}
-// MCPCustomCommand represents a custom command loaded from an MCP server.
-type MCPCustomCommand struct {
- ID string
- Name string
- Client string
- Arguments []Argument
+// MCPPrompt represents a custom command loaded from an MCP server.
+type MCPPrompt struct {
+ ID string
+ Title string
+ Description string
+ PromptID string
+ ClientID string
+ Arguments []Argument
}
// CustomCommand represents a user-defined custom command loaded from markdown files.
@@ -52,22 +57,32 @@ func LoadCustomCommands(cfg *config.Config) ([]CustomCommand, error) {
return loadAll(buildCommandSources(cfg))
}
-// LoadMCPCustomCommands loads custom commands from available MCP servers.
-func LoadMCPCustomCommands() ([]MCPCustomCommand, error) {
- var commands []MCPCustomCommand
+// LoadMCPPrompts loads custom commands from available MCP servers.
+func LoadMCPPrompts() ([]MCPPrompt, error) {
+ var commands []MCPPrompt
for mcpName, prompts := range mcp.Prompts() {
for _, prompt := range prompts {
key := mcpName + ":" + prompt.Name
var args []Argument
for _, arg := range prompt.Arguments {
- args = append(args, Argument{Name: arg.Name, Required: arg.Required})
+ title := arg.Title
+ if title == "" {
+ title = arg.Name
+ }
+ args = append(args, Argument{
+ ID: arg.Name,
+ Title: title,
+ Description: arg.Description,
+ Required: arg.Required,
+ })
}
-
- commands = append(commands, MCPCustomCommand{
- ID: key,
- Name: prompt.Name,
- Client: mcpName,
- Arguments: args,
+ commands = append(commands, MCPPrompt{
+ ID: key,
+ Title: prompt.Title,
+ Description: prompt.Description,
+ PromptID: prompt.Name,
+ ClientID: mcpName,
+ Arguments: args,
})
}
}
@@ -168,7 +183,7 @@ func extractArgNames(content string) []Argument {
if !seen[arg] {
seen[arg] = true
// for normal custom commands, all args are required
- args = append(args, Argument{Name: arg, Required: true})
+ args = append(args, Argument{ID: arg, Title: arg, Required: true})
}
}
@@ -211,3 +226,12 @@ func ensureDir(path string) error {
func isMarkdownFile(name string) bool {
return strings.HasSuffix(strings.ToLower(name), ".md")
}
+
+func GetMCPPrompt(clientID, promptID string, args map[string]string) (string, error) {
+ // TODO: we should pass the context down
+ result, err := mcp.GetPromptMessages(context.Background(), clientID, promptID, args)
+ if err != nil {
+ return "", err
+ }
+ return strings.Join(result, " "), nil
+}
diff --git a/internal/csync/maps_test.go b/internal/csync/maps_test.go
index 4c590f008dad91e8dcbc40d1b90d87ef1b3e5750..31e6fa0c3aef18a04c61ea3d4d36b5187228c3ff 100644
--- a/internal/csync/maps_test.go
+++ b/internal/csync/maps_test.go
@@ -4,6 +4,7 @@ import (
"encoding/json"
"maps"
"sync"
+ "sync/atomic"
"testing"
"testing/synctest"
"time"
@@ -46,12 +47,12 @@ func TestNewLazyMap(t *testing.T) {
waiter := sync.Mutex{}
waiter.Lock()
- loadCalled := false
+ var loadCalled atomic.Bool
loadFunc := func() map[string]int {
waiter.Lock()
defer waiter.Unlock()
- loadCalled = true
+ loadCalled.Store(true)
return map[string]int{
"key1": 1,
"key2": 2,
@@ -63,7 +64,7 @@ func TestNewLazyMap(t *testing.T) {
waiter.Unlock() // Allow the load function to proceed
time.Sleep(100 * time.Millisecond)
- require.True(t, loadCalled)
+ require.True(t, loadCalled.Load())
require.Equal(t, 2, m.Len())
value, ok := m.Get("key1")
diff --git a/internal/csync/slices.go b/internal/csync/slices.go
index c5c635683e70046694f1cdf647aac8cb425abd24..fcce9881b6e27021adcc9462b123f49d469dcd9f 100644
--- a/internal/csync/slices.go
+++ b/internal/csync/slices.go
@@ -2,7 +2,6 @@ package csync
import (
"iter"
- "slices"
"sync"
)
@@ -63,24 +62,6 @@ func (s *Slice[T]) Append(items ...T) {
s.inner = append(s.inner, items...)
}
-// Prepend adds an element to the beginning of the slice.
-func (s *Slice[T]) Prepend(item T) {
- s.mu.Lock()
- defer s.mu.Unlock()
- s.inner = append([]T{item}, s.inner...)
-}
-
-// Delete removes the element at the specified index.
-func (s *Slice[T]) Delete(index int) bool {
- s.mu.Lock()
- defer s.mu.Unlock()
- if index < 0 || index >= len(s.inner) {
- return false
- }
- s.inner = slices.Delete(s.inner, index, index+1)
- return true
-}
-
// Get returns the element at the specified index.
func (s *Slice[T]) Get(index int) (T, bool) {
s.mu.RLock()
@@ -92,17 +73,6 @@ func (s *Slice[T]) Get(index int) (T, bool) {
return s.inner[index], true
}
-// Set updates the element at the specified index.
-func (s *Slice[T]) Set(index int, item T) bool {
- s.mu.Lock()
- defer s.mu.Unlock()
- if index < 0 || index >= len(s.inner) {
- return false
- }
- s.inner[index] = item
- return true
-}
-
// Len returns the number of elements in the slice.
func (s *Slice[T]) Len() int {
s.mu.RLock()
@@ -131,10 +101,7 @@ func (s *Slice[T]) Seq() iter.Seq[T] {
// Seq2 returns an iterator that yields index-value pairs from the slice.
func (s *Slice[T]) Seq2() iter.Seq2[int, T] {
- s.mu.RLock()
- items := make([]T, len(s.inner))
- copy(items, s.inner)
- s.mu.RUnlock()
+ items := s.Copy()
return func(yield func(int, T) bool) {
for i, v := range items {
if !yield(i, v) {
@@ -143,3 +110,12 @@ func (s *Slice[T]) Seq2() iter.Seq2[int, T] {
}
}
}
+
+// Copy returns a copy of the inner slice.
+func (s *Slice[T]) Copy() []T {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ items := make([]T, len(s.inner))
+ copy(items, s.inner)
+ return items
+}
diff --git a/internal/csync/slices_test.go b/internal/csync/slices_test.go
index 85aedbaba40103ff9a8979e5c70299223f74591f..c7946ac6f1a84614def05b7b6e7e9b0ed11b3a73 100644
--- a/internal/csync/slices_test.go
+++ b/internal/csync/slices_test.go
@@ -109,44 +109,6 @@ func TestSlice(t *testing.T) {
require.Equal(t, "world", val)
})
- t.Run("Prepend", func(t *testing.T) {
- s := NewSlice[string]()
- s.Append("world")
- s.Prepend("hello")
-
- require.Equal(t, 2, s.Len())
- val, ok := s.Get(0)
- require.True(t, ok)
- require.Equal(t, "hello", val)
-
- val, ok = s.Get(1)
- require.True(t, ok)
- require.Equal(t, "world", val)
- })
-
- t.Run("Delete", func(t *testing.T) {
- s := NewSliceFrom([]int{1, 2, 3, 4, 5})
-
- // Delete middle element
- ok := s.Delete(2)
- require.True(t, ok)
- require.Equal(t, 4, s.Len())
-
- expected := []int{1, 2, 4, 5}
- actual := slices.Collect(s.Seq())
- require.Equal(t, expected, actual)
-
- // Delete out of bounds
- ok = s.Delete(10)
- require.False(t, ok)
- require.Equal(t, 4, s.Len())
-
- // Delete negative index
- ok = s.Delete(-1)
- require.False(t, ok)
- require.Equal(t, 4, s.Len())
- })
-
t.Run("Get", func(t *testing.T) {
s := NewSliceFrom([]string{"a", "b", "c"})
@@ -163,25 +125,6 @@ func TestSlice(t *testing.T) {
require.False(t, ok)
})
- t.Run("Set", func(t *testing.T) {
- s := NewSliceFrom([]string{"a", "b", "c"})
-
- ok := s.Set(1, "modified")
- require.True(t, ok)
-
- val, ok := s.Get(1)
- require.True(t, ok)
- require.Equal(t, "modified", val)
-
- // Out of bounds
- ok = s.Set(10, "invalid")
- require.False(t, ok)
-
- // Negative index
- ok = s.Set(-1, "invalid")
- require.False(t, ok)
- })
-
t.Run("SetSlice", func(t *testing.T) {
s := NewSlice[int]()
s.Append(1)
diff --git a/internal/csync/value.go b/internal/csync/value.go
new file mode 100644
index 0000000000000000000000000000000000000000..17528a281e0d34d49b206a7c3901b892370c18ba
--- /dev/null
+++ b/internal/csync/value.go
@@ -0,0 +1,44 @@
+package csync
+
+import (
+ "reflect"
+ "sync"
+)
+
+// Value is a generic thread-safe wrapper for any value type.
+//
+// For slices, use [Slice]. For maps, use [Map]. Pointers are not supported.
+type Value[T any] struct {
+ v T
+ mu sync.RWMutex
+}
+
+// NewValue creates a new Value with the given initial value.
+//
+// Panics if t is a pointer, slice, or map. Use the dedicated types for those.
+func NewValue[T any](t T) *Value[T] {
+ v := reflect.ValueOf(t)
+ switch v.Kind() {
+ case reflect.Pointer:
+ panic("csync.Value does not support pointer types")
+ case reflect.Slice:
+ panic("csync.Value does not support slice types; use csync.Slice")
+ case reflect.Map:
+ panic("csync.Value does not support map types; use csync.Map")
+ }
+ return &Value[T]{v: t}
+}
+
+// Get returns the current value.
+func (v *Value[T]) Get() T {
+ v.mu.RLock()
+ defer v.mu.RUnlock()
+ return v.v
+}
+
+// Set updates the value.
+func (v *Value[T]) Set(t T) {
+ v.mu.Lock()
+ defer v.mu.Unlock()
+ v.v = t
+}
diff --git a/internal/csync/value_test.go b/internal/csync/value_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..3fa41d85144ea9373c7d440238c0321f52286330
--- /dev/null
+++ b/internal/csync/value_test.go
@@ -0,0 +1,99 @@
+package csync
+
+import (
+ "sync"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestValue_GetSet(t *testing.T) {
+ t.Parallel()
+
+ v := NewValue(42)
+ require.Equal(t, 42, v.Get())
+
+ v.Set(100)
+ require.Equal(t, 100, v.Get())
+}
+
+func TestValue_ZeroValue(t *testing.T) {
+ t.Parallel()
+
+ v := NewValue("")
+ require.Equal(t, "", v.Get())
+
+ v.Set("hello")
+ require.Equal(t, "hello", v.Get())
+}
+
+func TestValue_Struct(t *testing.T) {
+ t.Parallel()
+
+ type config struct {
+ Name string
+ Count int
+ }
+
+ v := NewValue(config{Name: "test", Count: 1})
+ require.Equal(t, config{Name: "test", Count: 1}, v.Get())
+
+ v.Set(config{Name: "updated", Count: 2})
+ require.Equal(t, config{Name: "updated", Count: 2}, v.Get())
+}
+
+func TestValue_PointerPanics(t *testing.T) {
+ t.Parallel()
+
+ require.Panics(t, func() {
+ NewValue(&struct{}{})
+ })
+}
+
+func TestValue_SlicePanics(t *testing.T) {
+ t.Parallel()
+
+ require.Panics(t, func() {
+ NewValue([]string{"a", "b"})
+ })
+}
+
+func TestValue_MapPanics(t *testing.T) {
+ t.Parallel()
+
+ require.Panics(t, func() {
+ NewValue(map[string]int{"a": 1})
+ })
+}
+
+func TestValue_ConcurrentAccess(t *testing.T) {
+ t.Parallel()
+
+ v := NewValue(0)
+ var wg sync.WaitGroup
+
+ // Concurrent writers.
+ for i := range 100 {
+ wg.Add(1)
+ go func(val int) {
+ defer wg.Done()
+ v.Set(val)
+ }(i)
+ }
+
+ // Concurrent readers.
+ for range 100 {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ _ = v.Get()
+ }()
+ }
+
+ wg.Wait()
+
+ // Value should be one of the set values (0-99).
+ got := v.Get()
+ require.GreaterOrEqual(t, got, 0)
+ require.Less(t, got, 100)
+}
diff --git a/internal/permission/permission.go b/internal/permission/permission.go
index 9dc85e976238fdbe1ff2d3689b2a2c4160608760..e1bf1bae14b8473989b1c0890c58188591123d71 100644
--- a/internal/permission/permission.go
+++ b/internal/permission/permission.go
@@ -68,8 +68,9 @@ type permissionService struct {
allowedTools []string
// used to make sure we only process one request at a time
- requestMu sync.Mutex
- activeRequest *PermissionRequest
+ requestMu sync.Mutex
+ activeRequest *PermissionRequest
+ activeRequestMu sync.Mutex
}
func (s *permissionService) GrantPersistent(permission PermissionRequest) {
@@ -86,9 +87,11 @@ func (s *permissionService) GrantPersistent(permission PermissionRequest) {
s.sessionPermissions = append(s.sessionPermissions, permission)
s.sessionPermissionsMu.Unlock()
+ s.activeRequestMu.Lock()
if s.activeRequest != nil && s.activeRequest.ID == permission.ID {
s.activeRequest = nil
}
+ s.activeRequestMu.Unlock()
}
func (s *permissionService) Grant(permission PermissionRequest) {
@@ -101,9 +104,11 @@ func (s *permissionService) Grant(permission PermissionRequest) {
respCh <- true
}
+ s.activeRequestMu.Lock()
if s.activeRequest != nil && s.activeRequest.ID == permission.ID {
s.activeRequest = nil
}
+ s.activeRequestMu.Unlock()
}
func (s *permissionService) Deny(permission PermissionRequest) {
@@ -117,9 +122,11 @@ func (s *permissionService) Deny(permission PermissionRequest) {
respCh <- false
}
+ s.activeRequestMu.Lock()
if s.activeRequest != nil && s.activeRequest.ID == permission.ID {
s.activeRequest = nil
}
+ s.activeRequestMu.Unlock()
}
func (s *permissionService) Request(ctx context.Context, opts CreatePermissionRequest) (bool, error) {
@@ -190,7 +197,9 @@ func (s *permissionService) Request(ctx context.Context, opts CreatePermissionRe
}
s.sessionPermissionsMu.RUnlock()
+ s.activeRequestMu.Lock()
s.activeRequest = &permission
+ s.activeRequestMu.Unlock()
respCh := make(chan bool, 1)
s.pendingRequests.Set(permission.ID, respCh)
diff --git a/internal/permission/permission_test.go b/internal/permission/permission_test.go
index 89e06916024cd1669f5e0d0a263d4a71548c8a97..79930f3ae1e2ef15257f09724fef64d3ea28dada 100644
--- a/internal/permission/permission_test.go
+++ b/internal/permission/permission_test.go
@@ -189,7 +189,7 @@ func TestPermissionService_SequentialProperties(t *testing.T) {
events := service.Subscribe(t.Context())
var wg sync.WaitGroup
- results := make([]bool, 0)
+ results := make([]bool, 3)
requests := []CreatePermissionRequest{
{
@@ -220,7 +220,7 @@ func TestPermissionService_SequentialProperties(t *testing.T) {
go func(index int, request CreatePermissionRequest) {
defer wg.Done()
result, _ := service.Request(t.Context(), request)
- results = append(results, result)
+ results[index] = result
}(i, req)
}
diff --git a/internal/shell/background.go b/internal/shell/background.go
index bc81369ec877586c92fa9bc701d8b78b669f23d5..cb1855836f64bdd56a90802c2bbb939a5a514100 100644
--- a/internal/shell/background.go
+++ b/internal/shell/background.go
@@ -19,6 +19,30 @@ const (
CompletedJobRetentionMinutes = 8 * 60
)
+// syncBuffer is a thread-safe wrapper around bytes.Buffer.
+type syncBuffer struct {
+ buf bytes.Buffer
+ mu sync.RWMutex
+}
+
+func (sb *syncBuffer) Write(p []byte) (n int, err error) {
+ sb.mu.Lock()
+ defer sb.mu.Unlock()
+ return sb.buf.Write(p)
+}
+
+func (sb *syncBuffer) WriteString(s string) (n int, err error) {
+ sb.mu.Lock()
+ defer sb.mu.Unlock()
+ return sb.buf.WriteString(s)
+}
+
+func (sb *syncBuffer) String() string {
+ sb.mu.RLock()
+ defer sb.mu.RUnlock()
+ return sb.buf.String()
+}
+
// BackgroundShell represents a shell running in the background.
type BackgroundShell struct {
ID string
@@ -28,8 +52,8 @@ type BackgroundShell struct {
WorkingDir string
ctx context.Context
cancel context.CancelFunc
- stdout *bytes.Buffer
- stderr *bytes.Buffer
+ stdout *syncBuffer
+ stderr *syncBuffer
done chan struct{}
exitErr error
completedAt int64 // Unix timestamp when job completed (0 if still running)
@@ -46,12 +70,17 @@ var (
idCounter atomic.Uint64
)
+// newBackgroundShellManager creates a new BackgroundShellManager instance.
+func newBackgroundShellManager() *BackgroundShellManager {
+ return &BackgroundShellManager{
+ shells: csync.NewMap[string, *BackgroundShell](),
+ }
+}
+
// GetBackgroundShellManager returns the singleton background shell manager.
func GetBackgroundShellManager() *BackgroundShellManager {
backgroundManagerOnce.Do(func() {
- backgroundManager = &BackgroundShellManager{
- shells: csync.NewMap[string, *BackgroundShell](),
- }
+ backgroundManager = newBackgroundShellManager()
})
return backgroundManager
}
@@ -80,8 +109,8 @@ func (m *BackgroundShellManager) Start(ctx context.Context, workingDir string, b
Shell: shell,
ctx: shellCtx,
cancel: cancel,
- stdout: &bytes.Buffer{},
- stderr: &bytes.Buffer{},
+ stdout: &syncBuffer{},
+ stderr: &syncBuffer{},
done: make(chan struct{}),
}
diff --git a/internal/shell/background_test.go b/internal/shell/background_test.go
index 5149861d94e457e8a78650c48d9c6765a57d369e..7c521bc1477b07775cffb69f310fa83d710d4634 100644
--- a/internal/shell/background_test.go
+++ b/internal/shell/background_test.go
@@ -14,7 +14,7 @@ func TestBackgroundShellManager_Start(t *testing.T) {
ctx := context.Background()
workingDir := t.TempDir()
- manager := GetBackgroundShellManager()
+ manager := newBackgroundShellManager()
bgShell, err := manager.Start(ctx, workingDir, nil, "echo 'hello world'", "")
if err != nil {
@@ -51,7 +51,7 @@ func TestBackgroundShellManager_Get(t *testing.T) {
ctx := context.Background()
workingDir := t.TempDir()
- manager := GetBackgroundShellManager()
+ manager := newBackgroundShellManager()
bgShell, err := manager.Start(ctx, workingDir, nil, "echo 'test'", "")
if err != nil {
@@ -77,7 +77,7 @@ func TestBackgroundShellManager_Kill(t *testing.T) {
ctx := context.Background()
workingDir := t.TempDir()
- manager := GetBackgroundShellManager()
+ manager := newBackgroundShellManager()
// Start a long-running command
bgShell, err := manager.Start(ctx, workingDir, nil, "sleep 10", "")
@@ -106,7 +106,7 @@ func TestBackgroundShellManager_Kill(t *testing.T) {
func TestBackgroundShellManager_KillNonExistent(t *testing.T) {
t.Parallel()
- manager := GetBackgroundShellManager()
+ manager := newBackgroundShellManager()
err := manager.Kill("non-existent-id")
if err == nil {
@@ -119,7 +119,7 @@ func TestBackgroundShell_IsDone(t *testing.T) {
ctx := context.Background()
workingDir := t.TempDir()
- manager := GetBackgroundShellManager()
+ manager := newBackgroundShellManager()
bgShell, err := manager.Start(ctx, workingDir, nil, "echo 'quick'", "")
if err != nil {
@@ -142,7 +142,7 @@ func TestBackgroundShell_WithBlockFuncs(t *testing.T) {
ctx := context.Background()
workingDir := t.TempDir()
- manager := GetBackgroundShellManager()
+ manager := newBackgroundShellManager()
blockFuncs := []BlockFunc{
CommandsBlocker([]string{"curl", "wget"}),
@@ -180,7 +180,7 @@ func TestBackgroundShellManager_List(t *testing.T) {
ctx := context.Background()
workingDir := t.TempDir()
- manager := GetBackgroundShellManager()
+ manager := newBackgroundShellManager()
// Start two shells
bgShell1, err := manager.Start(ctx, workingDir, nil, "sleep 1", "")
@@ -224,7 +224,7 @@ func TestBackgroundShellManager_KillAll(t *testing.T) {
ctx := context.Background()
workingDir := t.TempDir()
- manager := GetBackgroundShellManager()
+ manager := newBackgroundShellManager()
// Start multiple long-running shells
shell1, err := manager.Start(ctx, workingDir, nil, "sleep 10", "")
diff --git a/internal/tui/components/chat/editor/editor.go b/internal/tui/components/chat/editor/editor.go
index 01badb98d37eb848ccf5962e01793ecaa3fc0f59..b5cadb8cde8a1ced8543d01eb7abd28d906f1597 100644
--- a/internal/tui/components/chat/editor/editor.go
+++ b/internal/tui/components/chat/editor/editor.go
@@ -16,6 +16,7 @@ import (
"charm.land/bubbles/v2/textarea"
tea "charm.land/bubbletea/v2"
"charm.land/lipgloss/v2"
+ nativeclipboard "github.com/aymanbagabas/go-nativeclipboard"
"github.com/charmbracelet/crush/internal/app"
"github.com/charmbracelet/crush/internal/filetracker"
"github.com/charmbracelet/crush/internal/fsext"
@@ -338,6 +339,84 @@ func (m *editorCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
m.textarea.InsertRune('\n')
cmds = append(cmds, util.CmdHandler(completions.CloseCompletionsMsg{}))
}
+ // Handle image paste from clipboard
+ if key.Matches(msg, m.keyMap.PasteImage) {
+ imageData, err := nativeclipboard.Image.Read()
+
+ if err != nil || len(imageData) == 0 {
+ // If no image data found, try to get text data (could be file path)
+ var textData []byte
+ textData, err = nativeclipboard.Text.Read()
+ if err != nil || len(textData) == 0 {
+ // If clipboard is empty, show a warning
+ return m, util.ReportWarn("No data found in clipboard. Note: Some terminals may not support reading image data from clipboard directly.")
+ }
+
+ // Check if the text data is a file path
+ textStr := string(textData)
+ // First, try to interpret as a file path (existing functionality)
+ path := strings.ReplaceAll(textStr, "\\ ", " ")
+ path, err = filepath.Abs(strings.TrimSpace(path))
+ if err == nil {
+ isAllowedType := false
+ for _, ext := range filepicker.AllowedTypes {
+ if strings.HasSuffix(path, ext) {
+ isAllowedType = true
+ break
+ }
+ }
+ if isAllowedType {
+ tooBig, _ := filepicker.IsFileTooBig(path, filepicker.MaxAttachmentSize)
+ if !tooBig {
+ content, err := os.ReadFile(path)
+ if err == nil {
+ mimeBufferSize := min(512, len(content))
+ mimeType := http.DetectContentType(content[:mimeBufferSize])
+ fileName := filepath.Base(path)
+ attachment := message.Attachment{FilePath: path, FileName: fileName, MimeType: mimeType, Content: content}
+ return m, util.CmdHandler(filepicker.FilePickedMsg{
+ Attachment: attachment,
+ })
+ }
+ }
+ }
+ }
+
+ // If not a valid file path, show a warning
+ return m, util.ReportWarn("No image found in clipboard")
+ } else {
+ // We have image data from the clipboard
+ // Create a temporary file to store the clipboard image data
+ tempFile, err := os.CreateTemp("", "clipboard_image_crush_*")
+ if err != nil {
+ return m, util.ReportError(err)
+ }
+ defer tempFile.Close()
+
+ // Write clipboard content to the temporary file
+ _, err = tempFile.Write(imageData)
+ if err != nil {
+ return m, util.ReportError(err)
+ }
+
+ // Determine the file extension based on the image data
+ mimeBufferSize := min(512, len(imageData))
+ mimeType := http.DetectContentType(imageData[:mimeBufferSize])
+
+ // Create an attachment from the temporary file
+ fileName := filepath.Base(tempFile.Name())
+ attachment := message.Attachment{
+ FilePath: tempFile.Name(),
+ FileName: fileName,
+ MimeType: mimeType,
+ Content: imageData,
+ }
+
+ return m, util.CmdHandler(filepicker.FilePickedMsg{
+ Attachment: attachment,
+ })
+ }
+ }
// Handle Enter key
if m.textarea.Focused() && key.Matches(msg, m.keyMap.SendMessage) {
value := m.textarea.Value()
diff --git a/internal/tui/components/chat/editor/keys.go b/internal/tui/components/chat/editor/keys.go
index 0ba4571888e547b1c4a85e7ee9dd73ff07ce13d2..c20df5cc1c071deab83754430543b9be2381127c 100644
--- a/internal/tui/components/chat/editor/keys.go
+++ b/internal/tui/components/chat/editor/keys.go
@@ -9,6 +9,7 @@ type EditorKeyMap struct {
SendMessage key.Binding
OpenEditor key.Binding
Newline key.Binding
+ PasteImage key.Binding
}
func DefaultEditorKeyMap() EditorKeyMap {
@@ -32,6 +33,10 @@ func DefaultEditorKeyMap() EditorKeyMap {
// to reflect that.
key.WithHelp("ctrl+j", "newline"),
),
+ PasteImage: key.NewBinding(
+ key.WithKeys("ctrl+v"),
+ key.WithHelp("ctrl+v", "paste image from clipboard"),
+ ),
}
}
@@ -42,6 +47,7 @@ func (k EditorKeyMap) KeyBindings() []key.Binding {
k.SendMessage,
k.OpenEditor,
k.Newline,
+ k.PasteImage,
AttachmentsKeyMaps.AttachmentDeleteMode,
AttachmentsKeyMaps.DeleteAllAttachments,
AttachmentsKeyMaps.Escape,
diff --git a/internal/tui/styles/theme.go b/internal/tui/styles/theme.go
index f87ffd9de8b324cec4dcfd8b7cee61f71e0390eb..b03603c57439f5f950f9860d3287b0f9d13742e5 100644
--- a/internal/tui/styles/theme.go
+++ b/internal/tui/styles/theme.go
@@ -4,6 +4,7 @@ import (
"fmt"
"image/color"
"strings"
+ "sync"
"charm.land/bubbles/v2/filepicker"
"charm.land/bubbles/v2/help"
@@ -97,7 +98,8 @@ type Theme struct {
AuthBorderUnselected lipgloss.Style
AuthTextUnselected lipgloss.Style
- styles *Styles
+ styles *Styles
+ stylesOnce sync.Once
}
type Styles struct {
@@ -134,9 +136,9 @@ type Styles struct {
}
func (t *Theme) S() *Styles {
- if t.styles == nil {
+ t.stylesOnce.Do(func() {
t.styles = t.buildStyles()
- }
+ })
return t.styles
}
@@ -500,27 +502,31 @@ type Manager struct {
current *Theme
}
-var defaultManager *Manager
+var (
+ defaultManager *Manager
+ defaultManagerOnce sync.Once
+)
+
+func initDefaultManager() *Manager {
+ defaultManagerOnce.Do(func() {
+ defaultManager = newManager()
+ })
+ return defaultManager
+}
func SetDefaultManager(m *Manager) {
defaultManager = m
}
func DefaultManager() *Manager {
- if defaultManager == nil {
- defaultManager = NewManager()
- }
- return defaultManager
+ return initDefaultManager()
}
func CurrentTheme() *Theme {
- if defaultManager == nil {
- defaultManager = NewManager()
- }
- return defaultManager.Current()
+ return initDefaultManager().Current()
}
-func NewManager() *Manager {
+func newManager() *Manager {
m := &Manager{
themes: make(map[string]*Theme),
}
diff --git a/internal/ui/dialog/actions.go b/internal/ui/dialog/actions.go
index 29faff93557730bfb0d60bd8dac3dc9bcca84828..a63d14ac6fbe66bb70d365c9b03a58a0da932fd5 100644
--- a/internal/ui/dialog/actions.go
+++ b/internal/ui/dialog/actions.go
@@ -59,20 +59,18 @@ type (
}
// ActionRunCustomCommand is a message to run a custom command.
ActionRunCustomCommand struct {
- CommandID string
- // Used when running a user-defined command
- Content string
- // Used when running a prompt from MCP
- Client string
- }
- // ActionOpenCustomCommandArgumentsDialog is a message to open the custom command arguments dialog.
- ActionOpenCustomCommandArgumentsDialog struct {
- CommandID string
- // Used when running a user-defined command
- Content string
- // Used when running a prompt from MCP
- Client string
+ Content string
Arguments []commands.Argument
+ Args map[string]string // Actual argument values
+ }
+ // ActionRunMCPPrompt is a message to run a custom command.
+ ActionRunMCPPrompt struct {
+ Title string
+ Description string
+ PromptID string
+ ClientID string
+ Arguments []commands.Argument
+ Args map[string]string // Actual argument values
}
)
diff --git a/internal/ui/dialog/api_key_input.go b/internal/ui/dialog/api_key_input.go
index bbc3d3746b26e51103ce8545eca5fe3ebaaf977a..e28dea2b823143d176d796c8775e8024df61d0bb 100644
--- a/internal/ui/dialog/api_key_input.go
+++ b/internal/ui/dialog/api_key_input.go
@@ -95,7 +95,7 @@ func (m *APIKeyInput) ID() string {
return APIKeyInputID
}
-// Update implements tea.Model.
+// HandleMsg implements [Dialog].
func (m *APIKeyInput) HandleMsg(msg tea.Msg) Action {
switch msg := msg.(type) {
case ActionChangeAPIKeyState:
@@ -149,7 +149,7 @@ func (m *APIKeyInput) HandleMsg(msg tea.Msg) Action {
return nil
}
-// View implements tea.Model.
+// Draw implements [Dialog].
func (m *APIKeyInput) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
t := m.com.Styles
@@ -239,8 +239,8 @@ func (m *APIKeyInput) inputView() string {
}
// Cursor returns the cursor position relative to the dialog.
-func (c *APIKeyInput) Cursor() *tea.Cursor {
- return InputCursor(c.com.Styles, c.input.Cursor())
+func (m *APIKeyInput) Cursor() *tea.Cursor {
+ return InputCursor(m.com.Styles, m.input.Cursor())
}
// FullHelp returns the full help view.
diff --git a/internal/ui/dialog/arguments.go b/internal/ui/dialog/arguments.go
new file mode 100644
index 0000000000000000000000000000000000000000..c016b7de6ec77e6e333d2b0f18ae5930ba0912fc
--- /dev/null
+++ b/internal/ui/dialog/arguments.go
@@ -0,0 +1,399 @@
+package dialog
+
+import (
+ "strings"
+
+ "charm.land/bubbles/v2/help"
+ "charm.land/bubbles/v2/key"
+ "charm.land/bubbles/v2/spinner"
+ "charm.land/bubbles/v2/textinput"
+ "charm.land/bubbles/v2/viewport"
+ tea "charm.land/bubbletea/v2"
+ "charm.land/lipgloss/v2"
+ "golang.org/x/text/cases"
+ "golang.org/x/text/language"
+
+ "github.com/charmbracelet/crush/internal/commands"
+ "github.com/charmbracelet/crush/internal/ui/common"
+ "github.com/charmbracelet/crush/internal/uiutil"
+ uv "github.com/charmbracelet/ultraviolet"
+)
+
+// ArgumentsID is the identifier for the arguments dialog.
+const ArgumentsID = "arguments"
+
+// Dialog sizing for arguments.
+const (
+ maxInputWidth = 120
+ minInputWidth = 30
+ maxViewportHeight = 20
+ argumentsFieldHeight = 3 // label + input + spacing per field
+)
+
+// Arguments represents a dialog for collecting command arguments.
+type Arguments struct {
+ com *common.Common
+ title string
+ arguments []commands.Argument
+ inputs []textinput.Model
+ focused int
+ spinner spinner.Model
+ loading bool
+
+ description string
+ resultAction Action
+
+ help help.Model
+ keyMap struct {
+ Confirm,
+ Next,
+ Previous,
+ ScrollUp,
+ ScrollDown,
+ Close key.Binding
+ }
+
+ viewport viewport.Model
+}
+
+var _ Dialog = (*Arguments)(nil)
+
+// NewArguments creates a new arguments dialog.
+func NewArguments(com *common.Common, title, description string, arguments []commands.Argument, resultAction Action) *Arguments {
+ a := &Arguments{
+ com: com,
+ title: title,
+ description: description,
+ arguments: arguments,
+ resultAction: resultAction,
+ }
+
+ a.help = help.New()
+ a.help.Styles = com.Styles.DialogHelpStyles()
+
+ a.keyMap.Confirm = key.NewBinding(
+ key.WithKeys("enter"),
+ key.WithHelp("enter", "confirm"),
+ )
+ a.keyMap.Next = key.NewBinding(
+ key.WithKeys("down", "tab"),
+ key.WithHelp("↓/tab", "next"),
+ )
+ a.keyMap.Previous = key.NewBinding(
+ key.WithKeys("up", "shift+tab"),
+ key.WithHelp("↑/shift+tab", "previous"),
+ )
+ a.keyMap.Close = CloseKey
+
+ // Create input fields for each argument.
+ a.inputs = make([]textinput.Model, len(arguments))
+ for i, arg := range arguments {
+ input := textinput.New()
+ input.SetVirtualCursor(false)
+ input.SetStyles(com.Styles.TextInput)
+ input.Prompt = "> "
+ // Use description as placeholder if available, otherwise title
+ if arg.Description != "" {
+ input.Placeholder = arg.Description
+ } else {
+ input.Placeholder = arg.Title
+ }
+
+ if i == 0 {
+ input.Focus()
+ } else {
+ input.Blur()
+ }
+
+ a.inputs[i] = input
+ }
+ s := spinner.New()
+ s.Spinner = spinner.Dot
+ s.Style = com.Styles.Dialog.Spinner
+ a.spinner = s
+
+ return a
+}
+
+// ID implements Dialog.
+func (a *Arguments) ID() string {
+ return ArgumentsID
+}
+
+// focusInput changes focus to a new input by index with wrap-around.
+func (a *Arguments) focusInput(newIndex int) {
+ a.inputs[a.focused].Blur()
+
+ // Wrap around: Go's modulo can return negative, so add len first.
+ n := len(a.inputs)
+ a.focused = ((newIndex % n) + n) % n
+
+ a.inputs[a.focused].Focus()
+
+ // Ensure the newly focused field is visible in the viewport
+ a.ensureFieldVisible(a.focused)
+}
+
+// isFieldVisible checks if a field at the given index is visible in the viewport.
+func (a *Arguments) isFieldVisible(fieldIndex int) bool {
+ fieldStart := fieldIndex * argumentsFieldHeight
+ fieldEnd := fieldStart + argumentsFieldHeight - 1
+ viewportTop := a.viewport.YOffset()
+ viewportBottom := viewportTop + a.viewport.Height() - 1
+
+ return fieldStart >= viewportTop && fieldEnd <= viewportBottom
+}
+
+// ensureFieldVisible scrolls the viewport to make the field visible.
+func (a *Arguments) ensureFieldVisible(fieldIndex int) {
+ if a.isFieldVisible(fieldIndex) {
+ return
+ }
+
+ fieldStart := fieldIndex * argumentsFieldHeight
+ fieldEnd := fieldStart + argumentsFieldHeight - 1
+ viewportTop := a.viewport.YOffset()
+ viewportHeight := a.viewport.Height()
+
+ // If field is above viewport, scroll up to show it at top
+ if fieldStart < viewportTop {
+ a.viewport.SetYOffset(fieldStart)
+ return
+ }
+
+ // If field is below viewport, scroll down to show it at bottom
+ if fieldEnd > viewportTop+viewportHeight-1 {
+ a.viewport.SetYOffset(fieldEnd - viewportHeight + 1)
+ }
+}
+
+// findVisibleFieldByOffset returns the field index closest to the given viewport offset.
+func (a *Arguments) findVisibleFieldByOffset(fromTop bool) int {
+ offset := a.viewport.YOffset()
+ if !fromTop {
+ offset += a.viewport.Height() - 1
+ }
+
+ fieldIndex := offset / argumentsFieldHeight
+ if fieldIndex >= len(a.inputs) {
+ return len(a.inputs) - 1
+ }
+ return fieldIndex
+}
+
+// HandleMsg implements Dialog.
+func (a *Arguments) HandleMsg(msg tea.Msg) Action {
+ switch msg := msg.(type) {
+ case spinner.TickMsg:
+ if a.loading {
+ var cmd tea.Cmd
+ a.spinner, cmd = a.spinner.Update(msg)
+ return ActionCmd{Cmd: cmd}
+ }
+ case tea.KeyPressMsg:
+ switch {
+ case key.Matches(msg, a.keyMap.Close):
+ return ActionClose{}
+ case key.Matches(msg, a.keyMap.Confirm):
+ // If we're on the last input or there's only one input, submit.
+ if a.focused == len(a.inputs)-1 || len(a.inputs) == 1 {
+ args := make(map[string]string)
+ var warning tea.Cmd
+ for i, arg := range a.arguments {
+ args[arg.ID] = a.inputs[i].Value()
+ if arg.Required && strings.TrimSpace(a.inputs[i].Value()) == "" {
+ warning = uiutil.ReportWarn("Required argument '" + arg.Title + "' is missing.")
+ break
+ }
+ }
+ if warning != nil {
+ return ActionCmd{Cmd: warning}
+ }
+
+ switch action := a.resultAction.(type) {
+ case ActionRunCustomCommand:
+ action.Args = args
+ return action
+ case ActionRunMCPPrompt:
+ action.Args = args
+ return action
+ }
+ }
+ a.focusInput(a.focused + 1)
+ case key.Matches(msg, a.keyMap.Next):
+ a.focusInput(a.focused + 1)
+ case key.Matches(msg, a.keyMap.Previous):
+ a.focusInput(a.focused - 1)
+ default:
+ var cmd tea.Cmd
+ a.inputs[a.focused], cmd = a.inputs[a.focused].Update(msg)
+ return ActionCmd{Cmd: cmd}
+ }
+ case tea.MouseWheelMsg:
+ a.viewport, _ = a.viewport.Update(msg)
+ // If focused field scrolled out of view, focus the visible field
+ if !a.isFieldVisible(a.focused) {
+ a.focusInput(a.findVisibleFieldByOffset(msg.Button == tea.MouseWheelDown))
+ }
+ case tea.PasteMsg:
+ var cmd tea.Cmd
+ a.inputs[a.focused], cmd = a.inputs[a.focused].Update(msg)
+ return ActionCmd{Cmd: cmd}
+ }
+ return nil
+}
+
+// Cursor returns the cursor position relative to the dialog.
+// we pass the description height to offset the cursor correctly.
+func (a *Arguments) Cursor(descriptionHeight int) *tea.Cursor {
+ cursor := InputCursor(a.com.Styles, a.inputs[a.focused].Cursor())
+ if cursor == nil {
+ return nil
+ }
+ cursor.Y += descriptionHeight + a.focused*argumentsFieldHeight - a.viewport.YOffset() + 1
+ return cursor
+}
+
+// Draw implements Dialog.
+func (a *Arguments) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
+ s := a.com.Styles
+
+ dialogContentStyle := s.Dialog.Arguments.Content
+ possibleWidth := area.Dx() - s.Dialog.View.GetHorizontalFrameSize() - dialogContentStyle.GetHorizontalFrameSize()
+ // Build fields with label and input.
+ caser := cases.Title(language.English)
+
+ var fields []string
+ for i, arg := range a.arguments {
+ isFocused := i == a.focused
+
+ // Try to pretty up the title for the label.
+ title := strings.ReplaceAll(arg.Title, "_", " ")
+ title = strings.ReplaceAll(title, "-", " ")
+ titleParts := strings.Fields(title)
+ for i, part := range titleParts {
+ titleParts[i] = caser.String(strings.ToLower(part))
+ }
+ labelText := strings.Join(titleParts, " ")
+
+ markRequiredStyle := s.Dialog.Arguments.InputRequiredMarkBlurred
+
+ labelStyle := s.Dialog.Arguments.InputLabelBlurred
+ if isFocused {
+ labelStyle = s.Dialog.Arguments.InputLabelFocused
+ markRequiredStyle = s.Dialog.Arguments.InputRequiredMarkFocused
+ }
+ if arg.Required {
+ labelText += markRequiredStyle.String()
+ }
+ label := labelStyle.Render(labelText)
+
+ labelWidth := lipgloss.Width(labelText)
+ placeholderWidth := lipgloss.Width(a.inputs[i].Placeholder)
+
+ inputWidth := max(placeholderWidth, labelWidth, minInputWidth)
+ inputWidth = min(inputWidth, min(possibleWidth, maxInputWidth))
+ a.inputs[i].SetWidth(inputWidth)
+
+ inputLine := a.inputs[i].View()
+
+ field := lipgloss.JoinVertical(lipgloss.Left, label, inputLine, "")
+ fields = append(fields, field)
+ }
+
+ renderedFields := lipgloss.JoinVertical(lipgloss.Left, fields...)
+
+ // Anchor width to the longest field, capped at maxInputWidth.
+ const scrollbarWidth = 1
+ width := lipgloss.Width(renderedFields)
+ height := lipgloss.Height(renderedFields)
+
+ // Use standard header
+ titleStyle := s.Dialog.Title
+
+ titleText := a.title
+ if titleText == "" {
+ titleText = "Arguments"
+ }
+
+ header := common.DialogTitle(s, titleText, width)
+
+ // Add description if available.
+ var description string
+ if a.description != "" {
+ descStyle := s.Dialog.Arguments.Description.Width(width)
+ description = descStyle.Render(a.description)
+ }
+
+ helpView := s.Dialog.HelpView.Width(width).Render(a.help.View(a))
+ if a.loading {
+ helpView = s.Dialog.HelpView.Width(width).Render(a.spinner.View() + " Generating Prompt...")
+ }
+
+ availableHeight := area.Dy() - s.Dialog.View.GetVerticalFrameSize() - dialogContentStyle.GetVerticalFrameSize() - lipgloss.Height(header) - lipgloss.Height(description) - lipgloss.Height(helpView) - 2 // extra spacing
+ viewportHeight := min(height, maxViewportHeight, availableHeight)
+
+ a.viewport.SetWidth(width) // -1 for scrollbar
+ a.viewport.SetHeight(viewportHeight)
+ a.viewport.SetContent(renderedFields)
+
+ scrollbar := common.Scrollbar(s, viewportHeight, a.viewport.TotalLineCount(), viewportHeight, a.viewport.YOffset())
+ content := a.viewport.View()
+ if scrollbar != "" {
+ content = lipgloss.JoinHorizontal(lipgloss.Top, content, scrollbar)
+ }
+ contentParts := []string{}
+ if description != "" {
+ contentParts = append(contentParts, description)
+ }
+ contentParts = append(contentParts, content)
+
+ view := lipgloss.JoinVertical(
+ lipgloss.Left,
+ titleStyle.Render(header),
+ dialogContentStyle.Render(lipgloss.JoinVertical(lipgloss.Left, contentParts...)),
+ helpView,
+ )
+
+ dialog := s.Dialog.View.Render(view)
+
+ descriptionHeight := 0
+ if a.description != "" {
+ descriptionHeight = lipgloss.Height(description)
+ }
+ cur := a.Cursor(descriptionHeight)
+
+ DrawCenterCursor(scr, area, dialog, cur)
+ return cur
+}
+
+// StartLoading implements [LoadingDialog].
+func (a *Arguments) StartLoading() tea.Cmd {
+ if a.loading {
+ return nil
+ }
+ a.loading = true
+ return a.spinner.Tick
+}
+
+// StopLoading implements [LoadingDialog].
+func (a *Arguments) StopLoading() {
+ a.loading = false
+}
+
+// ShortHelp implements help.KeyMap.
+func (a *Arguments) ShortHelp() []key.Binding {
+ return []key.Binding{
+ a.keyMap.Confirm,
+ a.keyMap.Next,
+ a.keyMap.Close,
+ }
+}
+
+// FullHelp implements help.KeyMap.
+func (a *Arguments) FullHelp() [][]key.Binding {
+ return [][]key.Binding{
+ {a.keyMap.Confirm, a.keyMap.Next, a.keyMap.Previous},
+ {a.keyMap.Close},
+ }
+}
diff --git a/internal/ui/dialog/commands.go b/internal/ui/dialog/commands.go
index 03707a54775992992a36e90e6857b0f55ce3c8e3..a6861a5c87707d7c0717ec4d3c50c1d995a528af 100644
--- a/internal/ui/dialog/commands.go
+++ b/internal/ui/dialog/commands.go
@@ -6,6 +6,7 @@ import (
"charm.land/bubbles/v2/help"
"charm.land/bubbles/v2/key"
+ "charm.land/bubbles/v2/spinner"
"charm.land/bubbles/v2/textinput"
tea "charm.land/bubbletea/v2"
"charm.land/lipgloss/v2"
@@ -52,26 +53,29 @@ type Commands struct {
sessionID string // can be empty for non-session-specific commands
selected CommandType
+ spinner spinner.Model
+ loading bool
+
help help.Model
input textinput.Model
list *list.FilterableList
windowWidth int
- customCommands []commands.CustomCommand
- mcpCustomCommands []commands.MCPCustomCommand
+ customCommands []commands.CustomCommand
+ mcpPrompts []commands.MCPPrompt
}
var _ Dialog = (*Commands)(nil)
// NewCommands creates a new commands dialog.
-func NewCommands(com *common.Common, sessionID string, customCommands []commands.CustomCommand, mcpCustomCommands []commands.MCPCustomCommand) (*Commands, error) {
+func NewCommands(com *common.Common, sessionID string, customCommands []commands.CustomCommand, mcpPrompts []commands.MCPPrompt) (*Commands, error) {
c := &Commands{
- com: com,
- selected: SystemCommands,
- sessionID: sessionID,
- customCommands: customCommands,
- mcpCustomCommands: mcpCustomCommands,
+ com: com,
+ selected: SystemCommands,
+ sessionID: sessionID,
+ customCommands: customCommands,
+ mcpPrompts: mcpPrompts,
}
help := help.New()
@@ -120,6 +124,11 @@ func NewCommands(com *common.Common, sessionID string, customCommands []commands
// Set initial commands
c.setCommandItems(c.selected)
+ s := spinner.New()
+ s.Spinner = spinner.Dot
+ s.Style = com.Styles.Dialog.Spinner
+ c.spinner = s
+
return c, nil
}
@@ -128,9 +137,15 @@ func (c *Commands) ID() string {
return CommandsID
}
-// HandleMsg implements Dialog.
+// HandleMsg implements [Dialog].
func (c *Commands) HandleMsg(msg tea.Msg) Action {
switch msg := msg.(type) {
+ case spinner.TickMsg:
+ if c.loading {
+ var cmd tea.Cmd
+ c.spinner, cmd = c.spinner.Update(msg)
+ return ActionCmd{Cmd: cmd}
+ }
case tea.KeyPressMsg:
switch {
case key.Matches(msg, c.keyMap.Close):
@@ -160,12 +175,12 @@ func (c *Commands) HandleMsg(msg tea.Msg) Action {
}
}
case key.Matches(msg, c.keyMap.Tab):
- if len(c.customCommands) > 0 || len(c.mcpCustomCommands) > 0 {
+ if len(c.customCommands) > 0 || len(c.mcpPrompts) > 0 {
c.selected = c.nextCommandType()
c.setCommandItems(c.selected)
}
case key.Matches(msg, c.keyMap.ShiftTab):
- if len(c.customCommands) > 0 || len(c.mcpCustomCommands) > 0 {
+ if len(c.customCommands) > 0 || len(c.mcpPrompts) > 0 {
c.selected = c.previousCommandType()
c.setCommandItems(c.selected)
}
@@ -242,12 +257,16 @@ func (c *Commands) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
c.list.SetSize(innerWidth, height-heightOffset)
c.help.SetWidth(innerWidth)
- radio := commandsRadioView(t, c.selected, len(c.customCommands) > 0, len(c.mcpCustomCommands) > 0)
+ radio := commandsRadioView(t, c.selected, len(c.customCommands) > 0, len(c.mcpPrompts) > 0)
titleStyle := t.Dialog.Title
dialogStyle := t.Dialog.View.Width(width)
headerOffset := lipgloss.Width(radio) + titleStyle.GetHorizontalFrameSize() + dialogStyle.GetHorizontalFrameSize()
helpView := ansi.Truncate(c.help.View(c), innerWidth, "")
header := common.DialogTitle(t, "Commands", width-headerOffset) + radio
+
+ if c.loading {
+ helpView = t.Dialog.HelpView.Width(width).Render(c.spinner.View() + " Generating Prompt...")
+ }
view := HeaderInputListHelpView(t, width, c.list.Height(), header,
c.input.View(), c.list.Render(), helpView)
@@ -281,12 +300,12 @@ func (c *Commands) nextCommandType() CommandType {
if len(c.customCommands) > 0 {
return UserCommands
}
- if len(c.mcpCustomCommands) > 0 {
+ if len(c.mcpPrompts) > 0 {
return MCPPrompts
}
fallthrough
case UserCommands:
- if len(c.mcpCustomCommands) > 0 {
+ if len(c.mcpPrompts) > 0 {
return MCPPrompts
}
fallthrough
@@ -301,7 +320,7 @@ func (c *Commands) nextCommandType() CommandType {
func (c *Commands) previousCommandType() CommandType {
switch c.selected {
case SystemCommands:
- if len(c.mcpCustomCommands) > 0 {
+ if len(c.mcpPrompts) > 0 {
return MCPPrompts
}
if len(c.customCommands) > 0 {
@@ -332,37 +351,22 @@ func (c *Commands) setCommandItems(commandType CommandType) {
}
case UserCommands:
for _, cmd := range c.customCommands {
- var action Action
- if len(cmd.Arguments) > 0 {
- action = ActionOpenCustomCommandArgumentsDialog{
- CommandID: cmd.ID,
- Content: cmd.Content,
- Arguments: cmd.Arguments,
- }
- } else {
- action = ActionRunCustomCommand{
- CommandID: cmd.ID,
- Content: cmd.Content,
- }
+ action := ActionRunCustomCommand{
+ Content: cmd.Content,
+ Arguments: cmd.Arguments,
}
commandItems = append(commandItems, NewCommandItem(c.com.Styles, "custom_"+cmd.ID, cmd.Name, "", action))
}
case MCPPrompts:
- for _, cmd := range c.mcpCustomCommands {
- var action Action
- if len(cmd.Arguments) > 0 {
- action = ActionOpenCustomCommandArgumentsDialog{
- CommandID: cmd.ID,
- Client: cmd.Client,
- Arguments: cmd.Arguments,
- }
- } else {
- action = ActionRunCustomCommand{
- CommandID: cmd.ID,
- Client: cmd.Client,
- }
+ for _, cmd := range c.mcpPrompts {
+ action := ActionRunMCPPrompt{
+ Title: cmd.Title,
+ Description: cmd.Description,
+ PromptID: cmd.PromptID,
+ ClientID: cmd.ClientID,
+ Arguments: cmd.Arguments,
}
- commandItems = append(commandItems, NewCommandItem(c.com.Styles, "mcp_"+cmd.ID, cmd.Name, "", action))
+ commandItems = append(commandItems, NewCommandItem(c.com.Styles, "mcp_"+cmd.ID, cmd.PromptID, "", action))
}
}
@@ -448,10 +452,24 @@ func (c *Commands) SetCustomCommands(customCommands []commands.CustomCommand) {
}
}
-// SetMCPCustomCommands sets the MCP custom commands and refreshes the view if MCP prompts are currently displayed.
-func (c *Commands) SetMCPCustomCommands(mcpCustomCommands []commands.MCPCustomCommand) {
- c.mcpCustomCommands = mcpCustomCommands
+// SetMCPPrompts sets the MCP prompts and refreshes the view if MCP prompts are currently displayed.
+func (c *Commands) SetMCPPrompts(mcpPrompts []commands.MCPPrompt) {
+ c.mcpPrompts = mcpPrompts
if c.selected == MCPPrompts {
c.setCommandItems(c.selected)
}
}
+
+// StartLoading implements [LoadingDialog].
+func (a *Commands) StartLoading() tea.Cmd {
+ if a.loading {
+ return nil
+ }
+ a.loading = true
+ return a.spinner.Tick
+}
+
+// StopLoading implements [LoadingDialog].
+func (a *Commands) StopLoading() {
+ a.loading = false
+}
diff --git a/internal/ui/dialog/dialog.go b/internal/ui/dialog/dialog.go
index 68eb313d4ec83cf8d098fcfccb5ebf27de8bd0d1..7a3db40128fb1e5543a94a93faa4ae9aeec5f947 100644
--- a/internal/ui/dialog/dialog.go
+++ b/internal/ui/dialog/dialog.go
@@ -27,7 +27,7 @@ var CloseKey = key.NewBinding(
)
// Action represents an action taken in a dialog after handling a message.
-type Action interface{}
+type Action any
// Dialog is a component that can be displayed on top of the UI.
type Dialog interface {
@@ -41,6 +41,12 @@ type Dialog interface {
Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor
}
+// LoadingDialog is a dialog that can show a loading state.
+type LoadingDialog interface {
+ StartLoading() tea.Cmd
+ StopLoading()
+}
+
// Overlay manages multiple dialogs as an overlay.
type Overlay struct {
dialogs []Dialog
@@ -136,6 +142,25 @@ func (d *Overlay) Update(msg tea.Msg) tea.Msg {
return dialog.HandleMsg(msg)
}
+// StartLoading starts the loading state for the front dialog if it
+// implements [LoadingDialog].
+func (d *Overlay) StartLoading() tea.Cmd {
+ dialog := d.DialogLast()
+ if ld, ok := dialog.(LoadingDialog); ok {
+ return ld.StartLoading()
+ }
+ return nil
+}
+
+// StopLoading stops the loading state for the front dialog if it
+// implements [LoadingDialog].
+func (d *Overlay) StopLoading() {
+ dialog := d.DialogLast()
+ if ld, ok := dialog.(LoadingDialog); ok {
+ ld.StopLoading()
+ }
+}
+
// DrawCenterCursor draws the given string view centered in the screen area and
// adjusts the cursor position accordingly.
func DrawCenterCursor(scr uv.Screen, area uv.Rectangle, view string, cur *tea.Cursor) {
diff --git a/internal/ui/model/ui.go b/internal/ui/model/ui.go
index 201f440d3820d72707800820e4f5c4e8bfa8af25..f11ad75454dd18f3a710f925da52309935290c37 100644
--- a/internal/ui/model/ui.go
+++ b/internal/ui/model/ui.go
@@ -18,6 +18,7 @@ import (
"charm.land/bubbles/v2/help"
"charm.land/bubbles/v2/key"
+ "charm.land/bubbles/v2/spinner"
"charm.land/bubbles/v2/textarea"
tea "charm.land/bubbletea/v2"
"charm.land/lipgloss/v2"
@@ -88,10 +89,19 @@ type (
userCommandsLoadedMsg struct {
Commands []commands.CustomCommand
}
- // mcpCustomCommandsLoadedMsg is sent when mcp prompts are loaded.
- mcpCustomCommandsLoadedMsg struct {
- Prompts []commands.MCPCustomCommand
+ // mcpPromptsLoadedMsg is sent when mcp prompts are loaded.
+ mcpPromptsLoadedMsg struct {
+ Prompts []commands.MCPPrompt
}
+ // sendMessageMsg is sent to send a message.
+ // currently only used for mcp prompts.
+ sendMessageMsg struct {
+ Content string
+ Attachments []message.Attachment
+ }
+
+ // closeDialogMsg is sent to close the current dialog.
+ closeDialogMsg struct{}
)
// UI represents the main user interface model.
@@ -165,8 +175,8 @@ type UI struct {
imgCaps timage.Capabilities
// custom commands & mcp commands
- customCommands []commands.CustomCommand
- mcpCustomCommands []commands.MCPCustomCommand
+ customCommands []commands.CustomCommand
+ mcpPrompts []commands.MCPPrompt
// forceCompactMode tracks whether compact mode is forced by user toggle
forceCompactMode bool
@@ -285,15 +295,15 @@ func (m *UI) loadCustomCommands() tea.Cmd {
// loadMCPrompts loads the MCP prompts asynchronously.
func (m *UI) loadMCPrompts() tea.Cmd {
return func() tea.Msg {
- prompts, err := commands.LoadMCPCustomCommands()
+ prompts, err := commands.LoadMCPPrompts()
if err != nil {
slog.Error("failed to load mcp prompts", "error", err)
}
if prompts == nil {
// flag them as loaded even if there is none or an error
- prompts = []commands.MCPCustomCommand{}
+ prompts = []commands.MCPPrompt{}
}
- return mcpCustomCommandsLoadedMsg{Prompts: prompts}
+ return mcpPromptsLoadedMsg{Prompts: prompts}
}
}
@@ -322,6 +332,9 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
cmds = append(cmds, cmd)
}
+ case sendMessageMsg:
+ cmds = append(cmds, m.sendMessage(msg.Content, msg.Attachments...))
+
case userCommandsLoadedMsg:
m.customCommands = msg.Commands
dia := m.dialog.Dialog(dialog.CommandsID)
@@ -333,8 +346,8 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if ok {
commands.SetCustomCommands(m.customCommands)
}
- case mcpCustomCommandsLoadedMsg:
- m.mcpCustomCommands = msg.Prompts
+ case mcpPromptsLoadedMsg:
+ m.mcpPrompts = msg.Prompts
dia := m.dialog.Dialog(dialog.CommandsID)
if dia == nil {
break
@@ -342,9 +355,12 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
commands, ok := dia.(*dialog.Commands)
if ok {
- commands.SetMCPCustomCommands(m.mcpCustomCommands)
+ commands.SetMCPPrompts(m.mcpPrompts)
}
+ case closeDialogMsg:
+ m.dialog.CloseFrontDialog()
+
case pubsub.Event[message.Message]:
// Check if this is a child session message for an agent tool.
if m.session == nil {
@@ -377,7 +393,7 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
break
}
}
- if initialized && m.mcpCustomCommands == nil {
+ if initialized && m.mcpPrompts == nil {
cmds = append(cmds, m.loadMCPrompts())
}
case pubsub.Event[permission.PermissionRequest]:
@@ -497,6 +513,14 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
cmds = append(cmds, cmd)
}
}
+ case spinner.TickMsg:
+ if m.dialog.HasDialogs() {
+ // route to dialog
+ if cmd := m.handleDialogMsg(msg); cmd != nil {
+ cmds = append(cmds, cmd)
+ }
+ }
+
case tea.KeyPressMsg:
if cmd := m.handleKeyPressMsg(msg); cmd != nil {
cmds = append(cmds, cmd)
@@ -660,6 +684,11 @@ func (m *UI) loadNestedToolCalls(items []chat.MessageItem) {
// if the message is a tool result it will update the corresponding tool call message
func (m *UI) appendSessionMessage(msg message.Message) tea.Cmd {
var cmds []tea.Cmd
+ existing := m.chat.MessageItem(msg.ID)
+ if existing != nil {
+ // message already exists, skip
+ return nil
+ }
switch msg.Role {
case message.User, message.Assistant:
items := chat.ExtractMessageItems(m.com.Styles, &msg, nil)
@@ -945,6 +974,43 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd {
},
))
+ case dialog.ActionRunCustomCommand:
+ if len(msg.Arguments) > 0 && msg.Args == nil {
+ m.dialog.CloseFrontDialog()
+ argsDialog := dialog.NewArguments(
+ m.com,
+ "Custom Command Arguments",
+ "",
+ msg.Arguments,
+ msg, // Pass the action as the result
+ )
+ m.dialog.OpenDialog(argsDialog)
+ break
+ }
+ content := msg.Content
+ if msg.Args != nil {
+ content = substituteArgs(content, msg.Args)
+ }
+ cmds = append(cmds, m.sendMessage(content))
+ m.dialog.CloseFrontDialog()
+ case dialog.ActionRunMCPPrompt:
+ if len(msg.Arguments) > 0 && msg.Args == nil {
+ m.dialog.CloseFrontDialog()
+ title := msg.Title
+ if title == "" {
+ title = "MCP Prompt Arguments"
+ }
+ argsDialog := dialog.NewArguments(
+ m.com,
+ title,
+ msg.Description,
+ msg.Arguments,
+ msg, // Pass the action as the result
+ )
+ m.dialog.OpenDialog(argsDialog)
+ break
+ }
+ cmds = append(cmds, m.runMCPPrompt(msg.ClientID, msg.PromptID, msg.Args))
default:
cmds = append(cmds, uiutil.CmdHandler(msg))
}
@@ -952,6 +1018,15 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd {
return tea.Batch(cmds...)
}
+// substituteArgs replaces $ARG_NAME placeholders in content with actual values.
+func substituteArgs(content string, args map[string]string) string {
+ for name, value := range args {
+ placeholder := "$" + name
+ content = strings.ReplaceAll(content, placeholder, value)
+ }
+ return content
+}
+
// openAPIKeyInputDialog opens the API key input dialog.
func (m *UI) openAPIKeyInputDialog(provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType) tea.Cmd {
if m.dialog.ContainsDialog(dialog.APIKeyInputID) {
@@ -1085,7 +1160,7 @@ func (m *UI) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd {
m.randomizePlaceholders()
- return m.sendMessage(value, attachments)
+ return m.sendMessage(value, attachments...)
case key.Matches(msg, m.keyMap.Chat.NewSession):
if m.session == nil || m.session.ID == "" {
break
@@ -2043,7 +2118,7 @@ func (m *UI) renderSidebarLogo(width int) {
}
// sendMessage sends a message with the given content and attachments.
-func (m *UI) sendMessage(content string, attachments []message.Attachment) tea.Cmd {
+func (m *UI) sendMessage(content string, attachments ...message.Attachment) tea.Cmd {
if m.com.App.AgentCoordinator == nil {
return uiutil.ReportError(fmt.Errorf("coder agent is not initialized"))
}
@@ -2195,7 +2270,7 @@ func (m *UI) openCommandsDialog() tea.Cmd {
sessionID = m.session.ID
}
- commands, err := dialog.NewCommands(m.com, sessionID, m.customCommands, m.mcpCustomCommands)
+ commands, err := dialog.NewCommands(m.com, sessionID, m.customCommands, m.mcpPrompts)
if err != nil {
return uiutil.ReportError(err)
}
@@ -2442,6 +2517,33 @@ func (m *UI) drawSessionDetails(scr uv.Screen, area uv.Rectangle) {
).Draw(scr, area)
}
+func (m *UI) runMCPPrompt(clientID, promptID string, arguments map[string]string) tea.Cmd {
+ load := func() tea.Msg {
+ prompt, err := commands.GetMCPPrompt(clientID, promptID, arguments)
+ if err != nil {
+ // TODO: make this better
+ return uiutil.ReportError(err)()
+ }
+
+ if prompt == "" {
+ return nil
+ }
+ return sendMessageMsg{
+ Content: prompt,
+ }
+ }
+
+ var cmds []tea.Cmd
+ if cmd := m.dialog.StartLoading(); cmd != nil {
+ cmds = append(cmds, cmd)
+ }
+ cmds = append(cmds, load, func() tea.Msg {
+ return closeDialogMsg{}
+ })
+
+ return tea.Sequence(cmds...)
+}
+
// renderLogo renders the Crush logo with the given styles and dimensions.
func renderLogo(t *styles.Styles, compact bool, width int) string {
return logo.Render(version.Version, compact, logo.Opts{
diff --git a/internal/ui/styles/styles.go b/internal/ui/styles/styles.go
index a2d45a3265d170bf545d09b8617ff376be9898b8..5c822b862bf76f1826ad4e367b1665a09ce467a5 100644
--- a/internal/ui/styles/styles.go
+++ b/internal/ui/styles/styles.go
@@ -333,6 +333,8 @@ type Styles struct {
List lipgloss.Style
+ Spinner lipgloss.Style
+
// ContentPanel is used for content blocks with subtle background.
ContentPanel lipgloss.Style
@@ -340,6 +342,16 @@ type Styles struct {
ScrollbarThumb lipgloss.Style
ScrollbarTrack lipgloss.Style
+ // Arguments
+ Arguments struct {
+ Content lipgloss.Style
+ Description lipgloss.Style
+ InputLabelBlurred lipgloss.Style
+ InputLabelFocused lipgloss.Style
+ InputRequiredMarkBlurred lipgloss.Style
+ InputRequiredMarkFocused lipgloss.Style
+ }
+
Commands struct{}
ImagePreview lipgloss.Style
@@ -1207,11 +1219,19 @@ func DefaultStyles() Styles {
s.Dialog.List = base.Margin(0, 0, 1, 0)
s.Dialog.ContentPanel = base.Background(bgSubtle).Foreground(fgBase).Padding(1, 2)
+ s.Dialog.Spinner = base.Foreground(secondary)
s.Dialog.ScrollbarThumb = base.Foreground(secondary)
s.Dialog.ScrollbarTrack = base.Foreground(border)
s.Dialog.ImagePreview = lipgloss.NewStyle().Padding(0, 1).Foreground(fgSubtle)
+ s.Dialog.Arguments.Content = base.Padding(1)
+ s.Dialog.Arguments.Description = base.MarginBottom(1).MaxHeight(3)
+ s.Dialog.Arguments.InputLabelBlurred = base.Foreground(fgMuted)
+ s.Dialog.Arguments.InputLabelFocused = base.Bold(true)
+ s.Dialog.Arguments.InputRequiredMarkBlurred = base.Foreground(fgMuted).SetString("*")
+ s.Dialog.Arguments.InputRequiredMarkFocused = base.Foreground(primary).Bold(true).SetString("*")
+
s.Status.Help = lipgloss.NewStyle().Padding(0, 1)
s.Status.SuccessIndicator = base.Foreground(bgSubtle).Background(green).Padding(0, 1).Bold(true).SetString("OKAY!")
s.Status.InfoIndicator = s.Status.SuccessIndicator