From 87df08c4088afb265a3886380db96ca6387c518d Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Tue, 13 Jan 2026 11:29:21 -0500 Subject: [PATCH 01/13] fix: race condition where title might not be generated (#1844) --- internal/agent/agent.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 8d2fa40fd427143bf988587ef7faa3a89c3e23b1..198159d53adbcbba8f8598bf24a8eef55825acfc 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -183,6 +183,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy a.generateTitle(titleCtx, call.SessionID, call.Prompt) }) } + defer wg.Wait() // Add the user message to the session. _, err = a.createUserMessage(ctx, call) @@ -491,7 +492,6 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy } return nil, err } - wg.Wait() if shouldSummarize { a.activeRequests.Del(call.SessionID) From 3fd9d970149e706149c47e86f3d007512472ff12 Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Tue, 13 Jan 2026 14:34:19 -0300 Subject: [PATCH 02/13] ci(sec): add more security jobs, improve build, enable race detector (#1849) Signed-off-by: Carlos Alexandro Becker Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .github/workflows/build.yml | 28 ++++++-- .github/workflows/security.yml | 88 ++++++++++++++++++++++++++ Taskfile.yaml | 10 +-- go.mod | 2 +- go.sum | 4 +- internal/csync/maps_test.go | 7 +- internal/permission/permission.go | 13 +++- internal/permission/permission_test.go | 4 +- internal/shell/background.go | 43 +++++++++++-- internal/shell/background_test.go | 16 ++--- internal/tui/styles/theme.go | 32 ++++++---- 11 files changed, 199 insertions(+), 48 deletions(-) create mode 100644 .github/workflows/security.yml diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 3511f7fe0c4f487eb3fc9009795361ada8e2eff7..39b5923298e2f7fa8d5452327a6e8b2a08f0df97 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,11 +1,27 @@ name: build on: [push, pull_request] +permissions: + contents: read + +concurrency: + group: build-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: build: - uses: charmbracelet/meta/.github/workflows/build.yml@main - with: - go-version: "" - go-version-file: ./go.mod - secrets: - gh_pat: "${{ secrets.PERSONAL_ACCESS_TOKEN }}" + strategy: + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + with: + persist-credentials: false + - uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # v6.1.0 + with: + go-version-file: go.mod + - run: go mod tidy + - run: git diff --exit-code + - run: go build -race ./... + - run: go test -race -failfast ./... diff --git a/.github/workflows/security.yml b/.github/workflows/security.yml new file mode 100644 index 0000000000000000000000000000000000000000..8fc56fa39e7b47d1fe5ba84c0f0e7cb65733a264 --- /dev/null +++ b/.github/workflows/security.yml @@ -0,0 +1,88 @@ +name: "security" + +on: + pull_request: + push: + branches: [main] + schedule: + - cron: "0 2 * * *" + +permissions: + contents: read + +concurrency: + group: security-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + codeql: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + language: ["go", "actions"] + permissions: + actions: read + contents: read + pull-requests: read + security-events: write + steps: + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + with: + persist-credentials: false + - uses: github/codeql-action/init@cf1bb45a277cb3c205638b2cd5c984db1c46a412 # v4.31.7 + with: + languages: ${{ matrix.language }} + - uses: github/codeql-action/autobuild@cf1bb45a277cb3c205638b2cd5c984db1c46a412 # v4.31.7 + - uses: github/codeql-action/analyze@cf1bb45a277cb3c205638b2cd5c984db1c46a412 # v4.31.7 + + grype: + runs-on: ubuntu-latest + permissions: + security-events: write + actions: read + contents: read + steps: + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + with: + persist-credentials: false + - uses: anchore/scan-action@40a61b52209e9d50e87917c5b901783d546b12d0 # v7.2.1 + id: scan + with: + path: "." + fail-build: true + severity-cutoff: critical + - uses: github/codeql-action/upload-sarif@cf1bb45a277cb3c205638b2cd5c984db1c46a412 # v4.31.7 + with: + sarif_file: ${{ steps.scan.outputs.sarif }} + + govulncheck: + runs-on: ubuntu-latest + permissions: + security-events: write + contents: read + steps: + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + with: + persist-credentials: false + - uses: golang/govulncheck-action@b625fbe08f3bccbe446d94fbf87fcc875a4f50ee # v1.0.4 + with: + output-format: sarif + output-file: results.sarif + - uses: github/codeql-action/upload-sarif@cf1bb45a277cb3c205638b2cd5c984db1c46a412 # v4.31.7 + with: + sarif_file: results.sarif + + dependency-review: + runs-on: ubuntu-latest + if: github.event_name == 'pull_request' + permissions: + contents: read + steps: + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + with: + persist-credentials: false + - uses: actions/dependency-review-action@3c4e3dcb1aa7874d2c16be7d79418e9b7efd6261 # v4.8.2 + with: + fail-on-severity: critical + allow-licenses: BSD-2-Clause, BSD-3-Clause, MIT, Apache-2.0, MPL-2.0, ISC diff --git a/Taskfile.yaml b/Taskfile.yaml index 68c805c599314cadde5c86fc37a0e3d1a6184f4e..0043f4f033e455a5800da2431848e620c37a0f5a 100644 --- a/Taskfile.yaml +++ b/Taskfile.yaml @@ -5,6 +5,8 @@ version: "3" vars: VERSION: sh: git describe --long 2>/dev/null || echo "" + RACE: + sh: test -f race.log && echo "1" || echo "" env: CGO_ENABLED: 0 @@ -37,20 +39,20 @@ tasks: vars: LDFLAGS: '{{if .VERSION}}-ldflags="-X github.com/charmbracelet/crush/internal/version.Version={{.VERSION}}"{{end}}' cmds: - - go build {{.LDFLAGS}} . + - "go build {{if .RACE}}-race{{end}} {{.LDFLAGS}} ." generates: - crush run: desc: Run build cmds: - - go build -o crush . - - ./crush {{.CLI_ARGS}} + - task: build + - "./crush {{.CLI_ARGS}} {{if .RACE}}2>race.log{{end}}" test: desc: Run tests cmds: - - go test ./... {{.CLI_ARGS}} + - go test -race -failfast ./... {{.CLI_ARGS}} test:record: desc: Run tests and record all VCR cassettes again diff --git a/go.mod b/go.mod index 8959596f7feca0d6df1ce50e88712e8cc1058fa9..fe3497de825754c0e60835a079f9d6014d9c603a 100644 --- a/go.mod +++ b/go.mod @@ -29,7 +29,7 @@ require ( github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f github.com/charmbracelet/x/exp/ordered v0.1.0 github.com/charmbracelet/x/exp/slice v0.0.0-20251201173703-9f73bfd934ff - github.com/charmbracelet/x/powernap v0.0.0-20251015113943-25f979b54ad4 + github.com/charmbracelet/x/powernap v0.0.0-20260113142046-c1fa3de7983b github.com/charmbracelet/x/term v0.2.2 github.com/denisbrodbeck/machineid v1.0.1 github.com/disintegration/imageorient v0.0.0-20180920195336-8147d86e83ec diff --git a/go.sum b/go.sum index 70582b7c92f86af89a03d9f9a43382e27235d2ca..d3d7696e9729d1a20dc45c7122a847e384bb72df 100644 --- a/go.sum +++ b/go.sum @@ -118,8 +118,8 @@ github.com/charmbracelet/x/exp/slice v0.0.0-20251201173703-9f73bfd934ff h1:Uwr+/ github.com/charmbracelet/x/exp/slice v0.0.0-20251201173703-9f73bfd934ff/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA= github.com/charmbracelet/x/json v0.2.0 h1:DqB+ZGx2h+Z+1s98HOuOyli+i97wsFQIxP2ZQANTPrQ= github.com/charmbracelet/x/json v0.2.0/go.mod h1:opFIflx2YgXgi49xVUu8gEQ21teFAxyMwvOiZhIvWNM= -github.com/charmbracelet/x/powernap v0.0.0-20251015113943-25f979b54ad4 h1:i/XilBPYK4L1Yo/mc9FPx0SyJzIsN0y4sj1MWq9Sscc= -github.com/charmbracelet/x/powernap v0.0.0-20251015113943-25f979b54ad4/go.mod h1:cmdl5zlP5mR8TF2Y68UKc7hdGUDiSJ2+4hk0h04Hsx4= +github.com/charmbracelet/x/powernap v0.0.0-20260113142046-c1fa3de7983b h1:5ye9hzBKH623bMVz5auIuY6K21loCdxpRmFle2O9R/8= +github.com/charmbracelet/x/powernap v0.0.0-20260113142046-c1fa3de7983b/go.mod h1:cmdl5zlP5mR8TF2Y68UKc7hdGUDiSJ2+4hk0h04Hsx4= github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk= github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI= github.com/charmbracelet/x/termios v0.1.1 h1:o3Q2bT8eqzGnGPOYheoYS8eEleT5ZVNYNy8JawjaNZY= diff --git a/internal/csync/maps_test.go b/internal/csync/maps_test.go index 4c590f008dad91e8dcbc40d1b90d87ef1b3e5750..31e6fa0c3aef18a04c61ea3d4d36b5187228c3ff 100644 --- a/internal/csync/maps_test.go +++ b/internal/csync/maps_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "maps" "sync" + "sync/atomic" "testing" "testing/synctest" "time" @@ -46,12 +47,12 @@ func TestNewLazyMap(t *testing.T) { waiter := sync.Mutex{} waiter.Lock() - loadCalled := false + var loadCalled atomic.Bool loadFunc := func() map[string]int { waiter.Lock() defer waiter.Unlock() - loadCalled = true + loadCalled.Store(true) return map[string]int{ "key1": 1, "key2": 2, @@ -63,7 +64,7 @@ func TestNewLazyMap(t *testing.T) { waiter.Unlock() // Allow the load function to proceed time.Sleep(100 * time.Millisecond) - require.True(t, loadCalled) + require.True(t, loadCalled.Load()) require.Equal(t, 2, m.Len()) value, ok := m.Get("key1") diff --git a/internal/permission/permission.go b/internal/permission/permission.go index 9dc85e976238fdbe1ff2d3689b2a2c4160608760..e1bf1bae14b8473989b1c0890c58188591123d71 100644 --- a/internal/permission/permission.go +++ b/internal/permission/permission.go @@ -68,8 +68,9 @@ type permissionService struct { allowedTools []string // used to make sure we only process one request at a time - requestMu sync.Mutex - activeRequest *PermissionRequest + requestMu sync.Mutex + activeRequest *PermissionRequest + activeRequestMu sync.Mutex } func (s *permissionService) GrantPersistent(permission PermissionRequest) { @@ -86,9 +87,11 @@ func (s *permissionService) GrantPersistent(permission PermissionRequest) { s.sessionPermissions = append(s.sessionPermissions, permission) s.sessionPermissionsMu.Unlock() + s.activeRequestMu.Lock() if s.activeRequest != nil && s.activeRequest.ID == permission.ID { s.activeRequest = nil } + s.activeRequestMu.Unlock() } func (s *permissionService) Grant(permission PermissionRequest) { @@ -101,9 +104,11 @@ func (s *permissionService) Grant(permission PermissionRequest) { respCh <- true } + s.activeRequestMu.Lock() if s.activeRequest != nil && s.activeRequest.ID == permission.ID { s.activeRequest = nil } + s.activeRequestMu.Unlock() } func (s *permissionService) Deny(permission PermissionRequest) { @@ -117,9 +122,11 @@ func (s *permissionService) Deny(permission PermissionRequest) { respCh <- false } + s.activeRequestMu.Lock() if s.activeRequest != nil && s.activeRequest.ID == permission.ID { s.activeRequest = nil } + s.activeRequestMu.Unlock() } func (s *permissionService) Request(ctx context.Context, opts CreatePermissionRequest) (bool, error) { @@ -190,7 +197,9 @@ func (s *permissionService) Request(ctx context.Context, opts CreatePermissionRe } s.sessionPermissionsMu.RUnlock() + s.activeRequestMu.Lock() s.activeRequest = &permission + s.activeRequestMu.Unlock() respCh := make(chan bool, 1) s.pendingRequests.Set(permission.ID, respCh) diff --git a/internal/permission/permission_test.go b/internal/permission/permission_test.go index 89e06916024cd1669f5e0d0a263d4a71548c8a97..79930f3ae1e2ef15257f09724fef64d3ea28dada 100644 --- a/internal/permission/permission_test.go +++ b/internal/permission/permission_test.go @@ -189,7 +189,7 @@ func TestPermissionService_SequentialProperties(t *testing.T) { events := service.Subscribe(t.Context()) var wg sync.WaitGroup - results := make([]bool, 0) + results := make([]bool, 3) requests := []CreatePermissionRequest{ { @@ -220,7 +220,7 @@ func TestPermissionService_SequentialProperties(t *testing.T) { go func(index int, request CreatePermissionRequest) { defer wg.Done() result, _ := service.Request(t.Context(), request) - results = append(results, result) + results[index] = result }(i, req) } diff --git a/internal/shell/background.go b/internal/shell/background.go index bc81369ec877586c92fa9bc701d8b78b669f23d5..cb1855836f64bdd56a90802c2bbb939a5a514100 100644 --- a/internal/shell/background.go +++ b/internal/shell/background.go @@ -19,6 +19,30 @@ const ( CompletedJobRetentionMinutes = 8 * 60 ) +// syncBuffer is a thread-safe wrapper around bytes.Buffer. +type syncBuffer struct { + buf bytes.Buffer + mu sync.RWMutex +} + +func (sb *syncBuffer) Write(p []byte) (n int, err error) { + sb.mu.Lock() + defer sb.mu.Unlock() + return sb.buf.Write(p) +} + +func (sb *syncBuffer) WriteString(s string) (n int, err error) { + sb.mu.Lock() + defer sb.mu.Unlock() + return sb.buf.WriteString(s) +} + +func (sb *syncBuffer) String() string { + sb.mu.RLock() + defer sb.mu.RUnlock() + return sb.buf.String() +} + // BackgroundShell represents a shell running in the background. type BackgroundShell struct { ID string @@ -28,8 +52,8 @@ type BackgroundShell struct { WorkingDir string ctx context.Context cancel context.CancelFunc - stdout *bytes.Buffer - stderr *bytes.Buffer + stdout *syncBuffer + stderr *syncBuffer done chan struct{} exitErr error completedAt int64 // Unix timestamp when job completed (0 if still running) @@ -46,12 +70,17 @@ var ( idCounter atomic.Uint64 ) +// newBackgroundShellManager creates a new BackgroundShellManager instance. +func newBackgroundShellManager() *BackgroundShellManager { + return &BackgroundShellManager{ + shells: csync.NewMap[string, *BackgroundShell](), + } +} + // GetBackgroundShellManager returns the singleton background shell manager. func GetBackgroundShellManager() *BackgroundShellManager { backgroundManagerOnce.Do(func() { - backgroundManager = &BackgroundShellManager{ - shells: csync.NewMap[string, *BackgroundShell](), - } + backgroundManager = newBackgroundShellManager() }) return backgroundManager } @@ -80,8 +109,8 @@ func (m *BackgroundShellManager) Start(ctx context.Context, workingDir string, b Shell: shell, ctx: shellCtx, cancel: cancel, - stdout: &bytes.Buffer{}, - stderr: &bytes.Buffer{}, + stdout: &syncBuffer{}, + stderr: &syncBuffer{}, done: make(chan struct{}), } diff --git a/internal/shell/background_test.go b/internal/shell/background_test.go index 5149861d94e457e8a78650c48d9c6765a57d369e..7c521bc1477b07775cffb69f310fa83d710d4634 100644 --- a/internal/shell/background_test.go +++ b/internal/shell/background_test.go @@ -14,7 +14,7 @@ func TestBackgroundShellManager_Start(t *testing.T) { ctx := context.Background() workingDir := t.TempDir() - manager := GetBackgroundShellManager() + manager := newBackgroundShellManager() bgShell, err := manager.Start(ctx, workingDir, nil, "echo 'hello world'", "") if err != nil { @@ -51,7 +51,7 @@ func TestBackgroundShellManager_Get(t *testing.T) { ctx := context.Background() workingDir := t.TempDir() - manager := GetBackgroundShellManager() + manager := newBackgroundShellManager() bgShell, err := manager.Start(ctx, workingDir, nil, "echo 'test'", "") if err != nil { @@ -77,7 +77,7 @@ func TestBackgroundShellManager_Kill(t *testing.T) { ctx := context.Background() workingDir := t.TempDir() - manager := GetBackgroundShellManager() + manager := newBackgroundShellManager() // Start a long-running command bgShell, err := manager.Start(ctx, workingDir, nil, "sleep 10", "") @@ -106,7 +106,7 @@ func TestBackgroundShellManager_Kill(t *testing.T) { func TestBackgroundShellManager_KillNonExistent(t *testing.T) { t.Parallel() - manager := GetBackgroundShellManager() + manager := newBackgroundShellManager() err := manager.Kill("non-existent-id") if err == nil { @@ -119,7 +119,7 @@ func TestBackgroundShell_IsDone(t *testing.T) { ctx := context.Background() workingDir := t.TempDir() - manager := GetBackgroundShellManager() + manager := newBackgroundShellManager() bgShell, err := manager.Start(ctx, workingDir, nil, "echo 'quick'", "") if err != nil { @@ -142,7 +142,7 @@ func TestBackgroundShell_WithBlockFuncs(t *testing.T) { ctx := context.Background() workingDir := t.TempDir() - manager := GetBackgroundShellManager() + manager := newBackgroundShellManager() blockFuncs := []BlockFunc{ CommandsBlocker([]string{"curl", "wget"}), @@ -180,7 +180,7 @@ func TestBackgroundShellManager_List(t *testing.T) { ctx := context.Background() workingDir := t.TempDir() - manager := GetBackgroundShellManager() + manager := newBackgroundShellManager() // Start two shells bgShell1, err := manager.Start(ctx, workingDir, nil, "sleep 1", "") @@ -224,7 +224,7 @@ func TestBackgroundShellManager_KillAll(t *testing.T) { ctx := context.Background() workingDir := t.TempDir() - manager := GetBackgroundShellManager() + manager := newBackgroundShellManager() // Start multiple long-running shells shell1, err := manager.Start(ctx, workingDir, nil, "sleep 10", "") diff --git a/internal/tui/styles/theme.go b/internal/tui/styles/theme.go index f87ffd9de8b324cec4dcfd8b7cee61f71e0390eb..b03603c57439f5f950f9860d3287b0f9d13742e5 100644 --- a/internal/tui/styles/theme.go +++ b/internal/tui/styles/theme.go @@ -4,6 +4,7 @@ import ( "fmt" "image/color" "strings" + "sync" "charm.land/bubbles/v2/filepicker" "charm.land/bubbles/v2/help" @@ -97,7 +98,8 @@ type Theme struct { AuthBorderUnselected lipgloss.Style AuthTextUnselected lipgloss.Style - styles *Styles + styles *Styles + stylesOnce sync.Once } type Styles struct { @@ -134,9 +136,9 @@ type Styles struct { } func (t *Theme) S() *Styles { - if t.styles == nil { + t.stylesOnce.Do(func() { t.styles = t.buildStyles() - } + }) return t.styles } @@ -500,27 +502,31 @@ type Manager struct { current *Theme } -var defaultManager *Manager +var ( + defaultManager *Manager + defaultManagerOnce sync.Once +) + +func initDefaultManager() *Manager { + defaultManagerOnce.Do(func() { + defaultManager = newManager() + }) + return defaultManager +} func SetDefaultManager(m *Manager) { defaultManager = m } func DefaultManager() *Manager { - if defaultManager == nil { - defaultManager = NewManager() - } - return defaultManager + return initDefaultManager() } func CurrentTheme() *Theme { - if defaultManager == nil { - defaultManager = NewManager() - } - return defaultManager.Current() + return initDefaultManager().Current() } -func NewManager() *Manager { +func newManager() *Manager { m := &Manager{ themes: make(map[string]*Theme), } From c57cbc653124b23be8ba9ebede33c124611709d8 Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Tue, 13 Jan 2026 16:45:49 -0300 Subject: [PATCH 03/13] ci: fix govulncheck Signed-off-by: Carlos Alexandro Becker --- .github/workflows/security.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/security.yml b/.github/workflows/security.yml index 8fc56fa39e7b47d1fe5ba84c0f0e7cb65733a264..9c70b17828afe559a096b5a55da1eeed762b9ee9 100644 --- a/.github/workflows/security.yml +++ b/.github/workflows/security.yml @@ -69,6 +69,7 @@ jobs: with: output-format: sarif output-file: results.sarif + go-version-input: 1.26.0-rc.1 # change to "stable" once Go 1.26 is released - uses: github/codeql-action/upload-sarif@cf1bb45a277cb3c205638b2cd5c984db1c46a412 # v4.31.7 with: sarif_file: results.sarif From 151e063dd2502d25f763b8e14bfeba8c870b21bf Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Tue, 13 Jan 2026 17:22:22 -0500 Subject: [PATCH 05/13] fix(ci): security: allow Google Patent License for Go modules --- .github/workflows/security.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/security.yml b/.github/workflows/security.yml index 9c70b17828afe559a096b5a55da1eeed762b9ee9..857184e7cca015984d36b0e08c6762d3570c12f2 100644 --- a/.github/workflows/security.yml +++ b/.github/workflows/security.yml @@ -86,4 +86,4 @@ jobs: - uses: actions/dependency-review-action@3c4e3dcb1aa7874d2c16be7d79418e9b7efd6261 # v4.8.2 with: fail-on-severity: critical - allow-licenses: BSD-2-Clause, BSD-3-Clause, MIT, Apache-2.0, MPL-2.0, ISC + allow-licenses: BSD-2-Clause, BSD-3-Clause, MIT, Apache-2.0, MPL-2.0, ISC, LicenseRef-scancode-google-patent-license-golang From 340defd5675f4fec003e5c4578a79a29fca4ea73 Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Tue, 13 Jan 2026 17:34:57 -0500 Subject: [PATCH 06/13] fix(ci): update security workflow to use setup-go and install govulncheck --- .github/workflows/security.yml | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/.github/workflows/security.yml b/.github/workflows/security.yml index 857184e7cca015984d36b0e08c6762d3570c12f2..3a90ea316c3d86f5b2f93224fd2b35eaa572e704 100644 --- a/.github/workflows/security.yml +++ b/.github/workflows/security.yml @@ -65,11 +65,14 @@ jobs: - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: persist-credentials: false - - uses: golang/govulncheck-action@b625fbe08f3bccbe446d94fbf87fcc875a4f50ee # v1.0.4 + - uses: actions/setup-go@7a3fe6cf4cb3a834922a1244abfce67bcef6a0c5 # v6.2.0 with: - output-format: sarif - output-file: results.sarif - go-version-input: 1.26.0-rc.1 # change to "stable" once Go 1.26 is released + go-version: 1.26.0-rc.1 # change to "stable" once Go 1.26 is released + - name: Install govulncheck + run: go install golang.org/x/vuln/cmd/govulncheck@latest + - name: Run govulncheck + run: | + govulncheck -C . -format sarif ./... > results.sarif - uses: github/codeql-action/upload-sarif@cf1bb45a277cb3c205638b2cd5c984db1c46a412 # v4.31.7 with: sarif_file: results.sarif From 67437151e91296db7c25908ef0e69031339de56a Mon Sep 17 00:00:00 2001 From: kslamph <15257433+kslamph@users.noreply.github.com> Date: Wed, 14 Jan 2026 06:46:07 +0800 Subject: [PATCH 07/13] feat: add clipboard image paste functionality to chat editor (#181) (#1151) Co-authored-by: Ayman Bagabas --- go.mod | 6 +- go.sum | 12 ++- internal/tui/components/chat/editor/editor.go | 79 +++++++++++++++++++ internal/tui/components/chat/editor/keys.go | 6 ++ 4 files changed, 97 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index fe3497de825754c0e60835a079f9d6014d9c603a..72f53ccfbd743a333730b34ac28cc0edffe5aa50 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/PuerkitoBio/goquery v1.11.0 github.com/alecthomas/chroma/v2 v2.22.0 github.com/atotto/clipboard v0.1.4 + github.com/aymanbagabas/go-nativeclipboard v0.1.2 github.com/aymanbagabas/go-udiff v0.3.1 github.com/bmatcuk/doublestar/v4 v4.9.2 github.com/charlievieth/fastwalk v1.0.14 @@ -100,13 +101,14 @@ require ( github.com/charmbracelet/x/json v0.2.0 // indirect github.com/charmbracelet/x/termios v0.1.1 // indirect github.com/charmbracelet/x/windows v0.2.2 // indirect - github.com/clipperhouse/displaywidth v0.6.1 // indirect + github.com/clipperhouse/displaywidth v0.6.2 // indirect github.com/clipperhouse/stringish v0.1.1 // indirect github.com/clipperhouse/uax29/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/disintegration/gift v1.1.2 // indirect github.com/dlclark/regexp2 v1.11.5 // indirect github.com/dustin/go-humanize v1.0.1 // indirect + github.com/ebitengine/purego v0.10.0-alpha.3.0.20260102153238-200df6041cff // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e // indirect @@ -171,7 +173,7 @@ require ( go.yaml.in/yaml/v4 v4.0.0-rc.3 // indirect golang.org/x/crypto v0.47.0 // indirect golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b // indirect - golang.org/x/image v0.27.0 // indirect + golang.org/x/image v0.34.0 // indirect golang.org/x/oauth2 v0.34.0 // indirect golang.org/x/sys v0.40.0 // indirect golang.org/x/term v0.39.0 // indirect diff --git a/go.sum b/go.sum index d3d7696e9729d1a20dc45c7122a847e384bb72df..8973fcdc1b227d0b30aad691220103319afa93ca 100644 --- a/go.sum +++ b/go.sum @@ -78,6 +78,8 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.41.6 h1:5fFjR/ToSOzB2OQ/XqWpZBmNvmP/ github.com/aws/aws-sdk-go-v2/service/sts v1.41.6/go.mod h1:qgFDZQSD/Kys7nJnVqYlWKnh0SSdMjAi0uSwON4wgYQ= github.com/aws/smithy-go v1.24.0 h1:LpilSUItNPFr1eY85RYgTIg5eIEPtvFbskaFcmmIUnk= github.com/aws/smithy-go v1.24.0/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= +github.com/aymanbagabas/go-nativeclipboard v0.1.2 h1:Z2iVRWQ4IynMLWM6a+lWH2Nk5gPyEtPRMuBIyZ2dECM= +github.com/aymanbagabas/go-nativeclipboard v0.1.2/go.mod h1:BVJhN7hs5DieCzUB2Atf4Yk9Y9kFe62E95+gOjpJq6Q= github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= github.com/aymanbagabas/go-udiff v0.3.1 h1:LV+qyBQ2pqe0u42ZsUEtPiCaUoqgA9gYRDs3vj1nolY= @@ -126,8 +128,8 @@ github.com/charmbracelet/x/termios v0.1.1 h1:o3Q2bT8eqzGnGPOYheoYS8eEleT5ZVNYNy8 github.com/charmbracelet/x/termios v0.1.1/go.mod h1:rB7fnv1TgOPOyyKRJ9o+AsTU/vK5WHJ2ivHeut/Pcwo= github.com/charmbracelet/x/windows v0.2.2 h1:IofanmuvaxnKHuV04sC0eBy/smG6kIKrWG2/jYn2GuM= github.com/charmbracelet/x/windows v0.2.2/go.mod h1:/8XtdKZzedat74NQFn0NGlGL4soHB0YQZrETF96h75k= -github.com/clipperhouse/displaywidth v0.6.1 h1:/zMlAezfDzT2xy6acHBzwIfyu2ic0hgkT83UX5EY2gY= -github.com/clipperhouse/displaywidth v0.6.1/go.mod h1:R+kHuzaYWFkTm7xoMmK1lFydbci4X2CicfbGstSGg0o= +github.com/clipperhouse/displaywidth v0.6.2 h1:ZDpTkFfpHOKte4RG5O/BOyf3ysnvFswpyYrV7z2uAKo= +github.com/clipperhouse/displaywidth v0.6.2/go.mod h1:R+kHuzaYWFkTm7xoMmK1lFydbci4X2CicfbGstSGg0o= github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs= github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA= github.com/clipperhouse/uax29/v2 v2.3.0 h1:SNdx9DVUqMoBuBoW3iLOj4FQv3dN5mDtuqwuhIGpJy4= @@ -150,6 +152,8 @@ github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZ github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/ebitengine/purego v0.10.0-alpha.3.0.20260102153238-200df6041cff h1:vAcU1VsCRstZ9ty11yD/L0WDyT73S/gVfmuWvcWX5DA= +github.com/ebitengine/purego v0.10.0-alpha.3.0.20260102153238-200df6041cff/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= github.com/envoyproxy/go-control-plane v0.13.4 h1:zEqyPVyku6IvWCFwux4x9RxkLOMUL+1vC9xUFv5l2/M= github.com/envoyproxy/go-control-plane/envoy v1.32.4 h1:jb83lalDRZSpPWW2Z7Mck/8kXZ5CQAFYVjQcdVIr83A= github.com/envoyproxy/go-control-plane/envoy v1.32.4/go.mod h1:Gzjc5k8JcJswLjAx1Zm+wSYE20UrLtt7JZMWiWQXQEw= @@ -389,8 +393,8 @@ golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= -golang.org/x/image v0.27.0 h1:C8gA4oWU/tKkdCfYT6T2u4faJu3MeNS5O8UPWlPF61w= -golang.org/x/image v0.27.0/go.mod h1:xbdrClrAUway1MUTEZDq9mz/UpRwYAkFFNUslZtcB+g= +golang.org/x/image v0.34.0 h1:33gCkyw9hmwbZJeZkct8XyR11yH889EQt/QH4VmXMn8= +golang.org/x/image v0.34.0/go.mod h1:2RNFBZRB+vnwwFil8GkMdRvrJOFd1AzdZI6vOY+eJVU= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= diff --git a/internal/tui/components/chat/editor/editor.go b/internal/tui/components/chat/editor/editor.go index 01badb98d37eb848ccf5962e01793ecaa3fc0f59..b5cadb8cde8a1ced8543d01eb7abd28d906f1597 100644 --- a/internal/tui/components/chat/editor/editor.go +++ b/internal/tui/components/chat/editor/editor.go @@ -16,6 +16,7 @@ import ( "charm.land/bubbles/v2/textarea" tea "charm.land/bubbletea/v2" "charm.land/lipgloss/v2" + nativeclipboard "github.com/aymanbagabas/go-nativeclipboard" "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/filetracker" "github.com/charmbracelet/crush/internal/fsext" @@ -338,6 +339,84 @@ func (m *editorCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { m.textarea.InsertRune('\n') cmds = append(cmds, util.CmdHandler(completions.CloseCompletionsMsg{})) } + // Handle image paste from clipboard + if key.Matches(msg, m.keyMap.PasteImage) { + imageData, err := nativeclipboard.Image.Read() + + if err != nil || len(imageData) == 0 { + // If no image data found, try to get text data (could be file path) + var textData []byte + textData, err = nativeclipboard.Text.Read() + if err != nil || len(textData) == 0 { + // If clipboard is empty, show a warning + return m, util.ReportWarn("No data found in clipboard. Note: Some terminals may not support reading image data from clipboard directly.") + } + + // Check if the text data is a file path + textStr := string(textData) + // First, try to interpret as a file path (existing functionality) + path := strings.ReplaceAll(textStr, "\\ ", " ") + path, err = filepath.Abs(strings.TrimSpace(path)) + if err == nil { + isAllowedType := false + for _, ext := range filepicker.AllowedTypes { + if strings.HasSuffix(path, ext) { + isAllowedType = true + break + } + } + if isAllowedType { + tooBig, _ := filepicker.IsFileTooBig(path, filepicker.MaxAttachmentSize) + if !tooBig { + content, err := os.ReadFile(path) + if err == nil { + mimeBufferSize := min(512, len(content)) + mimeType := http.DetectContentType(content[:mimeBufferSize]) + fileName := filepath.Base(path) + attachment := message.Attachment{FilePath: path, FileName: fileName, MimeType: mimeType, Content: content} + return m, util.CmdHandler(filepicker.FilePickedMsg{ + Attachment: attachment, + }) + } + } + } + } + + // If not a valid file path, show a warning + return m, util.ReportWarn("No image found in clipboard") + } else { + // We have image data from the clipboard + // Create a temporary file to store the clipboard image data + tempFile, err := os.CreateTemp("", "clipboard_image_crush_*") + if err != nil { + return m, util.ReportError(err) + } + defer tempFile.Close() + + // Write clipboard content to the temporary file + _, err = tempFile.Write(imageData) + if err != nil { + return m, util.ReportError(err) + } + + // Determine the file extension based on the image data + mimeBufferSize := min(512, len(imageData)) + mimeType := http.DetectContentType(imageData[:mimeBufferSize]) + + // Create an attachment from the temporary file + fileName := filepath.Base(tempFile.Name()) + attachment := message.Attachment{ + FilePath: tempFile.Name(), + FileName: fileName, + MimeType: mimeType, + Content: imageData, + } + + return m, util.CmdHandler(filepicker.FilePickedMsg{ + Attachment: attachment, + }) + } + } // Handle Enter key if m.textarea.Focused() && key.Matches(msg, m.keyMap.SendMessage) { value := m.textarea.Value() diff --git a/internal/tui/components/chat/editor/keys.go b/internal/tui/components/chat/editor/keys.go index 0ba4571888e547b1c4a85e7ee9dd73ff07ce13d2..c20df5cc1c071deab83754430543b9be2381127c 100644 --- a/internal/tui/components/chat/editor/keys.go +++ b/internal/tui/components/chat/editor/keys.go @@ -9,6 +9,7 @@ type EditorKeyMap struct { SendMessage key.Binding OpenEditor key.Binding Newline key.Binding + PasteImage key.Binding } func DefaultEditorKeyMap() EditorKeyMap { @@ -32,6 +33,10 @@ func DefaultEditorKeyMap() EditorKeyMap { // to reflect that. key.WithHelp("ctrl+j", "newline"), ), + PasteImage: key.NewBinding( + key.WithKeys("ctrl+v"), + key.WithHelp("ctrl+v", "paste image from clipboard"), + ), } } @@ -42,6 +47,7 @@ func (k EditorKeyMap) KeyBindings() []key.Binding { k.SendMessage, k.OpenEditor, k.Newline, + k.PasteImage, AttachmentsKeyMaps.AttachmentDeleteMode, AttachmentsKeyMaps.DeleteAllAttachments, AttachmentsKeyMaps.Escape, From a13019640e124872c3671139aa2d06cbf96c3b5f Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Tue, 13 Jan 2026 21:25:33 -0500 Subject: [PATCH 08/13] chore(README): update crush art (#1861) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 929b77425f1b42452a4e38d8cfa540773dd54a79..cfcd765ee150d181c00cf649a5ba15055b6bdbae 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # Crush

- Charm Crush Logo
+ Charm Crush Logo
Latest Release Build Status

From b8d88ddb8e4590b5d0ed434b7f2cdb69dcc8ad5b Mon Sep 17 00:00:00 2001 From: Charm <124303983+charmcli@users.noreply.github.com> Date: Wed, 14 Jan 2026 01:29:05 -0300 Subject: [PATCH 09/13] chore(legal): @kuxoapp has signed the CLA --- .github/cla-signatures.json | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.github/cla-signatures.json b/.github/cla-signatures.json index cf21b7c02c3ecb20d01ac8250cee76e2727b81b2..29ea40e9d11b164aacbbcc4539c541bdf2ba1214 100644 --- a/.github/cla-signatures.json +++ b/.github/cla-signatures.json @@ -1055,6 +1055,14 @@ "created_at": "2026-01-12T22:16:05Z", "repoId": 987670088, "pullRequestNo": 1841 + }, + { + "name": "kuxoapp", + "id": 254052994, + "comment_id": 3747622477, + "created_at": "2026-01-14T04:18:44Z", + "repoId": 987670088, + "pullRequestNo": 1864 } ] } \ No newline at end of file From 8f2ae5ceea85c17e11eeda685db90c64f24632c5 Mon Sep 17 00:00:00 2001 From: Charm <124303983+charmcli@users.noreply.github.com> Date: Wed, 14 Jan 2026 11:02:16 -0300 Subject: [PATCH 10/13] chore(legal): @mhpenta has signed the CLA --- .github/cla-signatures.json | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.github/cla-signatures.json b/.github/cla-signatures.json index 29ea40e9d11b164aacbbcc4539c541bdf2ba1214..5929987f916594da1109eee2082c154620edf660 100644 --- a/.github/cla-signatures.json +++ b/.github/cla-signatures.json @@ -1063,6 +1063,14 @@ "created_at": "2026-01-14T04:18:44Z", "repoId": 987670088, "pullRequestNo": 1864 + }, + { + "name": "mhpenta", + "id": 183146177, + "comment_id": 3749703014, + "created_at": "2026-01-14T14:02:04Z", + "repoId": 987670088, + "pullRequestNo": 1870 } ] } \ No newline at end of file From f66762b917c3ac133ca8f1fc430c5cbd524412f0 Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Wed, 14 Jan 2026 12:16:31 -0300 Subject: [PATCH 11/13] fix: race in agent.go (#1853) Signed-off-by: Carlos Alexandro Becker --- internal/agent/agent.go | 107 +++++++++++++++++++--------------- internal/agent/event.go | 12 ++-- internal/csync/slices.go | 44 ++++---------- internal/csync/slices_test.go | 57 ------------------ internal/csync/value.go | 44 ++++++++++++++ internal/csync/value_test.go | 99 +++++++++++++++++++++++++++++++ 6 files changed, 218 insertions(+), 145 deletions(-) create mode 100644 internal/csync/value.go create mode 100644 internal/csync/value_test.go diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 198159d53adbcbba8f8598bf24a8eef55825acfc..c0b9080bb640085c6fd0fdbde8db0fbfe7f476dd 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -87,12 +87,13 @@ type Model struct { } type sessionAgent struct { - largeModel Model - smallModel Model - systemPromptPrefix string - systemPrompt string + largeModel *csync.Value[Model] + smallModel *csync.Value[Model] + systemPromptPrefix *csync.Value[string] + systemPrompt *csync.Value[string] + tools *csync.Slice[fantasy.AgentTool] + isSubAgent bool - tools []fantasy.AgentTool sessions session.Service messages message.Service disableAutoSummarize bool @@ -119,15 +120,15 @@ func NewSessionAgent( opts SessionAgentOptions, ) SessionAgent { return &sessionAgent{ - largeModel: opts.LargeModel, - smallModel: opts.SmallModel, - systemPromptPrefix: opts.SystemPromptPrefix, - systemPrompt: opts.SystemPrompt, + largeModel: csync.NewValue(opts.LargeModel), + smallModel: csync.NewValue(opts.SmallModel), + systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix), + systemPrompt: csync.NewValue(opts.SystemPrompt), isSubAgent: opts.IsSubAgent, sessions: opts.Sessions, messages: opts.Messages, disableAutoSummarize: opts.DisableAutoSummarize, - tools: opts.Tools, + tools: csync.NewSliceFrom(opts.Tools), isYolo: opts.IsYolo, messageQueue: csync.NewMap[string, []SessionAgentCall](), activeRequests: csync.NewMap[string, context.CancelFunc](), @@ -153,15 +154,21 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy return nil, nil } - if len(a.tools) > 0 { + // Copy mutable fields under lock to avoid races with SetTools/SetModels. + agentTools := a.tools.Copy() + largeModel := a.largeModel.Get() + systemPrompt := a.systemPrompt.Get() + promptPrefix := a.systemPromptPrefix.Get() + + if len(agentTools) > 0 { // Add Anthropic caching to the last tool. - a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions()) + agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions()) } agent := fantasy.NewAgent( - a.largeModel.Model, - fantasy.WithSystemPrompt(a.systemPrompt), - fantasy.WithTools(a.tools...), + largeModel.Model, + fantasy.WithSystemPrompt(systemPrompt), + fantasy.WithTools(agentTools...), ) sessionLock := sync.Mutex{} @@ -234,7 +241,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...) } - prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages) + prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel) lastSystemRoleInx := 0 systemMessageUpdated := false @@ -252,7 +259,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy } } - if promptPrefix := a.promptPrefix(); promptPrefix != "" { + if promptPrefix != "" { prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...) } @@ -260,15 +267,15 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{ Role: message.Assistant, Parts: []message.ContentPart{}, - Model: a.largeModel.ModelCfg.Model, - Provider: a.largeModel.ModelCfg.Provider, + Model: largeModel.ModelCfg.Model, + Provider: largeModel.ModelCfg.Provider, }) if err != nil { return callContext, prepared, err } callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID) - callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages) - callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name) + callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages) + callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name) currentAssistant = &assistantMsg return callContext, prepared, err }, @@ -362,7 +369,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy sessionLock.Unlock() return getSessionErr } - a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata)) + a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata)) _, sessionErr := a.sessions.Save(genCtx, updatedSession) sessionLock.Unlock() if sessionErr != nil { @@ -372,7 +379,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy }, StopWhen: []fantasy.StopCondition{ func(_ []fantasy.StepResult) bool { - cw := int64(a.largeModel.CatwalkCfg.ContextWindow) + cw := int64(largeModel.CatwalkCfg.ContextWindow) tokens := currentSession.CompletionTokens + currentSession.PromptTokens remaining := cw - tokens var threshold int64 @@ -474,7 +481,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy currentAssistant.AddFinish( message.FinishReasonError, "Copilot model not enabled", - fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", a.largeModel.CatwalkCfg.Name, link), + fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link), ) } else { currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message) @@ -529,6 +536,10 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan return ErrSessionBusy } + // Copy mutable fields under lock to avoid races with SetModels. + largeModel := a.largeModel.Get() + systemPromptPrefix := a.systemPromptPrefix.Get() + currentSession, err := a.sessions.Get(ctx, sessionID) if err != nil { return fmt.Errorf("failed to get session: %w", err) @@ -549,13 +560,13 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan defer a.activeRequests.Del(sessionID) defer cancel() - agent := fantasy.NewAgent(a.largeModel.Model, + agent := fantasy.NewAgent(largeModel.Model, fantasy.WithSystemPrompt(string(summaryPrompt)), ) summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ Role: message.Assistant, - Model: a.largeModel.Model.Model(), - Provider: a.largeModel.Model.Provider(), + Model: largeModel.Model.Model(), + Provider: largeModel.Model.Provider(), IsSummaryMessage: true, }) if err != nil { @@ -570,8 +581,8 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan ProviderOptions: opts, PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) { prepared.Messages = options.Messages - if a.systemPromptPrefix != "" { - prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...) + if systemPromptPrefix != "" { + prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...) } return callContext, prepared, nil }, @@ -622,7 +633,7 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan } } - a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost) + a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost) // Just in case, get just the last usage info. usage := resp.Response.Usage @@ -730,9 +741,13 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, user return } + smallModel := a.smallModel.Get() + largeModel := a.largeModel.Get() + systemPromptPrefix := a.systemPromptPrefix.Get() + var maxOutputTokens int64 = 40 - if a.smallModel.CatwalkCfg.CanReason { - maxOutputTokens = a.smallModel.CatwalkCfg.DefaultMaxTokens + if smallModel.CatwalkCfg.CanReason { + maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens } newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent { @@ -746,9 +761,9 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, user Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n \n\n", userPrompt), PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) { prepared.Messages = opts.Messages - if a.systemPromptPrefix != "" { + if systemPromptPrefix != "" { prepared.Messages = append([]fantasy.Message{ - fantasy.NewSystemMessage(a.systemPromptPrefix), + fantasy.NewSystemMessage(systemPromptPrefix), }, prepared.Messages...) } return callCtx, prepared, nil @@ -756,7 +771,7 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, user } // Use the small model to generate the title. - model := &a.smallModel + model := smallModel agent := newAgent(model.Model, titlePrompt, maxOutputTokens) resp, err := agent.Stream(ctx, streamCall) if err == nil { @@ -765,7 +780,7 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, user } else { // It didn't work. Let's try with the big model. slog.Error("error generating title with small model; trying big model", "err", err) - model = &a.largeModel + model = largeModel agent = newAgent(model.Model, titlePrompt, maxOutputTokens) resp, err = agent.Stream(ctx, streamCall) if err == nil { @@ -960,24 +975,20 @@ func (a *sessionAgent) QueuedPromptsList(sessionID string) []string { } func (a *sessionAgent) SetModels(large Model, small Model) { - a.largeModel = large - a.smallModel = small + a.largeModel.Set(large) + a.smallModel.Set(small) } func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) { - a.tools = tools + a.tools.SetSlice(tools) } func (a *sessionAgent) SetSystemPrompt(systemPrompt string) { - a.systemPrompt = systemPrompt + a.systemPrompt.Set(systemPrompt) } func (a *sessionAgent) Model() Model { - return a.largeModel -} - -func (a *sessionAgent) promptPrefix() string { - return a.systemPromptPrefix + return a.largeModel.Get() } // convertToToolResult converts a fantasy tool result to a message tool result. @@ -1034,9 +1045,9 @@ func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) mes // // BEFORE: [tool result: image data] // AFTER: [tool result: "Image loaded - see attached"], [user: image attachment] -func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message { - providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) || - a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock) +func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message { + providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) || + largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock) if providerSupportsMedia { return messages diff --git a/internal/agent/event.go b/internal/agent/event.go index bf36ec84bf4270bd2e63ae0efae0440474288565..3f6c640f6a983c515034e0698676632d0cb57824 100644 --- a/internal/agent/event.go +++ b/internal/agent/event.go @@ -7,23 +7,23 @@ import ( "github.com/charmbracelet/crush/internal/event" ) -func (a sessionAgent) eventPromptSent(sessionID string) { +func (a *sessionAgent) eventPromptSent(sessionID string) { event.PromptSent( - a.eventCommon(sessionID, a.largeModel)..., + a.eventCommon(sessionID, a.largeModel.Get())..., ) } -func (a sessionAgent) eventPromptResponded(sessionID string, duration time.Duration) { +func (a *sessionAgent) eventPromptResponded(sessionID string, duration time.Duration) { event.PromptResponded( append( - a.eventCommon(sessionID, a.largeModel), + a.eventCommon(sessionID, a.largeModel.Get()), "prompt duration pretty", duration.String(), "prompt duration in seconds", int64(duration.Seconds()), )..., ) } -func (a sessionAgent) eventTokensUsed(sessionID string, model Model, usage fantasy.Usage, cost float64) { +func (a *sessionAgent) eventTokensUsed(sessionID string, model Model, usage fantasy.Usage, cost float64) { event.TokensUsed( append( a.eventCommon(sessionID, model), @@ -37,7 +37,7 @@ func (a sessionAgent) eventTokensUsed(sessionID string, model Model, usage fanta ) } -func (a sessionAgent) eventCommon(sessionID string, model Model) []any { +func (a *sessionAgent) eventCommon(sessionID string, model Model) []any { m := model.ModelCfg return []any{ diff --git a/internal/csync/slices.go b/internal/csync/slices.go index c5c635683e70046694f1cdf647aac8cb425abd24..fcce9881b6e27021adcc9462b123f49d469dcd9f 100644 --- a/internal/csync/slices.go +++ b/internal/csync/slices.go @@ -2,7 +2,6 @@ package csync import ( "iter" - "slices" "sync" ) @@ -63,24 +62,6 @@ func (s *Slice[T]) Append(items ...T) { s.inner = append(s.inner, items...) } -// Prepend adds an element to the beginning of the slice. -func (s *Slice[T]) Prepend(item T) { - s.mu.Lock() - defer s.mu.Unlock() - s.inner = append([]T{item}, s.inner...) -} - -// Delete removes the element at the specified index. -func (s *Slice[T]) Delete(index int) bool { - s.mu.Lock() - defer s.mu.Unlock() - if index < 0 || index >= len(s.inner) { - return false - } - s.inner = slices.Delete(s.inner, index, index+1) - return true -} - // Get returns the element at the specified index. func (s *Slice[T]) Get(index int) (T, bool) { s.mu.RLock() @@ -92,17 +73,6 @@ func (s *Slice[T]) Get(index int) (T, bool) { return s.inner[index], true } -// Set updates the element at the specified index. -func (s *Slice[T]) Set(index int, item T) bool { - s.mu.Lock() - defer s.mu.Unlock() - if index < 0 || index >= len(s.inner) { - return false - } - s.inner[index] = item - return true -} - // Len returns the number of elements in the slice. func (s *Slice[T]) Len() int { s.mu.RLock() @@ -131,10 +101,7 @@ func (s *Slice[T]) Seq() iter.Seq[T] { // Seq2 returns an iterator that yields index-value pairs from the slice. func (s *Slice[T]) Seq2() iter.Seq2[int, T] { - s.mu.RLock() - items := make([]T, len(s.inner)) - copy(items, s.inner) - s.mu.RUnlock() + items := s.Copy() return func(yield func(int, T) bool) { for i, v := range items { if !yield(i, v) { @@ -143,3 +110,12 @@ func (s *Slice[T]) Seq2() iter.Seq2[int, T] { } } } + +// Copy returns a copy of the inner slice. +func (s *Slice[T]) Copy() []T { + s.mu.RLock() + defer s.mu.RUnlock() + items := make([]T, len(s.inner)) + copy(items, s.inner) + return items +} diff --git a/internal/csync/slices_test.go b/internal/csync/slices_test.go index 85aedbaba40103ff9a8979e5c70299223f74591f..c7946ac6f1a84614def05b7b6e7e9b0ed11b3a73 100644 --- a/internal/csync/slices_test.go +++ b/internal/csync/slices_test.go @@ -109,44 +109,6 @@ func TestSlice(t *testing.T) { require.Equal(t, "world", val) }) - t.Run("Prepend", func(t *testing.T) { - s := NewSlice[string]() - s.Append("world") - s.Prepend("hello") - - require.Equal(t, 2, s.Len()) - val, ok := s.Get(0) - require.True(t, ok) - require.Equal(t, "hello", val) - - val, ok = s.Get(1) - require.True(t, ok) - require.Equal(t, "world", val) - }) - - t.Run("Delete", func(t *testing.T) { - s := NewSliceFrom([]int{1, 2, 3, 4, 5}) - - // Delete middle element - ok := s.Delete(2) - require.True(t, ok) - require.Equal(t, 4, s.Len()) - - expected := []int{1, 2, 4, 5} - actual := slices.Collect(s.Seq()) - require.Equal(t, expected, actual) - - // Delete out of bounds - ok = s.Delete(10) - require.False(t, ok) - require.Equal(t, 4, s.Len()) - - // Delete negative index - ok = s.Delete(-1) - require.False(t, ok) - require.Equal(t, 4, s.Len()) - }) - t.Run("Get", func(t *testing.T) { s := NewSliceFrom([]string{"a", "b", "c"}) @@ -163,25 +125,6 @@ func TestSlice(t *testing.T) { require.False(t, ok) }) - t.Run("Set", func(t *testing.T) { - s := NewSliceFrom([]string{"a", "b", "c"}) - - ok := s.Set(1, "modified") - require.True(t, ok) - - val, ok := s.Get(1) - require.True(t, ok) - require.Equal(t, "modified", val) - - // Out of bounds - ok = s.Set(10, "invalid") - require.False(t, ok) - - // Negative index - ok = s.Set(-1, "invalid") - require.False(t, ok) - }) - t.Run("SetSlice", func(t *testing.T) { s := NewSlice[int]() s.Append(1) diff --git a/internal/csync/value.go b/internal/csync/value.go new file mode 100644 index 0000000000000000000000000000000000000000..17528a281e0d34d49b206a7c3901b892370c18ba --- /dev/null +++ b/internal/csync/value.go @@ -0,0 +1,44 @@ +package csync + +import ( + "reflect" + "sync" +) + +// Value is a generic thread-safe wrapper for any value type. +// +// For slices, use [Slice]. For maps, use [Map]. Pointers are not supported. +type Value[T any] struct { + v T + mu sync.RWMutex +} + +// NewValue creates a new Value with the given initial value. +// +// Panics if t is a pointer, slice, or map. Use the dedicated types for those. +func NewValue[T any](t T) *Value[T] { + v := reflect.ValueOf(t) + switch v.Kind() { + case reflect.Pointer: + panic("csync.Value does not support pointer types") + case reflect.Slice: + panic("csync.Value does not support slice types; use csync.Slice") + case reflect.Map: + panic("csync.Value does not support map types; use csync.Map") + } + return &Value[T]{v: t} +} + +// Get returns the current value. +func (v *Value[T]) Get() T { + v.mu.RLock() + defer v.mu.RUnlock() + return v.v +} + +// Set updates the value. +func (v *Value[T]) Set(t T) { + v.mu.Lock() + defer v.mu.Unlock() + v.v = t +} diff --git a/internal/csync/value_test.go b/internal/csync/value_test.go new file mode 100644 index 0000000000000000000000000000000000000000..3fa41d85144ea9373c7d440238c0321f52286330 --- /dev/null +++ b/internal/csync/value_test.go @@ -0,0 +1,99 @@ +package csync + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestValue_GetSet(t *testing.T) { + t.Parallel() + + v := NewValue(42) + require.Equal(t, 42, v.Get()) + + v.Set(100) + require.Equal(t, 100, v.Get()) +} + +func TestValue_ZeroValue(t *testing.T) { + t.Parallel() + + v := NewValue("") + require.Equal(t, "", v.Get()) + + v.Set("hello") + require.Equal(t, "hello", v.Get()) +} + +func TestValue_Struct(t *testing.T) { + t.Parallel() + + type config struct { + Name string + Count int + } + + v := NewValue(config{Name: "test", Count: 1}) + require.Equal(t, config{Name: "test", Count: 1}, v.Get()) + + v.Set(config{Name: "updated", Count: 2}) + require.Equal(t, config{Name: "updated", Count: 2}, v.Get()) +} + +func TestValue_PointerPanics(t *testing.T) { + t.Parallel() + + require.Panics(t, func() { + NewValue(&struct{}{}) + }) +} + +func TestValue_SlicePanics(t *testing.T) { + t.Parallel() + + require.Panics(t, func() { + NewValue([]string{"a", "b"}) + }) +} + +func TestValue_MapPanics(t *testing.T) { + t.Parallel() + + require.Panics(t, func() { + NewValue(map[string]int{"a": 1}) + }) +} + +func TestValue_ConcurrentAccess(t *testing.T) { + t.Parallel() + + v := NewValue(0) + var wg sync.WaitGroup + + // Concurrent writers. + for i := range 100 { + wg.Add(1) + go func(val int) { + defer wg.Done() + v.Set(val) + }(i) + } + + // Concurrent readers. + for range 100 { + wg.Add(1) + go func() { + defer wg.Done() + _ = v.Get() + }() + } + + wg.Wait() + + // Value should be one of the set values (0-99). + got := v.Get() + require.GreaterOrEqual(t, got, 0) + require.Less(t, got, 100) +} From f7de0d5d9ec1721bd371e77e20b3b9281eb829d5 Mon Sep 17 00:00:00 2001 From: Andrey Nering Date: Wed, 14 Jan 2026 16:48:04 -0300 Subject: [PATCH 12/13] chore(deps): update catwalk to being openai gpt 5.2 codex --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 72f53ccfbd743a333730b34ac28cc0edffe5aa50..7d96d6bca5a5c659814a1911c11800f7b2e71c61 100644 --- a/go.mod +++ b/go.mod @@ -19,7 +19,7 @@ require ( github.com/aymanbagabas/go-udiff v0.3.1 github.com/bmatcuk/doublestar/v4 v4.9.2 github.com/charlievieth/fastwalk v1.0.14 - github.com/charmbracelet/catwalk v0.13.0 + github.com/charmbracelet/catwalk v0.14.1 github.com/charmbracelet/colorprofile v0.4.1 github.com/charmbracelet/fang v0.4.4 github.com/charmbracelet/ultraviolet v0.0.0-20251212194010-b927aa605560 diff --git a/go.sum b/go.sum index 8973fcdc1b227d0b30aad691220103319afa93ca..5a2b20e02e02085bf7f8559d946bced27a20cc27 100644 --- a/go.sum +++ b/go.sum @@ -96,8 +96,8 @@ github.com/charlievieth/fastwalk v1.0.14 h1:3Eh5uaFGwHZd8EGwTjJnSpBkfwfsak9h6ICg github.com/charlievieth/fastwalk v1.0.14/go.mod h1:diVcUreiU1aQ4/Wu3NbxxH4/KYdKpLDojrQ1Bb2KgNY= github.com/charmbracelet/anthropic-sdk-go v0.0.0-20251024181547-21d6f3d9a904 h1:rwLdEpG9wE6kL69KkEKDiWprO8pQOZHZXeod6+9K+mw= github.com/charmbracelet/anthropic-sdk-go v0.0.0-20251024181547-21d6f3d9a904/go.mod h1:8TIYxZxsuCqqeJ0lga/b91tBwrbjoHDC66Sq5t8N2R4= -github.com/charmbracelet/catwalk v0.13.0 h1:L+chddP+PJvX3Vl+hqlWW5HAwBErlkL/friQXih1JQI= -github.com/charmbracelet/catwalk v0.13.0/go.mod h1:qg+Yl9oaZTkTvRscqbxfttzOFQ4v0pOT5XwC7b5O0NQ= +github.com/charmbracelet/catwalk v0.14.1 h1:n16H880MHW8PPgQeh0dorP77AJMxw5JcOUPuC3FFhaQ= +github.com/charmbracelet/catwalk v0.14.1/go.mod h1:qg+Yl9oaZTkTvRscqbxfttzOFQ4v0pOT5XwC7b5O0NQ= github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk= github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk= github.com/charmbracelet/fang v0.4.4 h1:G4qKxF6or/eTPgmAolwPuRNyuci3hTUGGX1rj1YkHJY= From f2f63d1dfcd528a8fb64387f38f716cb0ca7378f Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Thu, 15 Jan 2026 12:27:10 +0100 Subject: [PATCH 13/13] Refactor Custom Command Arguments Dialog (#1869) --- internal/commands/commands.go | 64 +++-- internal/ui/dialog/actions.go | 24 +- internal/ui/dialog/api_key_input.go | 8 +- internal/ui/dialog/arguments.go | 399 ++++++++++++++++++++++++++++ internal/ui/dialog/commands.go | 106 +++++--- internal/ui/dialog/dialog.go | 27 +- internal/ui/model/ui.go | 133 ++++++++-- internal/ui/styles/styles.go | 20 ++ 8 files changed, 684 insertions(+), 97 deletions(-) create mode 100644 internal/ui/dialog/arguments.go diff --git a/internal/commands/commands.go b/internal/commands/commands.go index 169b789abd224b774592032c02a5156b91efb3a5..b3fd3915182fa293aefc1fe60ec54e5b369fa591 100644 --- a/internal/commands/commands.go +++ b/internal/commands/commands.go @@ -1,6 +1,7 @@ package commands import ( + "context" "io/fs" "os" "path/filepath" @@ -19,18 +20,22 @@ const ( projectCommandPrefix = "project:" ) -// Argument represents a command argument with its name and required status. +// Argument represents a command argument with its metadata. type Argument struct { - Name string - Required bool + ID string + Title string + Description string + Required bool } -// MCPCustomCommand represents a custom command loaded from an MCP server. -type MCPCustomCommand struct { - ID string - Name string - Client string - Arguments []Argument +// MCPPrompt represents a custom command loaded from an MCP server. +type MCPPrompt struct { + ID string + Title string + Description string + PromptID string + ClientID string + Arguments []Argument } // CustomCommand represents a user-defined custom command loaded from markdown files. @@ -52,22 +57,32 @@ func LoadCustomCommands(cfg *config.Config) ([]CustomCommand, error) { return loadAll(buildCommandSources(cfg)) } -// LoadMCPCustomCommands loads custom commands from available MCP servers. -func LoadMCPCustomCommands() ([]MCPCustomCommand, error) { - var commands []MCPCustomCommand +// LoadMCPPrompts loads custom commands from available MCP servers. +func LoadMCPPrompts() ([]MCPPrompt, error) { + var commands []MCPPrompt for mcpName, prompts := range mcp.Prompts() { for _, prompt := range prompts { key := mcpName + ":" + prompt.Name var args []Argument for _, arg := range prompt.Arguments { - args = append(args, Argument{Name: arg.Name, Required: arg.Required}) + title := arg.Title + if title == "" { + title = arg.Name + } + args = append(args, Argument{ + ID: arg.Name, + Title: title, + Description: arg.Description, + Required: arg.Required, + }) } - - commands = append(commands, MCPCustomCommand{ - ID: key, - Name: prompt.Name, - Client: mcpName, - Arguments: args, + commands = append(commands, MCPPrompt{ + ID: key, + Title: prompt.Title, + Description: prompt.Description, + PromptID: prompt.Name, + ClientID: mcpName, + Arguments: args, }) } } @@ -168,7 +183,7 @@ func extractArgNames(content string) []Argument { if !seen[arg] { seen[arg] = true // for normal custom commands, all args are required - args = append(args, Argument{Name: arg, Required: true}) + args = append(args, Argument{ID: arg, Title: arg, Required: true}) } } @@ -211,3 +226,12 @@ func ensureDir(path string) error { func isMarkdownFile(name string) bool { return strings.HasSuffix(strings.ToLower(name), ".md") } + +func GetMCPPrompt(clientID, promptID string, args map[string]string) (string, error) { + // TODO: we should pass the context down + result, err := mcp.GetPromptMessages(context.Background(), clientID, promptID, args) + if err != nil { + return "", err + } + return strings.Join(result, " "), nil +} diff --git a/internal/ui/dialog/actions.go b/internal/ui/dialog/actions.go index 2fe0513ec56bc70ed3ec5bbe1eb9dde365408cdf..81911f9919be6c94ac158052b4a4e9b2236342a0 100644 --- a/internal/ui/dialog/actions.go +++ b/internal/ui/dialog/actions.go @@ -51,20 +51,18 @@ type ( } // ActionRunCustomCommand is a message to run a custom command. ActionRunCustomCommand struct { - CommandID string - // Used when running a user-defined command - Content string - // Used when running a prompt from MCP - Client string - } - // ActionOpenCustomCommandArgumentsDialog is a message to open the custom command arguments dialog. - ActionOpenCustomCommandArgumentsDialog struct { - CommandID string - // Used when running a user-defined command - Content string - // Used when running a prompt from MCP - Client string + Content string Arguments []commands.Argument + Args map[string]string // Actual argument values + } + // ActionRunMCPPrompt is a message to run a custom command. + ActionRunMCPPrompt struct { + Title string + Description string + PromptID string + ClientID string + Arguments []commands.Argument + Args map[string]string // Actual argument values } ) diff --git a/internal/ui/dialog/api_key_input.go b/internal/ui/dialog/api_key_input.go index bbc3d3746b26e51103ce8545eca5fe3ebaaf977a..e28dea2b823143d176d796c8775e8024df61d0bb 100644 --- a/internal/ui/dialog/api_key_input.go +++ b/internal/ui/dialog/api_key_input.go @@ -95,7 +95,7 @@ func (m *APIKeyInput) ID() string { return APIKeyInputID } -// Update implements tea.Model. +// HandleMsg implements [Dialog]. func (m *APIKeyInput) HandleMsg(msg tea.Msg) Action { switch msg := msg.(type) { case ActionChangeAPIKeyState: @@ -149,7 +149,7 @@ func (m *APIKeyInput) HandleMsg(msg tea.Msg) Action { return nil } -// View implements tea.Model. +// Draw implements [Dialog]. func (m *APIKeyInput) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor { t := m.com.Styles @@ -239,8 +239,8 @@ func (m *APIKeyInput) inputView() string { } // Cursor returns the cursor position relative to the dialog. -func (c *APIKeyInput) Cursor() *tea.Cursor { - return InputCursor(c.com.Styles, c.input.Cursor()) +func (m *APIKeyInput) Cursor() *tea.Cursor { + return InputCursor(m.com.Styles, m.input.Cursor()) } // FullHelp returns the full help view. diff --git a/internal/ui/dialog/arguments.go b/internal/ui/dialog/arguments.go new file mode 100644 index 0000000000000000000000000000000000000000..c016b7de6ec77e6e333d2b0f18ae5930ba0912fc --- /dev/null +++ b/internal/ui/dialog/arguments.go @@ -0,0 +1,399 @@ +package dialog + +import ( + "strings" + + "charm.land/bubbles/v2/help" + "charm.land/bubbles/v2/key" + "charm.land/bubbles/v2/spinner" + "charm.land/bubbles/v2/textinput" + "charm.land/bubbles/v2/viewport" + tea "charm.land/bubbletea/v2" + "charm.land/lipgloss/v2" + "golang.org/x/text/cases" + "golang.org/x/text/language" + + "github.com/charmbracelet/crush/internal/commands" + "github.com/charmbracelet/crush/internal/ui/common" + "github.com/charmbracelet/crush/internal/uiutil" + uv "github.com/charmbracelet/ultraviolet" +) + +// ArgumentsID is the identifier for the arguments dialog. +const ArgumentsID = "arguments" + +// Dialog sizing for arguments. +const ( + maxInputWidth = 120 + minInputWidth = 30 + maxViewportHeight = 20 + argumentsFieldHeight = 3 // label + input + spacing per field +) + +// Arguments represents a dialog for collecting command arguments. +type Arguments struct { + com *common.Common + title string + arguments []commands.Argument + inputs []textinput.Model + focused int + spinner spinner.Model + loading bool + + description string + resultAction Action + + help help.Model + keyMap struct { + Confirm, + Next, + Previous, + ScrollUp, + ScrollDown, + Close key.Binding + } + + viewport viewport.Model +} + +var _ Dialog = (*Arguments)(nil) + +// NewArguments creates a new arguments dialog. +func NewArguments(com *common.Common, title, description string, arguments []commands.Argument, resultAction Action) *Arguments { + a := &Arguments{ + com: com, + title: title, + description: description, + arguments: arguments, + resultAction: resultAction, + } + + a.help = help.New() + a.help.Styles = com.Styles.DialogHelpStyles() + + a.keyMap.Confirm = key.NewBinding( + key.WithKeys("enter"), + key.WithHelp("enter", "confirm"), + ) + a.keyMap.Next = key.NewBinding( + key.WithKeys("down", "tab"), + key.WithHelp("↓/tab", "next"), + ) + a.keyMap.Previous = key.NewBinding( + key.WithKeys("up", "shift+tab"), + key.WithHelp("↑/shift+tab", "previous"), + ) + a.keyMap.Close = CloseKey + + // Create input fields for each argument. + a.inputs = make([]textinput.Model, len(arguments)) + for i, arg := range arguments { + input := textinput.New() + input.SetVirtualCursor(false) + input.SetStyles(com.Styles.TextInput) + input.Prompt = "> " + // Use description as placeholder if available, otherwise title + if arg.Description != "" { + input.Placeholder = arg.Description + } else { + input.Placeholder = arg.Title + } + + if i == 0 { + input.Focus() + } else { + input.Blur() + } + + a.inputs[i] = input + } + s := spinner.New() + s.Spinner = spinner.Dot + s.Style = com.Styles.Dialog.Spinner + a.spinner = s + + return a +} + +// ID implements Dialog. +func (a *Arguments) ID() string { + return ArgumentsID +} + +// focusInput changes focus to a new input by index with wrap-around. +func (a *Arguments) focusInput(newIndex int) { + a.inputs[a.focused].Blur() + + // Wrap around: Go's modulo can return negative, so add len first. + n := len(a.inputs) + a.focused = ((newIndex % n) + n) % n + + a.inputs[a.focused].Focus() + + // Ensure the newly focused field is visible in the viewport + a.ensureFieldVisible(a.focused) +} + +// isFieldVisible checks if a field at the given index is visible in the viewport. +func (a *Arguments) isFieldVisible(fieldIndex int) bool { + fieldStart := fieldIndex * argumentsFieldHeight + fieldEnd := fieldStart + argumentsFieldHeight - 1 + viewportTop := a.viewport.YOffset() + viewportBottom := viewportTop + a.viewport.Height() - 1 + + return fieldStart >= viewportTop && fieldEnd <= viewportBottom +} + +// ensureFieldVisible scrolls the viewport to make the field visible. +func (a *Arguments) ensureFieldVisible(fieldIndex int) { + if a.isFieldVisible(fieldIndex) { + return + } + + fieldStart := fieldIndex * argumentsFieldHeight + fieldEnd := fieldStart + argumentsFieldHeight - 1 + viewportTop := a.viewport.YOffset() + viewportHeight := a.viewport.Height() + + // If field is above viewport, scroll up to show it at top + if fieldStart < viewportTop { + a.viewport.SetYOffset(fieldStart) + return + } + + // If field is below viewport, scroll down to show it at bottom + if fieldEnd > viewportTop+viewportHeight-1 { + a.viewport.SetYOffset(fieldEnd - viewportHeight + 1) + } +} + +// findVisibleFieldByOffset returns the field index closest to the given viewport offset. +func (a *Arguments) findVisibleFieldByOffset(fromTop bool) int { + offset := a.viewport.YOffset() + if !fromTop { + offset += a.viewport.Height() - 1 + } + + fieldIndex := offset / argumentsFieldHeight + if fieldIndex >= len(a.inputs) { + return len(a.inputs) - 1 + } + return fieldIndex +} + +// HandleMsg implements Dialog. +func (a *Arguments) HandleMsg(msg tea.Msg) Action { + switch msg := msg.(type) { + case spinner.TickMsg: + if a.loading { + var cmd tea.Cmd + a.spinner, cmd = a.spinner.Update(msg) + return ActionCmd{Cmd: cmd} + } + case tea.KeyPressMsg: + switch { + case key.Matches(msg, a.keyMap.Close): + return ActionClose{} + case key.Matches(msg, a.keyMap.Confirm): + // If we're on the last input or there's only one input, submit. + if a.focused == len(a.inputs)-1 || len(a.inputs) == 1 { + args := make(map[string]string) + var warning tea.Cmd + for i, arg := range a.arguments { + args[arg.ID] = a.inputs[i].Value() + if arg.Required && strings.TrimSpace(a.inputs[i].Value()) == "" { + warning = uiutil.ReportWarn("Required argument '" + arg.Title + "' is missing.") + break + } + } + if warning != nil { + return ActionCmd{Cmd: warning} + } + + switch action := a.resultAction.(type) { + case ActionRunCustomCommand: + action.Args = args + return action + case ActionRunMCPPrompt: + action.Args = args + return action + } + } + a.focusInput(a.focused + 1) + case key.Matches(msg, a.keyMap.Next): + a.focusInput(a.focused + 1) + case key.Matches(msg, a.keyMap.Previous): + a.focusInput(a.focused - 1) + default: + var cmd tea.Cmd + a.inputs[a.focused], cmd = a.inputs[a.focused].Update(msg) + return ActionCmd{Cmd: cmd} + } + case tea.MouseWheelMsg: + a.viewport, _ = a.viewport.Update(msg) + // If focused field scrolled out of view, focus the visible field + if !a.isFieldVisible(a.focused) { + a.focusInput(a.findVisibleFieldByOffset(msg.Button == tea.MouseWheelDown)) + } + case tea.PasteMsg: + var cmd tea.Cmd + a.inputs[a.focused], cmd = a.inputs[a.focused].Update(msg) + return ActionCmd{Cmd: cmd} + } + return nil +} + +// Cursor returns the cursor position relative to the dialog. +// we pass the description height to offset the cursor correctly. +func (a *Arguments) Cursor(descriptionHeight int) *tea.Cursor { + cursor := InputCursor(a.com.Styles, a.inputs[a.focused].Cursor()) + if cursor == nil { + return nil + } + cursor.Y += descriptionHeight + a.focused*argumentsFieldHeight - a.viewport.YOffset() + 1 + return cursor +} + +// Draw implements Dialog. +func (a *Arguments) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor { + s := a.com.Styles + + dialogContentStyle := s.Dialog.Arguments.Content + possibleWidth := area.Dx() - s.Dialog.View.GetHorizontalFrameSize() - dialogContentStyle.GetHorizontalFrameSize() + // Build fields with label and input. + caser := cases.Title(language.English) + + var fields []string + for i, arg := range a.arguments { + isFocused := i == a.focused + + // Try to pretty up the title for the label. + title := strings.ReplaceAll(arg.Title, "_", " ") + title = strings.ReplaceAll(title, "-", " ") + titleParts := strings.Fields(title) + for i, part := range titleParts { + titleParts[i] = caser.String(strings.ToLower(part)) + } + labelText := strings.Join(titleParts, " ") + + markRequiredStyle := s.Dialog.Arguments.InputRequiredMarkBlurred + + labelStyle := s.Dialog.Arguments.InputLabelBlurred + if isFocused { + labelStyle = s.Dialog.Arguments.InputLabelFocused + markRequiredStyle = s.Dialog.Arguments.InputRequiredMarkFocused + } + if arg.Required { + labelText += markRequiredStyle.String() + } + label := labelStyle.Render(labelText) + + labelWidth := lipgloss.Width(labelText) + placeholderWidth := lipgloss.Width(a.inputs[i].Placeholder) + + inputWidth := max(placeholderWidth, labelWidth, minInputWidth) + inputWidth = min(inputWidth, min(possibleWidth, maxInputWidth)) + a.inputs[i].SetWidth(inputWidth) + + inputLine := a.inputs[i].View() + + field := lipgloss.JoinVertical(lipgloss.Left, label, inputLine, "") + fields = append(fields, field) + } + + renderedFields := lipgloss.JoinVertical(lipgloss.Left, fields...) + + // Anchor width to the longest field, capped at maxInputWidth. + const scrollbarWidth = 1 + width := lipgloss.Width(renderedFields) + height := lipgloss.Height(renderedFields) + + // Use standard header + titleStyle := s.Dialog.Title + + titleText := a.title + if titleText == "" { + titleText = "Arguments" + } + + header := common.DialogTitle(s, titleText, width) + + // Add description if available. + var description string + if a.description != "" { + descStyle := s.Dialog.Arguments.Description.Width(width) + description = descStyle.Render(a.description) + } + + helpView := s.Dialog.HelpView.Width(width).Render(a.help.View(a)) + if a.loading { + helpView = s.Dialog.HelpView.Width(width).Render(a.spinner.View() + " Generating Prompt...") + } + + availableHeight := area.Dy() - s.Dialog.View.GetVerticalFrameSize() - dialogContentStyle.GetVerticalFrameSize() - lipgloss.Height(header) - lipgloss.Height(description) - lipgloss.Height(helpView) - 2 // extra spacing + viewportHeight := min(height, maxViewportHeight, availableHeight) + + a.viewport.SetWidth(width) // -1 for scrollbar + a.viewport.SetHeight(viewportHeight) + a.viewport.SetContent(renderedFields) + + scrollbar := common.Scrollbar(s, viewportHeight, a.viewport.TotalLineCount(), viewportHeight, a.viewport.YOffset()) + content := a.viewport.View() + if scrollbar != "" { + content = lipgloss.JoinHorizontal(lipgloss.Top, content, scrollbar) + } + contentParts := []string{} + if description != "" { + contentParts = append(contentParts, description) + } + contentParts = append(contentParts, content) + + view := lipgloss.JoinVertical( + lipgloss.Left, + titleStyle.Render(header), + dialogContentStyle.Render(lipgloss.JoinVertical(lipgloss.Left, contentParts...)), + helpView, + ) + + dialog := s.Dialog.View.Render(view) + + descriptionHeight := 0 + if a.description != "" { + descriptionHeight = lipgloss.Height(description) + } + cur := a.Cursor(descriptionHeight) + + DrawCenterCursor(scr, area, dialog, cur) + return cur +} + +// StartLoading implements [LoadingDialog]. +func (a *Arguments) StartLoading() tea.Cmd { + if a.loading { + return nil + } + a.loading = true + return a.spinner.Tick +} + +// StopLoading implements [LoadingDialog]. +func (a *Arguments) StopLoading() { + a.loading = false +} + +// ShortHelp implements help.KeyMap. +func (a *Arguments) ShortHelp() []key.Binding { + return []key.Binding{ + a.keyMap.Confirm, + a.keyMap.Next, + a.keyMap.Close, + } +} + +// FullHelp implements help.KeyMap. +func (a *Arguments) FullHelp() [][]key.Binding { + return [][]key.Binding{ + {a.keyMap.Confirm, a.keyMap.Next, a.keyMap.Previous}, + {a.keyMap.Close}, + } +} diff --git a/internal/ui/dialog/commands.go b/internal/ui/dialog/commands.go index 03707a54775992992a36e90e6857b0f55ce3c8e3..a6861a5c87707d7c0717ec4d3c50c1d995a528af 100644 --- a/internal/ui/dialog/commands.go +++ b/internal/ui/dialog/commands.go @@ -6,6 +6,7 @@ import ( "charm.land/bubbles/v2/help" "charm.land/bubbles/v2/key" + "charm.land/bubbles/v2/spinner" "charm.land/bubbles/v2/textinput" tea "charm.land/bubbletea/v2" "charm.land/lipgloss/v2" @@ -52,26 +53,29 @@ type Commands struct { sessionID string // can be empty for non-session-specific commands selected CommandType + spinner spinner.Model + loading bool + help help.Model input textinput.Model list *list.FilterableList windowWidth int - customCommands []commands.CustomCommand - mcpCustomCommands []commands.MCPCustomCommand + customCommands []commands.CustomCommand + mcpPrompts []commands.MCPPrompt } var _ Dialog = (*Commands)(nil) // NewCommands creates a new commands dialog. -func NewCommands(com *common.Common, sessionID string, customCommands []commands.CustomCommand, mcpCustomCommands []commands.MCPCustomCommand) (*Commands, error) { +func NewCommands(com *common.Common, sessionID string, customCommands []commands.CustomCommand, mcpPrompts []commands.MCPPrompt) (*Commands, error) { c := &Commands{ - com: com, - selected: SystemCommands, - sessionID: sessionID, - customCommands: customCommands, - mcpCustomCommands: mcpCustomCommands, + com: com, + selected: SystemCommands, + sessionID: sessionID, + customCommands: customCommands, + mcpPrompts: mcpPrompts, } help := help.New() @@ -120,6 +124,11 @@ func NewCommands(com *common.Common, sessionID string, customCommands []commands // Set initial commands c.setCommandItems(c.selected) + s := spinner.New() + s.Spinner = spinner.Dot + s.Style = com.Styles.Dialog.Spinner + c.spinner = s + return c, nil } @@ -128,9 +137,15 @@ func (c *Commands) ID() string { return CommandsID } -// HandleMsg implements Dialog. +// HandleMsg implements [Dialog]. func (c *Commands) HandleMsg(msg tea.Msg) Action { switch msg := msg.(type) { + case spinner.TickMsg: + if c.loading { + var cmd tea.Cmd + c.spinner, cmd = c.spinner.Update(msg) + return ActionCmd{Cmd: cmd} + } case tea.KeyPressMsg: switch { case key.Matches(msg, c.keyMap.Close): @@ -160,12 +175,12 @@ func (c *Commands) HandleMsg(msg tea.Msg) Action { } } case key.Matches(msg, c.keyMap.Tab): - if len(c.customCommands) > 0 || len(c.mcpCustomCommands) > 0 { + if len(c.customCommands) > 0 || len(c.mcpPrompts) > 0 { c.selected = c.nextCommandType() c.setCommandItems(c.selected) } case key.Matches(msg, c.keyMap.ShiftTab): - if len(c.customCommands) > 0 || len(c.mcpCustomCommands) > 0 { + if len(c.customCommands) > 0 || len(c.mcpPrompts) > 0 { c.selected = c.previousCommandType() c.setCommandItems(c.selected) } @@ -242,12 +257,16 @@ func (c *Commands) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor { c.list.SetSize(innerWidth, height-heightOffset) c.help.SetWidth(innerWidth) - radio := commandsRadioView(t, c.selected, len(c.customCommands) > 0, len(c.mcpCustomCommands) > 0) + radio := commandsRadioView(t, c.selected, len(c.customCommands) > 0, len(c.mcpPrompts) > 0) titleStyle := t.Dialog.Title dialogStyle := t.Dialog.View.Width(width) headerOffset := lipgloss.Width(radio) + titleStyle.GetHorizontalFrameSize() + dialogStyle.GetHorizontalFrameSize() helpView := ansi.Truncate(c.help.View(c), innerWidth, "") header := common.DialogTitle(t, "Commands", width-headerOffset) + radio + + if c.loading { + helpView = t.Dialog.HelpView.Width(width).Render(c.spinner.View() + " Generating Prompt...") + } view := HeaderInputListHelpView(t, width, c.list.Height(), header, c.input.View(), c.list.Render(), helpView) @@ -281,12 +300,12 @@ func (c *Commands) nextCommandType() CommandType { if len(c.customCommands) > 0 { return UserCommands } - if len(c.mcpCustomCommands) > 0 { + if len(c.mcpPrompts) > 0 { return MCPPrompts } fallthrough case UserCommands: - if len(c.mcpCustomCommands) > 0 { + if len(c.mcpPrompts) > 0 { return MCPPrompts } fallthrough @@ -301,7 +320,7 @@ func (c *Commands) nextCommandType() CommandType { func (c *Commands) previousCommandType() CommandType { switch c.selected { case SystemCommands: - if len(c.mcpCustomCommands) > 0 { + if len(c.mcpPrompts) > 0 { return MCPPrompts } if len(c.customCommands) > 0 { @@ -332,37 +351,22 @@ func (c *Commands) setCommandItems(commandType CommandType) { } case UserCommands: for _, cmd := range c.customCommands { - var action Action - if len(cmd.Arguments) > 0 { - action = ActionOpenCustomCommandArgumentsDialog{ - CommandID: cmd.ID, - Content: cmd.Content, - Arguments: cmd.Arguments, - } - } else { - action = ActionRunCustomCommand{ - CommandID: cmd.ID, - Content: cmd.Content, - } + action := ActionRunCustomCommand{ + Content: cmd.Content, + Arguments: cmd.Arguments, } commandItems = append(commandItems, NewCommandItem(c.com.Styles, "custom_"+cmd.ID, cmd.Name, "", action)) } case MCPPrompts: - for _, cmd := range c.mcpCustomCommands { - var action Action - if len(cmd.Arguments) > 0 { - action = ActionOpenCustomCommandArgumentsDialog{ - CommandID: cmd.ID, - Client: cmd.Client, - Arguments: cmd.Arguments, - } - } else { - action = ActionRunCustomCommand{ - CommandID: cmd.ID, - Client: cmd.Client, - } + for _, cmd := range c.mcpPrompts { + action := ActionRunMCPPrompt{ + Title: cmd.Title, + Description: cmd.Description, + PromptID: cmd.PromptID, + ClientID: cmd.ClientID, + Arguments: cmd.Arguments, } - commandItems = append(commandItems, NewCommandItem(c.com.Styles, "mcp_"+cmd.ID, cmd.Name, "", action)) + commandItems = append(commandItems, NewCommandItem(c.com.Styles, "mcp_"+cmd.ID, cmd.PromptID, "", action)) } } @@ -448,10 +452,24 @@ func (c *Commands) SetCustomCommands(customCommands []commands.CustomCommand) { } } -// SetMCPCustomCommands sets the MCP custom commands and refreshes the view if MCP prompts are currently displayed. -func (c *Commands) SetMCPCustomCommands(mcpCustomCommands []commands.MCPCustomCommand) { - c.mcpCustomCommands = mcpCustomCommands +// SetMCPPrompts sets the MCP prompts and refreshes the view if MCP prompts are currently displayed. +func (c *Commands) SetMCPPrompts(mcpPrompts []commands.MCPPrompt) { + c.mcpPrompts = mcpPrompts if c.selected == MCPPrompts { c.setCommandItems(c.selected) } } + +// StartLoading implements [LoadingDialog]. +func (a *Commands) StartLoading() tea.Cmd { + if a.loading { + return nil + } + a.loading = true + return a.spinner.Tick +} + +// StopLoading implements [LoadingDialog]. +func (a *Commands) StopLoading() { + a.loading = false +} diff --git a/internal/ui/dialog/dialog.go b/internal/ui/dialog/dialog.go index 68eb313d4ec83cf8d098fcfccb5ebf27de8bd0d1..7a3db40128fb1e5543a94a93faa4ae9aeec5f947 100644 --- a/internal/ui/dialog/dialog.go +++ b/internal/ui/dialog/dialog.go @@ -27,7 +27,7 @@ var CloseKey = key.NewBinding( ) // Action represents an action taken in a dialog after handling a message. -type Action interface{} +type Action any // Dialog is a component that can be displayed on top of the UI. type Dialog interface { @@ -41,6 +41,12 @@ type Dialog interface { Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor } +// LoadingDialog is a dialog that can show a loading state. +type LoadingDialog interface { + StartLoading() tea.Cmd + StopLoading() +} + // Overlay manages multiple dialogs as an overlay. type Overlay struct { dialogs []Dialog @@ -136,6 +142,25 @@ func (d *Overlay) Update(msg tea.Msg) tea.Msg { return dialog.HandleMsg(msg) } +// StartLoading starts the loading state for the front dialog if it +// implements [LoadingDialog]. +func (d *Overlay) StartLoading() tea.Cmd { + dialog := d.DialogLast() + if ld, ok := dialog.(LoadingDialog); ok { + return ld.StartLoading() + } + return nil +} + +// StopLoading stops the loading state for the front dialog if it +// implements [LoadingDialog]. +func (d *Overlay) StopLoading() { + dialog := d.DialogLast() + if ld, ok := dialog.(LoadingDialog); ok { + ld.StopLoading() + } +} + // DrawCenterCursor draws the given string view centered in the screen area and // adjusts the cursor position accordingly. func DrawCenterCursor(scr uv.Screen, area uv.Rectangle, view string, cur *tea.Cursor) { diff --git a/internal/ui/model/ui.go b/internal/ui/model/ui.go index 6a92f5bb9c7c4f856cd83b38272a63120fea929f..c3c6edebbccab27b1072b9376e3c5169d3137894 100644 --- a/internal/ui/model/ui.go +++ b/internal/ui/model/ui.go @@ -18,6 +18,7 @@ import ( "charm.land/bubbles/v2/help" "charm.land/bubbles/v2/key" + "charm.land/bubbles/v2/spinner" "charm.land/bubbles/v2/textarea" tea "charm.land/bubbletea/v2" "charm.land/lipgloss/v2" @@ -93,10 +94,19 @@ type ( userCommandsLoadedMsg struct { Commands []commands.CustomCommand } - // mcpCustomCommandsLoadedMsg is sent when mcp prompts are loaded. - mcpCustomCommandsLoadedMsg struct { - Prompts []commands.MCPCustomCommand + // mcpPromptsLoadedMsg is sent when mcp prompts are loaded. + mcpPromptsLoadedMsg struct { + Prompts []commands.MCPPrompt } + // sendMessageMsg is sent to send a message. + // currently only used for mcp prompts. + sendMessageMsg struct { + Content string + Attachments []message.Attachment + } + + // closeDialogMsg is sent to close the current dialog. + closeDialogMsg struct{} ) // UI represents the main user interface model. @@ -167,8 +177,8 @@ type UI struct { sidebarLogo string // custom commands & mcp commands - customCommands []commands.CustomCommand - mcpCustomCommands []commands.MCPCustomCommand + customCommands []commands.CustomCommand + mcpPrompts []commands.MCPPrompt // forceCompactMode tracks whether compact mode is forced by user toggle forceCompactMode bool @@ -282,15 +292,15 @@ func (m *UI) loadCustomCommands() tea.Cmd { // loadMCPrompts loads the MCP prompts asynchronously. func (m *UI) loadMCPrompts() tea.Cmd { return func() tea.Msg { - prompts, err := commands.LoadMCPCustomCommands() + prompts, err := commands.LoadMCPPrompts() if err != nil { slog.Error("failed to load mcp prompts", "error", err) } if prompts == nil { // flag them as loaded even if there is none or an error - prompts = []commands.MCPCustomCommand{} + prompts = []commands.MCPPrompt{} } - return mcpCustomCommandsLoadedMsg{Prompts: prompts} + return mcpPromptsLoadedMsg{Prompts: prompts} } } @@ -319,6 +329,9 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) { cmds = append(cmds, cmd) } + case sendMessageMsg: + cmds = append(cmds, m.sendMessage(msg.Content, msg.Attachments...)) + case userCommandsLoadedMsg: m.customCommands = msg.Commands dia := m.dialog.Dialog(dialog.CommandsID) @@ -330,8 +343,8 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if ok { commands.SetCustomCommands(m.customCommands) } - case mcpCustomCommandsLoadedMsg: - m.mcpCustomCommands = msg.Prompts + case mcpPromptsLoadedMsg: + m.mcpPrompts = msg.Prompts dia := m.dialog.Dialog(dialog.CommandsID) if dia == nil { break @@ -339,9 +352,12 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) { commands, ok := dia.(*dialog.Commands) if ok { - commands.SetMCPCustomCommands(m.mcpCustomCommands) + commands.SetMCPPrompts(m.mcpPrompts) } + case closeDialogMsg: + m.dialog.CloseFrontDialog() + case pubsub.Event[message.Message]: // Check if this is a child session message for an agent tool. if m.session == nil { @@ -374,7 +390,7 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) { break } } - if initialized && m.mcpCustomCommands == nil { + if initialized && m.mcpPrompts == nil { cmds = append(cmds, m.loadMCPrompts()) } case pubsub.Event[permission.PermissionRequest]: @@ -492,6 +508,14 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) { cmds = append(cmds, cmd) } } + case spinner.TickMsg: + if m.dialog.HasDialogs() { + // route to dialog + if cmd := m.handleDialogMsg(msg); cmd != nil { + cmds = append(cmds, cmd) + } + } + case tea.KeyPressMsg: if cmd := m.handleKeyPressMsg(msg); cmd != nil { cmds = append(cmds, cmd) @@ -645,6 +669,11 @@ func (m *UI) loadNestedToolCalls(items []chat.MessageItem) { // if the message is a tool result it will update the corresponding tool call message func (m *UI) appendSessionMessage(msg message.Message) tea.Cmd { var cmds []tea.Cmd + existing := m.chat.MessageItem(msg.ID) + if existing != nil { + // message already exists, skip + return nil + } switch msg.Role { case message.User, message.Assistant: items := chat.ExtractMessageItems(m.com.Styles, &msg, nil) @@ -920,6 +949,44 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { case dialog.PermissionDeny: m.com.App.Permissions.Deny(msg.Permission) } + + case dialog.ActionRunCustomCommand: + if len(msg.Arguments) > 0 && msg.Args == nil { + m.dialog.CloseFrontDialog() + argsDialog := dialog.NewArguments( + m.com, + "Custom Command Arguments", + "", + msg.Arguments, + msg, // Pass the action as the result + ) + m.dialog.OpenDialog(argsDialog) + break + } + content := msg.Content + if msg.Args != nil { + content = substituteArgs(content, msg.Args) + } + cmds = append(cmds, m.sendMessage(content)) + m.dialog.CloseFrontDialog() + case dialog.ActionRunMCPPrompt: + if len(msg.Arguments) > 0 && msg.Args == nil { + m.dialog.CloseFrontDialog() + title := msg.Title + if title == "" { + title = "MCP Prompt Arguments" + } + argsDialog := dialog.NewArguments( + m.com, + title, + msg.Description, + msg.Arguments, + msg, // Pass the action as the result + ) + m.dialog.OpenDialog(argsDialog) + break + } + cmds = append(cmds, m.runMCPPrompt(msg.ClientID, msg.PromptID, msg.Args)) default: cmds = append(cmds, uiutil.CmdHandler(msg)) } @@ -927,6 +994,15 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { return tea.Batch(cmds...) } +// substituteArgs replaces $ARG_NAME placeholders in content with actual values. +func substituteArgs(content string, args map[string]string) string { + for name, value := range args { + placeholder := "$" + name + content = strings.ReplaceAll(content, placeholder, value) + } + return content +} + // openAPIKeyInputDialog opens the API key input dialog. func (m *UI) openAPIKeyInputDialog(provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType) tea.Cmd { if m.dialog.ContainsDialog(dialog.APIKeyInputID) { @@ -1055,7 +1131,7 @@ func (m *UI) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd { m.randomizePlaceholders() - return m.sendMessage(value, attachments) + return m.sendMessage(value, attachments...) case key.Matches(msg, m.keyMap.Chat.NewSession): if m.session == nil || m.session.ID == "" { break @@ -2013,7 +2089,7 @@ func (m *UI) renderSidebarLogo(width int) { } // sendMessage sends a message with the given content and attachments. -func (m *UI) sendMessage(content string, attachments []message.Attachment) tea.Cmd { +func (m *UI) sendMessage(content string, attachments ...message.Attachment) tea.Cmd { if m.com.App.AgentCoordinator == nil { return uiutil.ReportError(fmt.Errorf("coder agent is not initialized")) } @@ -2165,7 +2241,7 @@ func (m *UI) openCommandsDialog() tea.Cmd { sessionID = m.session.ID } - commands, err := dialog.NewCommands(m.com, sessionID, m.customCommands, m.mcpCustomCommands) + commands, err := dialog.NewCommands(m.com, sessionID, m.customCommands, m.mcpPrompts) if err != nil { return uiutil.ReportError(err) } @@ -2393,6 +2469,33 @@ func (m *UI) drawSessionDetails(scr uv.Screen, area uv.Rectangle) { ).Draw(scr, area) } +func (m *UI) runMCPPrompt(clientID, promptID string, arguments map[string]string) tea.Cmd { + load := func() tea.Msg { + prompt, err := commands.GetMCPPrompt(clientID, promptID, arguments) + if err != nil { + // TODO: make this better + return uiutil.ReportError(err)() + } + + if prompt == "" { + return nil + } + return sendMessageMsg{ + Content: prompt, + } + } + + var cmds []tea.Cmd + if cmd := m.dialog.StartLoading(); cmd != nil { + cmds = append(cmds, cmd) + } + cmds = append(cmds, load, func() tea.Msg { + return closeDialogMsg{} + }) + + return tea.Sequence(cmds...) +} + // renderLogo renders the Crush logo with the given styles and dimensions. func renderLogo(t *styles.Styles, compact bool, width int) string { return logo.Render(version.Version, compact, logo.Opts{ diff --git a/internal/ui/styles/styles.go b/internal/ui/styles/styles.go index 442e3f78a449baae2c99868ae9434d69debce40e..878ed83eaf7c0eaaa490dc11546a72f0a9a8a539 100644 --- a/internal/ui/styles/styles.go +++ b/internal/ui/styles/styles.go @@ -333,6 +333,8 @@ type Styles struct { List lipgloss.Style + Spinner lipgloss.Style + // ContentPanel is used for content blocks with subtle background. ContentPanel lipgloss.Style @@ -340,6 +342,16 @@ type Styles struct { ScrollbarThumb lipgloss.Style ScrollbarTrack lipgloss.Style + // Arguments + Arguments struct { + Content lipgloss.Style + Description lipgloss.Style + InputLabelBlurred lipgloss.Style + InputLabelFocused lipgloss.Style + InputRequiredMarkBlurred lipgloss.Style + InputRequiredMarkFocused lipgloss.Style + } + Commands struct{} } @@ -1205,9 +1217,17 @@ func DefaultStyles() Styles { s.Dialog.List = base.Margin(0, 0, 1, 0) s.Dialog.ContentPanel = base.Background(bgSubtle).Foreground(fgBase).Padding(1, 2) + s.Dialog.Spinner = base.Foreground(secondary) s.Dialog.ScrollbarThumb = base.Foreground(secondary) s.Dialog.ScrollbarTrack = base.Foreground(border) + s.Dialog.Arguments.Content = base.Padding(1) + s.Dialog.Arguments.Description = base.MarginBottom(1).MaxHeight(3) + s.Dialog.Arguments.InputLabelBlurred = base.Foreground(fgMuted) + s.Dialog.Arguments.InputLabelFocused = base.Bold(true) + s.Dialog.Arguments.InputRequiredMarkBlurred = base.Foreground(fgMuted).SetString("*") + s.Dialog.Arguments.InputRequiredMarkFocused = base.Foreground(primary).Bold(true).SetString("*") + s.Status.Help = lipgloss.NewStyle().Padding(0, 1) s.Status.SuccessIndicator = base.Foreground(bgSubtle).Background(green).Padding(0, 1).Bold(true).SetString("OKAY!") s.Status.InfoIndicator = s.Status.SuccessIndicator