diff --git a/.github/cla-signatures.json b/.github/cla-signatures.json index cf21b7c02c3ecb20d01ac8250cee76e2727b81b2..5929987f916594da1109eee2082c154620edf660 100644 --- a/.github/cla-signatures.json +++ b/.github/cla-signatures.json @@ -1055,6 +1055,22 @@ "created_at": "2026-01-12T22:16:05Z", "repoId": 987670088, "pullRequestNo": 1841 + }, + { + "name": "kuxoapp", + "id": 254052994, + "comment_id": 3747622477, + "created_at": "2026-01-14T04:18:44Z", + "repoId": 987670088, + "pullRequestNo": 1864 + }, + { + "name": "mhpenta", + "id": 183146177, + "comment_id": 3749703014, + "created_at": "2026-01-14T14:02:04Z", + "repoId": 987670088, + "pullRequestNo": 1870 } ] } \ No newline at end of file diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 3511f7fe0c4f487eb3fc9009795361ada8e2eff7..39b5923298e2f7fa8d5452327a6e8b2a08f0df97 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,11 +1,27 @@ name: build on: [push, pull_request] +permissions: + contents: read + +concurrency: + group: build-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: build: - uses: charmbracelet/meta/.github/workflows/build.yml@main - with: - go-version: "" - go-version-file: ./go.mod - secrets: - gh_pat: "${{ secrets.PERSONAL_ACCESS_TOKEN }}" + strategy: + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + with: + persist-credentials: false + - uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # v6.1.0 + with: + go-version-file: go.mod + - run: go mod tidy + - run: git diff --exit-code + - run: go build -race ./... + - run: go test -race -failfast ./... diff --git a/.github/workflows/security.yml b/.github/workflows/security.yml new file mode 100644 index 0000000000000000000000000000000000000000..3a90ea316c3d86f5b2f93224fd2b35eaa572e704 --- /dev/null +++ b/.github/workflows/security.yml @@ -0,0 +1,92 @@ +name: "security" + +on: + pull_request: + push: + branches: [main] + schedule: + - cron: "0 2 * * *" + +permissions: + contents: read + +concurrency: + group: security-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + codeql: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + language: ["go", "actions"] + permissions: + actions: read + contents: read + pull-requests: read + security-events: write + steps: + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + with: + persist-credentials: false + - uses: github/codeql-action/init@cf1bb45a277cb3c205638b2cd5c984db1c46a412 # v4.31.7 + with: + languages: ${{ matrix.language }} + - uses: github/codeql-action/autobuild@cf1bb45a277cb3c205638b2cd5c984db1c46a412 # v4.31.7 + - uses: github/codeql-action/analyze@cf1bb45a277cb3c205638b2cd5c984db1c46a412 # v4.31.7 + + grype: + runs-on: ubuntu-latest + permissions: + security-events: write + actions: read + contents: read + steps: + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + with: + persist-credentials: false + - uses: anchore/scan-action@40a61b52209e9d50e87917c5b901783d546b12d0 # v7.2.1 + id: scan + with: + path: "." + fail-build: true + severity-cutoff: critical + - uses: github/codeql-action/upload-sarif@cf1bb45a277cb3c205638b2cd5c984db1c46a412 # v4.31.7 + with: + sarif_file: ${{ steps.scan.outputs.sarif }} + + govulncheck: + runs-on: ubuntu-latest + permissions: + security-events: write + contents: read + steps: + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + with: + persist-credentials: false + - uses: actions/setup-go@7a3fe6cf4cb3a834922a1244abfce67bcef6a0c5 # v6.2.0 + with: + go-version: 1.26.0-rc.1 # change to "stable" once Go 1.26 is released + - name: Install govulncheck + run: go install golang.org/x/vuln/cmd/govulncheck@latest + - name: Run govulncheck + run: | + govulncheck -C . -format sarif ./... > results.sarif + - uses: github/codeql-action/upload-sarif@cf1bb45a277cb3c205638b2cd5c984db1c46a412 # v4.31.7 + with: + sarif_file: results.sarif + + dependency-review: + runs-on: ubuntu-latest + if: github.event_name == 'pull_request' + permissions: + contents: read + steps: + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + with: + persist-credentials: false + - uses: actions/dependency-review-action@3c4e3dcb1aa7874d2c16be7d79418e9b7efd6261 # v4.8.2 + with: + fail-on-severity: critical + allow-licenses: BSD-2-Clause, BSD-3-Clause, MIT, Apache-2.0, MPL-2.0, ISC, LicenseRef-scancode-google-patent-license-golang diff --git a/README.md b/README.md index 929b77425f1b42452a4e38d8cfa540773dd54a79..cfcd765ee150d181c00cf649a5ba15055b6bdbae 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # Crush

- Charm Crush Logo
+ Charm Crush Logo
Latest Release Build Status

diff --git a/Taskfile.yaml b/Taskfile.yaml index 68c805c599314cadde5c86fc37a0e3d1a6184f4e..0043f4f033e455a5800da2431848e620c37a0f5a 100644 --- a/Taskfile.yaml +++ b/Taskfile.yaml @@ -5,6 +5,8 @@ version: "3" vars: VERSION: sh: git describe --long 2>/dev/null || echo "" + RACE: + sh: test -f race.log && echo "1" || echo "" env: CGO_ENABLED: 0 @@ -37,20 +39,20 @@ tasks: vars: LDFLAGS: '{{if .VERSION}}-ldflags="-X github.com/charmbracelet/crush/internal/version.Version={{.VERSION}}"{{end}}' cmds: - - go build {{.LDFLAGS}} . + - "go build {{if .RACE}}-race{{end}} {{.LDFLAGS}} ." generates: - crush run: desc: Run build cmds: - - go build -o crush . - - ./crush {{.CLI_ARGS}} + - task: build + - "./crush {{.CLI_ARGS}} {{if .RACE}}2>race.log{{end}}" test: desc: Run tests cmds: - - go test ./... {{.CLI_ARGS}} + - go test -race -failfast ./... {{.CLI_ARGS}} test:record: desc: Run tests and record all VCR cassettes again diff --git a/go.mod b/go.mod index b607a70975a383b3c9ba5e2c945fcada2f27c125..0ab5bd264dfdbab157a0fe38ba5308d84a4521a2 100644 --- a/go.mod +++ b/go.mod @@ -15,10 +15,11 @@ require ( github.com/PuerkitoBio/goquery v1.11.0 github.com/alecthomas/chroma/v2 v2.22.0 github.com/atotto/clipboard v0.1.4 + github.com/aymanbagabas/go-nativeclipboard v0.1.2 github.com/aymanbagabas/go-udiff v0.3.1 github.com/bmatcuk/doublestar/v4 v4.9.2 github.com/charlievieth/fastwalk v1.0.14 - github.com/charmbracelet/catwalk v0.13.0 + github.com/charmbracelet/catwalk v0.14.1 github.com/charmbracelet/colorprofile v0.4.1 github.com/charmbracelet/fang v0.4.4 github.com/charmbracelet/ultraviolet v0.0.0-20251212194010-b927aa605560 @@ -29,7 +30,7 @@ require ( github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f github.com/charmbracelet/x/exp/ordered v0.1.0 github.com/charmbracelet/x/exp/slice v0.0.0-20251201173703-9f73bfd934ff - github.com/charmbracelet/x/powernap v0.0.0-20251015113943-25f979b54ad4 + github.com/charmbracelet/x/powernap v0.0.0-20260113142046-c1fa3de7983b github.com/charmbracelet/x/term v0.2.2 github.com/denisbrodbeck/machineid v1.0.1 github.com/disintegration/imageorient v0.0.0-20180920195336-8147d86e83ec @@ -101,12 +102,13 @@ require ( github.com/charmbracelet/x/json v0.2.0 // indirect github.com/charmbracelet/x/termios v0.1.1 // indirect github.com/charmbracelet/x/windows v0.2.2 // indirect - github.com/clipperhouse/displaywidth v0.6.1 // indirect + github.com/clipperhouse/displaywidth v0.6.2 // indirect github.com/clipperhouse/stringish v0.1.1 // indirect github.com/clipperhouse/uax29/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/disintegration/gift v1.1.2 // indirect github.com/dlclark/regexp2 v1.11.5 // indirect + github.com/ebitengine/purego v0.10.0-alpha.3.0.20260102153238-200df6041cff // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e // indirect @@ -171,7 +173,7 @@ require ( go.yaml.in/yaml/v4 v4.0.0-rc.3 // indirect golang.org/x/crypto v0.47.0 // indirect golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b // indirect - golang.org/x/image v0.27.0 // indirect + golang.org/x/image v0.34.0 // indirect golang.org/x/oauth2 v0.34.0 // indirect golang.org/x/sys v0.40.0 // indirect golang.org/x/term v0.39.0 // indirect diff --git a/go.sum b/go.sum index 70582b7c92f86af89a03d9f9a43382e27235d2ca..5a2b20e02e02085bf7f8559d946bced27a20cc27 100644 --- a/go.sum +++ b/go.sum @@ -78,6 +78,8 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.41.6 h1:5fFjR/ToSOzB2OQ/XqWpZBmNvmP/ github.com/aws/aws-sdk-go-v2/service/sts v1.41.6/go.mod h1:qgFDZQSD/Kys7nJnVqYlWKnh0SSdMjAi0uSwON4wgYQ= github.com/aws/smithy-go v1.24.0 h1:LpilSUItNPFr1eY85RYgTIg5eIEPtvFbskaFcmmIUnk= github.com/aws/smithy-go v1.24.0/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= +github.com/aymanbagabas/go-nativeclipboard v0.1.2 h1:Z2iVRWQ4IynMLWM6a+lWH2Nk5gPyEtPRMuBIyZ2dECM= +github.com/aymanbagabas/go-nativeclipboard v0.1.2/go.mod h1:BVJhN7hs5DieCzUB2Atf4Yk9Y9kFe62E95+gOjpJq6Q= github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= github.com/aymanbagabas/go-udiff v0.3.1 h1:LV+qyBQ2pqe0u42ZsUEtPiCaUoqgA9gYRDs3vj1nolY= @@ -94,8 +96,8 @@ github.com/charlievieth/fastwalk v1.0.14 h1:3Eh5uaFGwHZd8EGwTjJnSpBkfwfsak9h6ICg github.com/charlievieth/fastwalk v1.0.14/go.mod h1:diVcUreiU1aQ4/Wu3NbxxH4/KYdKpLDojrQ1Bb2KgNY= github.com/charmbracelet/anthropic-sdk-go v0.0.0-20251024181547-21d6f3d9a904 h1:rwLdEpG9wE6kL69KkEKDiWprO8pQOZHZXeod6+9K+mw= github.com/charmbracelet/anthropic-sdk-go v0.0.0-20251024181547-21d6f3d9a904/go.mod h1:8TIYxZxsuCqqeJ0lga/b91tBwrbjoHDC66Sq5t8N2R4= -github.com/charmbracelet/catwalk v0.13.0 h1:L+chddP+PJvX3Vl+hqlWW5HAwBErlkL/friQXih1JQI= -github.com/charmbracelet/catwalk v0.13.0/go.mod h1:qg+Yl9oaZTkTvRscqbxfttzOFQ4v0pOT5XwC7b5O0NQ= +github.com/charmbracelet/catwalk v0.14.1 h1:n16H880MHW8PPgQeh0dorP77AJMxw5JcOUPuC3FFhaQ= +github.com/charmbracelet/catwalk v0.14.1/go.mod h1:qg+Yl9oaZTkTvRscqbxfttzOFQ4v0pOT5XwC7b5O0NQ= github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk= github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk= github.com/charmbracelet/fang v0.4.4 h1:G4qKxF6or/eTPgmAolwPuRNyuci3hTUGGX1rj1YkHJY= @@ -118,16 +120,16 @@ github.com/charmbracelet/x/exp/slice v0.0.0-20251201173703-9f73bfd934ff h1:Uwr+/ github.com/charmbracelet/x/exp/slice v0.0.0-20251201173703-9f73bfd934ff/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA= github.com/charmbracelet/x/json v0.2.0 h1:DqB+ZGx2h+Z+1s98HOuOyli+i97wsFQIxP2ZQANTPrQ= github.com/charmbracelet/x/json v0.2.0/go.mod h1:opFIflx2YgXgi49xVUu8gEQ21teFAxyMwvOiZhIvWNM= -github.com/charmbracelet/x/powernap v0.0.0-20251015113943-25f979b54ad4 h1:i/XilBPYK4L1Yo/mc9FPx0SyJzIsN0y4sj1MWq9Sscc= -github.com/charmbracelet/x/powernap v0.0.0-20251015113943-25f979b54ad4/go.mod h1:cmdl5zlP5mR8TF2Y68UKc7hdGUDiSJ2+4hk0h04Hsx4= +github.com/charmbracelet/x/powernap v0.0.0-20260113142046-c1fa3de7983b h1:5ye9hzBKH623bMVz5auIuY6K21loCdxpRmFle2O9R/8= +github.com/charmbracelet/x/powernap v0.0.0-20260113142046-c1fa3de7983b/go.mod h1:cmdl5zlP5mR8TF2Y68UKc7hdGUDiSJ2+4hk0h04Hsx4= github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk= github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI= github.com/charmbracelet/x/termios v0.1.1 h1:o3Q2bT8eqzGnGPOYheoYS8eEleT5ZVNYNy8JawjaNZY= github.com/charmbracelet/x/termios v0.1.1/go.mod h1:rB7fnv1TgOPOyyKRJ9o+AsTU/vK5WHJ2ivHeut/Pcwo= github.com/charmbracelet/x/windows v0.2.2 h1:IofanmuvaxnKHuV04sC0eBy/smG6kIKrWG2/jYn2GuM= github.com/charmbracelet/x/windows v0.2.2/go.mod h1:/8XtdKZzedat74NQFn0NGlGL4soHB0YQZrETF96h75k= -github.com/clipperhouse/displaywidth v0.6.1 h1:/zMlAezfDzT2xy6acHBzwIfyu2ic0hgkT83UX5EY2gY= -github.com/clipperhouse/displaywidth v0.6.1/go.mod h1:R+kHuzaYWFkTm7xoMmK1lFydbci4X2CicfbGstSGg0o= +github.com/clipperhouse/displaywidth v0.6.2 h1:ZDpTkFfpHOKte4RG5O/BOyf3ysnvFswpyYrV7z2uAKo= +github.com/clipperhouse/displaywidth v0.6.2/go.mod h1:R+kHuzaYWFkTm7xoMmK1lFydbci4X2CicfbGstSGg0o= github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs= github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA= github.com/clipperhouse/uax29/v2 v2.3.0 h1:SNdx9DVUqMoBuBoW3iLOj4FQv3dN5mDtuqwuhIGpJy4= @@ -150,6 +152,8 @@ github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZ github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/ebitengine/purego v0.10.0-alpha.3.0.20260102153238-200df6041cff h1:vAcU1VsCRstZ9ty11yD/L0WDyT73S/gVfmuWvcWX5DA= +github.com/ebitengine/purego v0.10.0-alpha.3.0.20260102153238-200df6041cff/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= github.com/envoyproxy/go-control-plane v0.13.4 h1:zEqyPVyku6IvWCFwux4x9RxkLOMUL+1vC9xUFv5l2/M= github.com/envoyproxy/go-control-plane/envoy v1.32.4 h1:jb83lalDRZSpPWW2Z7Mck/8kXZ5CQAFYVjQcdVIr83A= github.com/envoyproxy/go-control-plane/envoy v1.32.4/go.mod h1:Gzjc5k8JcJswLjAx1Zm+wSYE20UrLtt7JZMWiWQXQEw= @@ -389,8 +393,8 @@ golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= -golang.org/x/image v0.27.0 h1:C8gA4oWU/tKkdCfYT6T2u4faJu3MeNS5O8UPWlPF61w= -golang.org/x/image v0.27.0/go.mod h1:xbdrClrAUway1MUTEZDq9mz/UpRwYAkFFNUslZtcB+g= +golang.org/x/image v0.34.0 h1:33gCkyw9hmwbZJeZkct8XyR11yH889EQt/QH4VmXMn8= +golang.org/x/image v0.34.0/go.mod h1:2RNFBZRB+vnwwFil8GkMdRvrJOFd1AzdZI6vOY+eJVU= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 8d2fa40fd427143bf988587ef7faa3a89c3e23b1..c0b9080bb640085c6fd0fdbde8db0fbfe7f476dd 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -87,12 +87,13 @@ type Model struct { } type sessionAgent struct { - largeModel Model - smallModel Model - systemPromptPrefix string - systemPrompt string + largeModel *csync.Value[Model] + smallModel *csync.Value[Model] + systemPromptPrefix *csync.Value[string] + systemPrompt *csync.Value[string] + tools *csync.Slice[fantasy.AgentTool] + isSubAgent bool - tools []fantasy.AgentTool sessions session.Service messages message.Service disableAutoSummarize bool @@ -119,15 +120,15 @@ func NewSessionAgent( opts SessionAgentOptions, ) SessionAgent { return &sessionAgent{ - largeModel: opts.LargeModel, - smallModel: opts.SmallModel, - systemPromptPrefix: opts.SystemPromptPrefix, - systemPrompt: opts.SystemPrompt, + largeModel: csync.NewValue(opts.LargeModel), + smallModel: csync.NewValue(opts.SmallModel), + systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix), + systemPrompt: csync.NewValue(opts.SystemPrompt), isSubAgent: opts.IsSubAgent, sessions: opts.Sessions, messages: opts.Messages, disableAutoSummarize: opts.DisableAutoSummarize, - tools: opts.Tools, + tools: csync.NewSliceFrom(opts.Tools), isYolo: opts.IsYolo, messageQueue: csync.NewMap[string, []SessionAgentCall](), activeRequests: csync.NewMap[string, context.CancelFunc](), @@ -153,15 +154,21 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy return nil, nil } - if len(a.tools) > 0 { + // Copy mutable fields under lock to avoid races with SetTools/SetModels. + agentTools := a.tools.Copy() + largeModel := a.largeModel.Get() + systemPrompt := a.systemPrompt.Get() + promptPrefix := a.systemPromptPrefix.Get() + + if len(agentTools) > 0 { // Add Anthropic caching to the last tool. - a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions()) + agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions()) } agent := fantasy.NewAgent( - a.largeModel.Model, - fantasy.WithSystemPrompt(a.systemPrompt), - fantasy.WithTools(a.tools...), + largeModel.Model, + fantasy.WithSystemPrompt(systemPrompt), + fantasy.WithTools(agentTools...), ) sessionLock := sync.Mutex{} @@ -183,6 +190,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy a.generateTitle(titleCtx, call.SessionID, call.Prompt) }) } + defer wg.Wait() // Add the user message to the session. _, err = a.createUserMessage(ctx, call) @@ -233,7 +241,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...) } - prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages) + prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel) lastSystemRoleInx := 0 systemMessageUpdated := false @@ -251,7 +259,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy } } - if promptPrefix := a.promptPrefix(); promptPrefix != "" { + if promptPrefix != "" { prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...) } @@ -259,15 +267,15 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{ Role: message.Assistant, Parts: []message.ContentPart{}, - Model: a.largeModel.ModelCfg.Model, - Provider: a.largeModel.ModelCfg.Provider, + Model: largeModel.ModelCfg.Model, + Provider: largeModel.ModelCfg.Provider, }) if err != nil { return callContext, prepared, err } callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID) - callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages) - callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name) + callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages) + callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name) currentAssistant = &assistantMsg return callContext, prepared, err }, @@ -361,7 +369,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy sessionLock.Unlock() return getSessionErr } - a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata)) + a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata)) _, sessionErr := a.sessions.Save(genCtx, updatedSession) sessionLock.Unlock() if sessionErr != nil { @@ -371,7 +379,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy }, StopWhen: []fantasy.StopCondition{ func(_ []fantasy.StepResult) bool { - cw := int64(a.largeModel.CatwalkCfg.ContextWindow) + cw := int64(largeModel.CatwalkCfg.ContextWindow) tokens := currentSession.CompletionTokens + currentSession.PromptTokens remaining := cw - tokens var threshold int64 @@ -473,7 +481,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy currentAssistant.AddFinish( message.FinishReasonError, "Copilot model not enabled", - fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", a.largeModel.CatwalkCfg.Name, link), + fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link), ) } else { currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message) @@ -491,7 +499,6 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy } return nil, err } - wg.Wait() if shouldSummarize { a.activeRequests.Del(call.SessionID) @@ -529,6 +536,10 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan return ErrSessionBusy } + // Copy mutable fields under lock to avoid races with SetModels. + largeModel := a.largeModel.Get() + systemPromptPrefix := a.systemPromptPrefix.Get() + currentSession, err := a.sessions.Get(ctx, sessionID) if err != nil { return fmt.Errorf("failed to get session: %w", err) @@ -549,13 +560,13 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan defer a.activeRequests.Del(sessionID) defer cancel() - agent := fantasy.NewAgent(a.largeModel.Model, + agent := fantasy.NewAgent(largeModel.Model, fantasy.WithSystemPrompt(string(summaryPrompt)), ) summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ Role: message.Assistant, - Model: a.largeModel.Model.Model(), - Provider: a.largeModel.Model.Provider(), + Model: largeModel.Model.Model(), + Provider: largeModel.Model.Provider(), IsSummaryMessage: true, }) if err != nil { @@ -570,8 +581,8 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan ProviderOptions: opts, PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) { prepared.Messages = options.Messages - if a.systemPromptPrefix != "" { - prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...) + if systemPromptPrefix != "" { + prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...) } return callContext, prepared, nil }, @@ -622,7 +633,7 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan } } - a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost) + a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost) // Just in case, get just the last usage info. usage := resp.Response.Usage @@ -730,9 +741,13 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, user return } + smallModel := a.smallModel.Get() + largeModel := a.largeModel.Get() + systemPromptPrefix := a.systemPromptPrefix.Get() + var maxOutputTokens int64 = 40 - if a.smallModel.CatwalkCfg.CanReason { - maxOutputTokens = a.smallModel.CatwalkCfg.DefaultMaxTokens + if smallModel.CatwalkCfg.CanReason { + maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens } newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent { @@ -746,9 +761,9 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, user Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n \n\n", userPrompt), PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) { prepared.Messages = opts.Messages - if a.systemPromptPrefix != "" { + if systemPromptPrefix != "" { prepared.Messages = append([]fantasy.Message{ - fantasy.NewSystemMessage(a.systemPromptPrefix), + fantasy.NewSystemMessage(systemPromptPrefix), }, prepared.Messages...) } return callCtx, prepared, nil @@ -756,7 +771,7 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, user } // Use the small model to generate the title. - model := &a.smallModel + model := smallModel agent := newAgent(model.Model, titlePrompt, maxOutputTokens) resp, err := agent.Stream(ctx, streamCall) if err == nil { @@ -765,7 +780,7 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, user } else { // It didn't work. Let's try with the big model. slog.Error("error generating title with small model; trying big model", "err", err) - model = &a.largeModel + model = largeModel agent = newAgent(model.Model, titlePrompt, maxOutputTokens) resp, err = agent.Stream(ctx, streamCall) if err == nil { @@ -960,24 +975,20 @@ func (a *sessionAgent) QueuedPromptsList(sessionID string) []string { } func (a *sessionAgent) SetModels(large Model, small Model) { - a.largeModel = large - a.smallModel = small + a.largeModel.Set(large) + a.smallModel.Set(small) } func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) { - a.tools = tools + a.tools.SetSlice(tools) } func (a *sessionAgent) SetSystemPrompt(systemPrompt string) { - a.systemPrompt = systemPrompt + a.systemPrompt.Set(systemPrompt) } func (a *sessionAgent) Model() Model { - return a.largeModel -} - -func (a *sessionAgent) promptPrefix() string { - return a.systemPromptPrefix + return a.largeModel.Get() } // convertToToolResult converts a fantasy tool result to a message tool result. @@ -1034,9 +1045,9 @@ func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) mes // // BEFORE: [tool result: image data] // AFTER: [tool result: "Image loaded - see attached"], [user: image attachment] -func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message { - providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) || - a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock) +func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message { + providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) || + largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock) if providerSupportsMedia { return messages diff --git a/internal/agent/event.go b/internal/agent/event.go index bf36ec84bf4270bd2e63ae0efae0440474288565..3f6c640f6a983c515034e0698676632d0cb57824 100644 --- a/internal/agent/event.go +++ b/internal/agent/event.go @@ -7,23 +7,23 @@ import ( "github.com/charmbracelet/crush/internal/event" ) -func (a sessionAgent) eventPromptSent(sessionID string) { +func (a *sessionAgent) eventPromptSent(sessionID string) { event.PromptSent( - a.eventCommon(sessionID, a.largeModel)..., + a.eventCommon(sessionID, a.largeModel.Get())..., ) } -func (a sessionAgent) eventPromptResponded(sessionID string, duration time.Duration) { +func (a *sessionAgent) eventPromptResponded(sessionID string, duration time.Duration) { event.PromptResponded( append( - a.eventCommon(sessionID, a.largeModel), + a.eventCommon(sessionID, a.largeModel.Get()), "prompt duration pretty", duration.String(), "prompt duration in seconds", int64(duration.Seconds()), )..., ) } -func (a sessionAgent) eventTokensUsed(sessionID string, model Model, usage fantasy.Usage, cost float64) { +func (a *sessionAgent) eventTokensUsed(sessionID string, model Model, usage fantasy.Usage, cost float64) { event.TokensUsed( append( a.eventCommon(sessionID, model), @@ -37,7 +37,7 @@ func (a sessionAgent) eventTokensUsed(sessionID string, model Model, usage fanta ) } -func (a sessionAgent) eventCommon(sessionID string, model Model) []any { +func (a *sessionAgent) eventCommon(sessionID string, model Model) []any { m := model.ModelCfg return []any{ diff --git a/internal/csync/maps_test.go b/internal/csync/maps_test.go index 4c590f008dad91e8dcbc40d1b90d87ef1b3e5750..31e6fa0c3aef18a04c61ea3d4d36b5187228c3ff 100644 --- a/internal/csync/maps_test.go +++ b/internal/csync/maps_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "maps" "sync" + "sync/atomic" "testing" "testing/synctest" "time" @@ -46,12 +47,12 @@ func TestNewLazyMap(t *testing.T) { waiter := sync.Mutex{} waiter.Lock() - loadCalled := false + var loadCalled atomic.Bool loadFunc := func() map[string]int { waiter.Lock() defer waiter.Unlock() - loadCalled = true + loadCalled.Store(true) return map[string]int{ "key1": 1, "key2": 2, @@ -63,7 +64,7 @@ func TestNewLazyMap(t *testing.T) { waiter.Unlock() // Allow the load function to proceed time.Sleep(100 * time.Millisecond) - require.True(t, loadCalled) + require.True(t, loadCalled.Load()) require.Equal(t, 2, m.Len()) value, ok := m.Get("key1") diff --git a/internal/csync/slices.go b/internal/csync/slices.go index c5c635683e70046694f1cdf647aac8cb425abd24..fcce9881b6e27021adcc9462b123f49d469dcd9f 100644 --- a/internal/csync/slices.go +++ b/internal/csync/slices.go @@ -2,7 +2,6 @@ package csync import ( "iter" - "slices" "sync" ) @@ -63,24 +62,6 @@ func (s *Slice[T]) Append(items ...T) { s.inner = append(s.inner, items...) } -// Prepend adds an element to the beginning of the slice. -func (s *Slice[T]) Prepend(item T) { - s.mu.Lock() - defer s.mu.Unlock() - s.inner = append([]T{item}, s.inner...) -} - -// Delete removes the element at the specified index. -func (s *Slice[T]) Delete(index int) bool { - s.mu.Lock() - defer s.mu.Unlock() - if index < 0 || index >= len(s.inner) { - return false - } - s.inner = slices.Delete(s.inner, index, index+1) - return true -} - // Get returns the element at the specified index. func (s *Slice[T]) Get(index int) (T, bool) { s.mu.RLock() @@ -92,17 +73,6 @@ func (s *Slice[T]) Get(index int) (T, bool) { return s.inner[index], true } -// Set updates the element at the specified index. -func (s *Slice[T]) Set(index int, item T) bool { - s.mu.Lock() - defer s.mu.Unlock() - if index < 0 || index >= len(s.inner) { - return false - } - s.inner[index] = item - return true -} - // Len returns the number of elements in the slice. func (s *Slice[T]) Len() int { s.mu.RLock() @@ -131,10 +101,7 @@ func (s *Slice[T]) Seq() iter.Seq[T] { // Seq2 returns an iterator that yields index-value pairs from the slice. func (s *Slice[T]) Seq2() iter.Seq2[int, T] { - s.mu.RLock() - items := make([]T, len(s.inner)) - copy(items, s.inner) - s.mu.RUnlock() + items := s.Copy() return func(yield func(int, T) bool) { for i, v := range items { if !yield(i, v) { @@ -143,3 +110,12 @@ func (s *Slice[T]) Seq2() iter.Seq2[int, T] { } } } + +// Copy returns a copy of the inner slice. +func (s *Slice[T]) Copy() []T { + s.mu.RLock() + defer s.mu.RUnlock() + items := make([]T, len(s.inner)) + copy(items, s.inner) + return items +} diff --git a/internal/csync/slices_test.go b/internal/csync/slices_test.go index 85aedbaba40103ff9a8979e5c70299223f74591f..c7946ac6f1a84614def05b7b6e7e9b0ed11b3a73 100644 --- a/internal/csync/slices_test.go +++ b/internal/csync/slices_test.go @@ -109,44 +109,6 @@ func TestSlice(t *testing.T) { require.Equal(t, "world", val) }) - t.Run("Prepend", func(t *testing.T) { - s := NewSlice[string]() - s.Append("world") - s.Prepend("hello") - - require.Equal(t, 2, s.Len()) - val, ok := s.Get(0) - require.True(t, ok) - require.Equal(t, "hello", val) - - val, ok = s.Get(1) - require.True(t, ok) - require.Equal(t, "world", val) - }) - - t.Run("Delete", func(t *testing.T) { - s := NewSliceFrom([]int{1, 2, 3, 4, 5}) - - // Delete middle element - ok := s.Delete(2) - require.True(t, ok) - require.Equal(t, 4, s.Len()) - - expected := []int{1, 2, 4, 5} - actual := slices.Collect(s.Seq()) - require.Equal(t, expected, actual) - - // Delete out of bounds - ok = s.Delete(10) - require.False(t, ok) - require.Equal(t, 4, s.Len()) - - // Delete negative index - ok = s.Delete(-1) - require.False(t, ok) - require.Equal(t, 4, s.Len()) - }) - t.Run("Get", func(t *testing.T) { s := NewSliceFrom([]string{"a", "b", "c"}) @@ -163,25 +125,6 @@ func TestSlice(t *testing.T) { require.False(t, ok) }) - t.Run("Set", func(t *testing.T) { - s := NewSliceFrom([]string{"a", "b", "c"}) - - ok := s.Set(1, "modified") - require.True(t, ok) - - val, ok := s.Get(1) - require.True(t, ok) - require.Equal(t, "modified", val) - - // Out of bounds - ok = s.Set(10, "invalid") - require.False(t, ok) - - // Negative index - ok = s.Set(-1, "invalid") - require.False(t, ok) - }) - t.Run("SetSlice", func(t *testing.T) { s := NewSlice[int]() s.Append(1) diff --git a/internal/csync/value.go b/internal/csync/value.go new file mode 100644 index 0000000000000000000000000000000000000000..17528a281e0d34d49b206a7c3901b892370c18ba --- /dev/null +++ b/internal/csync/value.go @@ -0,0 +1,44 @@ +package csync + +import ( + "reflect" + "sync" +) + +// Value is a generic thread-safe wrapper for any value type. +// +// For slices, use [Slice]. For maps, use [Map]. Pointers are not supported. +type Value[T any] struct { + v T + mu sync.RWMutex +} + +// NewValue creates a new Value with the given initial value. +// +// Panics if t is a pointer, slice, or map. Use the dedicated types for those. +func NewValue[T any](t T) *Value[T] { + v := reflect.ValueOf(t) + switch v.Kind() { + case reflect.Pointer: + panic("csync.Value does not support pointer types") + case reflect.Slice: + panic("csync.Value does not support slice types; use csync.Slice") + case reflect.Map: + panic("csync.Value does not support map types; use csync.Map") + } + return &Value[T]{v: t} +} + +// Get returns the current value. +func (v *Value[T]) Get() T { + v.mu.RLock() + defer v.mu.RUnlock() + return v.v +} + +// Set updates the value. +func (v *Value[T]) Set(t T) { + v.mu.Lock() + defer v.mu.Unlock() + v.v = t +} diff --git a/internal/csync/value_test.go b/internal/csync/value_test.go new file mode 100644 index 0000000000000000000000000000000000000000..3fa41d85144ea9373c7d440238c0321f52286330 --- /dev/null +++ b/internal/csync/value_test.go @@ -0,0 +1,99 @@ +package csync + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestValue_GetSet(t *testing.T) { + t.Parallel() + + v := NewValue(42) + require.Equal(t, 42, v.Get()) + + v.Set(100) + require.Equal(t, 100, v.Get()) +} + +func TestValue_ZeroValue(t *testing.T) { + t.Parallel() + + v := NewValue("") + require.Equal(t, "", v.Get()) + + v.Set("hello") + require.Equal(t, "hello", v.Get()) +} + +func TestValue_Struct(t *testing.T) { + t.Parallel() + + type config struct { + Name string + Count int + } + + v := NewValue(config{Name: "test", Count: 1}) + require.Equal(t, config{Name: "test", Count: 1}, v.Get()) + + v.Set(config{Name: "updated", Count: 2}) + require.Equal(t, config{Name: "updated", Count: 2}, v.Get()) +} + +func TestValue_PointerPanics(t *testing.T) { + t.Parallel() + + require.Panics(t, func() { + NewValue(&struct{}{}) + }) +} + +func TestValue_SlicePanics(t *testing.T) { + t.Parallel() + + require.Panics(t, func() { + NewValue([]string{"a", "b"}) + }) +} + +func TestValue_MapPanics(t *testing.T) { + t.Parallel() + + require.Panics(t, func() { + NewValue(map[string]int{"a": 1}) + }) +} + +func TestValue_ConcurrentAccess(t *testing.T) { + t.Parallel() + + v := NewValue(0) + var wg sync.WaitGroup + + // Concurrent writers. + for i := range 100 { + wg.Add(1) + go func(val int) { + defer wg.Done() + v.Set(val) + }(i) + } + + // Concurrent readers. + for range 100 { + wg.Add(1) + go func() { + defer wg.Done() + _ = v.Get() + }() + } + + wg.Wait() + + // Value should be one of the set values (0-99). + got := v.Get() + require.GreaterOrEqual(t, got, 0) + require.Less(t, got, 100) +} diff --git a/internal/permission/permission.go b/internal/permission/permission.go index 9dc85e976238fdbe1ff2d3689b2a2c4160608760..e1bf1bae14b8473989b1c0890c58188591123d71 100644 --- a/internal/permission/permission.go +++ b/internal/permission/permission.go @@ -68,8 +68,9 @@ type permissionService struct { allowedTools []string // used to make sure we only process one request at a time - requestMu sync.Mutex - activeRequest *PermissionRequest + requestMu sync.Mutex + activeRequest *PermissionRequest + activeRequestMu sync.Mutex } func (s *permissionService) GrantPersistent(permission PermissionRequest) { @@ -86,9 +87,11 @@ func (s *permissionService) GrantPersistent(permission PermissionRequest) { s.sessionPermissions = append(s.sessionPermissions, permission) s.sessionPermissionsMu.Unlock() + s.activeRequestMu.Lock() if s.activeRequest != nil && s.activeRequest.ID == permission.ID { s.activeRequest = nil } + s.activeRequestMu.Unlock() } func (s *permissionService) Grant(permission PermissionRequest) { @@ -101,9 +104,11 @@ func (s *permissionService) Grant(permission PermissionRequest) { respCh <- true } + s.activeRequestMu.Lock() if s.activeRequest != nil && s.activeRequest.ID == permission.ID { s.activeRequest = nil } + s.activeRequestMu.Unlock() } func (s *permissionService) Deny(permission PermissionRequest) { @@ -117,9 +122,11 @@ func (s *permissionService) Deny(permission PermissionRequest) { respCh <- false } + s.activeRequestMu.Lock() if s.activeRequest != nil && s.activeRequest.ID == permission.ID { s.activeRequest = nil } + s.activeRequestMu.Unlock() } func (s *permissionService) Request(ctx context.Context, opts CreatePermissionRequest) (bool, error) { @@ -190,7 +197,9 @@ func (s *permissionService) Request(ctx context.Context, opts CreatePermissionRe } s.sessionPermissionsMu.RUnlock() + s.activeRequestMu.Lock() s.activeRequest = &permission + s.activeRequestMu.Unlock() respCh := make(chan bool, 1) s.pendingRequests.Set(permission.ID, respCh) diff --git a/internal/permission/permission_test.go b/internal/permission/permission_test.go index 89e06916024cd1669f5e0d0a263d4a71548c8a97..79930f3ae1e2ef15257f09724fef64d3ea28dada 100644 --- a/internal/permission/permission_test.go +++ b/internal/permission/permission_test.go @@ -189,7 +189,7 @@ func TestPermissionService_SequentialProperties(t *testing.T) { events := service.Subscribe(t.Context()) var wg sync.WaitGroup - results := make([]bool, 0) + results := make([]bool, 3) requests := []CreatePermissionRequest{ { @@ -220,7 +220,7 @@ func TestPermissionService_SequentialProperties(t *testing.T) { go func(index int, request CreatePermissionRequest) { defer wg.Done() result, _ := service.Request(t.Context(), request) - results = append(results, result) + results[index] = result }(i, req) } diff --git a/internal/shell/background.go b/internal/shell/background.go index bc81369ec877586c92fa9bc701d8b78b669f23d5..cb1855836f64bdd56a90802c2bbb939a5a514100 100644 --- a/internal/shell/background.go +++ b/internal/shell/background.go @@ -19,6 +19,30 @@ const ( CompletedJobRetentionMinutes = 8 * 60 ) +// syncBuffer is a thread-safe wrapper around bytes.Buffer. +type syncBuffer struct { + buf bytes.Buffer + mu sync.RWMutex +} + +func (sb *syncBuffer) Write(p []byte) (n int, err error) { + sb.mu.Lock() + defer sb.mu.Unlock() + return sb.buf.Write(p) +} + +func (sb *syncBuffer) WriteString(s string) (n int, err error) { + sb.mu.Lock() + defer sb.mu.Unlock() + return sb.buf.WriteString(s) +} + +func (sb *syncBuffer) String() string { + sb.mu.RLock() + defer sb.mu.RUnlock() + return sb.buf.String() +} + // BackgroundShell represents a shell running in the background. type BackgroundShell struct { ID string @@ -28,8 +52,8 @@ type BackgroundShell struct { WorkingDir string ctx context.Context cancel context.CancelFunc - stdout *bytes.Buffer - stderr *bytes.Buffer + stdout *syncBuffer + stderr *syncBuffer done chan struct{} exitErr error completedAt int64 // Unix timestamp when job completed (0 if still running) @@ -46,12 +70,17 @@ var ( idCounter atomic.Uint64 ) +// newBackgroundShellManager creates a new BackgroundShellManager instance. +func newBackgroundShellManager() *BackgroundShellManager { + return &BackgroundShellManager{ + shells: csync.NewMap[string, *BackgroundShell](), + } +} + // GetBackgroundShellManager returns the singleton background shell manager. func GetBackgroundShellManager() *BackgroundShellManager { backgroundManagerOnce.Do(func() { - backgroundManager = &BackgroundShellManager{ - shells: csync.NewMap[string, *BackgroundShell](), - } + backgroundManager = newBackgroundShellManager() }) return backgroundManager } @@ -80,8 +109,8 @@ func (m *BackgroundShellManager) Start(ctx context.Context, workingDir string, b Shell: shell, ctx: shellCtx, cancel: cancel, - stdout: &bytes.Buffer{}, - stderr: &bytes.Buffer{}, + stdout: &syncBuffer{}, + stderr: &syncBuffer{}, done: make(chan struct{}), } diff --git a/internal/shell/background_test.go b/internal/shell/background_test.go index 5149861d94e457e8a78650c48d9c6765a57d369e..7c521bc1477b07775cffb69f310fa83d710d4634 100644 --- a/internal/shell/background_test.go +++ b/internal/shell/background_test.go @@ -14,7 +14,7 @@ func TestBackgroundShellManager_Start(t *testing.T) { ctx := context.Background() workingDir := t.TempDir() - manager := GetBackgroundShellManager() + manager := newBackgroundShellManager() bgShell, err := manager.Start(ctx, workingDir, nil, "echo 'hello world'", "") if err != nil { @@ -51,7 +51,7 @@ func TestBackgroundShellManager_Get(t *testing.T) { ctx := context.Background() workingDir := t.TempDir() - manager := GetBackgroundShellManager() + manager := newBackgroundShellManager() bgShell, err := manager.Start(ctx, workingDir, nil, "echo 'test'", "") if err != nil { @@ -77,7 +77,7 @@ func TestBackgroundShellManager_Kill(t *testing.T) { ctx := context.Background() workingDir := t.TempDir() - manager := GetBackgroundShellManager() + manager := newBackgroundShellManager() // Start a long-running command bgShell, err := manager.Start(ctx, workingDir, nil, "sleep 10", "") @@ -106,7 +106,7 @@ func TestBackgroundShellManager_Kill(t *testing.T) { func TestBackgroundShellManager_KillNonExistent(t *testing.T) { t.Parallel() - manager := GetBackgroundShellManager() + manager := newBackgroundShellManager() err := manager.Kill("non-existent-id") if err == nil { @@ -119,7 +119,7 @@ func TestBackgroundShell_IsDone(t *testing.T) { ctx := context.Background() workingDir := t.TempDir() - manager := GetBackgroundShellManager() + manager := newBackgroundShellManager() bgShell, err := manager.Start(ctx, workingDir, nil, "echo 'quick'", "") if err != nil { @@ -142,7 +142,7 @@ func TestBackgroundShell_WithBlockFuncs(t *testing.T) { ctx := context.Background() workingDir := t.TempDir() - manager := GetBackgroundShellManager() + manager := newBackgroundShellManager() blockFuncs := []BlockFunc{ CommandsBlocker([]string{"curl", "wget"}), @@ -180,7 +180,7 @@ func TestBackgroundShellManager_List(t *testing.T) { ctx := context.Background() workingDir := t.TempDir() - manager := GetBackgroundShellManager() + manager := newBackgroundShellManager() // Start two shells bgShell1, err := manager.Start(ctx, workingDir, nil, "sleep 1", "") @@ -224,7 +224,7 @@ func TestBackgroundShellManager_KillAll(t *testing.T) { ctx := context.Background() workingDir := t.TempDir() - manager := GetBackgroundShellManager() + manager := newBackgroundShellManager() // Start multiple long-running shells shell1, err := manager.Start(ctx, workingDir, nil, "sleep 10", "") diff --git a/internal/tui/components/chat/editor/editor.go b/internal/tui/components/chat/editor/editor.go index 01badb98d37eb848ccf5962e01793ecaa3fc0f59..b5cadb8cde8a1ced8543d01eb7abd28d906f1597 100644 --- a/internal/tui/components/chat/editor/editor.go +++ b/internal/tui/components/chat/editor/editor.go @@ -16,6 +16,7 @@ import ( "charm.land/bubbles/v2/textarea" tea "charm.land/bubbletea/v2" "charm.land/lipgloss/v2" + nativeclipboard "github.com/aymanbagabas/go-nativeclipboard" "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/filetracker" "github.com/charmbracelet/crush/internal/fsext" @@ -338,6 +339,84 @@ func (m *editorCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { m.textarea.InsertRune('\n') cmds = append(cmds, util.CmdHandler(completions.CloseCompletionsMsg{})) } + // Handle image paste from clipboard + if key.Matches(msg, m.keyMap.PasteImage) { + imageData, err := nativeclipboard.Image.Read() + + if err != nil || len(imageData) == 0 { + // If no image data found, try to get text data (could be file path) + var textData []byte + textData, err = nativeclipboard.Text.Read() + if err != nil || len(textData) == 0 { + // If clipboard is empty, show a warning + return m, util.ReportWarn("No data found in clipboard. Note: Some terminals may not support reading image data from clipboard directly.") + } + + // Check if the text data is a file path + textStr := string(textData) + // First, try to interpret as a file path (existing functionality) + path := strings.ReplaceAll(textStr, "\\ ", " ") + path, err = filepath.Abs(strings.TrimSpace(path)) + if err == nil { + isAllowedType := false + for _, ext := range filepicker.AllowedTypes { + if strings.HasSuffix(path, ext) { + isAllowedType = true + break + } + } + if isAllowedType { + tooBig, _ := filepicker.IsFileTooBig(path, filepicker.MaxAttachmentSize) + if !tooBig { + content, err := os.ReadFile(path) + if err == nil { + mimeBufferSize := min(512, len(content)) + mimeType := http.DetectContentType(content[:mimeBufferSize]) + fileName := filepath.Base(path) + attachment := message.Attachment{FilePath: path, FileName: fileName, MimeType: mimeType, Content: content} + return m, util.CmdHandler(filepicker.FilePickedMsg{ + Attachment: attachment, + }) + } + } + } + } + + // If not a valid file path, show a warning + return m, util.ReportWarn("No image found in clipboard") + } else { + // We have image data from the clipboard + // Create a temporary file to store the clipboard image data + tempFile, err := os.CreateTemp("", "clipboard_image_crush_*") + if err != nil { + return m, util.ReportError(err) + } + defer tempFile.Close() + + // Write clipboard content to the temporary file + _, err = tempFile.Write(imageData) + if err != nil { + return m, util.ReportError(err) + } + + // Determine the file extension based on the image data + mimeBufferSize := min(512, len(imageData)) + mimeType := http.DetectContentType(imageData[:mimeBufferSize]) + + // Create an attachment from the temporary file + fileName := filepath.Base(tempFile.Name()) + attachment := message.Attachment{ + FilePath: tempFile.Name(), + FileName: fileName, + MimeType: mimeType, + Content: imageData, + } + + return m, util.CmdHandler(filepicker.FilePickedMsg{ + Attachment: attachment, + }) + } + } // Handle Enter key if m.textarea.Focused() && key.Matches(msg, m.keyMap.SendMessage) { value := m.textarea.Value() diff --git a/internal/tui/components/chat/editor/keys.go b/internal/tui/components/chat/editor/keys.go index 0ba4571888e547b1c4a85e7ee9dd73ff07ce13d2..c20df5cc1c071deab83754430543b9be2381127c 100644 --- a/internal/tui/components/chat/editor/keys.go +++ b/internal/tui/components/chat/editor/keys.go @@ -9,6 +9,7 @@ type EditorKeyMap struct { SendMessage key.Binding OpenEditor key.Binding Newline key.Binding + PasteImage key.Binding } func DefaultEditorKeyMap() EditorKeyMap { @@ -32,6 +33,10 @@ func DefaultEditorKeyMap() EditorKeyMap { // to reflect that. key.WithHelp("ctrl+j", "newline"), ), + PasteImage: key.NewBinding( + key.WithKeys("ctrl+v"), + key.WithHelp("ctrl+v", "paste image from clipboard"), + ), } } @@ -42,6 +47,7 @@ func (k EditorKeyMap) KeyBindings() []key.Binding { k.SendMessage, k.OpenEditor, k.Newline, + k.PasteImage, AttachmentsKeyMaps.AttachmentDeleteMode, AttachmentsKeyMaps.DeleteAllAttachments, AttachmentsKeyMaps.Escape, diff --git a/internal/tui/styles/theme.go b/internal/tui/styles/theme.go index f87ffd9de8b324cec4dcfd8b7cee61f71e0390eb..b03603c57439f5f950f9860d3287b0f9d13742e5 100644 --- a/internal/tui/styles/theme.go +++ b/internal/tui/styles/theme.go @@ -4,6 +4,7 @@ import ( "fmt" "image/color" "strings" + "sync" "charm.land/bubbles/v2/filepicker" "charm.land/bubbles/v2/help" @@ -97,7 +98,8 @@ type Theme struct { AuthBorderUnselected lipgloss.Style AuthTextUnselected lipgloss.Style - styles *Styles + styles *Styles + stylesOnce sync.Once } type Styles struct { @@ -134,9 +136,9 @@ type Styles struct { } func (t *Theme) S() *Styles { - if t.styles == nil { + t.stylesOnce.Do(func() { t.styles = t.buildStyles() - } + }) return t.styles } @@ -500,27 +502,31 @@ type Manager struct { current *Theme } -var defaultManager *Manager +var ( + defaultManager *Manager + defaultManagerOnce sync.Once +) + +func initDefaultManager() *Manager { + defaultManagerOnce.Do(func() { + defaultManager = newManager() + }) + return defaultManager +} func SetDefaultManager(m *Manager) { defaultManager = m } func DefaultManager() *Manager { - if defaultManager == nil { - defaultManager = NewManager() - } - return defaultManager + return initDefaultManager() } func CurrentTheme() *Theme { - if defaultManager == nil { - defaultManager = NewManager() - } - return defaultManager.Current() + return initDefaultManager().Current() } -func NewManager() *Manager { +func newManager() *Manager { m := &Manager{ themes: make(map[string]*Theme), }