diff --git a/.config/hakari.toml b/.config/hakari.toml index 982542ca397e072d83af67608ea31a3415360a8e..2050065cc2d6be2a27ec012dcd125af992793eeb 100644 --- a/.config/hakari.toml +++ b/.config/hakari.toml @@ -23,6 +23,8 @@ workspace-members = [ ] third-party = [ { name = "reqwest", version = "0.11.27" }, + # build of remote_server should not include scap / its x11 dependency + { name = "scap", git = "https://github.com/zed-industries/scap", rev = "808aa5c45b41e8f44729d02e38fd00a2fe2722e7" }, ] [final-excludes] diff --git a/.github/actionlint.yml b/.github/actionlint.yml new file mode 100644 index 0000000000000000000000000000000000000000..ad0954590267f6c4d0450f7fe8ce3ccf21409560 --- /dev/null +++ b/.github/actionlint.yml @@ -0,0 +1,29 @@ +# Configuration related to self-hosted runner. +self-hosted-runner: + # Labels of self-hosted runner in array of strings. + labels: + # GitHub-hosted Runners + - github-8vcpu-ubuntu-2404 + - github-16vcpu-ubuntu-2404 + - github-32vcpu-ubuntu-2404 + - github-8vcpu-ubuntu-2204 + - github-16vcpu-ubuntu-2204 + - github-32vcpu-ubuntu-2204 + - github-16vcpu-ubuntu-2204-arm + - windows-2025-16 + - windows-2025-32 + - windows-2025-64 + # Namespace Ubuntu 20.04 (Release builds) + - namespace-profile-16x32-ubuntu-2004 + - namespace-profile-32x64-ubuntu-2004 + - namespace-profile-16x32-ubuntu-2004-arm + - namespace-profile-32x64-ubuntu-2004-arm + # Namespace Ubuntu 22.04 (Everything else) + - namespace-profile-2x4-ubuntu-2204 + - namespace-profile-4x8-ubuntu-2204 + - namespace-profile-8x16-ubuntu-2204 + - namespace-profile-16x32-ubuntu-2204 + - namespace-profile-32x64-ubuntu-2204 + # Self Hosted Runners + - self-mini-macos + - self-32vcpu-windows-2022 diff --git a/.github/actions/build_docs/action.yml b/.github/actions/build_docs/action.yml index 9a2d7e1ec718fd73cff7a32a6573e6b9b0f8ddd4..d2e62d5b22ee49c7dcb9b42085a648098fbdb6bb 100644 --- a/.github/actions/build_docs/action.yml +++ b/.github/actions/build_docs/action.yml @@ -13,13 +13,13 @@ runs: uses: swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2 with: save-if: ${{ github.ref == 'refs/heads/main' }} - cache-provider: "buildjet" + # cache-provider: "buildjet" - name: Install Linux dependencies shell: bash -euxo pipefail {0} run: ./script/linux - - name: Check for broken links + - name: Check for broken links (in MD) uses: lycheeverse/lychee-action@82202e5e9c2f4ef1a55a3d02563e1cb6041e5332 # v2.4.1 with: args: --no-progress --exclude '^http' './docs/src/**/*' @@ -30,3 +30,9 @@ runs: run: | mkdir -p target/deploy mdbook build ./docs --dest-dir=../target/deploy/docs/ + + - name: Check for broken links (in HTML) + uses: lycheeverse/lychee-action@82202e5e9c2f4ef1a55a3d02563e1cb6041e5332 # v2.4.1 + with: + args: --no-progress --exclude '^http' 'target/deploy/docs/' + fail: true diff --git a/.github/actions/install_trusted_signing/action.yml b/.github/actions/install_trusted_signing/action.yml deleted file mode 100644 index a99ff08eb1eb1f1b92cdea2c374a62b2384b2237..0000000000000000000000000000000000000000 --- a/.github/actions/install_trusted_signing/action.yml +++ /dev/null @@ -1,64 +0,0 @@ -name: "Trusted Signing on Windows" -description: "Install trusted signing on Windows." - -# Modified from https://github.com/Azure/trusted-signing-action -runs: - using: "composite" - steps: - - name: Set variables - id: set-variables - shell: "pwsh" - run: | - $defaultPath = $env:PSModulePath -split ';' | Select-Object -First 1 - "PSMODULEPATH=$defaultPath" | Out-File -FilePath $env:GITHUB_OUTPUT -Append - - "TRUSTED_SIGNING_MODULE_VERSION=0.5.3" | Out-File -FilePath $env:GITHUB_OUTPUT -Append - "BUILD_TOOLS_NUGET_VERSION=10.0.22621.3233" | Out-File -FilePath $env:GITHUB_OUTPUT -Append - "TRUSTED_SIGNING_NUGET_VERSION=1.0.53" | Out-File -FilePath $env:GITHUB_OUTPUT -Append - "DOTNET_SIGNCLI_NUGET_VERSION=0.9.1-beta.24469.1" | Out-File -FilePath $env:GITHUB_OUTPUT -Append - - - name: Cache TrustedSigning PowerShell module - id: cache-module - uses: actions/cache@v4 - env: - cache-name: cache-module - with: - path: ${{ steps.set-variables.outputs.PSMODULEPATH }}\TrustedSigning\${{ steps.set-variables.outputs.TRUSTED_SIGNING_MODULE_VERSION }} - key: TrustedSigning-${{ steps.set-variables.outputs.TRUSTED_SIGNING_MODULE_VERSION }} - if: ${{ inputs.cache-dependencies == 'true' }} - - - name: Cache Microsoft.Windows.SDK.BuildTools NuGet package - id: cache-buildtools - uses: actions/cache@v4 - env: - cache-name: cache-buildtools - with: - path: ~\AppData\Local\TrustedSigning\Microsoft.Windows.SDK.BuildTools\Microsoft.Windows.SDK.BuildTools.${{ steps.set-variables.outputs.BUILD_TOOLS_NUGET_VERSION }} - key: Microsoft.Windows.SDK.BuildTools-${{ steps.set-variables.outputs.BUILD_TOOLS_NUGET_VERSION }} - if: ${{ inputs.cache-dependencies == 'true' }} - - - name: Cache Microsoft.Trusted.Signing.Client NuGet package - id: cache-tsclient - uses: actions/cache@v4 - env: - cache-name: cache-tsclient - with: - path: ~\AppData\Local\TrustedSigning\Microsoft.Trusted.Signing.Client\Microsoft.Trusted.Signing.Client.${{ steps.set-variables.outputs.TRUSTED_SIGNING_NUGET_VERSION }} - key: Microsoft.Trusted.Signing.Client-${{ steps.set-variables.outputs.TRUSTED_SIGNING_NUGET_VERSION }} - if: ${{ inputs.cache-dependencies == 'true' }} - - - name: Cache SignCli NuGet package - id: cache-signcli - uses: actions/cache@v4 - env: - cache-name: cache-signcli - with: - path: ~\AppData\Local\TrustedSigning\sign\sign.${{ steps.set-variables.outputs.DOTNET_SIGNCLI_NUGET_VERSION }} - key: SignCli-${{ steps.set-variables.outputs.DOTNET_SIGNCLI_NUGET_VERSION }} - if: ${{ inputs.cache-dependencies == 'true' }} - - - name: Install Trusted Signing module - shell: "pwsh" - run: | - Install-Module -Name TrustedSigning -RequiredVersion ${{ steps.set-variables.outputs.TRUSTED_SIGNING_MODULE_VERSION }} -Force -Repository PSGallery - if: ${{ inputs.cache-dependencies != 'true' || steps.cache-module.outputs.cache-hit != 'true' }} diff --git a/.github/workflows/bump_patch_version.yml b/.github/workflows/bump_patch_version.yml index 02857a9151cc3ea88914113f9c792bb2d6b7a811..bfaf7a271b5e31b60c999c7dcf8d17538d135355 100644 --- a/.github/workflows/bump_patch_version.yml +++ b/.github/workflows/bump_patch_version.yml @@ -16,7 +16,7 @@ jobs: bump_patch_version: if: github.repository_owner == 'zed-industries' runs-on: - - buildjet-16vcpu-ubuntu-2204 + - namespace-profile-16x32-ubuntu-2204 steps: - name: Checkout code uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -28,7 +28,7 @@ jobs: run: | set -eux - channel=$(cat crates/zed/RELEASE_CHANNEL) + channel="$(cat crates/zed/RELEASE_CHANNEL)" tag_suffix="" case $channel in @@ -43,9 +43,9 @@ jobs: ;; esac which cargo-set-version > /dev/null || cargo install cargo-edit - output=$(cargo set-version -p zed --bump patch 2>&1 | sed 's/.* //') + output="$(cargo set-version -p zed --bump patch 2>&1 | sed 's/.* //')" export GIT_COMMITTER_NAME="Zed Bot" export GIT_COMMITTER_EMAIL="hi@zed.dev" git commit -am "Bump to $output for @$GITHUB_ACTOR" --author "Zed Bot " - git tag v${output}${tag_suffix} - git push origin HEAD v${output}${tag_suffix} + git tag "v${output}${tag_suffix}" + git push origin HEAD "v${output}${tag_suffix}" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ea30df1cc1984c6f0a3104ccd64dce24bb86c57e..84907351fe287c93c8ef0e50c66a3ba51e610ce7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,6 +21,10 @@ env: CARGO_TERM_COLOR: always CARGO_INCREMENTAL: 0 RUST_BACKTRACE: 1 + DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} + DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} + ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }} + ZED_MINIDUMP_ENDPOINT: ${{ secrets.ZED_SENTRY_MINIDUMP_ENDPOINT }} jobs: job_spec: @@ -31,6 +35,7 @@ jobs: run_license: ${{ steps.filter.outputs.run_license }} run_docs: ${{ steps.filter.outputs.run_docs }} run_nix: ${{ steps.filter.outputs.run_nix }} + run_actionlint: ${{ steps.filter.outputs.run_actionlint }} runs-on: - ubuntu-latest steps: @@ -44,39 +49,40 @@ jobs: run: | if [ -z "$GITHUB_BASE_REF" ]; then echo "Not in a PR context (i.e., push to main/stable/preview)" - COMPARE_REV=$(git rev-parse HEAD~1) + COMPARE_REV="$(git rev-parse HEAD~1)" else echo "In a PR context comparing to pull_request.base.ref" git fetch origin "$GITHUB_BASE_REF" --depth=350 - COMPARE_REV=$(git merge-base "origin/${GITHUB_BASE_REF}" HEAD) + COMPARE_REV="$(git merge-base "origin/${GITHUB_BASE_REF}" HEAD)" fi - # Specify anything which should skip full CI in this regex: + CHANGED_FILES="$(git diff --name-only "$COMPARE_REV" ${{ github.sha }})" + + # Specify anything which should potentially skip full test suite in this regex: # - docs/ # - script/update_top_ranking_issues/ # - .github/ISSUE_TEMPLATE/ # - .github/workflows/ (except .github/workflows/ci.yml) SKIP_REGEX='^(docs/|script/update_top_ranking_issues/|\.github/(ISSUE_TEMPLATE|workflows/(?!ci)))' - if [[ $(git diff --name-only $COMPARE_REV ${{ github.sha }} | grep -vP "$SKIP_REGEX") ]]; then - echo "run_tests=true" >> $GITHUB_OUTPUT - else - echo "run_tests=false" >> $GITHUB_OUTPUT - fi - if [[ $(git diff --name-only $COMPARE_REV ${{ github.sha }} | grep '^docs/') ]]; then - echo "run_docs=true" >> $GITHUB_OUTPUT - else - echo "run_docs=false" >> $GITHUB_OUTPUT - fi - if [[ $(git diff --name-only $COMPARE_REV ${{ github.sha }} | grep -P '^(Cargo.lock|script/.*licenses)') ]]; then - echo "run_license=true" >> $GITHUB_OUTPUT - else - echo "run_license=false" >> $GITHUB_OUTPUT - fi - NIX_REGEX='^(nix/|flake\.|Cargo\.|rust-toolchain.toml|\.cargo/config.toml)' - if [[ $(git diff --name-only $COMPARE_REV ${{ github.sha }} | grep -P "$NIX_REGEX") ]]; then - echo "run_nix=true" >> $GITHUB_OUTPUT - else - echo "run_nix=false" >> $GITHUB_OUTPUT - fi + + echo "$CHANGED_FILES" | grep -qvP "$SKIP_REGEX" && \ + echo "run_tests=true" >> "$GITHUB_OUTPUT" || \ + echo "run_tests=false" >> "$GITHUB_OUTPUT" + + echo "$CHANGED_FILES" | grep -qP '^docs/' && \ + echo "run_docs=true" >> "$GITHUB_OUTPUT" || \ + echo "run_docs=false" >> "$GITHUB_OUTPUT" + + echo "$CHANGED_FILES" | grep -qP '^\.github/(workflows/|actions/|actionlint.yml)' && \ + echo "run_actionlint=true" >> "$GITHUB_OUTPUT" || \ + echo "run_actionlint=false" >> "$GITHUB_OUTPUT" + + echo "$CHANGED_FILES" | grep -qP '^(Cargo.lock|script/.*licenses)' && \ + echo "run_license=true" >> "$GITHUB_OUTPUT" || \ + echo "run_license=false" >> "$GITHUB_OUTPUT" + + echo "$CHANGED_FILES" | grep -qP '^(nix/|flake\.|Cargo\.|rust-toolchain.toml|\.cargo/config.toml)' && \ + echo "run_nix=true" >> "$GITHUB_OUTPUT" || \ + echo "run_nix=false" >> "$GITHUB_OUTPUT" migration_checks: name: Check Postgres and Protobuf migrations, mergability @@ -86,8 +92,7 @@ jobs: needs.job_spec.outputs.run_tests == 'true' timeout-minutes: 60 runs-on: - - self-hosted - - macOS + - self-mini-macos steps: - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -109,11 +114,11 @@ jobs: run: | if [ -z "$GITHUB_BASE_REF" ]; then - echo "BUF_BASE_BRANCH=$(git merge-base origin/main HEAD)" >> $GITHUB_ENV + echo "BUF_BASE_BRANCH=$(git merge-base origin/main HEAD)" >> "$GITHUB_ENV" else git checkout -B temp - git merge -q origin/$GITHUB_BASE_REF -m "merge main into temp" - echo "BUF_BASE_BRANCH=$GITHUB_BASE_REF" >> $GITHUB_ENV + git merge -q "origin/$GITHUB_BASE_REF" -m "merge main into temp" + echo "BUF_BASE_BRANCH=$GITHUB_BASE_REF" >> "$GITHUB_ENV" fi - uses: bufbuild/buf-setup-action@v1 @@ -132,12 +137,12 @@ jobs: github.repository_owner == 'zed-industries' && needs.job_spec.outputs.run_tests == 'true' runs-on: - - buildjet-8vcpu-ubuntu-2204 + - namespace-profile-8x16-ubuntu-2204 steps: - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 - name: Add Rust to the PATH - run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH + run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH" - name: Install cargo-hakari uses: clechasseur/rs-cargo@8435b10f6e71c2e3d4d3b7573003a8ce4bfc6386 # v2 with: @@ -163,7 +168,7 @@ jobs: needs: [job_spec] if: github.repository_owner == 'zed-industries' runs-on: - - buildjet-8vcpu-ubuntu-2204 + - namespace-profile-4x8-ubuntu-2204 steps: - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -175,7 +180,7 @@ jobs: - name: Prettier Check on /docs working-directory: ./docs run: | - pnpm dlx prettier@${PRETTIER_VERSION} . --check || { + pnpm dlx "prettier@${PRETTIER_VERSION}" . --check || { echo "To fix, run from the root of the Zed repo:" echo " cd docs && pnpm dlx prettier@${PRETTIER_VERSION} . --write && cd .." false @@ -185,7 +190,7 @@ jobs: - name: Prettier Check on default.json run: | - pnpm dlx prettier@${PRETTIER_VERSION} assets/settings/default.json --check || { + pnpm dlx "prettier@${PRETTIER_VERSION}" assets/settings/default.json --check || { echo "To fix, run from the root of the Zed repo:" echo " pnpm dlx prettier@${PRETTIER_VERSION} assets/settings/default.json --write" false @@ -216,7 +221,7 @@ jobs: github.repository_owner == 'zed-industries' && (needs.job_spec.outputs.run_tests == 'true' || needs.job_spec.outputs.run_docs == 'true') runs-on: - - buildjet-8vcpu-ubuntu-2204 + - namespace-profile-8x16-ubuntu-2204 steps: - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -231,6 +236,20 @@ jobs: - name: Build docs uses: ./.github/actions/build_docs + actionlint: + runs-on: ubuntu-latest + if: github.repository_owner == 'zed-industries' && needs.job_spec.outputs.run_actionlint == 'true' + needs: [job_spec] + steps: + - uses: actions/checkout@v4 + - name: Download actionlint + id: get_actionlint + run: bash <(curl https://raw.githubusercontent.com/rhysd/actionlint/main/scripts/download-actionlint.bash) + shell: bash + - name: Check workflow files + run: ${{ steps.get_actionlint.outputs.executable }} -color + shell: bash + macos_tests: timeout-minutes: 60 name: (macOS) Run Clippy and tests @@ -239,8 +258,7 @@ jobs: github.repository_owner == 'zed-industries' && needs.job_spec.outputs.run_tests == 'true' runs-on: - - self-hosted - - macOS + - self-mini-macos steps: - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -252,6 +270,10 @@ jobs: mkdir -p ./../.cargo cp ./.cargo/ci-config.toml ./../.cargo/config.toml + - name: Check that Cargo.lock is up to date + run: | + cargo update --locked --workspace + - name: cargo clippy run: ./script/clippy @@ -306,10 +328,10 @@ jobs: github.repository_owner == 'zed-industries' && needs.job_spec.outputs.run_tests == 'true' runs-on: - - buildjet-16vcpu-ubuntu-2204 + - namespace-profile-16x32-ubuntu-2204 steps: - name: Add Rust to the PATH - run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH + run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH" - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -320,7 +342,7 @@ jobs: uses: swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2 with: save-if: ${{ github.ref == 'refs/heads/main' }} - cache-provider: "buildjet" + # cache-provider: "buildjet" - name: Install Linux dependencies run: ./script/linux @@ -358,10 +380,10 @@ jobs: github.repository_owner == 'zed-industries' && needs.job_spec.outputs.run_tests == 'true' runs-on: - - buildjet-8vcpu-ubuntu-2204 + - namespace-profile-16x32-ubuntu-2204 steps: - name: Add Rust to the PATH - run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH + run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH" - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -372,7 +394,7 @@ jobs: uses: swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2 with: save-if: ${{ github.ref == 'refs/heads/main' }} - cache-provider: "buildjet" + # cache-provider: "buildjet" - name: Install Clang & Mold run: ./script/remote-server && ./script/install-mold 2.34.0 @@ -441,6 +463,7 @@ jobs: - job_spec - style - check_docs + - actionlint - migration_checks # run_tests: If adding required tests, add them here and to script below. - workspace_hack @@ -462,6 +485,11 @@ jobs: if [[ "${{ needs.job_spec.outputs.run_docs }}" == "true" ]]; then [[ "${{ needs.check_docs.result }}" != 'success' ]] && { RET_CODE=1; echo "docs checks failed"; } fi + + if [[ "${{ needs.job_spec.outputs.run_actionlint }}" == "true" ]]; then + [[ "${{ needs.actionlint.result }}" != 'success' ]] && { RET_CODE=1; echo "actionlint checks failed"; } + fi + # Only check test jobs if they were supposed to run if [[ "${{ needs.job_spec.outputs.run_tests }}" == "true" ]]; then [[ "${{ needs.workspace_hack.result }}" != 'success' ]] && { RET_CODE=1; echo "Workspace Hack failed"; } @@ -481,8 +509,7 @@ jobs: timeout-minutes: 120 name: Create a macOS bundle runs-on: - - self-hosted - - bundle + - self-mini-macos if: | startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') @@ -493,9 +520,6 @@ jobs: APPLE_NOTARIZATION_KEY: ${{ secrets.APPLE_NOTARIZATION_KEY }} APPLE_NOTARIZATION_KEY_ID: ${{ secrets.APPLE_NOTARIZATION_KEY_ID }} APPLE_NOTARIZATION_ISSUER_ID: ${{ secrets.APPLE_NOTARIZATION_ISSUER_ID }} - ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }} - DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} - DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} steps: - name: Install Node uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4 @@ -573,15 +597,11 @@ jobs: timeout-minutes: 60 name: Linux x86_x64 release bundle runs-on: - - buildjet-16vcpu-ubuntu-2004 # ubuntu 20.04 for minimal glibc + - namespace-profile-16x32-ubuntu-2004 # ubuntu 20.04 for minimal glibc if: | startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') needs: [linux_tests] - env: - ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }} - DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} - DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} steps: - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -630,15 +650,11 @@ jobs: timeout-minutes: 60 name: Linux arm64 release bundle runs-on: - - buildjet-16vcpu-ubuntu-2204-arm + - namespace-profile-32x64-ubuntu-2004-arm # ubuntu 20.04 for minimal glibc if: | startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') needs: [linux_tests] - env: - ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }} - DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} - DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} steps: - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -687,20 +703,18 @@ jobs: timeout-minutes: 60 runs-on: github-8vcpu-ubuntu-2404 if: | + false && ( startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') + ) needs: [linux_tests] name: Build Zed on FreeBSD - # env: - # MYTOKEN : ${{ secrets.MYTOKEN }} - # MYTOKEN2: "value2" steps: - uses: actions/checkout@v4 - name: Build FreeBSD remote-server id: freebsd-build uses: vmactions/freebsd-vm@c3ae29a132c8ef1924775414107a97cac042aad5 # v1.2.0 with: - # envs: "MYTOKEN MYTOKEN2" usesh: true release: 13.5 copyback: true @@ -758,7 +772,8 @@ jobs: timeout-minutes: 120 name: Create a Windows installer runs-on: [self-hosted, Windows, X64] - if: ${{ startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') }} + if: contains(github.event.pull_request.labels.*.name, 'run-bundling') + # if: (startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling')) needs: [windows_tests] env: AZURE_TENANT_ID: ${{ secrets.AZURE_SIGNING_TENANT_ID }} @@ -767,8 +782,6 @@ jobs: ACCOUNT_NAME: ${{ vars.AZURE_SIGNING_ACCOUNT_NAME }} CERT_PROFILE_NAME: ${{ vars.AZURE_SIGNING_CERT_PROFILE_NAME }} ENDPOINT: ${{ vars.AZURE_SIGNING_ENDPOINT }} - DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} - DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} FILE_DIGEST: SHA256 TIMESTAMP_DIGEST: SHA256 TIMESTAMP_SERVER: "http://timestamp.acs.microsoft.com" @@ -785,9 +798,6 @@ jobs: # This exports RELEASE_CHANNEL into env (GITHUB_ENV) script/determine-release-channel.ps1 - - name: Install trusted signing - uses: ./.github/actions/install_trusted_signing - - name: Build Zed installer working-directory: ${{ env.ZED_WORKSPACE }} run: script/bundle-windows.ps1 @@ -802,7 +812,7 @@ jobs: - name: Upload Artifacts to release uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844 # v1 # Re-enable when we are ready to publish windows preview releases - if: false && ${{ !(contains(github.event.pull_request.labels.*.name, 'run-bundling')) && env.RELEASE_CHANNEL == 'preview' }} # upload only preview + if: ${{ !(contains(github.event.pull_request.labels.*.name, 'run-bundling')) && env.RELEASE_CHANNEL == 'preview' }} # upload only preview with: draft: true prerelease: ${{ env.RELEASE_CHANNEL == 'preview' }} @@ -815,12 +825,11 @@ jobs: if: | startsWith(github.ref, 'refs/tags/v') && endsWith(github.ref, '-pre') && !endsWith(github.ref, '.0-pre') - needs: [bundle-mac, bundle-linux-x86_x64, bundle-linux-aarch64, bundle-windows-x64, freebsd] + needs: [bundle-mac, bundle-linux-x86_x64, bundle-linux-aarch64, bundle-windows-x64] runs-on: - - self-hosted - - bundle + - self-mini-macos steps: - name: gh release - run: gh release edit $GITHUB_REF_NAME --draft=false + run: gh release edit "$GITHUB_REF_NAME" --draft=false env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/community_release_actions.yml b/.github/workflows/community_release_actions.yml index 3e253978b7a54c2a12ea281f84f0d3dbf45a8561..31dda1fa6d005ee16eb9d13aec6277ebf9a3ab94 100644 --- a/.github/workflows/community_release_actions.yml +++ b/.github/workflows/community_release_actions.yml @@ -18,7 +18,7 @@ jobs: URL="https://zed.dev/releases/stable/latest" fi - echo "URL=$URL" >> $GITHUB_OUTPUT + echo "URL=$URL" >> "$GITHUB_OUTPUT" - name: Get content uses: 2428392/gh-truncate-string-action@b3ff790d21cf42af3ca7579146eedb93c8fb0757 # v1.4.1 id: get-content @@ -50,9 +50,9 @@ jobs: PREVIEW_TAG="${VERSION}-pre" if git rev-parse "$PREVIEW_TAG" > /dev/null 2>&1; then - echo "was_promoted_from_preview=true" >> $GITHUB_OUTPUT + echo "was_promoted_from_preview=true" >> "$GITHUB_OUTPUT" else - echo "was_promoted_from_preview=false" >> $GITHUB_OUTPUT + echo "was_promoted_from_preview=false" >> "$GITHUB_OUTPUT" fi - name: Send release notes email diff --git a/.github/workflows/deploy_cloudflare.yml b/.github/workflows/deploy_cloudflare.yml index fe443d493e3d70e6dec15b6dbdab745fd475d2ee..df35d44ca9ceb00a0503e941110c472c0b418fa2 100644 --- a/.github/workflows/deploy_cloudflare.yml +++ b/.github/workflows/deploy_cloudflare.yml @@ -9,7 +9,7 @@ jobs: deploy-docs: name: Deploy Docs if: github.repository_owner == 'zed-industries' - runs-on: buildjet-16vcpu-ubuntu-2204 + runs-on: namespace-profile-16x32-ubuntu-2204 steps: - name: Checkout repo diff --git a/.github/workflows/deploy_collab.yml b/.github/workflows/deploy_collab.yml index cfd455f92092d773dc68fccf08c39fe7d5147c0f..ff2a3589e4c5482089536919618f1bbff982c63c 100644 --- a/.github/workflows/deploy_collab.yml +++ b/.github/workflows/deploy_collab.yml @@ -61,7 +61,7 @@ jobs: - style - tests runs-on: - - buildjet-16vcpu-ubuntu-2204 + - namespace-profile-16x32-ubuntu-2204 steps: - name: Install doctl uses: digitalocean/action-doctl@v2 @@ -79,12 +79,12 @@ jobs: - name: Build docker image run: | docker build -f Dockerfile-collab \ - --build-arg GITHUB_SHA=$GITHUB_SHA \ - --tag registry.digitalocean.com/zed/collab:$GITHUB_SHA \ + --build-arg "GITHUB_SHA=$GITHUB_SHA" \ + --tag "registry.digitalocean.com/zed/collab:$GITHUB_SHA" \ . - name: Publish docker image - run: docker push registry.digitalocean.com/zed/collab:${GITHUB_SHA} + run: docker push "registry.digitalocean.com/zed/collab:${GITHUB_SHA}" - name: Prune Docker system run: docker system prune --filter 'until=72h' -f @@ -94,7 +94,7 @@ jobs: needs: - publish runs-on: - - buildjet-16vcpu-ubuntu-2204 + - namespace-profile-16x32-ubuntu-2204 steps: - name: Checkout repo @@ -131,17 +131,20 @@ jobs: source script/lib/deploy-helpers.sh export_vars_for_environment $ZED_KUBE_NAMESPACE - export ZED_DO_CERTIFICATE_ID=$(doctl compute certificate list --format ID --no-header) + ZED_DO_CERTIFICATE_ID="$(doctl compute certificate list --format ID --no-header)" + export ZED_DO_CERTIFICATE_ID export ZED_IMAGE_ID="registry.digitalocean.com/zed/collab:${GITHUB_SHA}" export ZED_SERVICE_NAME=collab export ZED_LOAD_BALANCER_SIZE_UNIT=$ZED_COLLAB_LOAD_BALANCER_SIZE_UNIT + export DATABASE_MAX_CONNECTIONS=850 envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f - kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}" export ZED_SERVICE_NAME=api export ZED_LOAD_BALANCER_SIZE_UNIT=$ZED_API_LOAD_BALANCER_SIZE_UNIT + export DATABASE_MAX_CONNECTIONS=60 envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f - kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}" diff --git a/.github/workflows/eval.yml b/.github/workflows/eval.yml index 6eefdfea954c58919850baabe013d3d8676b54f9..b5da9e7b7c8e293fb565f4de269a1ae266c19692 100644 --- a/.github/workflows/eval.yml +++ b/.github/workflows/eval.yml @@ -32,10 +32,10 @@ jobs: github.repository_owner == 'zed-industries' && (github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'run-eval')) runs-on: - - buildjet-16vcpu-ubuntu-2204 + - namespace-profile-16x32-ubuntu-2204 steps: - name: Add Rust to the PATH - run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH + run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH" - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -46,7 +46,7 @@ jobs: uses: swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2 with: save-if: ${{ github.ref == 'refs/heads/main' }} - cache-provider: "buildjet" + # cache-provider: "buildjet" - name: Install Linux dependencies run: ./script/linux diff --git a/.github/workflows/nix.yml b/.github/workflows/nix.yml index 155fc484f57b593dbdca1811f571d97384ceb3c0..e682ce5890b86e8a3cf181be2d302d66025572c2 100644 --- a/.github/workflows/nix.yml +++ b/.github/workflows/nix.yml @@ -20,7 +20,7 @@ jobs: matrix: system: - os: x86 Linux - runner: buildjet-16vcpu-ubuntu-2204 + runner: namespace-profile-16x32-ubuntu-2204 install_nix: true - os: arm Mac runner: [macOS, ARM64, test] @@ -29,6 +29,7 @@ jobs: runs-on: ${{ matrix.system.runner }} env: ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }} + ZED_MINIDUMP_ENDPOINT: ${{ secrets.ZED_SENTRY_MINIDUMP_ENDPOINT }} ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }} GIT_LFS_SKIP_SMUDGE: 1 # breaks the livekit rust sdk examples which we don't actually depend on steps: @@ -43,8 +44,8 @@ jobs: - name: Set path if: ${{ ! matrix.system.install_nix }} run: | - echo "/nix/var/nix/profiles/default/bin" >> $GITHUB_PATH - echo "/Users/administrator/.nix-profile/bin" >> $GITHUB_PATH + echo "/nix/var/nix/profiles/default/bin" >> "$GITHUB_PATH" + echo "/Users/administrator/.nix-profile/bin" >> "$GITHUB_PATH" - uses: cachix/install-nix-action@02a151ada4993995686f9ed4f1be7cfbb229e56f # v31 if: ${{ matrix.system.install_nix }} @@ -56,11 +57,13 @@ jobs: name: zed authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}" pushFilter: "${{ inputs.cachix-filter }}" - cachixArgs: '-v' + cachixArgs: "-v" - run: nix build .#${{ inputs.flake-output }} -L --accept-flake-config - name: Limit /nix/store to 50GB on macs if: ${{ ! matrix.system.install_nix }} run: | - [ $(du -sm /nix/store | cut -f1) -gt 50000 ] && nix-collect-garbage -d || : + if [ "$(du -sm /nix/store | cut -f1)" -gt 50000 ]; then + nix-collect-garbage -d || true + fi diff --git a/.github/workflows/randomized_tests.yml b/.github/workflows/randomized_tests.yml index db4d44318eb40bc038788c07598825b88074e801..de96c3df78bdb67edd584696f02316478e4446dd 100644 --- a/.github/workflows/randomized_tests.yml +++ b/.github/workflows/randomized_tests.yml @@ -20,7 +20,7 @@ jobs: name: Run randomized tests if: github.repository_owner == 'zed-industries' runs-on: - - buildjet-16vcpu-ubuntu-2204 + - namespace-profile-16x32-ubuntu-2204 steps: - name: Install Node uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4 diff --git a/.github/workflows/release_nightly.yml b/.github/workflows/release_nightly.yml index df9f6ef40fc9faf87c5be82e3f95288012cd4221..b3500a085b6e52db45993acb7a40c40bb420a847 100644 --- a/.github/workflows/release_nightly.yml +++ b/.github/workflows/release_nightly.yml @@ -12,6 +12,10 @@ env: CARGO_TERM_COLOR: always CARGO_INCREMENTAL: 0 RUST_BACKTRACE: 1 + ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }} + ZED_MINIDUMP_ENDPOINT: ${{ secrets.ZED_SENTRY_MINIDUMP_ENDPOINT }} + DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} + DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} jobs: style: @@ -82,8 +86,7 @@ jobs: name: Create a macOS bundle if: github.repository_owner == 'zed-industries' runs-on: - - self-hosted - - bundle + - self-mini-macos needs: tests env: MACOS_CERTIFICATE: ${{ secrets.MACOS_CERTIFICATE }} @@ -91,9 +94,6 @@ jobs: APPLE_NOTARIZATION_KEY: ${{ secrets.APPLE_NOTARIZATION_KEY }} APPLE_NOTARIZATION_KEY_ID: ${{ secrets.APPLE_NOTARIZATION_KEY_ID }} APPLE_NOTARIZATION_ISSUER_ID: ${{ secrets.APPLE_NOTARIZATION_ISSUER_ID }} - DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} - DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} - ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }} steps: - name: Install Node uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4 @@ -112,6 +112,11 @@ jobs: echo "Publishing version: ${version} on release channel nightly" echo "nightly" > crates/zed/RELEASE_CHANNEL + - name: Setup Sentry CLI + uses: matbour/setup-sentry-cli@3e938c54b3018bdd019973689ef984e033b0454b #v2 + with: + token: ${{ SECRETS.SENTRY_AUTH_TOKEN }} + - name: Create macOS app bundle run: script/bundle-mac @@ -123,12 +128,8 @@ jobs: name: Create a Linux *.tar.gz bundle for x86 if: github.repository_owner == 'zed-industries' runs-on: - - buildjet-16vcpu-ubuntu-2004 + - namespace-profile-16x32-ubuntu-2004 # ubuntu 20.04 for minimal glibc needs: tests - env: - DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} - DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} - ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }} steps: - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -136,11 +137,16 @@ jobs: clean: false - name: Add Rust to the PATH - run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH + run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH" - name: Install Linux dependencies run: ./script/linux && ./script/install-mold 2.34.0 + - name: Setup Sentry CLI + uses: matbour/setup-sentry-cli@3e938c54b3018bdd019973689ef984e033b0454b #v2 + with: + token: ${{ SECRETS.SENTRY_AUTH_TOKEN }} + - name: Limit target directory size run: script/clear-target-dir-if-larger-than 100 @@ -162,12 +168,8 @@ jobs: name: Create a Linux *.tar.gz bundle for ARM if: github.repository_owner == 'zed-industries' runs-on: - - buildjet-16vcpu-ubuntu-2204-arm + - namespace-profile-32x64-ubuntu-2004-arm # ubuntu 20.04 for minimal glibc needs: tests - env: - DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} - DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} - ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }} steps: - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -177,6 +179,11 @@ jobs: - name: Install Linux dependencies run: ./script/linux + - name: Setup Sentry CLI + uses: matbour/setup-sentry-cli@3e938c54b3018bdd019973689ef984e033b0454b #v2 + with: + token: ${{ SECRETS.SENTRY_AUTH_TOKEN }} + - name: Limit target directory size run: script/clear-target-dir-if-larger-than 100 @@ -195,12 +202,9 @@ jobs: freebsd: timeout-minutes: 60 - if: github.repository_owner == 'zed-industries' + if: false && github.repository_owner == 'zed-industries' runs-on: github-8vcpu-ubuntu-2404 needs: tests - env: - DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} - DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} name: Build Zed on FreeBSD # env: # MYTOKEN : ${{ secrets.MYTOKEN }} @@ -257,8 +261,6 @@ jobs: ACCOUNT_NAME: ${{ vars.AZURE_SIGNING_ACCOUNT_NAME }} CERT_PROFILE_NAME: ${{ vars.AZURE_SIGNING_CERT_PROFILE_NAME }} ENDPOINT: ${{ vars.AZURE_SIGNING_ENDPOINT }} - DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} - DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} FILE_DIGEST: SHA256 TIMESTAMP_DIGEST: SHA256 TIMESTAMP_SERVER: "http://timestamp.acs.microsoft.com" @@ -276,8 +278,10 @@ jobs: Write-Host "Publishing version: $version on release channel nightly" "nightly" | Set-Content -Path "crates/zed/RELEASE_CHANNEL" - - name: Install trusted signing - uses: ./.github/actions/install_trusted_signing + - name: Setup Sentry CLI + uses: matbour/setup-sentry-cli@3e938c54b3018bdd019973689ef984e033b0454b #v2 + with: + token: ${{ SECRETS.SENTRY_AUTH_TOKEN }} - name: Build Zed installer working-directory: ${{ env.ZED_WORKSPACE }} diff --git a/.github/workflows/unit_evals.yml b/.github/workflows/unit_evals.yml index 705caff37afcba6cfb1f303b28559d1425147437..2e03fb028f14b7a345e33365b3ffe0ba9dbfd756 100644 --- a/.github/workflows/unit_evals.yml +++ b/.github/workflows/unit_evals.yml @@ -23,10 +23,10 @@ jobs: timeout-minutes: 60 name: Run unit evals runs-on: - - buildjet-16vcpu-ubuntu-2204 + - namespace-profile-16x32-ubuntu-2204 steps: - name: Add Rust to the PATH - run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH + run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH" - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -37,7 +37,7 @@ jobs: uses: swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2 with: save-if: ${{ github.ref == 'refs/heads/main' }} - cache-provider: "buildjet" + # cache-provider: "buildjet" - name: Install Linux dependencies run: ./script/linux diff --git a/.zed/settings.json b/.zed/settings.json index 1ef6bc28f7dffb3fd7b25489f3f6ff0c1b0f74c9..68e05a426f2474cb663aa5ff843905f375170e0f 100644 --- a/.zed/settings.json +++ b/.zed/settings.json @@ -40,7 +40,7 @@ }, "file_types": { "Dockerfile": ["Dockerfile*[!dockerignore]"], - "JSONC": ["assets/**/*.json", "renovate.json"], + "JSONC": ["**/assets/**/*.json", "renovate.json"], "Git Ignore": ["dockerignore"] }, "hard_tabs": false, diff --git a/Cargo.lock b/Cargo.lock index 81df7ea2b9e4238a65eb2ef84bee9bb52488f0af..bb6877f762d544546eacefc35540290f6fac3646 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3,13 +3,12 @@ version = 4 [[package]] -name = "acp" +name = "acp_thread" version = "0.1.0" dependencies = [ - "agent_servers", - "agentic-coding-protocol", + "agent-client-protocol", "anyhow", - "async-pipe", + "assistant_tool", "buffer_diff", "editor", "env_logger 0.11.8", @@ -18,8 +17,12 @@ dependencies = [ "indoc", "itertools 0.14.0", "language", + "language_model", "markdown", + "parking_lot", "project", + "rand 0.8.5", + "serde", "serde_json", "settings", "smol", @@ -88,6 +91,7 @@ dependencies = [ "assistant_tools", "chrono", "client", + "cloud_llm_client", "collections", "component", "context_server", @@ -111,7 +115,6 @@ dependencies = [ "pretty_assertions", "project", "prompt_store", - "proto", "rand 0.8.5", "ref-cast", "rope", @@ -130,24 +133,100 @@ dependencies = [ "uuid", "workspace", "workspace-hack", - "zed_llm_client", "zstd", ] +[[package]] +name = "agent-client-protocol" +version = "0.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fad72b7b8ee4331b3a4c8d43c107e982a4725564b4ee658ae5c4e79d2b486e8" +dependencies = [ + "anyhow", + "futures 0.3.31", + "log", + "parking_lot", + "schemars", + "serde", + "serde_json", +] + +[[package]] +name = "agent2" +version = "0.1.0" +dependencies = [ + "acp_thread", + "agent-client-protocol", + "agent_servers", + "anyhow", + "assistant_tool", + "client", + "clock", + "cloud_llm_client", + "collections", + "ctor", + "env_logger 0.11.8", + "fs", + "futures 0.3.31", + "gpui", + "gpui_tokio", + "handlebars 4.5.0", + "indoc", + "language", + "language_model", + "language_models", + "log", + "project", + "prompt_store", + "reqwest_client", + "rust-embed", + "schemars", + "serde", + "serde_json", + "settings", + "smol", + "ui", + "util", + "uuid", + "watch", + "workspace-hack", + "worktree", +] + [[package]] name = "agent_servers" version = "0.1.0" dependencies = [ + "acp_thread", + "agent-client-protocol", + "agentic-coding-protocol", "anyhow", "collections", + "context_server", + "env_logger 0.11.8", "futures 0.3.31", "gpui", + "indoc", + "itertools 0.14.0", + "language", + "libc", + "log", + "nix 0.29.0", "paths", "project", + "rand 0.8.5", "schemars", "serde", + "serde_json", "settings", + "smol", + "strum 0.27.1", + "tempfile", + "thiserror 2.0.12", + "ui", "util", + "uuid", + "watch", "which 6.0.3", "workspace-hack", ] @@ -157,6 +236,7 @@ name = "agent_settings" version = "0.1.0" dependencies = [ "anyhow", + "cloud_llm_client", "collections", "fs", "gpui", @@ -168,18 +248,19 @@ dependencies = [ "serde_json_lenient", "settings", "workspace-hack", - "zed_llm_client", ] [[package]] name = "agent_ui" version = "0.1.0" dependencies = [ - "acp", + "acp_thread", "agent", + "agent-client-protocol", + "agent2", "agent_servers", "agent_settings", - "agentic-coding-protocol", + "ai_onboarding", "anyhow", "assistant_context", "assistant_slash_command", @@ -190,7 +271,9 @@ dependencies = [ "buffer_diff", "chrono", "client", + "cloud_llm_client", "collections", + "command_palette_hooks", "component", "context_server", "db", @@ -212,6 +295,7 @@ dependencies = [ "jsonschema", "language", "language_model", + "language_models", "languages", "log", "lsp", @@ -250,6 +334,7 @@ dependencies = [ "time_format", "tree-sitter-md", "ui", + "ui_input", "unindent", "urlencoding", "util", @@ -258,21 +343,22 @@ dependencies = [ "workspace", "workspace-hack", "zed_actions", - "zed_llm_client", ] [[package]] name = "agentic-coding-protocol" -version = "0.0.6" +version = "0.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1ac0351749af7bf53c65042ef69fefb9351aa8b7efa0a813d6281377605c37d" +checksum = "a3e6ae951b36fa2f8d9dd6e1af6da2fcaba13d7c866cf6a9e65deda9dc6c5fe4" dependencies = [ "anyhow", "chrono", + "derive_more 2.0.1", "futures 0.3.31", "log", "parking_lot", "schemars", + "semver", "serde", "serde_json", ] @@ -312,6 +398,23 @@ dependencies = [ "memchr", ] +[[package]] +name = "ai_onboarding" +version = "0.1.0" +dependencies = [ + "client", + "cloud_llm_client", + "component", + "gpui", + "language_model", + "serde", + "smallvec", + "telemetry", + "ui", + "workspace-hack", + "zed_actions", +] + [[package]] name = "alacritty_terminal" version = "0.25.1-dev" @@ -642,6 +745,7 @@ dependencies = [ "chrono", "client", "clock", + "cloud_llm_client", "collections", "context_server", "fs", @@ -675,7 +779,6 @@ dependencies = [ "uuid", "workspace", "workspace-hack", - "zed_llm_client", ] [[package]] @@ -685,7 +788,7 @@ dependencies = [ "anyhow", "async-trait", "collections", - "derive_more", + "derive_more 0.99.19", "extension", "futures 0.3.31", "gpui", @@ -748,10 +851,11 @@ dependencies = [ "clock", "collections", "ctor", - "derive_more", + "derive_more 0.99.19", "futures 0.3.31", "gpui", "icons", + "indoc", "language", "language_model", "log", @@ -782,9 +886,11 @@ dependencies = [ "chrono", "client", "clock", + "cloud_llm_client", "collections", "component", - "derive_more", + "derive_more 0.99.19", + "diffy", "editor", "feature_flags", "fs", @@ -834,7 +940,6 @@ dependencies = [ "which 6.0.3", "workspace", "workspace-hack", - "zed_llm_client", "zlog", ] @@ -1028,17 +1133,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "async-recursion" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7d78656ba01f1b93024b7c3a0467f1608e4be67d725749fdcd7d2c7678fd7a2" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "async-recursion" version = "1.1.1" @@ -1132,7 +1226,7 @@ dependencies = [ "serde_json", "serde_path_to_error", "serde_qs 0.10.1", - "smart-default", + "smart-default 0.6.0", "smol_str 0.1.24", "thiserror 1.0.69", "tokio", @@ -1241,7 +1335,7 @@ version = "0.1.0" dependencies = [ "anyhow", "collections", - "derive_more", + "derive_more 0.99.19", "gpui", "parking_lot", "rodio", @@ -1331,7 +1425,7 @@ dependencies = [ "anyhow", "arrayvec", "log", - "nom", + "nom 7.1.3", "num-rational", "v_frame", ] @@ -1840,9 +1934,7 @@ version = "0.1.0" dependencies = [ "aws-smithy-runtime-api", "aws-smithy-types", - "futures 0.3.31", "http_client", - "tokio", "workspace-hack", ] @@ -2154,7 +2246,7 @@ dependencies = [ [[package]] name = "blade-graphics" version = "0.6.0" -source = "git+https://github.com/kvark/blade?rev=416375211bb0b5826b3584dccdb6a43369e499ad#416375211bb0b5826b3584dccdb6a43369e499ad" +source = "git+https://github.com/kvark/blade?rev=e0ec4e720957edd51b945b64dd85605ea54bcfe5#e0ec4e720957edd51b945b64dd85605ea54bcfe5" dependencies = [ "ash", "ash-window", @@ -2187,7 +2279,7 @@ dependencies = [ [[package]] name = "blade-macros" version = "0.3.0" -source = "git+https://github.com/kvark/blade?rev=416375211bb0b5826b3584dccdb6a43369e499ad#416375211bb0b5826b3584dccdb6a43369e499ad" +source = "git+https://github.com/kvark/blade?rev=e0ec4e720957edd51b945b64dd85605ea54bcfe5#e0ec4e720957edd51b945b64dd85605ea54bcfe5" dependencies = [ "proc-macro2", "quote", @@ -2197,7 +2289,7 @@ dependencies = [ [[package]] name = "blade-util" version = "0.2.0" -source = "git+https://github.com/kvark/blade?rev=416375211bb0b5826b3584dccdb6a43369e499ad#416375211bb0b5826b3584dccdb6a43369e499ad" +source = "git+https://github.com/kvark/blade?rev=e0ec4e720957edd51b945b64dd85605ea54bcfe5#e0ec4e720957edd51b945b64dd85605ea54bcfe5" dependencies = [ "blade-graphics", "bytemuck", @@ -2708,7 +2800,7 @@ version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" dependencies = [ - "nom", + "nom 7.1.3", ] [[package]] @@ -2927,15 +3019,16 @@ name = "client" version = "0.1.0" dependencies = [ "anyhow", - "async-recursion 0.3.2", "async-tungstenite", "base64 0.22.1", "chrono", "clock", + "cloud_api_client", + "cloud_llm_client", "cocoa 0.26.0", "collections", "credentials_provider", - "derive_more", + "derive_more 0.99.19", "feature_flags", "fs", "futures 0.3.31", @@ -2974,7 +3067,6 @@ dependencies = [ "windows 0.61.1", "workspace-hack", "worktree", - "zed_llm_client", ] [[package]] @@ -2987,6 +3079,49 @@ dependencies = [ "workspace-hack", ] +[[package]] +name = "cloud_api_client" +version = "0.1.0" +dependencies = [ + "anyhow", + "cloud_api_types", + "futures 0.3.31", + "gpui", + "gpui_tokio", + "http_client", + "parking_lot", + "serde_json", + "workspace-hack", + "yawc", +] + +[[package]] +name = "cloud_api_types" +version = "0.1.0" +dependencies = [ + "anyhow", + "chrono", + "ciborium", + "cloud_llm_client", + "pretty_assertions", + "serde", + "serde_json", + "workspace-hack", +] + +[[package]] +name = "cloud_llm_client" +version = "0.1.0" +dependencies = [ + "anyhow", + "pretty_assertions", + "serde", + "serde_json", + "strum 0.27.1", + "uuid", + "workspace-hack", +] + [[package]] name = "clru" version = "0.6.2" @@ -3113,16 +3248,18 @@ dependencies = [ "chrono", "client", "clock", + "cloud_llm_client", "collab_ui", "collections", "command_palette_hooks", "context_server", "ctor", "dap", + "dap-types", "dap_adapters", "dashmap 6.1.0", "debugger_ui", - "derive_more", + "derive_more 0.99.19", "editor", "envy", "extension", @@ -3175,6 +3312,7 @@ dependencies = [ "session", "settings", "sha2", + "smol", "sqlx", "strum 0.27.1", "subtle", @@ -3197,7 +3335,6 @@ dependencies = [ "workspace", "workspace-hack", "worktree", - "zed_llm_client", "zlog", ] @@ -3327,7 +3464,7 @@ name = "command_palette_hooks" version = "0.1.0" dependencies = [ "collections", - "derive_more", + "derive_more 0.99.19", "gpui", "workspace-hack", ] @@ -3415,12 +3552,14 @@ dependencies = [ "futures 0.3.31", "gpui", "log", + "net", "parking_lot", "postage", "schemars", "serde", "serde_json", "smol", + "tempfile", "url", "util", "workspace-hack", @@ -3472,13 +3611,13 @@ dependencies = [ "command_palette_hooks", "ctor", "dirs 4.0.0", + "edit_prediction", "editor", "fs", "futures 0.3.31", "gpui", "http_client", "indoc", - "inline_completion", "itertools 0.14.0", "language", "log", @@ -3492,6 +3631,7 @@ dependencies = [ "serde", "serde_json", "settings", + "sum_tree", "task", "theme", "ui", @@ -3645,17 +3785,6 @@ dependencies = [ "libm", ] -[[package]] -name = "coreaudio-rs" -version = "0.11.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "321077172d79c662f64f5071a03120748d5bb652f5231570141be24cfcd2bace" -dependencies = [ - "bitflags 1.3.2", - "core-foundation-sys", - "coreaudio-sys", -] - [[package]] name = "coreaudio-rs" version = "0.12.1" @@ -3713,29 +3842,6 @@ dependencies = [ "unicode-segmentation", ] -[[package]] -name = "cpal" -version = "0.15.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "873dab07c8f743075e57f524c583985fbaf745602acbe916a01539364369a779" -dependencies = [ - "alsa", - "core-foundation-sys", - "coreaudio-rs 0.11.3", - "dasp_sample", - "jni", - "js-sys", - "libc", - "mach2", - "ndk 0.8.0", - "ndk-context", - "oboe", - "wasm-bindgen", - "wasm-bindgen-futures", - "web-sys", - "windows 0.54.0", -] - [[package]] name = "cpal" version = "0.16.0" @@ -3749,7 +3855,7 @@ dependencies = [ "js-sys", "libc", "mach2", - "ndk 0.9.0", + "ndk", "ndk-context", "num-derive", "num-traits", @@ -3890,6 +3996,42 @@ dependencies = [ "target-lexicon 0.13.2", ] +[[package]] +name = "crash-context" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "031ed29858d90cfdf27fe49fae28028a1f20466db97962fa2f4ea34809aeebf3" +dependencies = [ + "cfg-if", + "libc", + "mach2", +] + +[[package]] +name = "crash-handler" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2066907075af649bcb8bcb1b9b986329b243677e6918b2d920aa64b0aac5ace3" +dependencies = [ + "cfg-if", + "crash-context", + "libc", + "mach2", + "parking_lot", +] + +[[package]] +name = "crashes" +version = "0.1.0" +dependencies = [ + "crash-handler", + "log", + "minidumper", + "paths", + "smol", + "workspace-hack", +] + [[package]] name = "crc" version = "3.2.1" @@ -4219,7 +4361,7 @@ dependencies = [ [[package]] name = "dap-types" version = "0.0.1" -source = "git+https://github.com/zed-industries/dap-types?rev=7f39295b441614ca9dbf44293e53c32f666897f9#7f39295b441614ca9dbf44293e53c32f666897f9" +source = "git+https://github.com/zed-industries/dap-types?rev=1b461b310481d01e02b2603c16d7144b926339f8#1b461b310481d01e02b2603c16d7144b926339f8" dependencies = [ "schemars", "serde", @@ -4245,46 +4387,12 @@ dependencies = [ "serde", "serde_json", "shlex", + "smol", "task", "util", "workspace-hack", ] -[[package]] -name = "darling" -version = "0.20.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" -dependencies = [ - "darling_core", - "darling_macro", -] - -[[package]] -name = "darling_core" -version = "0.20.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e" -dependencies = [ - "fnv", - "ident_case", - "proc-macro2", - "quote", - "strsim", - "syn 2.0.101", -] - -[[package]] -name = "darling_macro" -version = "0.20.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" -dependencies = [ - "darling_core", - "quote", - "syn 2.0.101", -] - [[package]] name = "dashmap" version = "5.5.3" @@ -4411,17 +4519,21 @@ dependencies = [ "futures 0.3.31", "fuzzy", "gpui", + "hex", "indoc", "itertools 0.14.0", "language", "log", "menu", + "notifications", "parking_lot", + "parse_int", "paths", "picker", "pretty_assertions", "project", "rpc", + "schemars", "serde", "serde_json", "serde_json_lenient", @@ -4446,6 +4558,15 @@ dependencies = [ "zlog", ] +[[package]] +name = "debugid" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef552e6f588e446098f6ba40d89ac146c8c7b64aade83c051ee00bb5d2bc18d" +dependencies = [ + "uuid", +] + [[package]] name = "deepseek" version = "0.1.0" @@ -4497,47 +4618,37 @@ dependencies = [ ] [[package]] -name = "derive_builder" -version = "0.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" -dependencies = [ - "derive_builder_macro", -] - -[[package]] -name = "derive_builder_core" -version = "0.20.2" +name = "derive_more" +version = "0.99.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" +checksum = "3da29a38df43d6f156149c9b43ded5e018ddff2a855cf2cfd62e8cd7d079c69f" dependencies = [ - "darling", + "convert_case 0.4.0", "proc-macro2", "quote", + "rustc_version", "syn 2.0.101", ] [[package]] -name = "derive_builder_macro" -version = "0.20.2" +name = "derive_more" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" +checksum = "093242cf7570c207c83073cf82f79706fe7b8317e98620a47d5be7c3d8497678" dependencies = [ - "derive_builder_core", - "syn 2.0.101", + "derive_more-impl", ] [[package]] -name = "derive_more" -version = "0.99.19" +name = "derive_more-impl" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3da29a38df43d6f156149c9b43ded5e018ddff2a855cf2cfd62e8cd7d079c69f" +checksum = "bda628edc44c4bb645fbe0f758797143e4e07926f7ebf4e9bdfbd3d2ce621df3" dependencies = [ - "convert_case 0.4.0", "proc-macro2", "quote", - "rustc_version", "syn 2.0.101", + "unicode-xid", ] [[package]] @@ -4727,7 +4838,6 @@ name = "docs_preprocessor" version = "0.1.0" dependencies = [ "anyhow", - "clap", "command_palette", "gpui", "mdbook", @@ -4738,6 +4848,7 @@ dependencies = [ "util", "workspace-hack", "zed", + "zlog", ] [[package]] @@ -4859,6 +4970,49 @@ dependencies = [ "signature 1.6.4", ] +[[package]] +name = "edit_prediction" +version = "0.1.0" +dependencies = [ + "client", + "gpui", + "language", + "project", + "workspace-hack", +] + +[[package]] +name = "edit_prediction_button" +version = "0.1.0" +dependencies = [ + "anyhow", + "client", + "cloud_llm_client", + "copilot", + "edit_prediction", + "editor", + "feature_flags", + "fs", + "futures 0.3.31", + "gpui", + "indoc", + "language", + "lsp", + "paths", + "project", + "regex", + "serde_json", + "settings", + "supermaven", + "telemetry", + "theme", + "ui", + "workspace", + "workspace-hack", + "zed_actions", + "zeta", +] + [[package]] name = "editor" version = "0.1.0" @@ -4874,6 +5028,7 @@ dependencies = [ "ctor", "dap", "db", + "edit_prediction", "emojis", "file_icons", "fs", @@ -4883,7 +5038,6 @@ dependencies = [ "gpui", "http_client", "indoc", - "inline_completion", "itertools 0.14.0", "language", "languages", @@ -4915,6 +5069,8 @@ dependencies = [ "text", "theme", "time", + "tree-sitter-bash", + "tree-sitter-c", "tree-sitter-html", "tree-sitter-python", "tree-sitter-rust", @@ -5197,6 +5353,7 @@ dependencies = [ "chrono", "clap", "client", + "cloud_llm_client", "collections", "debug_adapter_extension", "dirs 4.0.0", @@ -5236,7 +5393,6 @@ dependencies = [ "uuid", "watch", "workspace-hack", - "zed_llm_client", ] [[package]] @@ -5301,6 +5457,12 @@ dependencies = [ "zune-inflate", ] +[[package]] +name = "extended" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af9673d8203fcb076b19dfd17e38b3d4ae9f44959416ea532ce72415a6020365" + [[package]] name = "extension" version = "0.1.0" @@ -5320,11 +5482,13 @@ dependencies = [ "log", "lsp", "parking_lot", + "pretty_assertions", "semantic_version", "serde", "serde_json", "task", "toml 0.8.20", + "url", "util", "wasm-encoder 0.221.3", "wasmparser 0.221.3", @@ -5875,7 +6039,7 @@ dependencies = [ "ignore", "libc", "log", - "notify", + "notify 8.0.0", "objc", "parking_lot", "paths", @@ -6238,7 +6402,7 @@ dependencies = [ "askpass", "async-trait", "collections", - "derive_more", + "derive_more 0.99.19", "futures 0.3.31", "git2", "gpui", @@ -6309,6 +6473,7 @@ dependencies = [ "buffer_diff", "call", "chrono", + "cloud_llm_client", "collections", "command_palette_hooks", "component", @@ -6319,6 +6484,7 @@ dependencies = [ "fuzzy", "git", "gpui", + "indoc", "itertools 0.14.0", "language", "language_model", @@ -6351,7 +6517,6 @@ dependencies = [ "workspace", "workspace-hack", "zed_actions", - "zed_llm_client", "zlog", ] @@ -7184,6 +7349,17 @@ dependencies = [ "workspace-hack", ] +[[package]] +name = "goblin" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b363a30c165f666402fe6a3024d3bec7ebc898f96a4a23bd1c99f8dbf3f4f47" +dependencies = [ + "log", + "plain", + "scroll", +] + [[package]] name = "google_ai" version = "0.1.0" @@ -7255,7 +7431,7 @@ dependencies = [ "core-video", "cosmic-text", "ctor", - "derive_more", + "derive_more 0.99.19", "embed-resource", "env_logger 0.11.8", "etagere", @@ -7350,9 +7526,9 @@ dependencies = [ [[package]] name = "grid" -version = "0.13.0" +version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d196ffc1627db18a531359249b2bf8416178d84b729f3cebeb278f285fb9b58c" +checksum = "12101ecc8225ea6d675bc70263074eab6169079621c2186fe0c66590b2df9681" [[package]] name = "group" @@ -7431,18 +7607,16 @@ dependencies = [ [[package]] name = "handlebars" -version = "6.3.2" +version = "5.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "759e2d5aea3287cb1190c8ec394f42866cb5bf74fcbf213f354e3c856ea26098" +checksum = "d08485b96a0e6393e9e4d1b8d48cf74ad6c063cd905eb33f42c1ce3f0377539b" dependencies = [ - "derive_builder", "log", - "num-order", "pest", "pest_derive", "serde", "serde_json", - "thiserror 2.0.12", + "thiserror 1.0.69", ] [[package]] @@ -7667,12 +7841,6 @@ dependencies = [ "windows-sys 0.59.0", ] -[[package]] -name = "hound" -version = "3.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62adaabb884c94955b19907d60019f4e145d091c75345379e70d1ee696f7854f" - [[package]] name = "html5ever" version = "0.27.0" @@ -7801,10 +7969,13 @@ version = "0.1.0" dependencies = [ "anyhow", "bytes 1.10.1", - "derive_more", + "derive_more 0.99.19", "futures 0.3.31", "http 1.3.1", + "http-body 1.0.1", "log", + "parking_lot", + "reqwest 0.12.15 (git+https://github.com/zed-industries/reqwest.git?rev=951c770a32f1998d6e999cef3e59e0013e6c4415)", "serde", "serde_json", "url", @@ -8112,12 +8283,6 @@ version = "2.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "25a2bc672d1148e28034f176e01fffebb08b35768468cc954630da77a1449005" -[[package]] -name = "ident_case" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" - [[package]] name = "idna" version = "1.0.3" @@ -8239,7 +8404,7 @@ dependencies = [ "async-trait", "cargo_metadata", "collections", - "derive_more", + "derive_more 0.99.19", "extension", "fs", "futures 0.3.31", @@ -8294,46 +8459,14 @@ dependencies = [ ] [[package]] -name = "inline_completion" -version = "0.1.0" -dependencies = [ - "client", - "gpui", - "language", - "project", - "workspace-hack", -] - -[[package]] -name = "inline_completion_button" -version = "0.1.0" +name = "inotify" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8069d3ec154eb856955c1c0fbffefbf5f3c40a104ec912d4797314c1801abff" dependencies = [ - "anyhow", - "client", - "copilot", - "editor", - "feature_flags", - "fs", - "futures 0.3.31", - "gpui", - "indoc", - "inline_completion", - "language", - "lsp", - "paths", - "project", - "regex", - "serde_json", - "settings", - "supermaven", - "telemetry", - "theme", - "ui", - "workspace", - "workspace-hack", - "zed_actions", - "zed_llm_client", - "zeta", + "bitflags 1.3.2", + "inotify-sys", + "libc", ] [[package]] @@ -8489,7 +8622,7 @@ dependencies = [ "fnv", "lazy_static", "libc", - "mio", + "mio 1.0.3", "rand 0.8.5", "serde", "tempfile", @@ -8966,6 +9099,7 @@ dependencies = [ "task", "text", "theme", + "toml 0.8.20", "tree-sitter", "tree-sitter-elixir", "tree-sitter-embedded-template", @@ -9013,6 +9147,8 @@ dependencies = [ "anyhow", "base64 0.22.1", "client", + "cloud_api_types", + "cloud_llm_client", "collections", "futures 0.3.31", "gpui", @@ -9030,13 +9166,13 @@ dependencies = [ "thiserror 2.0.12", "util", "workspace-hack", - "zed_llm_client", ] [[package]] name = "language_models" version = "0.1.0" dependencies = [ + "ai_onboarding", "anthropic", "anyhow", "aws-config", @@ -9045,14 +9181,14 @@ dependencies = [ "bedrock", "chrono", "client", + "cloud_llm_client", "collections", "component", + "convert_case 0.8.0", "copilot", "credentials_provider", "deepseek", "editor", - "feature_flags", - "fs", "futures 0.3.31", "google_ai", "gpui", @@ -9069,7 +9205,6 @@ dependencies = [ "open_router", "partial-json-fixer", "project", - "proto", "release_channel", "schemars", "serde", @@ -9086,7 +9221,7 @@ dependencies = [ "util", "vercel", "workspace-hack", - "zed_llm_client", + "x_ai", ] [[package]] @@ -9118,7 +9253,6 @@ dependencies = [ "collections", "copilot", "editor", - "feature_flags", "futures 0.3.31", "gpui", "itertools 0.14.0", @@ -9144,11 +9278,13 @@ version = "0.1.0" dependencies = [ "anyhow", "async-compression", + "async-fs", "async-tar", "async-trait", "chrono", "collections", "dap", + "feature_flags", "futures 0.3.31", "gpui", "http_client", @@ -9174,9 +9310,11 @@ dependencies = [ "serde_json", "serde_json_lenient", "settings", + "sha2", "smol", "snippet_provider", "task", + "tempfile", "text", "theme", "toml 0.8.20", @@ -9340,7 +9478,7 @@ dependencies = [ "libc", "libspa-sys", "nix 0.27.1", - "nom", + "nom 7.1.3", "system-deps", ] @@ -9369,7 +9507,7 @@ dependencies = [ [[package]] name = "libwebrtc" version = "0.3.10" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" dependencies = [ "cxx", "jni", @@ -9449,7 +9587,7 @@ checksum = "23fb14cb19457329c82206317a5663005a4d404783dc74f4252769b0d5f42856" [[package]] name = "livekit" version = "0.7.8" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" dependencies = [ "chrono", "futures-util", @@ -9472,7 +9610,7 @@ dependencies = [ [[package]] name = "livekit-api" version = "0.4.2" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" dependencies = [ "futures-util", "http 0.2.12", @@ -9496,7 +9634,7 @@ dependencies = [ [[package]] name = "livekit-protocol" version = "0.3.9" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" dependencies = [ "futures-util", "livekit-runtime", @@ -9513,7 +9651,7 @@ dependencies = [ [[package]] name = "livekit-runtime" version = "0.4.0" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" dependencies = [ "tokio", "tokio-stream", @@ -9545,7 +9683,7 @@ dependencies = [ "core-foundation 0.10.0", "core-video", "coreaudio-rs 0.12.1", - "cpal 0.16.0", + "cpal", "futures 0.3.31", "gpui", "gpui_tokio", @@ -9596,9 +9734,9 @@ dependencies = [ [[package]] name = "lock_api" -version = "0.4.12" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765" dependencies = [ "autocfg", "scopeguard", @@ -9706,12 +9844,11 @@ dependencies = [ [[package]] name = "lsp-types" version = "0.95.1" -source = "git+https://github.com/zed-industries/lsp-types?rev=c9c189f1c5dd53c624a419ce35bc77ad6a908d18#c9c189f1c5dd53c624a419ce35bc77ad6a908d18" +source = "git+https://github.com/zed-industries/lsp-types?rev=39f629bdd03d59abd786ed9fc27e8bca02c0c0ec#39f629bdd03d59abd786ed9fc27e8bca02c0c0ec" dependencies = [ "bitflags 1.3.2", "serde", "serde_json", - "serde_repr", "url", ] @@ -9836,7 +9973,7 @@ name = "markdown_preview" version = "0.1.0" dependencies = [ "anyhow", - "async-recursion 1.1.1", + "async-recursion", "collections", "editor", "fs", @@ -9956,9 +10093,9 @@ dependencies = [ [[package]] name = "mdbook" -version = "0.4.48" +version = "0.4.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b6fbb4ac2d9fd7aa987c3510309ea3c80004a968d063c42f0d34fea070817c1" +checksum = "b45a38e19bd200220ef07c892b0157ad3d2365e5b5a267ca01ad12182491eea5" dependencies = [ "ammonia", "anyhow", @@ -9968,12 +10105,11 @@ dependencies = [ "elasticlunr-rs", "env_logger 0.11.8", "futures-util", - "handlebars 6.3.2", - "hex", + "handlebars 5.1.2", "ignore", "log", "memchr", - "notify", + "notify 6.1.1", "notify-debouncer-mini", "once_cell", "opener", @@ -9982,7 +10118,6 @@ dependencies = [ "regex", "serde", "serde_json", - "sha2", "shlex", "tempfile", "tokio", @@ -10103,6 +10238,63 @@ dependencies = [ "unicase", ] +[[package]] +name = "minidump-common" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c4d14bcca0fd3ed165a03000480aaa364c6860c34e900cb2dafdf3b95340e77" +dependencies = [ + "bitflags 2.9.0", + "debugid", + "num-derive", + "num-traits", + "range-map", + "scroll", + "smart-default 0.7.1", +] + +[[package]] +name = "minidump-writer" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abcd9c8a1e6e1e9d56ce3627851f39a17ea83e17c96bc510f29d7e43d78a7d" +dependencies = [ + "bitflags 2.9.0", + "byteorder", + "cfg-if", + "crash-context", + "goblin", + "libc", + "log", + "mach2", + "memmap2", + "memoffset", + "minidump-common", + "nix 0.28.0", + "procfs-core", + "scroll", + "tempfile", + "thiserror 1.0.69", +] + +[[package]] +name = "minidumper" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4ebc9d1f8847ec1d078f78b35ed598e0ebefa1f242d5f83cd8d7f03960a7d1" +dependencies = [ + "cfg-if", + "crash-context", + "libc", + "log", + "minidump-writer", + "parking_lot", + "polling", + "scroll", + "thiserror 1.0.69", + "uds", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -10125,6 +10317,18 @@ version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e53debba6bda7a793e5f99b8dacf19e626084f525f7829104ba9898f367d85ff" +[[package]] +name = "mio" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" +dependencies = [ + "libc", + "log", + "wasi 0.11.0+wasi-snapshot-preview1", + "windows-sys 0.48.0", +] + [[package]] name = "mio" version = "1.0.3" @@ -10307,17 +10511,14 @@ dependencies = [ ] [[package]] -name = "ndk" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2076a31b7010b17a38c01907c45b945e8f11495ee4dd588309718901b1f7a5b7" +name = "nc" +version = "0.1.0" dependencies = [ - "bitflags 2.9.0", - "jni-sys", - "log", - "ndk-sys 0.5.0+25.2.9519653", - "num_enum", - "thiserror 1.0.69", + "anyhow", + "futures 0.3.31", + "net", + "smol", + "workspace-hack", ] [[package]] @@ -10329,7 +10530,7 @@ dependencies = [ "bitflags 2.9.0", "jni-sys", "log", - "ndk-sys 0.6.0+11769913", + "ndk-sys", "num_enum", "thiserror 1.0.69", ] @@ -10340,15 +10541,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "27b02d87554356db9e9a873add8782d4ea6e3e58ea071a9adb9a2e8ddb884a8b" -[[package]] -name = "ndk-sys" -version = "0.5.0+25.2.9519653" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c196769dd60fd4f363e11d948139556a344e79d451aeb2fa2fd040738ef7691" -dependencies = [ - "jni-sys", -] - [[package]] name = "ndk-sys" version = "0.6.0+11769913" @@ -10457,6 +10649,15 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nom" +version = "8.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df9761775871bdef83bee530e60050f7e54b1105350d6884eb0fb4f46c2f9405" +dependencies = [ + "memchr", +] + [[package]] name = "noop_proc_macro" version = "0.3.0" @@ -10494,6 +10695,25 @@ dependencies = [ "zed_actions", ] +[[package]] +name = "notify" +version = "6.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6205bd8bb1e454ad2e27422015fb5e4f2bcc7e08fa8f27058670d208324a4d2d" +dependencies = [ + "bitflags 2.9.0", + "crossbeam-channel", + "filetime", + "fsevent-sys 4.1.0", + "inotify 0.9.6", + "kqueue", + "libc", + "log", + "mio 0.8.11", + "walkdir", + "windows-sys 0.48.0", +] + [[package]] name = "notify" version = "8.0.0" @@ -10502,11 +10722,11 @@ dependencies = [ "bitflags 2.9.0", "filetime", "fsevent-sys 4.1.0", - "inotify", + "inotify 0.11.0", "kqueue", "libc", "log", - "mio", + "mio 1.0.3", "notify-types", "walkdir", "windows-sys 0.59.0", @@ -10514,14 +10734,13 @@ dependencies = [ [[package]] name = "notify-debouncer-mini" -version = "0.6.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a689eb4262184d9a1727f9087cd03883ea716682ab03ed24efec57d7716dccb8" +checksum = "5d40b221972a1fc5ef4d858a2f671fb34c75983eb385463dff3780eeff6a9d43" dependencies = [ + "crossbeam-channel", "log", - "notify", - "notify-types", - "tempfile", + "notify 6.1.1", ] [[package]] @@ -10661,21 +10880,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "num-modular" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17bb261bf36fa7d83f4c294f834e91256769097b3cb505d44831e0a179ac647f" - -[[package]] -name = "num-order" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "537b596b97c40fcf8056d153049eb22f481c17ebce72a513ec9286e4986d1bb6" -dependencies = [ - "num-modular", -] - [[package]] name = "num-rational" version = "0.4.2" @@ -10930,39 +11134,52 @@ dependencies = [ ] [[package]] -name = "oboe" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8b61bebd49e5d43f5f8cc7ee2891c16e0f41ec7954d36bcb6c14c5e0de867fb" -dependencies = [ - "jni", - "ndk 0.8.0", - "ndk-context", - "num-derive", - "num-traits", - "oboe-sys", -] - -[[package]] -name = "oboe-sys" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c8bb09a4a2b1d668170cfe0a7d5bc103f8999fb316c98099b6a9939c9f2e79d" +name = "ollama" +version = "0.1.0" dependencies = [ - "cc", + "anyhow", + "futures 0.3.31", + "http_client", + "schemars", + "serde", + "serde_json", + "workspace-hack", ] [[package]] -name = "ollama" +name = "onboarding" version = "0.1.0" dependencies = [ + "ai_onboarding", "anyhow", - "futures 0.3.31", - "http_client", + "client", + "command_palette_hooks", + "component", + "db", + "documented", + "editor", + "feature_flags", + "fs", + "fuzzy", + "gpui", + "itertools 0.14.0", + "language", + "language_model", + "menu", + "notifications", + "picker", + "project", "schemars", "serde", - "serde_json", + "settings", + "theme", + "ui", + "util", + "vim_mode_setting", + "workspace", "workspace-hack", + "zed_actions", + "zlog", ] [[package]] @@ -11313,9 +11530,9 @@ checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" [[package]] name = "parking_lot" -version = "0.12.3" +version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +checksum = "70d58bf43669b5795d1576d0641cfb6fbb2057bf629506267a92807158584a13" dependencies = [ "lock_api", "parking_lot_core", @@ -11323,9 +11540,9 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.10" +version = "0.9.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5" dependencies = [ "cfg-if", "libc", @@ -11334,6 +11551,15 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "parse_int" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c464266693329dd5a8715098c7f86e6c5fd5d985018b8318f53d9c6c2b21a31" +dependencies = [ + "num-traits", +] + [[package]] name = "partial-json-fixer" version = "0.5.3" @@ -12108,6 +12334,12 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "plain" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6" + [[package]] name = "plist" version = "1.7.1" @@ -12368,6 +12600,16 @@ dependencies = [ "yansi", ] +[[package]] +name = "procfs-core" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d3554923a69f4ce04c4a754260c338f505ce22642d3830e049a399fc2059a29" +dependencies = [ + "bitflags 2.9.0", + "hex", +] + [[package]] name = "prodash" version = "29.0.2" @@ -12405,6 +12647,7 @@ dependencies = [ "anyhow", "askpass", "async-trait", + "base64 0.22.1", "buffer_diff", "circular-buffer", "client", @@ -12450,6 +12693,7 @@ dependencies = [ "sha2", "shellexpand 2.1.2", "shlex", + "smallvec", "smol", "snippet", "snippet_provider", @@ -12480,6 +12724,7 @@ dependencies = [ "editor", "file_icons", "git", + "git_ui", "gpui", "indexmap", "language", @@ -13016,6 +13261,15 @@ dependencies = [ "rand_core 0.5.1", ] +[[package]] +name = "range-map" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12a5a2d6c7039059af621472a4389be1215a816df61aa4d531cfe85264aee95f" +dependencies = [ + "num-traits", +] + [[package]] name = "rangemap" version = "1.5.1" @@ -13358,6 +13612,8 @@ dependencies = [ "clap", "client", "clock", + "crash-handler", + "crashes", "dap", "dap_adapters", "debug_adapter_extension", @@ -13381,6 +13637,7 @@ dependencies = [ "libc", "log", "lsp", + "minidumper", "node_runtime", "paths", "project", @@ -13569,6 +13826,7 @@ dependencies = [ "js-sys", "log", "mime", + "mime_guess", "once_cell", "percent-encoding", "pin-project-lite", @@ -13730,12 +13988,15 @@ dependencies = [ [[package]] name = "rodio" -version = "0.20.1" +version = "0.21.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7ceb6607dd738c99bc8cb28eff249b7cd5c8ec88b9db96c0608c1480d140fb1" +checksum = "e40ecf59e742e03336be6a3d53755e789fd05a059fa22dfa0ed624722319e183" dependencies = [ - "cpal 0.15.3", - "hound", + "cpal", + "dasp_sample", + "num-rational", + "symphonia", + "tracing", ] [[package]] @@ -14184,7 +14445,7 @@ dependencies = [ [[package]] name = "scap" version = "0.0.8" -source = "git+https://github.com/zed-industries/scap?rev=28dd306ff2e3374404936dec778fc1e975b8dd12#28dd306ff2e3374404936dec778fc1e975b8dd12" +source = "git+https://github.com/zed-industries/scap?rev=808aa5c45b41e8f44729d02e38fd00a2fe2722e7#808aa5c45b41e8f44729d02e38fd00a2fe2722e7" dependencies = [ "anyhow", "cocoa 0.25.0", @@ -14238,6 +14499,7 @@ dependencies = [ "indexmap", "ref-cast", "schemars_derive", + "semver", "serde", "serde_json", ] @@ -14295,6 +14557,26 @@ dependencies = [ "once_cell", ] +[[package]] +name = "scroll" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ab8598aa408498679922eff7fa985c25d58a90771bd6be794434c5277eab1a6" +dependencies = [ + "scroll_derive", +] + +[[package]] +name = "scroll_derive" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1783eabc414609e28a5ba76aee5ddd52199f7107a0b24c2e9746a1ecc34a683d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "scrypt" version = "0.11.0" @@ -14741,6 +15023,27 @@ dependencies = [ "zlog", ] +[[package]] +name = "settings_profile_selector" +version = "0.1.0" +dependencies = [ + "client", + "editor", + "fuzzy", + "gpui", + "language", + "menu", + "picker", + "project", + "serde_json", + "settings", + "theme", + "ui", + "workspace", + "workspace-hack", + "zed_actions", +] + [[package]] name = "settings_ui" version = "0.1.0" @@ -14756,19 +15059,24 @@ dependencies = [ "fs", "fuzzy", "gpui", + "itertools 0.14.0", "language", "log", "menu", + "notifications", "paths", "project", - "schemars", "search", "serde", + "serde_json", "settings", + "telemetry", + "tempfile", "theme", "tree-sitter-json", "tree-sitter-rust", "ui", + "ui_input", "util", "workspace", "workspace-hack", @@ -15014,6 +15322,17 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "smart-default" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eb01866308440fc64d6c44d9e86c5cc17adfe33c4d6eed55da9145044d0ffc1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "smol" version = "2.0.2" @@ -15191,7 +15510,7 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7bba3a93db0cc4f7bdece8bb09e77e2e785c20bfebf79eb8340ed80708048790" dependencies = [ - "nom", + "nom 7.1.3", "unicode_categories", ] @@ -15596,12 +15915,12 @@ dependencies = [ "anyhow", "client", "collections", + "edit_prediction", "editor", "env_logger 0.11.8", "futures 0.3.31", "gpui", "http_client", - "inline_completion", "language", "log", "postage", @@ -15751,6 +16070,66 @@ dependencies = [ "zeno", ] +[[package]] +name = "symphonia" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "815c942ae7ee74737bb00f965fa5b5a2ac2ce7b6c01c0cc169bbeaf7abd5f5a9" +dependencies = [ + "lazy_static", + "symphonia-codec-pcm", + "symphonia-core", + "symphonia-format-riff", + "symphonia-metadata", +] + +[[package]] +name = "symphonia-codec-pcm" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f395a67057c2ebc5e84d7bb1be71cce1a7ba99f64e0f0f0e303a03f79116f89b" +dependencies = [ + "log", + "symphonia-core", +] + +[[package]] +name = "symphonia-core" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "798306779e3dc7d5231bd5691f5a813496dc79d3f56bf82e25789f2094e022c3" +dependencies = [ + "arrayvec", + "bitflags 1.3.2", + "bytemuck", + "lazy_static", + "log", +] + +[[package]] +name = "symphonia-format-riff" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f7be232f962f937f4b7115cbe62c330929345434c834359425e043bfd15f50" +dependencies = [ + "extended", + "log", + "symphonia-core", + "symphonia-metadata", +] + +[[package]] +name = "symphonia-metadata" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc622b9841a10089c5b18e99eb904f4341615d5aa55bbf4eedde1be721a4023c" +dependencies = [ + "encoding_rs", + "lazy_static", + "log", + "symphonia-core", +] + [[package]] name = "syn" version = "1.0.109" @@ -15931,13 +16310,12 @@ dependencies = [ [[package]] name = "taffy" -version = "0.4.4" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ec17858c2d465b2f734b798b920818a974faf0babb15d7fef81818a4b2d16f1" +checksum = "a13e5d13f79d558b5d353a98072ca8ca0e99da429467804de959aa8c83c9a004" dependencies = [ "arrayvec", "grid", - "num-traits", "serde", "slotmap", ] @@ -16135,7 +16513,7 @@ version = "0.1.0" dependencies = [ "anyhow", "assistant_slash_command", - "async-recursion 1.1.1", + "async-recursion", "breadcrumbs", "client", "collections", @@ -16196,7 +16574,7 @@ version = "0.1.0" dependencies = [ "anyhow", "collections", - "derive_more", + "derive_more 0.99.19", "fs", "futures 0.3.31", "gpui", @@ -16335,9 +16713,8 @@ dependencies = [ [[package]] name = "tiktoken-rs" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25563eeba904d770acf527e8b370fe9a5547bacd20ff84a0b6c3bc41288e5625" +version = "0.8.0" +source = "git+https://github.com/zed-industries/tiktoken-rs?rev=30c32a4522751699adeda0d5840c71c3b75ae73d#30c32a4522751699adeda0d5840c71c3b75ae73d" dependencies = [ "anyhow", "base64 0.22.1", @@ -16479,10 +16856,12 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" name = "title_bar" version = "0.1.0" dependencies = [ + "anyhow", "auto_update", "call", "chrono", "client", + "cloud_llm_client", "collections", "db", "gpui", @@ -16495,6 +16874,7 @@ dependencies = [ "schemars", "serde", "settings", + "settings_ui", "smallvec", "story", "telemetry", @@ -16517,7 +16897,7 @@ dependencies = [ "backtrace", "bytes 1.10.1", "libc", - "mio", + "mio 1.0.3", "parking_lot", "pin-project-lite", "signal-hook-registry", @@ -17235,6 +17615,15 @@ version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" +[[package]] +name = "uds" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "885c31f06fce836457fe3ef09a59f83fe8db95d270b11cd78f40a4666c4d1661" +dependencies = [ + "libc", +] + [[package]] name = "uds_windows" version = "1.1.0" @@ -18450,11 +18839,11 @@ name = "web_search" version = "0.1.0" dependencies = [ "anyhow", + "cloud_llm_client", "collections", "gpui", "serde", "workspace-hack", - "zed_llm_client", ] [[package]] @@ -18463,7 +18852,7 @@ version = "0.1.0" dependencies = [ "anyhow", "client", - "feature_flags", + "cloud_llm_client", "futures 0.3.31", "gpui", "http_client", @@ -18472,7 +18861,6 @@ dependencies = [ "serde_json", "web_search", "workspace-hack", - "zed_llm_client", ] [[package]] @@ -18496,7 +18884,7 @@ dependencies = [ [[package]] name = "webrtc-sys" version = "0.3.7" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" dependencies = [ "cc", "cxx", @@ -18509,7 +18897,7 @@ dependencies = [ [[package]] name = "webrtc-sys-build" version = "0.3.6" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" dependencies = [ "fs2", "regex", @@ -18544,7 +18932,6 @@ dependencies = [ "serde", "settings", "telemetry", - "theme", "ui", "util", "vim_mode_setting", @@ -18724,8 +19111,7 @@ dependencies = [ [[package]] name = "windows-capture" version = "1.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59d10b4be8b907c7055bc7270dd68d2b920978ffacc1599dcb563a79f0e68d16" +source = "git+https://github.com/zed-industries/windows-capture.git?rev=f0d6c1b6691db75461b732f6d5ff56eed002eeb9#f0d6c1b6691db75461b732f6d5ff56eed002eeb9" dependencies = [ "clap", "ctrlc", @@ -19560,7 +19946,7 @@ version = "0.1.0" dependencies = [ "any_vec", "anyhow", - "async-recursion 1.1.1", + "async-recursion", "bincode", "call", "client", @@ -19637,14 +20023,12 @@ dependencies = [ "cc", "chrono", "cipher", - "clang-sys", "clap", "clap_builder", "codespan-reporting 0.12.0", "concurrent-queue", "core-foundation 0.9.4", "core-foundation-sys", - "coreaudio-sys", "cranelift-codegen", "crc32fast", "crossbeam-epoch", @@ -19695,11 +20079,13 @@ dependencies = [ "lyon_path", "md-5", "memchr", + "mime_guess", "miniz_oxide", - "mio", + "mio 1.0.3", "naga", + "nix 0.28.0", "nix 0.29.0", - "nom", + "nom 7.1.3", "num-bigint", "num-bigint-dig", "num-integer", @@ -19782,9 +20168,7 @@ dependencies = [ "wasmtime-cranelift", "wasmtime-environ", "winapi", - "windows 0.61.1", "windows-core 0.61.0", - "windows-future", "windows-numerics", "windows-sys 0.48.0", "windows-sys 0.52.0", @@ -19890,6 +20274,17 @@ version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec107c4503ea0b4a98ef47356329af139c0a4f7750e621cf2973cd3385ebcb3d" +[[package]] +name = "x_ai" +version = "0.1.0" +dependencies = [ + "anyhow", + "schemars", + "serde", + "strum 0.27.1", + "workspace-hack", +] + [[package]] name = "xattr" version = "0.2.3" @@ -20034,6 +20429,34 @@ dependencies = [ "winapi", ] +[[package]] +name = "yawc" +version = "0.2.4" +source = "git+https://github.com/deviant-forks/yawc?rev=1899688f3e69ace4545aceb97b2a13881cf26142#1899688f3e69ace4545aceb97b2a13881cf26142" +dependencies = [ + "base64 0.22.1", + "bytes 1.10.1", + "flate2", + "futures 0.3.31", + "http-body-util", + "hyper 1.6.0", + "hyper-util", + "js-sys", + "nom 8.0.0", + "pin-project", + "rand 0.8.5", + "sha1", + "thiserror 1.0.69", + "tokio", + "tokio-rustls 0.26.2", + "tokio-util", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "webpki-roots", +] + [[package]] name = "yazi" version = "0.2.1" @@ -20087,7 +20510,7 @@ dependencies = [ "async-io", "async-lock", "async-process", - "async-recursion 1.1.1", + "async-recursion", "async-task", "async-trait", "blocking", @@ -20140,7 +20563,7 @@ dependencies = [ [[package]] name = "zed" -version = "0.196.0" +version = "0.200.0" dependencies = [ "activity_indicator", "agent", @@ -20169,6 +20592,7 @@ dependencies = [ "command_palette", "component", "copilot", + "crashes", "dap", "dap_adapters", "db", @@ -20176,11 +20600,13 @@ dependencies = [ "debugger_tools", "debugger_ui", "diagnostics", + "edit_prediction_button", "editor", "env_logger 0.11.8", "extension", "extension_host", "extensions_ui", + "feature_flags", "feedback", "file_finder", "fs", @@ -20194,7 +20620,6 @@ dependencies = [ "http_client", "image_viewer", "indoc", - "inline_completion_button", "inspector_ui", "install_cli", "itertools 0.14.0", @@ -20214,9 +20639,11 @@ dependencies = [ "menu", "migrator", "mimalloc", + "nc", "nix 0.29.0", "node_runtime", "notifications", + "onboarding", "outline", "outline_panel", "parking_lot", @@ -20233,6 +20660,7 @@ dependencies = [ "release_channel", "remote", "repl", + "reqwest 0.12.15 (git+https://github.com/zed-industries/reqwest.git?rev=951c770a32f1998d6e999cef3e59e0013e6c4415)", "reqwest_client", "rope", "search", @@ -20240,6 +20668,7 @@ dependencies = [ "serde_json", "session", "settings", + "settings_profile_selector", "settings_ui", "shellexpand 2.1.2", "smol", @@ -20298,7 +20727,7 @@ dependencies = [ [[package]] name = "zed_emmet" -version = "0.0.3" +version = "0.0.4" dependencies = [ "zed_extension_api 0.1.0", ] @@ -20337,19 +20766,6 @@ dependencies = [ "zed_extension_api 0.1.0", ] -[[package]] -name = "zed_llm_client" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6607f74dee2a18a9ce0f091844944a0e59881359ab62e0768fb0618f55d4c1dc" -dependencies = [ - "anyhow", - "serde", - "serde_json", - "strum 0.27.1", - "uuid", -] - [[package]] name = "zed_proto" version = "0.2.2" @@ -20359,7 +20775,7 @@ dependencies = [ [[package]] name = "zed_ruff" -version = "0.1.0" +version = "0.1.1" dependencies = [ "zed_extension_api 0.1.0", ] @@ -20523,15 +20939,20 @@ dependencies = [ name = "zeta" version = "0.1.0" dependencies = [ + "ai_onboarding", "anyhow", "arrayvec", "call", "client", "clock", + "cloud_api_types", + "cloud_llm_client", "collections", "command_palette_hooks", + "copilot", "ctor", "db", + "edit_prediction", "editor", "feature_flags", "fs", @@ -20539,16 +20960,12 @@ dependencies = [ "gpui", "http_client", "indoc", - "inline_completion", "language", "language_model", "log", "menu", - "migrator", - "paths", "postage", "project", - "proto", "regex", "release_channel", "reqwest_client", @@ -20570,10 +20987,45 @@ dependencies = [ "workspace-hack", "worktree", "zed_actions", - "zed_llm_client", "zlog", ] +[[package]] +name = "zeta_cli" +version = "0.1.0" +dependencies = [ + "anyhow", + "clap", + "client", + "debug_adapter_extension", + "extension", + "fs", + "futures 0.3.31", + "gpui", + "gpui_tokio", + "language", + "language_extension", + "language_model", + "language_models", + "languages", + "node_runtime", + "paths", + "project", + "prompt_store", + "release_channel", + "reqwest_client", + "serde", + "serde_json", + "settings", + "shellexpand 2.1.2", + "smol", + "terminal_view", + "util", + "watch", + "workspace-hack", + "zeta", +] + [[package]] name = "zip" version = "0.6.6" diff --git a/Cargo.toml b/Cargo.toml index fd5cbff545351e3248f0d700c199e47f632d32d8..d547110bb4a984e868965da4d281dc394dd2fb33 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,12 +1,14 @@ [workspace] resolver = "2" members = [ + "crates/acp_thread", "crates/activity_indicator", - "crates/acp", - "crates/agent_ui", "crates/agent", - "crates/agent_settings", + "crates/agent2", "crates/agent_servers", + "crates/agent_settings", + "crates/agent_ui", + "crates/ai_onboarding", "crates/anthropic", "crates/askpass", "crates/assets", @@ -28,6 +30,9 @@ members = [ "crates/cli", "crates/client", "crates/clock", + "crates/cloud_api_client", + "crates/cloud_api_types", + "crates/cloud_llm_client", "crates/collab", "crates/collab_ui", "crates/collections", @@ -36,6 +41,7 @@ members = [ "crates/component", "crates/context_server", "crates/copilot", + "crates/crashes", "crates/credentials_provider", "crates/dap", "crates/dap_adapters", @@ -47,8 +53,8 @@ members = [ "crates/diagnostics", "crates/docs_preprocessor", "crates/editor", - "crates/explorer_command_injector", "crates/eval", + "crates/explorer_command_injector", "crates/extension", "crates/extension_api", "crates/extension_cli", @@ -69,15 +75,14 @@ members = [ "crates/gpui", "crates/gpui_macros", "crates/gpui_tokio", - "crates/html_to_markdown", "crates/http_client", "crates/http_client_tls", "crates/icons", "crates/image_viewer", "crates/indexed_docs", - "crates/inline_completion", - "crates/inline_completion_button", + "crates/edit_prediction", + "crates/edit_prediction_button", "crates/inspector_ui", "crates/install_cli", "crates/jj", @@ -98,14 +103,15 @@ members = [ "crates/markdown_preview", "crates/media", "crates/menu", - "crates/svg_preview", "crates/migrator", "crates/mistral", "crates/multi_buffer", + "crates/nc", "crates/net", "crates/node_runtime", "crates/notifications", "crates/ollama", + "crates/onboarding", "crates/open_ai", "crates/open_router", "crates/outline", @@ -137,6 +143,7 @@ members = [ "crates/semantic_version", "crates/session", "crates/settings", + "crates/settings_profile_selector", "crates/settings_ui", "crates/snippet", "crates/snippet_provider", @@ -149,6 +156,7 @@ members = [ "crates/sum_tree", "crates/supermaven", "crates/supermaven_api", + "crates/svg_preview", "crates/tab_switcher", "crates/task", "crates/tasks_ui", @@ -179,9 +187,11 @@ members = [ "crates/welcome", "crates/workspace", "crates/worktree", + "crates/x_ai", "crates/zed", "crates/zed_actions", "crates/zeta", + "crates/zeta_cli", "crates/zlog", "crates/zlog_settings", @@ -218,13 +228,15 @@ edition = "2024" # Workspace member crates # -acp = { path = "crates/acp" } +acp_thread = { path = "crates/acp_thread" } agent = { path = "crates/agent" } +agent2 = { path = "crates/agent2" } activity_indicator = { path = "crates/activity_indicator" } agent_ui = { path = "crates/agent_ui" } agent_settings = { path = "crates/agent_settings" } agent_servers = { path = "crates/agent_servers" } ai = { path = "crates/ai" } +ai_onboarding = { path = "crates/ai_onboarding" } anthropic = { path = "crates/anthropic" } askpass = { path = "crates/askpass" } assets = { path = "crates/assets" } @@ -246,6 +258,9 @@ channel = { path = "crates/channel" } cli = { path = "crates/cli" } client = { path = "crates/client" } clock = { path = "crates/clock" } +cloud_api_client = { path = "crates/cloud_api_client" } +cloud_api_types = { path = "crates/cloud_api_types" } +cloud_llm_client = { path = "crates/cloud_llm_client" } collab = { path = "crates/collab" } collab_ui = { path = "crates/collab_ui" } collections = { path = "crates/collections" } @@ -254,6 +269,7 @@ command_palette_hooks = { path = "crates/command_palette_hooks" } component = { path = "crates/component" } context_server = { path = "crates/context_server" } copilot = { path = "crates/copilot" } +crashes = { path = "crates/crashes" } credentials_provider = { path = "crates/credentials_provider" } dap = { path = "crates/dap" } dap_adapters = { path = "crates/dap_adapters" } @@ -290,8 +306,8 @@ http_client_tls = { path = "crates/http_client_tls" } icons = { path = "crates/icons" } image_viewer = { path = "crates/image_viewer" } indexed_docs = { path = "crates/indexed_docs" } -inline_completion = { path = "crates/inline_completion" } -inline_completion_button = { path = "crates/inline_completion_button" } +edit_prediction = { path = "crates/edit_prediction" } +edit_prediction_button = { path = "crates/edit_prediction_button" } inspector_ui = { path = "crates/inspector_ui" } install_cli = { path = "crates/install_cli" } jj = { path = "crates/jj" } @@ -316,10 +332,12 @@ menu = { path = "crates/menu" } migrator = { path = "crates/migrator" } mistral = { path = "crates/mistral" } multi_buffer = { path = "crates/multi_buffer" } +nc = { path = "crates/nc" } net = { path = "crates/net" } node_runtime = { path = "crates/node_runtime" } notifications = { path = "crates/notifications" } ollama = { path = "crates/ollama" } +onboarding = { path = "crates/onboarding" } open_ai = { path = "crates/open_ai" } open_router = { path = "crates/open_router", features = ["schemars"] } outline = { path = "crates/outline" } @@ -330,6 +348,7 @@ picker = { path = "crates/picker" } plugin = { path = "crates/plugin" } plugin_macros = { path = "crates/plugin_macros" } prettier = { path = "crates/prettier" } +settings_profile_selector = { path = "crates/settings_profile_selector" } project = { path = "crates/project" } project_panel = { path = "crates/project_panel" } project_symbols = { path = "crates/project_symbols" } @@ -394,6 +413,7 @@ web_search_providers = { path = "crates/web_search_providers" } welcome = { path = "crates/welcome" } workspace = { path = "crates/workspace" } worktree = { path = "crates/worktree" } +x_ai = { path = "crates/x_ai" } zed = { path = "crates/zed" } zed_actions = { path = "crates/zed_actions" } zeta = { path = "crates/zeta" } @@ -404,7 +424,8 @@ zlog_settings = { path = "crates/zlog_settings" } # External crates # -agentic-coding-protocol = "0.0.6" +agentic-coding-protocol = "0.0.10" +agent-client-protocol = { version = "0.0.23" } aho-corasick = "1.1" alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" } any_vec = "0.14" @@ -432,14 +453,15 @@ aws-smithy-runtime-api = { version = "1.7.4", features = ["http-1x", "client"] } aws-smithy-types = { version = "1.3.0", features = ["http-body-1-x"] } base64 = "0.22" bitflags = "2.6.0" -blade-graphics = { git = "https://github.com/kvark/blade", rev = "416375211bb0b5826b3584dccdb6a43369e499ad" } -blade-macros = { git = "https://github.com/kvark/blade", rev = "416375211bb0b5826b3584dccdb6a43369e499ad" } -blade-util = { git = "https://github.com/kvark/blade", rev = "416375211bb0b5826b3584dccdb6a43369e499ad" } +blade-graphics = { git = "https://github.com/kvark/blade", rev = "e0ec4e720957edd51b945b64dd85605ea54bcfe5" } +blade-macros = { git = "https://github.com/kvark/blade", rev = "e0ec4e720957edd51b945b64dd85605ea54bcfe5" } +blade-util = { git = "https://github.com/kvark/blade", rev = "e0ec4e720957edd51b945b64dd85605ea54bcfe5" } blake3 = "1.5.3" bytes = "1.0" cargo_metadata = "0.19" cargo_toml = "0.21" chrono = { version = "0.4", features = ["serde"] } +ciborium = "0.2" circular-buffer = "1.0" clap = { version = "4.4", features = ["derive"] } cocoa = "0.26" @@ -449,9 +471,10 @@ core-foundation = "0.10.0" core-foundation-sys = "0.8.6" core-video = { version = "0.4.3", features = ["metal"] } cpal = "0.16" +crash-handler = "0.6" criterion = { version = "0.5", features = ["html_reports"] } ctor = "0.4.0" -dap-types = { git = "https://github.com/zed-industries/dap-types", rev = "7f39295b441614ca9dbf44293e53c32f666897f9" } +dap-types = { git = "https://github.com/zed-industries/dap-types", rev = "1b461b310481d01e02b2603c16d7144b926339f8" } dashmap = "6.0" derive_more = "0.99.17" dirs = "4.0" @@ -474,6 +497,7 @@ heed = { version = "0.21.0", features = ["read-txn-no-tls"] } hex = "0.4.3" html5ever = "0.27.0" http = "1.1" +http-body = "1.0" hyper = "0.14" ignore = "0.4.22" image = "0.25.1" @@ -487,18 +511,19 @@ json_dotpath = "1.1" jsonschema = "0.30.0" jsonwebtoken = "9.3" jupyter-protocol = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" } -jupyter-websocket-client = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" } +jupyter-websocket-client = { git = "https://github.com/ConradIrwin/runtimed" ,rev = "7130c804216b6914355d15d0b91ea91f6babd734" } libc = "0.2" libsqlite3-sys = { version = "0.30.1", features = ["bundled"] } linkify = "0.10.0" log = { version = "0.4.16", features = ["kv_unstable_serde", "serde"] } -lsp-types = { git = "https://github.com/zed-industries/lsp-types", rev = "c9c189f1c5dd53c624a419ce35bc77ad6a908d18" } +lsp-types = { git = "https://github.com/zed-industries/lsp-types", rev = "39f629bdd03d59abd786ed9fc27e8bca02c0c0ec" } markup5ever_rcdom = "0.3.0" metal = "0.29" +minidumper = "0.8" moka = { version = "0.12.10", features = ["sync"] } naga = { version = "25.0", features = ["wgsl-in"] } nanoid = "0.4" -nbformat = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" } +nbformat = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" } nix = "0.29" num-format = "0.4.4" objc = "0.2" @@ -507,6 +532,7 @@ ordered-float = "2.1.1" palette = { version = "0.7.5", default-features = false, features = ["std"] } parking_lot = "0.12.1" partial-json-fixer = "0.5.3" +parse_int = "0.9" pathdiff = "0.2" pet = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" } pet-conda = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" } @@ -533,12 +559,13 @@ reqwest = { git = "https://github.com/zed-industries/reqwest.git", rev = "951c77 "charset", "http2", "macos-system-configuration", + "multipart", "rustls-tls-native-roots", "socks", "stream", ] } rsa = "0.9.6" -runtimelib = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734", default-features = false, features = [ +runtimelib = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734", default-features = false, features = [ "async-dispatcher-runtime", ] } rust-embed = { version = "8.4", features = ["include-exclude"] } @@ -546,7 +573,7 @@ rustc-demangle = "0.1.23" rustc-hash = "2.1.0" rustls = { version = "0.23.26" } rustls-platform-verifier = "0.5.0" -scap = { git = "https://github.com/zed-industries/scap", rev = "28dd306ff2e3374404936dec778fc1e975b8dd12", default-features = false } +scap = { git = "https://github.com/zed-industries/scap", rev = "808aa5c45b41e8f44729d02e38fd00a2fe2722e7", default-features = false } schemars = { version = "1.0", features = ["indexmap2"] } semver = "1.0" serde = { version = "1.0", features = ["derive", "rc"] } @@ -574,7 +601,7 @@ sysinfo = "0.31.0" take-until = "0.2.0" tempfile = "3.20.0" thiserror = "2.0.12" -tiktoken-rs = "0.7.0" +tiktoken-rs = { git = "https://github.com/zed-industries/tiktoken-rs", rev = "30c32a4522751699adeda0d5840c71c3b75ae73d" } time = { version = "0.3", features = [ "macros", "parsing", @@ -634,7 +661,9 @@ which = "6.0.0" windows-core = "0.61" wit-component = "0.221" workspace-hack = "0.1.0" -zed_llm_client = "= 0.8.6" +# We can switch back to the published version once https://github.com/infinitefield/yawc/pull/16 is merged and a new +# version is released. +yawc = { git = "https://github.com/deviant-forks/yawc", rev = "1899688f3e69ace4545aceb97b2a13881cf26142" } zstd = "0.11" [workspace.dependencies.async-stripe] @@ -661,14 +690,16 @@ features = [ "UI_ViewManagement", "Wdk_System_SystemServices", "Win32_Globalization", - "Win32_Graphics_Direct2D", - "Win32_Graphics_Direct2D_Common", + "Win32_Graphics_Direct3D", + "Win32_Graphics_Direct3D11", + "Win32_Graphics_Direct3D_Fxc", + "Win32_Graphics_DirectComposition", "Win32_Graphics_DirectWrite", "Win32_Graphics_Dwm", + "Win32_Graphics_Dxgi", "Win32_Graphics_Dxgi_Common", "Win32_Graphics_Gdi", "Win32_Graphics_Imaging", - "Win32_Graphics_Imaging_D2D", "Win32_Networking_WinSock", "Win32_Security", "Win32_Security_Credentials", @@ -700,6 +731,7 @@ features = [ [patch.crates-io] notify = { git = "https://github.com/zed-industries/notify.git", rev = "bbb9ea5ae52b253e095737847e367c30653a2e96" } notify-types = { git = "https://github.com/zed-industries/notify.git", rev = "bbb9ea5ae52b253e095737847e367c30653a2e96" } +windows-capture = { git = "https://github.com/zed-industries/windows-capture.git", rev = "f0d6c1b6691db75461b732f6d5ff56eed002eeb9" } # Makes the workspace hack crate refer to the local one, but only when you're building locally workspace-hack = { path = "tooling/workspace-hack" } @@ -708,6 +740,11 @@ workspace-hack = { path = "tooling/workspace-hack" } split-debuginfo = "unpacked" codegen-units = 16 +# mirror configuration for crates compiled for the build platform +# (without this cargo will compile ~400 crates twice) +[profile.dev.build-override] +codegen-units = 16 + [profile.dev.package] taffy = { opt-level = 3 } cranelift-codegen = { opt-level = 3 } @@ -730,7 +767,7 @@ feature_flags = { codegen-units = 1 } file_icons = { codegen-units = 1 } fsevent = { codegen-units = 1 } image_viewer = { codegen-units = 1 } -inline_completion_button = { codegen-units = 1 } +edit_prediction_button = { codegen-units = 1 } install_cli = { codegen-units = 1 } journal = { codegen-units = 1 } lmstudio = { codegen-units = 1 } diff --git a/Dockerfile-collab b/Dockerfile-collab index 2dafe296c7c8bb46c758d6c5f67ce6feed055d2b..c1621d6ee67e42117315ea49eac99f6f6260f4b7 100644 --- a/Dockerfile-collab +++ b/Dockerfile-collab @@ -1,6 +1,6 @@ # syntax = docker/dockerfile:1.2 -FROM rust:1.88-bookworm as builder +FROM rust:1.89-bookworm as builder WORKDIR app COPY . . diff --git a/Procfile b/Procfile index 5f1231b90a41cc9fdadb1d856b705fe52d37db22..b3f13f66a60b97aed9671ae310ed20d7f4813026 100644 --- a/Procfile +++ b/Procfile @@ -1,3 +1,4 @@ collab: RUST_LOG=${RUST_LOG:-info} cargo run --package=collab serve all +cloud: cd ../cloud; cargo make dev livekit: livekit-server --dev blob_store: ./script/run-local-minio diff --git a/README.md b/README.md index 4c794efc3de3f26fb1e5dbf943f6c7379174791a..38547c1ca441b918b773d8b1a884a1e3f48c785f 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ # Zed +[![Zed](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/zed-industries/zed/main/assets/badge/v0.json)](https://zed.dev) [![CI](https://github.com/zed-industries/zed/actions/workflows/ci.yml/badge.svg)](https://github.com/zed-industries/zed/actions/workflows/ci.yml) Welcome to Zed, a high-performance, multiplayer code editor from the creators of [Atom](https://github.com/atom/atom) and [Tree-sitter](https://github.com/tree-sitter/tree-sitter). diff --git a/assets/badge/v0.json b/assets/badge/v0.json new file mode 100644 index 0000000000000000000000000000000000000000..c7d18bb42b71f2d57696ce56b8211e0395afab9d --- /dev/null +++ b/assets/badge/v0.json @@ -0,0 +1,8 @@ +{ + "label": "", + "message": "Zed", + "logoSvg": "", + "logoWidth": 16, + "labelColor": "black", + "color": "white" +} diff --git a/assets/icons/ai_bedrock.svg b/assets/icons/ai_bedrock.svg index 2b672c364ea42e0e6c7b0e0166aa9efb121a424f..c9bbcc82e10bb0b277d4fbe6885077ed3d228a1f 100644 --- a/assets/icons/ai_bedrock.svg +++ b/assets/icons/ai_bedrock.svg @@ -1,4 +1,8 @@ - - - + + + + + + + diff --git a/assets/icons/ai_claude.svg b/assets/icons/ai_claude.svg new file mode 100644 index 0000000000000000000000000000000000000000..a3e3e1f4cd7bcc4924ed3f8164c35c5c8e2a9c4c --- /dev/null +++ b/assets/icons/ai_claude.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/ai_deep_seek.svg b/assets/icons/ai_deep_seek.svg index cf480c834c9f01d914c6fe37885903cdb79ff27f..c8e5483fb3379f4dc0f08ae7e1ab7ee696bdcab9 100644 --- a/assets/icons/ai_deep_seek.svg +++ b/assets/icons/ai_deep_seek.svg @@ -1 +1,3 @@ -DeepSeek + + + diff --git a/assets/icons/ai_gemini.svg b/assets/icons/ai_gemini.svg index 60197dc4adcf912128756b32ead43b8b1da61222..bdde44ed2475313f0dfd418a496f372ca61db22d 100644 --- a/assets/icons/ai_gemini.svg +++ b/assets/icons/ai_gemini.svg @@ -1 +1,3 @@ -Google Gemini + + + diff --git a/assets/icons/ai_lm_studio.svg b/assets/icons/ai_lm_studio.svg index 0b455f48a7382c744685b6dbd3c3a8ba6537ccf9..5cfdeb5578cb34e6781fa74da2d9773096c451e1 100644 --- a/assets/icons/ai_lm_studio.svg +++ b/assets/icons/ai_lm_studio.svg @@ -1,33 +1,15 @@ - - - Artboard - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + diff --git a/assets/icons/ai_mistral.svg b/assets/icons/ai_mistral.svg index 23b8f2ef6cd23f9738cb453798becb2e31f7e0cc..f11c177e2fc9e069cd01ed10ff8914b3b5b7c2d8 100644 --- a/assets/icons/ai_mistral.svg +++ b/assets/icons/ai_mistral.svg @@ -1 +1,8 @@ -Mistral \ No newline at end of file + + + + + + + + diff --git a/assets/icons/ai_ollama.svg b/assets/icons/ai_ollama.svg index d433df39811c4d5e0c9f85c1d9bdab891da8e255..36a88c1ad6d70dfff1fe13d1be93a9aade3b7b5d 100644 --- a/assets/icons/ai_ollama.svg +++ b/assets/icons/ai_ollama.svg @@ -1,14 +1,7 @@ - - - - - - - - - - - - + + + + + diff --git a/assets/icons/ai_open_ai.svg b/assets/icons/ai_open_ai.svg index e659a472d89275919bf83ffe3446fa133ae07ad2..e45ac315a011853a9e9343171659b7623017fb31 100644 --- a/assets/icons/ai_open_ai.svg +++ b/assets/icons/ai_open_ai.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/ai_open_ai_compat.svg b/assets/icons/ai_open_ai_compat.svg new file mode 100644 index 0000000000000000000000000000000000000000..f6557caac3304821b051fa6375c7ef32b225f70e --- /dev/null +++ b/assets/icons/ai_open_ai_compat.svg @@ -0,0 +1,4 @@ + + + + diff --git a/assets/icons/ai_open_router.svg b/assets/icons/ai_open_router.svg index cc8597729a8ac4011d5ef8937fdb4f0ddaff7839..b6f5164e0b385f26e7b22a12253d18200dbff24e 100644 --- a/assets/icons/ai_open_router.svg +++ b/assets/icons/ai_open_router.svg @@ -1,8 +1,8 @@ - - - - - - - + + + + + + + diff --git a/assets/icons/ai_x_ai.svg b/assets/icons/ai_x_ai.svg new file mode 100644 index 0000000000000000000000000000000000000000..d3400fbe9cd4c8f82a38219bd489d1fa1c3d5f8e --- /dev/null +++ b/assets/icons/ai_x_ai.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/ai_zed.svg b/assets/icons/ai_zed.svg index 1c6bb8ad63a09c22301570180d4381d49a35af72..6d78efacd5ffdaa06570ed9459bdd91eaae05bb1 100644 --- a/assets/icons/ai_zed.svg +++ b/assets/icons/ai_zed.svg @@ -1,10 +1,3 @@ - - - - - - - - + diff --git a/assets/icons/audio_off.svg b/assets/icons/audio_off.svg index 93b98471ca1a15e4ef92860e953dde8beb559c37..dfb5a1c45829119ea0dc89bbca3a3f33228ee88f 100644 --- a/assets/icons/audio_off.svg +++ b/assets/icons/audio_off.svg @@ -1 +1,7 @@ - + + + + + + + diff --git a/assets/icons/audio_on.svg b/assets/icons/audio_on.svg index 42310ea32c289e2ecf24a6fa231ae55fce3cb05e..d1bef0d337d6c8a0e79cb0dab8b7d63d5cb2a4d1 100644 --- a/assets/icons/audio_on.svg +++ b/assets/icons/audio_on.svg @@ -1 +1,5 @@ - + + + + + diff --git a/assets/icons/bolt.svg b/assets/icons/bolt.svg deleted file mode 100644 index 2688ede2a502e723e188787fab0cc82e43ca097c..0000000000000000000000000000000000000000 --- a/assets/icons/bolt.svg +++ /dev/null @@ -1,3 +0,0 @@ - - - diff --git a/assets/icons/bolt_filled.svg b/assets/icons/bolt_filled.svg index 543e72adf8f36dc9b7bbe33058d652e5ada072b9..14d8f53e02fe9b82500ba4d9a6a030a8c7cea252 100644 --- a/assets/icons/bolt_filled.svg +++ b/assets/icons/bolt_filled.svg @@ -1,3 +1,3 @@ - - + + diff --git a/assets/icons/bolt_filled_alt.svg b/assets/icons/bolt_filled_alt.svg deleted file mode 100644 index 141e1c5f577bbd9bdc661de6629f863bfc760de9..0000000000000000000000000000000000000000 --- a/assets/icons/bolt_filled_alt.svg +++ /dev/null @@ -1,3 +0,0 @@ - - - diff --git a/assets/icons/bolt_outlined.svg b/assets/icons/bolt_outlined.svg new file mode 100644 index 0000000000000000000000000000000000000000..58fccf778813d3653f1066f45e5573adbf2d9ec2 --- /dev/null +++ b/assets/icons/bolt_outlined.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/brain.svg b/assets/icons/brain.svg deleted file mode 100644 index 80c93814f7c483f9e90d20f81e4ce7d32459ab57..0000000000000000000000000000000000000000 --- a/assets/icons/brain.svg +++ /dev/null @@ -1 +0,0 @@ - diff --git a/assets/icons/chat.svg b/assets/icons/chat.svg new file mode 100644 index 0000000000000000000000000000000000000000..a0548c3d3e6917fbea2bfba825761e01cd215a33 --- /dev/null +++ b/assets/icons/chat.svg @@ -0,0 +1,4 @@ + + + + diff --git a/assets/icons/book_plus.svg b/assets/icons/cloud_download.svg similarity index 51% rename from assets/icons/book_plus.svg rename to assets/icons/cloud_download.svg index 2868f07cd098d717982ff000312f8eb64257c474..bc7a8376d123088119643fc20346506690559bdc 100644 --- a/assets/icons/book_plus.svg +++ b/assets/icons/cloud_download.svg @@ -1 +1 @@ - + \ No newline at end of file diff --git a/assets/icons/debug.svg b/assets/icons/debug.svg index 8cea0c460402fbb36769aa0aaadab9f80513d101..ff51e42b1a9483f4f6d0382d67aa34bd3405f1ff 100644 --- a/assets/icons/debug.svg +++ b/assets/icons/debug.svg @@ -1 +1,12 @@ - + + + + + + + + + + + + diff --git a/assets/icons/editor_atom.svg b/assets/icons/editor_atom.svg new file mode 100644 index 0000000000000000000000000000000000000000..cc5fa83843fd6fa8800bf824ad5d184af06b4cb2 --- /dev/null +++ b/assets/icons/editor_atom.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/editor_cursor.svg b/assets/icons/editor_cursor.svg new file mode 100644 index 0000000000000000000000000000000000000000..338697be8a621e80099c308b3dda0a4e11fcfd61 --- /dev/null +++ b/assets/icons/editor_cursor.svg @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/assets/icons/editor_emacs.svg b/assets/icons/editor_emacs.svg new file mode 100644 index 0000000000000000000000000000000000000000..951d7b2be16387a57e940b964d3ed1621b7a5819 --- /dev/null +++ b/assets/icons/editor_emacs.svg @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/assets/icons/editor_jet_brains.svg b/assets/icons/editor_jet_brains.svg new file mode 100644 index 0000000000000000000000000000000000000000..7d9cf0c65cd31137153b417b31d8d3022e0e3ef3 --- /dev/null +++ b/assets/icons/editor_jet_brains.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/editor_sublime.svg b/assets/icons/editor_sublime.svg new file mode 100644 index 0000000000000000000000000000000000000000..95a04f6b54127dc20ad850271939b15057649ae7 --- /dev/null +++ b/assets/icons/editor_sublime.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/assets/icons/editor_vs_code.svg b/assets/icons/editor_vs_code.svg new file mode 100644 index 0000000000000000000000000000000000000000..2a71ad52af22bbc9cd9a9c557fa52f91d00fc7ce --- /dev/null +++ b/assets/icons/editor_vs_code.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/at_sign.svg b/assets/icons/equal.svg similarity index 62% rename from assets/icons/at_sign.svg rename to assets/icons/equal.svg index 4cf8cd468f17e5cbcd012c6225543f6c4b027969..9b3a151a12fc3dea5f1eb295bf299e6360846ed2 100644 --- a/assets/icons/at_sign.svg +++ b/assets/icons/equal.svg @@ -1 +1 @@ - + diff --git a/assets/icons/exit.svg b/assets/icons/exit.svg index 2cc6ce120dc9af17a642ac3bf2f2451209cb5e5e..1ff9d7882441548e9c3534ae5ffe6b6331391b45 100644 --- a/assets/icons/exit.svg +++ b/assets/icons/exit.svg @@ -1,8 +1,5 @@ - - + + + + diff --git a/assets/icons/file_icons/kdl.svg b/assets/icons/file_icons/kdl.svg new file mode 100644 index 0000000000000000000000000000000000000000..92d9f28428a8192f739bd8a796164b5e8bc71bda --- /dev/null +++ b/assets/icons/file_icons/kdl.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/assets/icons/file_icons/puppet.svg b/assets/icons/file_icons/puppet.svg new file mode 100644 index 0000000000000000000000000000000000000000..cdf903bc62b50d50fc2f3f391c8660cd6c4c2a50 --- /dev/null +++ b/assets/icons/file_icons/puppet.svg @@ -0,0 +1 @@ + diff --git a/assets/icons/file_icons/surrealql.svg b/assets/icons/file_icons/surrealql.svg new file mode 100644 index 0000000000000000000000000000000000000000..076f93e808fc38a28313956e905a95561422b58c --- /dev/null +++ b/assets/icons/file_icons/surrealql.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/file_text.svg b/assets/icons/file_text.svg index 7c602f2ac79c213aed6b219b5b7abd62cad727b6..a9b8f971e00911294565712161c0a36f342cb498 100644 --- a/assets/icons/file_text.svg +++ b/assets/icons/file_text.svg @@ -1 +1,6 @@ - + + + + + + diff --git a/assets/icons/file_tree.svg b/assets/icons/file_tree.svg index 4c921b135183b7b58126f16c68f39aec22677285..a140cd70b12d1be180d2c683d59400212969c47a 100644 --- a/assets/icons/file_tree.svg +++ b/assets/icons/file_tree.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/git_branch_small.svg b/assets/icons/git_branch_small.svg index d23fc176ac797fff35c6c9d35176d5e03c6170fe..22832d6fedfc5221c31c81eae497f8172b59c21e 100644 --- a/assets/icons/git_branch_small.svg +++ b/assets/icons/git_branch_small.svg @@ -1,6 +1,7 @@ - - - - - + + + + + + diff --git a/assets/icons/git_onboarding_bg.svg b/assets/icons/git_onboarding_bg.svg deleted file mode 100644 index 18da0230a26c4e67b7e3ac2e64894132102f93c6..0000000000000000000000000000000000000000 --- a/assets/icons/git_onboarding_bg.svg +++ /dev/null @@ -1,40 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/assets/icons/list_tree.svg b/assets/icons/list_tree.svg index 8cf157ec135d13395fc8ac66d8f8086f0d199a2e..09872a60f7ed9c85e89f06b7384b083a7f4b5779 100644 --- a/assets/icons/list_tree.svg +++ b/assets/icons/list_tree.svg @@ -1 +1,7 @@ - \ No newline at end of file + + + + + + + diff --git a/assets/icons/location_edit.svg b/assets/icons/location_edit.svg new file mode 100644 index 0000000000000000000000000000000000000000..de82e8db4e05da232d024d6a92e329fd15a94ff0 --- /dev/null +++ b/assets/icons/location_edit.svg @@ -0,0 +1 @@ + diff --git a/assets/icons/message_bubbles.svg b/assets/icons/message_bubbles.svg deleted file mode 100644 index 03a6c7760cdce8a19ec1fc243fbc47d51d8a0988..0000000000000000000000000000000000000000 --- a/assets/icons/message_bubbles.svg +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - diff --git a/assets/icons/mic.svg b/assets/icons/mic.svg index 01f4c9bf669ba253edaa43dc641fdb9a1b7c51d1..1d9c5bc9edf2a48b3311965fb57758b3ee2e015e 100644 --- a/assets/icons/mic.svg +++ b/assets/icons/mic.svg @@ -1,3 +1,5 @@ - - + + + + diff --git a/assets/icons/mic_mute.svg b/assets/icons/mic_mute.svg index fe5f8201cc4da5e2cf6a1b770c538d421994e1c4..8c61ae2f1ccedc1b27244ed80e1a3fdd75cd4120 100644 --- a/assets/icons/mic_mute.svg +++ b/assets/icons/mic_mute.svg @@ -1,3 +1,8 @@ - - + + + + + + + diff --git a/assets/icons/microscope.svg b/assets/icons/microscope.svg deleted file mode 100644 index 2b3009a28be7068877e52cf3f3d3295f78abe7b6..0000000000000000000000000000000000000000 --- a/assets/icons/microscope.svg +++ /dev/null @@ -1 +0,0 @@ - diff --git a/assets/icons/play.svg b/assets/icons/play.svg deleted file mode 100644 index 2481bda7d6d0f310dd6dccfe4918cf0971ab9d0d..0000000000000000000000000000000000000000 --- a/assets/icons/play.svg +++ /dev/null @@ -1,3 +0,0 @@ - - - diff --git a/assets/icons/play_bug.svg b/assets/icons/play_bug.svg deleted file mode 100644 index 7d265dd42a488ea3ab65b6e60a1597dc8d518d46..0000000000000000000000000000000000000000 --- a/assets/icons/play_bug.svg +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - - - diff --git a/assets/icons/play_filled.svg b/assets/icons/play_filled.svg new file mode 100644 index 0000000000000000000000000000000000000000..c632434305c6bd25da205ca8cee8203b9d3611b1 --- /dev/null +++ b/assets/icons/play_filled.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/play_alt.svg b/assets/icons/play_outlined.svg similarity index 70% rename from assets/icons/play_alt.svg rename to assets/icons/play_outlined.svg index b327ab07b5f99cdca7a73e07bc29498f6148b02a..7e1cacd5af8795501cc30f4e33927f752a1eba7f 100644 --- a/assets/icons/play_alt.svg +++ b/assets/icons/play_outlined.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/reveal.svg b/assets/icons/reveal.svg deleted file mode 100644 index ff5444d8f84c311ac79c2f289b9f35a8c897fb0f..0000000000000000000000000000000000000000 --- a/assets/icons/reveal.svg +++ /dev/null @@ -1 +0,0 @@ - diff --git a/assets/icons/screen.svg b/assets/icons/screen.svg index ad252e64cf5c73d4cd6ae48dd0abede47d3323e6..4b686b58f9de2e4993546ddad1a20af395d50330 100644 --- a/assets/icons/screen.svg +++ b/assets/icons/screen.svg @@ -1,8 +1,5 @@ - - + + + + diff --git a/assets/icons/shield_check.svg b/assets/icons/shield_check.svg new file mode 100644 index 0000000000000000000000000000000000000000..6e58c314682a5e87de9b2ca582262a3110f7006d --- /dev/null +++ b/assets/icons/shield_check.svg @@ -0,0 +1,4 @@ + + + + diff --git a/assets/icons/spinner.svg b/assets/icons/spinner.svg deleted file mode 100644 index 4f4034ae8944288faf7bb7bad2cb510c2aaaf495..0000000000000000000000000000000000000000 --- a/assets/icons/spinner.svg +++ /dev/null @@ -1,13 +0,0 @@ - - - - - - - - - - - - - diff --git a/assets/icons/strikethrough.svg b/assets/icons/strikethrough.svg deleted file mode 100644 index d7d09059129462afe14b6af9a7393f7aff96dd81..0000000000000000000000000000000000000000 --- a/assets/icons/strikethrough.svg +++ /dev/null @@ -1,3 +0,0 @@ - - - diff --git a/assets/icons/terminal_alt.svg b/assets/icons/terminal_alt.svg new file mode 100644 index 0000000000000000000000000000000000000000..7afb89db2130b8d9233c1662d7fcf86f63de305a --- /dev/null +++ b/assets/icons/terminal_alt.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/assets/icons/text_thread.svg b/assets/icons/text_thread.svg new file mode 100644 index 0000000000000000000000000000000000000000..75afa934a028f1bddd104effe536db70ad4f241c --- /dev/null +++ b/assets/icons/text_thread.svg @@ -0,0 +1,7 @@ + + + + + + + diff --git a/assets/icons/thread.svg b/assets/icons/thread.svg new file mode 100644 index 0000000000000000000000000000000000000000..8c2596a4c9fca9f75a122dc85225f33696320030 --- /dev/null +++ b/assets/icons/thread.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/thread_from_summary.svg b/assets/icons/thread_from_summary.svg new file mode 100644 index 0000000000000000000000000000000000000000..7519935affc03bf50e9a39bcb5792237fba1e44f --- /dev/null +++ b/assets/icons/thread_from_summary.svg @@ -0,0 +1,6 @@ + + + + + + diff --git a/assets/icons/todo_complete.svg b/assets/icons/todo_complete.svg new file mode 100644 index 0000000000000000000000000000000000000000..9fa2e818bb61137de35d260f4384a0db545d4125 --- /dev/null +++ b/assets/icons/todo_complete.svg @@ -0,0 +1,4 @@ + + + + diff --git a/assets/icons/todo_pending.svg b/assets/icons/todo_pending.svg new file mode 100644 index 0000000000000000000000000000000000000000..dfb013b52b987a3f99e1b8304418b847ff1ccf2b --- /dev/null +++ b/assets/icons/todo_pending.svg @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/assets/icons/todo_progress.svg b/assets/icons/todo_progress.svg new file mode 100644 index 0000000000000000000000000000000000000000..9b2ed7375d9807139261a2d81f7f1f168470d0f4 --- /dev/null +++ b/assets/icons/todo_progress.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/assets/icons/tool_bulb.svg b/assets/icons/tool_think.svg similarity index 100% rename from assets/icons/tool_bulb.svg rename to assets/icons/tool_think.svg diff --git a/assets/icons/trash.svg b/assets/icons/trash.svg index b71035b99cc53fa5b038e08e064f96cb5a74762d..1322e90f9fdc1fad9901febff0f71a938621f900 100644 --- a/assets/icons/trash.svg +++ b/assets/icons/trash.svg @@ -1 +1,5 @@ - + + + + + diff --git a/assets/icons/trash_alt.svg b/assets/icons/trash_alt.svg deleted file mode 100644 index 6867b421475a6e5cfdc51da124e648d92f88f055..0000000000000000000000000000000000000000 --- a/assets/icons/trash_alt.svg +++ /dev/null @@ -1 +0,0 @@ - diff --git a/assets/icons/user_group.svg b/assets/icons/user_group.svg index aa99277646653c899ee049547e5574b76b25b840..ac1f7bdc633190f88b202d9e5ae7430af225aecd 100644 --- a/assets/icons/user_group.svg +++ b/assets/icons/user_group.svg @@ -1,3 +1,5 @@ - + + + diff --git a/assets/icons/zed_assistant.svg b/assets/icons/zed_assistant.svg index 693d86f929ff170f08edf3d2a0a7a28af17a30bf..d21252de8c234611ddd41caff287e3fc0d540ed3 100644 --- a/assets/icons/zed_assistant.svg +++ b/assets/icons/zed_assistant.svg @@ -1,5 +1,5 @@ - + diff --git a/assets/icons/zed_predict_bg.svg b/assets/icons/zed_predict_bg.svg deleted file mode 100644 index 1dccbb51af0e61106318568898e626dbeea07ff6..0000000000000000000000000000000000000000 --- a/assets/icons/zed_predict_bg.svg +++ /dev/null @@ -1,19 +0,0 @@ - - - - - - - - - - - - - - - - - - - diff --git a/assets/images/pro_trial_stamp.svg b/assets/images/pro_trial_stamp.svg new file mode 100644 index 0000000000000000000000000000000000000000..a3f9095120876949c51f1cd03f8fb8499bf4ea3e --- /dev/null +++ b/assets/images/pro_trial_stamp.svg @@ -0,0 +1 @@ + diff --git a/assets/images/pro_user_stamp.svg b/assets/images/pro_user_stamp.svg new file mode 100644 index 0000000000000000000000000000000000000000..d037a9e8335d31f4b515a674f3bfa9495bf8a6a3 --- /dev/null +++ b/assets/images/pro_user_stamp.svg @@ -0,0 +1 @@ + diff --git a/assets/keymaps/default-linux.json b/assets/keymaps/default-linux.json index 8a46e6c234c20fb8d662c4ecc3004955e293af1d..c436b1a8fb5a6da4bf7dc086ae01b280062935e0 100644 --- a/assets/keymaps/default-linux.json +++ b/assets/keymaps/default-linux.json @@ -232,7 +232,7 @@ "ctrl-n": "agent::NewThread", "ctrl-alt-n": "agent::NewTextThread", "ctrl-shift-h": "agent::OpenHistory", - "ctrl-alt-c": "agent::OpenConfiguration", + "ctrl-alt-c": "agent::OpenSettings", "ctrl-alt-p": "agent::OpenRulesLibrary", "ctrl-i": "agent::ToggleProfileSelector", "ctrl-alt-/": "agent::ToggleModelSelector", @@ -269,15 +269,15 @@ } }, { - "context": "AgentPanel && acp_thread", + "context": "AgentPanel && external_agent_thread", "use_key_equivalents": true, "bindings": { - "ctrl-n": "agent::NewAcpThread", + "ctrl-n": "agent::NewExternalAgentThread", "ctrl-alt-t": "agent::NewThread" } }, { - "context": "MessageEditor > Editor", + "context": "MessageEditor && !Picker > Editor && !use_modifier_to_send", "bindings": { "enter": "agent::Chat", "ctrl-enter": "agent::ChatWithFollow", @@ -287,6 +287,17 @@ "ctrl-shift-n": "agent::RejectAll" } }, + { + "context": "MessageEditor && !Picker > Editor && use_modifier_to_send", + "bindings": { + "ctrl-enter": "agent::Chat", + "enter": "editor::Newline", + "ctrl-i": "agent::ToggleProfileSelector", + "shift-ctrl-r": "agent::OpenAgentDiff", + "ctrl-shift-y": "agent::KeepAll", + "ctrl-shift-n": "agent::RejectAll" + } + }, { "context": "EditMessageEditor > Editor", "bindings": { @@ -320,7 +331,10 @@ "bindings": { "enter": "agent::Chat", "up": "agent::PreviousHistoryMessage", - "down": "agent::NextHistoryMessage" + "down": "agent::NextHistoryMessage", + "shift-ctrl-r": "agent::OpenAgentDiff", + "ctrl-shift-y": "agent::KeepAll", + "ctrl-shift-n": "agent::RejectAll" } }, { @@ -418,7 +432,7 @@ "ctrl-shift-pagedown": "pane::SwapItemRight", "ctrl-f4": ["pane::CloseActiveItem", { "close_pinned": false }], "ctrl-w": ["pane::CloseActiveItem", { "close_pinned": false }], - "alt-ctrl-t": ["pane::CloseInactiveItems", { "close_pinned": false }], + "alt-ctrl-t": ["pane::CloseOtherItems", { "close_pinned": false }], "alt-ctrl-shift-w": "workspace::CloseInactiveTabsAndPanes", "ctrl-k e": ["pane::CloseItemsToTheLeft", { "close_pinned": false }], "ctrl-k t": ["pane::CloseItemsToTheRight", { "close_pinned": false }], @@ -471,11 +485,10 @@ "ctrl-k ctrl-d": ["editor::SelectNext", { "replace_newest": true }], // editor.action.moveSelectionToNextFindMatch / find_under_expand_skip "ctrl-k ctrl-shift-d": ["editor::SelectPrevious", { "replace_newest": true }], // editor.action.moveSelectionToPreviousFindMatch "ctrl-k ctrl-i": "editor::Hover", + "ctrl-k ctrl-b": "editor::BlameHover", "ctrl-/": ["editor::ToggleComments", { "advance_downwards": false }], - "ctrl-u": "editor::UndoSelection", - "ctrl-shift-u": "editor::RedoSelection", - "f8": "editor::GoToDiagnostic", - "shift-f8": "editor::GoToPreviousDiagnostic", + "f8": ["editor::GoToDiagnostic", { "severity": { "min": "hint", "max": "error" } }], + "shift-f8": ["editor::GoToPreviousDiagnostic", { "severity": { "min": "hint", "max": "error" } }], "f2": "editor::Rename", "f12": "editor::GoToDefinition", "alt-f12": "editor::GoToDefinitionSplit", @@ -484,7 +497,7 @@ "shift-f12": "editor::GoToImplementation", "alt-ctrl-f12": "editor::GoToTypeDefinitionSplit", "alt-shift-f12": "editor::FindAllReferences", - "ctrl-m": "editor::MoveToEnclosingBracket", + "ctrl-m": "editor::MoveToEnclosingBracket", // from jetbrains "ctrl-|": "editor::MoveToEnclosingBracket", "ctrl-{": "editor::Fold", "ctrl-}": "editor::UnfoldLines", @@ -585,8 +598,9 @@ "ctrl-shift-f": "pane::DeploySearch", "ctrl-shift-h": ["pane::DeploySearch", { "replace_enabled": true }], "ctrl-shift-t": "pane::ReopenClosedItem", - "ctrl-k ctrl-s": "zed::OpenKeymap", + "ctrl-k ctrl-s": "zed::OpenKeymapEditor", "ctrl-k ctrl-t": "theme_selector::Toggle", + "ctrl-alt-super-p": "settings_profile_selector::Toggle", "ctrl-t": "project_symbols::Toggle", "ctrl-p": "file_finder::Toggle", "ctrl-tab": "tab_switcher::Toggle", @@ -651,6 +665,8 @@ { "context": "Editor", "bindings": { + "ctrl-u": "editor::UndoSelection", + "ctrl-shift-u": "editor::RedoSelection", "ctrl-shift-j": "editor::JoinLines", "ctrl-alt-backspace": "editor::DeleteToPreviousSubwordStart", "ctrl-alt-h": "editor::DeleteToPreviousSubwordStart", @@ -832,6 +848,7 @@ "ctrl-delete": ["project_panel::Delete", { "skip_prompt": false }], "alt-ctrl-r": "project_panel::RevealInFileManager", "ctrl-shift-enter": "project_panel::OpenWithSystem", + "alt-d": "project_panel::CompareMarkedFiles", "shift-find": "project_panel::NewSearchInDirectory", "ctrl-alt-shift-f": "project_panel::NewSearchInDirectory", "shift-down": "menu::SelectNext", @@ -855,11 +872,10 @@ "alt-shift-y": "git::UnstageFile", "ctrl-alt-y": "git::ToggleStaged", "space": "git::ToggleStaged", + "shift-space": "git::StageRange", "tab": "git_panel::FocusEditor", "shift-tab": "git_panel::FocusEditor", "escape": "git_panel::ToggleFocus", - "ctrl-enter": "git::Commit", - "ctrl-shift-enter": "git::Amend", "alt-enter": "menu::SecondaryConfirm", "delete": ["git::RestoreFile", { "skip_prompt": false }], "backspace": ["git::RestoreFile", { "skip_prompt": false }], @@ -896,7 +912,9 @@ "ctrl-g backspace": "git::RestoreTrackedFiles", "ctrl-g shift-backspace": "git::TrashUntrackedFiles", "ctrl-space": "git::StageAll", - "ctrl-shift-space": "git::UnstageAll" + "ctrl-shift-space": "git::UnstageAll", + "ctrl-enter": "git::Commit", + "ctrl-shift-enter": "git::Amend" } }, { @@ -915,7 +933,7 @@ } }, { - "context": "GitPanel > Editor", + "context": "CommitEditor > Editor", "bindings": { "escape": "git_panel::FocusChanges", "tab": "git_panel::FocusChanges", @@ -961,9 +979,14 @@ "context": "CollabPanel && not_editing", "bindings": { "ctrl-backspace": "collab_panel::Remove", - "space": "menu::Confirm", - "ctrl-up": "collab_panel::MoveChannelUp", - "ctrl-down": "collab_panel::MoveChannelDown" + "space": "menu::Confirm" + } + }, + { + "context": "CollabPanel", + "bindings": { + "alt-up": "collab_panel::MoveChannelUp", + "alt-down": "collab_panel::MoveChannelDown" } }, { @@ -997,6 +1020,7 @@ { "context": "FileFinder || (FileFinder > Picker > Editor)", "bindings": { + "ctrl-p": "file_finder::Toggle", "ctrl-shift-a": "file_finder::ToggleSplitMenu", "ctrl-shift-i": "file_finder::ToggleFilterMenu" } @@ -1079,6 +1103,13 @@ "ctrl-enter": "menu::Confirm" } }, + { + "context": "OnboardingAiConfigurationModal", + "use_key_equivalents": true, + "bindings": { + "escape": "menu::Cancel" + } + }, { "context": "Diagnostics", "use_key_equivalents": true, @@ -1112,7 +1143,51 @@ "context": "KeymapEditor", "use_key_equivalents": true, "bindings": { - "ctrl-f": "search::FocusSearch" + "ctrl-f": "search::FocusSearch", + "alt-find": "keymap_editor::ToggleKeystrokeSearch", + "alt-ctrl-f": "keymap_editor::ToggleKeystrokeSearch", + "alt-c": "keymap_editor::ToggleConflictFilter", + "enter": "keymap_editor::EditBinding", + "alt-enter": "keymap_editor::CreateBinding", + "ctrl-c": "keymap_editor::CopyAction", + "ctrl-shift-c": "keymap_editor::CopyContext", + "ctrl-t": "keymap_editor::ShowMatchingKeybinds" + } + }, + { + "context": "KeystrokeInput", + "use_key_equivalents": true, + "bindings": { + "enter": "keystroke_input::StartRecording", + "escape escape escape": "keystroke_input::StopRecording", + "delete": "keystroke_input::ClearKeystrokes" + } + }, + { + "context": "KeybindEditorModal", + "use_key_equivalents": true, + "bindings": { + "ctrl-enter": "menu::Confirm", + "escape": "menu::Cancel" + } + }, + { + "context": "KeybindEditorModal > Editor", + "use_key_equivalents": true, + "bindings": { + "up": "menu::SelectPrevious", + "down": "menu::SelectNext" + } + }, + { + "context": "Onboarding", + "use_key_equivalents": true, + "bindings": { + "ctrl-1": "onboarding::ActivateBasicsPage", + "ctrl-2": "onboarding::ActivateEditingPage", + "ctrl-3": "onboarding::ActivateAISetupPage", + "ctrl-escape": "onboarding::Finish", + "alt-tab": "onboarding::SignIn" } } ] diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json index cb1cf572fbe256f60937ab385e37618932330725..960bac14797692f7db9336d1c90b7156ee952688 100644 --- a/assets/keymaps/default-macos.json +++ b/assets/keymaps/default-macos.json @@ -272,7 +272,7 @@ "cmd-n": "agent::NewThread", "cmd-alt-n": "agent::NewTextThread", "cmd-shift-h": "agent::OpenHistory", - "cmd-alt-c": "agent::OpenConfiguration", + "cmd-alt-c": "agent::OpenSettings", "cmd-alt-p": "agent::OpenRulesLibrary", "cmd-i": "agent::ToggleProfileSelector", "cmd-alt-/": "agent::ToggleModelSelector", @@ -310,15 +310,15 @@ } }, { - "context": "AgentPanel && acp_thread", + "context": "AgentPanel && external_agent_thread", "use_key_equivalents": true, "bindings": { - "cmd-n": "agent::NewAcpThread", + "cmd-n": "agent::NewExternalAgentThread", "cmd-alt-t": "agent::NewThread" } }, { - "context": "MessageEditor > Editor", + "context": "MessageEditor && !Picker > Editor && !use_modifier_to_send", "use_key_equivalents": true, "bindings": { "enter": "agent::Chat", @@ -329,6 +329,18 @@ "cmd-shift-n": "agent::RejectAll" } }, + { + "context": "MessageEditor && !Picker > Editor && use_modifier_to_send", + "use_key_equivalents": true, + "bindings": { + "cmd-enter": "agent::Chat", + "enter": "editor::Newline", + "cmd-i": "agent::ToggleProfileSelector", + "shift-ctrl-r": "agent::OpenAgentDiff", + "cmd-shift-y": "agent::KeepAll", + "cmd-shift-n": "agent::RejectAll" + } + }, { "context": "EditMessageEditor > Editor", "use_key_equivalents": true, @@ -371,7 +383,10 @@ "bindings": { "enter": "agent::Chat", "up": "agent::PreviousHistoryMessage", - "down": "agent::NextHistoryMessage" + "down": "agent::NextHistoryMessage", + "shift-ctrl-r": "agent::OpenAgentDiff", + "cmd-shift-y": "agent::KeepAll", + "cmd-shift-n": "agent::RejectAll" } }, { @@ -476,7 +491,7 @@ "ctrl-shift-pageup": "pane::SwapItemLeft", "ctrl-shift-pagedown": "pane::SwapItemRight", "cmd-w": ["pane::CloseActiveItem", { "close_pinned": false }], - "alt-cmd-t": ["pane::CloseInactiveItems", { "close_pinned": false }], + "alt-cmd-t": ["pane::CloseOtherItems", { "close_pinned": false }], "ctrl-alt-cmd-w": "workspace::CloseInactiveTabsAndPanes", "cmd-k e": ["pane::CloseItemsToTheLeft", { "close_pinned": false }], "cmd-k t": ["pane::CloseItemsToTheRight", { "close_pinned": false }], @@ -524,11 +539,10 @@ "ctrl-cmd-d": ["editor::SelectPrevious", { "replace_newest": false }], // editor.action.addSelectionToPreviousFindMatch "cmd-k ctrl-cmd-d": ["editor::SelectPrevious", { "replace_newest": true }], // editor.action.moveSelectionToPreviousFindMatch "cmd-k cmd-i": "editor::Hover", + "cmd-k cmd-b": "editor::BlameHover", "cmd-/": ["editor::ToggleComments", { "advance_downwards": false }], - "cmd-u": "editor::UndoSelection", - "cmd-shift-u": "editor::RedoSelection", - "f8": "editor::GoToDiagnostic", - "shift-f8": "editor::GoToPreviousDiagnostic", + "f8": ["editor::GoToDiagnostic", { "severity": { "min": "hint", "max": "error" } }], + "shift-f8": ["editor::GoToPreviousDiagnostic", { "severity": { "min": "hint", "max": "error" } }], "f2": "editor::Rename", "f12": "editor::GoToDefinition", "alt-f12": "editor::GoToDefinitionSplit", @@ -537,7 +551,7 @@ "alt-cmd-f12": "editor::GoToTypeDefinitionSplit", "alt-shift-f12": "editor::FindAllReferences", "cmd-|": "editor::MoveToEnclosingBracket", - "ctrl-m": "editor::MoveToEnclosingBracket", + "ctrl-m": "editor::MoveToEnclosingBracket", // From Jetbrains "alt-cmd-[": "editor::Fold", "alt-cmd-]": "editor::UnfoldLines", "cmd-k cmd-l": "editor::ToggleFold", @@ -651,8 +665,9 @@ "cmd-shift-f": "pane::DeploySearch", "cmd-shift-h": ["pane::DeploySearch", { "replace_enabled": true }], "cmd-shift-t": "pane::ReopenClosedItem", - "cmd-k cmd-s": "zed::OpenKeymap", + "cmd-k cmd-s": "zed::OpenKeymapEditor", "cmd-k cmd-t": "theme_selector::Toggle", + "ctrl-alt-cmd-p": "settings_profile_selector::Toggle", "cmd-t": "project_symbols::Toggle", "cmd-p": "file_finder::Toggle", "ctrl-tab": "tab_switcher::Toggle", @@ -713,6 +728,8 @@ "context": "Editor", "use_key_equivalents": true, "bindings": { + "cmd-u": "editor::UndoSelection", + "cmd-shift-u": "editor::RedoSelection", "ctrl-j": "editor::JoinLines", "ctrl-alt-backspace": "editor::DeleteToPreviousSubwordStart", "ctrl-alt-h": "editor::DeleteToPreviousSubwordStart", @@ -890,6 +907,7 @@ "cmd-delete": ["project_panel::Delete", { "skip_prompt": false }], "alt-cmd-r": "project_panel::RevealInFileManager", "ctrl-shift-enter": "project_panel::OpenWithSystem", + "alt-d": "project_panel::CompareMarkedFiles", "cmd-alt-backspace": ["project_panel::Delete", { "skip_prompt": false }], "cmd-alt-shift-f": "project_panel::NewSearchInDirectory", "shift-down": "menu::SelectNext", @@ -929,14 +947,13 @@ "enter": "menu::Confirm", "cmd-alt-y": "git::ToggleStaged", "space": "git::ToggleStaged", + "shift-space": "git::StageRange", "cmd-y": "git::StageFile", "cmd-shift-y": "git::UnstageFile", "alt-down": "git_panel::FocusEditor", "tab": "git_panel::FocusEditor", "shift-tab": "git_panel::FocusEditor", "escape": "git_panel::ToggleFocus", - "cmd-enter": "git::Commit", - "cmd-shift-enter": "git::Amend", "backspace": ["git::RestoreFile", { "skip_prompt": false }], "delete": ["git::RestoreFile", { "skip_prompt": false }], "cmd-backspace": ["git::RestoreFile", { "skip_prompt": true }], @@ -961,7 +978,7 @@ } }, { - "context": "GitPanel > Editor", + "context": "CommitEditor > Editor", "use_key_equivalents": true, "bindings": { "enter": "editor::Newline", @@ -986,7 +1003,9 @@ "ctrl-g backspace": "git::RestoreTrackedFiles", "ctrl-g shift-backspace": "git::TrashUntrackedFiles", "cmd-ctrl-y": "git::StageAll", - "cmd-ctrl-shift-y": "git::UnstageAll" + "cmd-ctrl-shift-y": "git::UnstageAll", + "cmd-enter": "git::Commit", + "cmd-shift-enter": "git::Amend" } }, { @@ -1022,9 +1041,15 @@ "use_key_equivalents": true, "bindings": { "ctrl-backspace": "collab_panel::Remove", - "space": "menu::Confirm", - "cmd-up": "collab_panel::MoveChannelUp", - "cmd-down": "collab_panel::MoveChannelDown" + "space": "menu::Confirm" + } + }, + { + "context": "CollabPanel", + "use_key_equivalents": true, + "bindings": { + "alt-up": "collab_panel::MoveChannelUp", + "alt-down": "collab_panel::MoveChannelDown" } }, { @@ -1096,13 +1121,16 @@ "ctrl-cmd-space": "terminal::ShowCharacterPalette", "cmd-c": "terminal::Copy", "cmd-v": "terminal::Paste", + "cmd-f": "buffer_search::Deploy", "cmd-a": "editor::SelectAll", "cmd-k": "terminal::Clear", "cmd-n": "workspace::NewTerminal", "ctrl-enter": "assistant::InlineAssist", "ctrl-_": null, // emacs undo // Some nice conveniences - "cmd-backspace": ["terminal::SendText", "\u0015"], + "cmd-backspace": ["terminal::SendText", "\u0015"], // ctrl-u: clear line + "alt-delete": ["terminal::SendText", "\u001bd"], // alt-d: delete word forward + "cmd-delete": ["terminal::SendText", "\u000b"], // ctrl-k: delete to end of line "cmd-right": ["terminal::SendText", "\u0005"], "cmd-left": ["terminal::SendText", "\u0001"], // Terminal.app compatibility @@ -1177,6 +1205,13 @@ "cmd-enter": "menu::Confirm" } }, + { + "context": "OnboardingAiConfigurationModal", + "use_key_equivalents": true, + "bindings": { + "escape": "menu::Cancel" + } + }, { "context": "Diagnostics", "use_key_equivalents": true, @@ -1211,7 +1246,50 @@ "context": "KeymapEditor", "use_key_equivalents": true, "bindings": { - "cmd-f": "search::FocusSearch" + "cmd-f": "search::FocusSearch", + "cmd-alt-f": "keymap_editor::ToggleKeystrokeSearch", + "cmd-alt-c": "keymap_editor::ToggleConflictFilter", + "enter": "keymap_editor::EditBinding", + "alt-enter": "keymap_editor::CreateBinding", + "cmd-c": "keymap_editor::CopyAction", + "cmd-shift-c": "keymap_editor::CopyContext", + "cmd-t": "keymap_editor::ShowMatchingKeybinds" + } + }, + { + "context": "KeystrokeInput", + "use_key_equivalents": true, + "bindings": { + "enter": "keystroke_input::StartRecording", + "escape escape escape": "keystroke_input::StopRecording", + "delete": "keystroke_input::ClearKeystrokes" + } + }, + { + "context": "KeybindEditorModal", + "use_key_equivalents": true, + "bindings": { + "cmd-enter": "menu::Confirm", + "escape": "menu::Cancel" + } + }, + { + "context": "KeybindEditorModal > Editor", + "use_key_equivalents": true, + "bindings": { + "up": "menu::SelectPrevious", + "down": "menu::SelectNext" + } + }, + { + "context": "Onboarding", + "use_key_equivalents": true, + "bindings": { + "cmd-1": "onboarding::ActivateBasicsPage", + "cmd-2": "onboarding::ActivateEditingPage", + "cmd-3": "onboarding::ActivateAISetupPage", + "cmd-escape": "onboarding::Finish", + "alt-tab": "onboarding::SignIn" } } ] diff --git a/assets/keymaps/initial.json b/assets/keymaps/initial.json index 0cfd28f0e5d458e0bbffdbbce6cd3b53168ece57..8e4fe59f44ea7346a51e1c064ffa0553315da3b9 100644 --- a/assets/keymaps/initial.json +++ b/assets/keymaps/initial.json @@ -13,9 +13,9 @@ } }, { - "context": "Editor && vim_mode == insert && !menu", + "context": "Editor && vim_mode == insert", "bindings": { - // "j k": "vim::SwitchToNormalMode" + // "j k": "vim::NormalBefore" } } ] diff --git a/assets/keymaps/linux/cursor.json b/assets/keymaps/linux/cursor.json index 347b7885fcc6b013f62e0c6f2ca1504ecc24fb51..1c381b0cf05531e7fd5743d71be1b4d662bb4c0d 100644 --- a/assets/keymaps/linux/cursor.json +++ b/assets/keymaps/linux/cursor.json @@ -8,7 +8,7 @@ "ctrl-shift-i": "agent::ToggleFocus", "ctrl-l": "agent::ToggleFocus", "ctrl-shift-l": "agent::ToggleFocus", - "ctrl-shift-j": "agent::OpenConfiguration" + "ctrl-shift-j": "agent::OpenSettings" } }, { diff --git a/assets/keymaps/linux/emacs.json b/assets/keymaps/linux/emacs.json index 0c633efabee89e5756b36e2ea5e5f31d02a5819d..0ff3796f03d85affdae88d009e88e73516ba385a 100755 --- a/assets/keymaps/linux/emacs.json +++ b/assets/keymaps/linux/emacs.json @@ -114,7 +114,7 @@ "ctrl-x o": "workspace::ActivateNextPane", // other-window "ctrl-x k": "pane::CloseActiveItem", // kill-buffer "ctrl-x 0": "pane::CloseActiveItem", // delete-window - "ctrl-x 1": "pane::CloseInactiveItems", // delete-other-windows + "ctrl-x 1": "pane::CloseOtherItems", // delete-other-windows "ctrl-x 2": "pane::SplitDown", // split-window-below "ctrl-x 3": "pane::SplitRight", // split-window-right "ctrl-x ctrl-f": "file_finder::Toggle", // find-file diff --git a/assets/keymaps/linux/jetbrains.json b/assets/keymaps/linux/jetbrains.json index dbf50b0fcefa99063f0c8f535bd453aade4d7e56..3df1243feda88680a4ce03cd0b25ab9ea9a36edd 100644 --- a/assets/keymaps/linux/jetbrains.json +++ b/assets/keymaps/linux/jetbrains.json @@ -4,6 +4,7 @@ "ctrl-alt-s": "zed::OpenSettings", "ctrl-{": "pane::ActivatePreviousItem", "ctrl-}": "pane::ActivateNextItem", + "shift-escape": null, // Unmap workspace::zoom "ctrl-f2": "debugger::Stop", "f6": "debugger::Pause", "f7": "debugger::StepInto", @@ -44,8 +45,8 @@ "ctrl-alt-right": "pane::GoForward", "alt-f7": "editor::FindAllReferences", "ctrl-alt-f7": "editor::FindAllReferences", - // "ctrl-b": "editor::GoToDefinition", // Conflicts with workspace::ToggleLeftDock - // "ctrl-alt-b": "editor::GoToDefinitionSplit", // Conflicts with workspace::ToggleLeftDock + "ctrl-b": "editor::GoToDefinition", // Conflicts with workspace::ToggleLeftDock + "ctrl-alt-b": "editor::GoToDefinitionSplit", // Conflicts with workspace::ToggleRightDock "ctrl-shift-b": "editor::GoToTypeDefinition", "ctrl-alt-shift-b": "editor::GoToTypeDefinitionSplit", "f2": "editor::GoToDiagnostic", @@ -66,22 +67,66 @@ "context": "Editor && mode == full", "bindings": { "ctrl-f12": "outline::Toggle", - "alt-7": "outline::Toggle", + "ctrl-r": ["buffer_search::Deploy", { "replace_enabled": true }], "ctrl-shift-n": "file_finder::Toggle", "ctrl-g": "go_to_line::Toggle", "alt-enter": "editor::ToggleCodeActions" } }, + { + "context": "BufferSearchBar", + "bindings": { + "shift-enter": "search::SelectPreviousMatch" + } + }, + { + "context": "BufferSearchBar || ProjectSearchBar", + "bindings": { + "alt-c": "search::ToggleCaseSensitive", + "alt-e": "search::ToggleSelection", + "alt-x": "search::ToggleRegex", + "alt-w": "search::ToggleWholeWord" + } + }, { "context": "Workspace", "bindings": { + "ctrl-shift-f12": "workspace::CloseAllDocks", + "ctrl-shift-r": ["pane::DeploySearch", { "replace_enabled": true }], + "alt-shift-f10": "task::Spawn", + "ctrl-e": "file_finder::Toggle", + // "ctrl-k": "git_panel::ToggleFocus", // bug: This should also focus commit editor "ctrl-shift-n": "file_finder::Toggle", "ctrl-shift-a": "command_palette::Toggle", "shift shift": "command_palette::Toggle", "ctrl-alt-shift-n": "project_symbols::Toggle", - "alt-1": "workspace::ToggleLeftDock", - "ctrl-e": "tab_switcher::Toggle", - "alt-6": "diagnostics::Deploy" + "alt-0": "git_panel::ToggleFocus", + "alt-1": "project_panel::ToggleFocus", + "alt-5": "debug_panel::ToggleFocus", + "alt-6": "diagnostics::Deploy", + "alt-7": "outline_panel::ToggleFocus" + } + }, + { + "context": "Pane", // this is to override the default Pane mappings to switch tabs + "bindings": { + "alt-1": "project_panel::ToggleFocus", + "alt-2": null, // Bookmarks (left dock) + "alt-3": null, // Find Panel (bottom dock) + "alt-4": null, // Run Panel (bottom dock) + "alt-5": "debug_panel::ToggleFocus", + "alt-6": "diagnostics::Deploy", + "alt-7": "outline_panel::ToggleFocus", + "alt-8": null, // Services (bottom dock) + "alt-9": null, // Git History (bottom dock) + "alt-0": "git_panel::ToggleFocus" + } + }, + { + "context": "Workspace || Editor", + "bindings": { + "alt-f12": "terminal_panel::ToggleFocus", + "ctrl-shift-k": "git::Push" } }, { @@ -95,10 +140,36 @@ "context": "ProjectPanel", "bindings": { "enter": "project_panel::Open", + "ctrl-shift-f": "project_panel::NewSearchInDirectory", "backspace": ["project_panel::Trash", { "skip_prompt": false }], "delete": ["project_panel::Trash", { "skip_prompt": false }], "shift-delete": ["project_panel::Delete", { "skip_prompt": false }], "shift-f6": "project_panel::Rename" } + }, + { + "context": "Terminal", + "bindings": { + "ctrl-shift-t": "workspace::NewTerminal", + "alt-f12": "workspace::CloseActiveDock", + "alt-left": "pane::ActivatePreviousItem", + "alt-right": "pane::ActivateNextItem", + "ctrl-up": "terminal::ScrollLineUp", + "ctrl-down": "terminal::ScrollLineDown", + "shift-pageup": "terminal::ScrollPageUp", + "shift-pagedown": "terminal::ScrollPageDown" + } + }, + { "context": "GitPanel", "bindings": { "alt-0": "workspace::CloseActiveDock" } }, + { "context": "ProjectPanel", "bindings": { "alt-1": "workspace::CloseActiveDock" } }, + { "context": "DebugPanel", "bindings": { "alt-5": "workspace::CloseActiveDock" } }, + { "context": "Diagnostics > Editor", "bindings": { "alt-6": "pane::CloseActiveItem" } }, + { "context": "OutlinePanel", "bindings": { "alt-7": "workspace::CloseActiveDock" } }, + { + "context": "Dock || Workspace || OutlinePanel || ProjectPanel || CollabPanel || (Editor && mode == auto_height)", + "bindings": { + "escape": "editor::ToggleFocus", + "shift-escape": "workspace::CloseActiveDock" + } } ] diff --git a/assets/keymaps/macos/cursor.json b/assets/keymaps/macos/cursor.json index b1d39bef9eb1397ceaeb0fb82956f14a0391b068..fdf9c437cf395c074e42ae9c9dc53c1aa6ff66c2 100644 --- a/assets/keymaps/macos/cursor.json +++ b/assets/keymaps/macos/cursor.json @@ -8,7 +8,7 @@ "cmd-shift-i": "agent::ToggleFocus", "cmd-l": "agent::ToggleFocus", "cmd-shift-l": "agent::ToggleFocus", - "cmd-shift-j": "agent::OpenConfiguration" + "cmd-shift-j": "agent::OpenSettings" } }, { diff --git a/assets/keymaps/macos/emacs.json b/assets/keymaps/macos/emacs.json index 0c633efabee89e5756b36e2ea5e5f31d02a5819d..0ff3796f03d85affdae88d009e88e73516ba385a 100755 --- a/assets/keymaps/macos/emacs.json +++ b/assets/keymaps/macos/emacs.json @@ -114,7 +114,7 @@ "ctrl-x o": "workspace::ActivateNextPane", // other-window "ctrl-x k": "pane::CloseActiveItem", // kill-buffer "ctrl-x 0": "pane::CloseActiveItem", // delete-window - "ctrl-x 1": "pane::CloseInactiveItems", // delete-other-windows + "ctrl-x 1": "pane::CloseOtherItems", // delete-other-windows "ctrl-x 2": "pane::SplitDown", // split-window-below "ctrl-x 3": "pane::SplitRight", // split-window-right "ctrl-x ctrl-f": "file_finder::Toggle", // find-file diff --git a/assets/keymaps/macos/jetbrains.json b/assets/keymaps/macos/jetbrains.json index 22c6f18383a32f5def1869b953e31baf665404b9..66962811f48a429f2f5d036241c64d6549f60334 100644 --- a/assets/keymaps/macos/jetbrains.json +++ b/assets/keymaps/macos/jetbrains.json @@ -3,6 +3,8 @@ "bindings": { "cmd-{": "pane::ActivatePreviousItem", "cmd-}": "pane::ActivateNextItem", + "cmd-0": "git_panel::ToggleFocus", // overrides `cmd-0` zoom reset + "shift-escape": null, // Unmap workspace::zoom "ctrl-f2": "debugger::Stop", "f6": "debugger::Pause", "f7": "debugger::StepInto", @@ -63,28 +65,70 @@ "context": "Editor && mode == full", "bindings": { "cmd-f12": "outline::Toggle", - "cmd-7": "outline::Toggle", + "cmd-r": ["buffer_search::Deploy", { "replace_enabled": true }], "cmd-shift-o": "file_finder::Toggle", "cmd-l": "go_to_line::Toggle", "alt-enter": "editor::ToggleCodeActions" } }, { - "context": "BufferSearchBar > Editor", + "context": "BufferSearchBar", "bindings": { "shift-enter": "search::SelectPreviousMatch" } }, + { + "context": "BufferSearchBar || ProjectSearchBar", + "bindings": { + "alt-c": "search::ToggleCaseSensitive", + "alt-e": "search::ToggleSelection", + "alt-x": "search::ToggleRegex", + "alt-w": "search::ToggleWholeWord", + "ctrl-alt-c": "search::ToggleCaseSensitive", + "ctrl-alt-e": "search::ToggleSelection", + "ctrl-alt-w": "search::ToggleWholeWord", + "ctrl-alt-x": "search::ToggleRegex" + } + }, { "context": "Workspace", "bindings": { + "cmd-shift-f12": "workspace::CloseAllDocks", + "cmd-shift-r": ["pane::DeploySearch", { "replace_enabled": true }], + "ctrl-alt-r": "task::Spawn", + "cmd-e": "file_finder::Toggle", + // "cmd-k": "git_panel::ToggleFocus", // bug: This should also focus commit editor "cmd-shift-o": "file_finder::Toggle", "cmd-shift-a": "command_palette::Toggle", "shift shift": "command_palette::Toggle", "cmd-alt-o": "project_symbols::Toggle", // JetBrains: Go to Symbol "cmd-o": "project_symbols::Toggle", // JetBrains: Go to Class - "cmd-1": "workspace::ToggleLeftDock", - "cmd-6": "diagnostics::Deploy" + "cmd-1": "project_panel::ToggleFocus", + "cmd-5": "debug_panel::ToggleFocus", + "cmd-6": "diagnostics::Deploy", + "cmd-7": "outline_panel::ToggleFocus" + } + }, + { + "context": "Pane", // this is to override the default Pane mappings to switch tabs + "bindings": { + "cmd-1": "project_panel::ToggleFocus", + "cmd-2": null, // Bookmarks (left dock) + "cmd-3": null, // Find Panel (bottom dock) + "cmd-4": null, // Run Panel (bottom dock) + "cmd-5": "debug_panel::ToggleFocus", + "cmd-6": "diagnostics::Deploy", + "cmd-7": "outline_panel::ToggleFocus", + "cmd-8": null, // Services (bottom dock) + "cmd-9": null, // Git History (bottom dock) + "cmd-0": "git_panel::ToggleFocus" + } + }, + { + "context": "Workspace || Editor", + "bindings": { + "alt-f12": "terminal_panel::ToggleFocus", + "cmd-shift-k": "git::Push" } }, { @@ -98,11 +142,35 @@ "context": "ProjectPanel", "bindings": { "enter": "project_panel::Open", + "cmd-shift-f": "project_panel::NewSearchInDirectory", "cmd-backspace": ["project_panel::Trash", { "skip_prompt": false }], "backspace": ["project_panel::Trash", { "skip_prompt": false }], "delete": ["project_panel::Trash", { "skip_prompt": false }], "shift-delete": ["project_panel::Delete", { "skip_prompt": false }], "shift-f6": "project_panel::Rename" } + }, + { + "context": "Terminal", + "bindings": { + "cmd-t": "workspace::NewTerminal", + "alt-f12": "workspace::CloseActiveDock", + "cmd-up": "terminal::ScrollLineUp", + "cmd-down": "terminal::ScrollLineDown", + "shift-pageup": "terminal::ScrollPageUp", + "shift-pagedown": "terminal::ScrollPageDown" + } + }, + { "context": "GitPanel", "bindings": { "cmd-0": "workspace::CloseActiveDock" } }, + { "context": "ProjectPanel", "bindings": { "cmd-1": "workspace::CloseActiveDock" } }, + { "context": "DebugPanel", "bindings": { "cmd-5": "workspace::CloseActiveDock" } }, + { "context": "Diagnostics > Editor", "bindings": { "cmd-6": "pane::CloseActiveItem" } }, + { "context": "OutlinePanel", "bindings": { "cmd-7": "workspace::CloseActiveDock" } }, + { + "context": "Dock || Workspace || OutlinePanel || ProjectPanel || CollabPanel || (Editor && mode == auto_height)", + "bindings": { + "escape": "editor::ToggleFocus", + "shift-escape": "workspace::CloseActiveDock" + } } ] diff --git a/assets/keymaps/macos/textmate.json b/assets/keymaps/macos/textmate.json index dccb675f6c9734c6da3b797cae199037591085fc..0bd8873b1749d2423d97df480b1aadeb28fe9bab 100644 --- a/assets/keymaps/macos/textmate.json +++ b/assets/keymaps/macos/textmate.json @@ -6,7 +6,7 @@ } }, { - "context": "Editor", + "context": "Editor && mode == full", "bindings": { "cmd-l": "go_to_line::Toggle", "ctrl-shift-d": "editor::DuplicateLineDown", @@ -15,7 +15,12 @@ "cmd-enter": "editor::NewlineBelow", "cmd-alt-enter": "editor::NewlineAbove", "cmd-shift-l": "editor::SelectLine", - "cmd-shift-t": "outline::Toggle", + "cmd-shift-t": "outline::Toggle" + } + }, + { + "context": "Editor", + "bindings": { "alt-backspace": "editor::DeleteToPreviousWordStart", "alt-shift-backspace": "editor::DeleteToNextWordEnd", "alt-delete": "editor::DeleteToNextWordEnd", @@ -39,10 +44,6 @@ "ctrl-_": "editor::ConvertToSnakeCase" } }, - { - "context": "Editor && mode == full", - "bindings": {} - }, { "context": "BufferSearchBar", "bindings": { diff --git a/assets/keymaps/vim.json b/assets/keymaps/vim.json index 571192a4791846011318238ade9aad84091bca4d..57edb1e4c1c534ce90b6c5534d4ceebddd4f9a38 100644 --- a/assets/keymaps/vim.json +++ b/assets/keymaps/vim.json @@ -124,6 +124,7 @@ "g r a": "editor::ToggleCodeActions", "g g": "vim::StartOfDocument", "g h": "editor::Hover", + "g B": "editor::BlameHover", "g t": "pane::ActivateNextItem", "g shift-t": "pane::ActivatePreviousItem", "g d": "editor::GoToDefinition", @@ -219,6 +220,8 @@ { "context": "vim_mode == normal", "bindings": { + "i": "vim::InsertBefore", + "a": "vim::InsertAfter", "ctrl-[": "editor::Cancel", ":": "command_palette::Toggle", "c": "vim::PushChange", @@ -352,9 +355,7 @@ "shift-d": "vim::DeleteToEndOfLine", "shift-j": "vim::JoinLines", "shift-y": "vim::YankLine", - "i": "vim::InsertBefore", "shift-i": "vim::InsertFirstNonWhitespace", - "a": "vim::InsertAfter", "shift-a": "vim::InsertEndOfLine", "o": "vim::InsertLineBelow", "shift-o": "vim::InsertLineAbove", @@ -376,7 +377,10 @@ { "context": "vim_mode == helix_normal && !menu", "bindings": { + "i": "vim::HelixInsert", + "a": "vim::HelixAppend", "ctrl-[": "editor::Cancel", + ";": "vim::HelixCollapseSelection", ":": "command_palette::Toggle", "left": "vim::WrappingLeft", "right": "vim::WrappingRight", @@ -466,7 +470,7 @@ } }, { - "context": "vim_mode == insert && showing_signature_help && !showing_completions", + "context": "(vim_mode == insert || vim_mode == normal) && showing_signature_help && !showing_completions", "bindings": { "ctrl-p": "editor::SignatureHelpPrevious", "ctrl-n": "editor::SignatureHelpNext" @@ -723,7 +727,7 @@ } }, { - "context": "AgentPanel || GitPanel || ProjectPanel || CollabPanel || OutlinePanel || ChatPanel || VimControl || EmptyPane || SharedScreen || MarkdownPreview || KeyContextView || DebugPanel", + "context": "VimControl || !Editor && !Terminal", "bindings": { // window related commands (ctrl-w X) "ctrl-w": null, @@ -781,7 +785,7 @@ } }, { - "context": "ChangesList || EmptyPane || SharedScreen || MarkdownPreview || KeyContextView || Welcome", + "context": "!Editor && !Terminal", "bindings": { ":": "command_palette::Toggle", "g /": "pane::DeploySearch" @@ -809,6 +813,7 @@ "p": "project_panel::Open", "x": "project_panel::RevealInFileManager", "s": "project_panel::OpenWithSystem", + "z d": "project_panel::CompareMarkedFiles", "] c": "project_panel::SelectNextGitEntry", "[ c": "project_panel::SelectPrevGitEntry", "] d": "project_panel::SelectNextDiagnostic", @@ -841,6 +846,7 @@ "i": "git_panel::FocusEditor", "x": "git::ToggleStaged", "shift-x": "git::StageAll", + "g x": "git::StageRange", "shift-u": "git::UnstageAll" } }, @@ -856,6 +862,14 @@ "shift-n": null } }, + { + "context": "Picker > Editor", + "bindings": { + "ctrl-h": "editor::Backspace", + "ctrl-u": "editor::DeleteToBeginningOfLine", + "ctrl-w": "editor::DeleteToPreviousWordStart" + } + }, { "context": "GitCommit > Editor && VimControl && vim_mode == normal", "bindings": { diff --git a/assets/settings/default.json b/assets/settings/default.json index dc892bd6a33443df5d3d2a0fa2b8a8d9fca41acd..4734b5d1188b7175f8af2ea4adaadee0157ff0fc 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -84,7 +84,7 @@ "bottom_dock_layout": "contained", // The direction that you want to split panes horizontally. Defaults to "up" "pane_split_direction_horizontal": "up", - // The direction that you want to split panes horizontally. Defaults to "left" + // The direction that you want to split panes vertically. Defaults to "left" "pane_split_direction_vertical": "left", // Centered layout related settings. "centered_layout": { @@ -197,6 +197,8 @@ // "inline" // 3. Place snippets at the bottom of the completion list: // "bottom" + // 4. Do not show snippets in the completion list: + // "none" "snippet_sort_order": "inline", // How to highlight the current line in the editor. // @@ -689,7 +691,10 @@ // 5. Never show the scrollbar: // "never" "show": null - } + }, + // Default depth to expand outline items in the current file. + // Set to 0 to collapse all items that have children, 1 or higher to collapse items at that depth or deeper. + "expand_outlines_with_depth": 100 }, "collaboration_panel": { // Whether to show the collaboration panel button in the status bar. @@ -817,7 +822,7 @@ "edit_file": true, "fetch": true, "list_directory": true, - "project_notifications": true, + "project_notifications": false, "move_path": true, "now": true, "find_path": true, @@ -837,7 +842,7 @@ "diagnostics": true, "fetch": true, "list_directory": true, - "project_notifications": true, + "project_notifications": false, "now": true, "find_path": true, "read_file": true, @@ -1074,6 +1079,10 @@ // Send anonymized usage data like what languages you're using Zed with. "metrics": true }, + // Whether to disable all AI features in Zed. + // + // Default: false + "disable_ai": false, // Automatically update Zed. This setting may be ignored on Linux if // installed through a package manager. "auto_update": true, @@ -1135,6 +1144,7 @@ "**/.svn", "**/.hg", "**/.jj", + "**/.repo", "**/CVS", "**/.DS_Store", "**/Thumbs.db", @@ -1670,6 +1680,10 @@ "allowed": true } }, + "SystemVerilog": { + "format_on_save": "off", + "use_on_type_format": false + }, "Vue.js": { "language_servers": ["vue-language-server", "..."], "prettier": { @@ -1705,6 +1719,7 @@ "openai": { "api_url": "https://api.openai.com/v1" }, + "openai_compatible": {}, "open_router": { "api_url": "https://openrouter.ai/api/v1" }, @@ -1862,5 +1877,25 @@ "save_breakpoints": true, "dock": "bottom", "button": true - } + }, + // Configures any number of settings profiles that are temporarily applied on + // top of your existing user settings when selected from + // `settings profile selector: toggle`. + // Examples: + // "profiles": { + // "Presenting": { + // "agent_font_size": 20.0, + // "buffer_font_size": 20.0, + // "theme": "One Light", + // "ui_font_size": 20.0 + // }, + // "Python (ty)": { + // "languages": { + // "Python": { + // "language_servers": ["ty"] + // } + // } + // } + // } + "profiles": [] } diff --git a/assets/settings/initial_debug_tasks.json b/assets/settings/initial_debug_tasks.json index 78fc1fc5f02a03bc83c93a4cf5cc7c517fd301c7..af4512bd51aa82d57ce62e605b45ee61e8f98030 100644 --- a/assets/settings/initial_debug_tasks.json +++ b/assets/settings/initial_debug_tasks.json @@ -15,13 +15,15 @@ "adapter": "JavaScript", "program": "$ZED_FILE", "request": "launch", - "cwd": "$ZED_WORKTREE_ROOT" + "cwd": "$ZED_WORKTREE_ROOT", + "type": "pwa-node" }, { "label": "JavaScript debug terminal", "adapter": "JavaScript", "request": "launch", "cwd": "$ZED_WORKTREE_ROOT", - "console": "integratedTerminal" + "console": "integratedTerminal", + "type": "pwa-node" } ] diff --git a/assets/settings/initial_user_settings.json b/assets/settings/initial_user_settings.json index 71f3beb1d6076ed5a41064291a83662ee7023f03..5ac2063bdb481e057a2d124c1e72f998390b066b 100644 --- a/assets/settings/initial_user_settings.json +++ b/assets/settings/initial_user_settings.json @@ -8,7 +8,7 @@ // command palette (cmd-shift-p / ctrl-shift-p) { "ui_font_size": 16, - "buffer_font_size": 16, + "buffer_font_size": 15, "theme": { "mode": "system", "light": "One Light", diff --git a/compose.yml b/compose.yml index 4cd4c86df646fdd142df1206d18d78f7a4267083..d0d9bac425356687bfb33efab9ee24e76d1b30a0 100644 --- a/compose.yml +++ b/compose.yml @@ -59,5 +59,11 @@ services: depends_on: - postgres + stripe-mock: + image: stripe/stripe-mock:v0.178.0 + ports: + - 12111:12111 + - 12112:12112 + volumes: postgres_data: diff --git a/crates/acp/src/acp.rs b/crates/acp/src/acp.rs deleted file mode 100644 index ddb7c50f7a2f28156fd5c9c1eb3460e3771a1fbd..0000000000000000000000000000000000000000 --- a/crates/acp/src/acp.rs +++ /dev/null @@ -1,1645 +0,0 @@ -pub use acp::ToolCallId; -use agent_servers::AgentServer; -use agentic_coding_protocol::{self as acp, UserMessageChunk}; -use anyhow::{Context as _, Result, anyhow}; -use buffer_diff::BufferDiff; -use editor::{MultiBuffer, PathKey}; -use futures::{FutureExt, channel::oneshot, future::BoxFuture}; -use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity}; -use itertools::Itertools; -use language::{Anchor, Buffer, Capability, LanguageRegistry, OffsetRangeExt as _}; -use markdown::Markdown; -use project::Project; -use std::error::Error; -use std::fmt::{Formatter, Write}; -use std::{ - fmt::Display, - mem, - path::{Path, PathBuf}, - sync::Arc, -}; -use ui::{App, IconName}; -use util::ResultExt; - -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct UserMessage { - pub content: Entity, -} - -impl UserMessage { - pub fn from_acp( - message: &acp::SendUserMessageParams, - language_registry: Arc, - cx: &mut App, - ) -> Self { - let mut md_source = String::new(); - - for chunk in &message.chunks { - match chunk { - UserMessageChunk::Text { text } => md_source.push_str(&text), - UserMessageChunk::Path { path } => { - write!(&mut md_source, "{}", MentionPath(&path)).unwrap() - } - } - } - - Self { - content: cx - .new(|cx| Markdown::new(md_source.into(), Some(language_registry), None, cx)), - } - } - - fn to_markdown(&self, cx: &App) -> String { - format!("## User\n\n{}\n\n", self.content.read(cx).source()) - } -} - -#[derive(Debug)] -pub struct MentionPath<'a>(&'a Path); - -impl<'a> MentionPath<'a> { - const PREFIX: &'static str = "@file:"; - - pub fn new(path: &'a Path) -> Self { - MentionPath(path) - } - - pub fn try_parse(url: &'a str) -> Option { - let path = url.strip_prefix(Self::PREFIX)?; - Some(MentionPath(Path::new(path))) - } - - pub fn path(&self) -> &Path { - self.0 - } -} - -impl Display for MentionPath<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "[@{}]({}{})", - self.0.file_name().unwrap_or_default().display(), - Self::PREFIX, - self.0.display() - ) - } -} - -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct AssistantMessage { - pub chunks: Vec, -} - -impl AssistantMessage { - fn to_markdown(&self, cx: &App) -> String { - format!( - "## Assistant\n\n{}\n\n", - self.chunks - .iter() - .map(|chunk| chunk.to_markdown(cx)) - .join("\n\n") - ) - } -} - -#[derive(Clone, Debug, Eq, PartialEq)] -pub enum AssistantMessageChunk { - Text { chunk: Entity }, - Thought { chunk: Entity }, -} - -impl AssistantMessageChunk { - pub fn from_acp( - chunk: acp::AssistantMessageChunk, - language_registry: Arc, - cx: &mut App, - ) -> Self { - match chunk { - acp::AssistantMessageChunk::Text { text } => Self::Text { - chunk: cx.new(|cx| Markdown::new(text.into(), Some(language_registry), None, cx)), - }, - acp::AssistantMessageChunk::Thought { thought } => Self::Thought { - chunk: cx - .new(|cx| Markdown::new(thought.into(), Some(language_registry), None, cx)), - }, - } - } - - pub fn from_str(chunk: &str, language_registry: Arc, cx: &mut App) -> Self { - Self::Text { - chunk: cx.new(|cx| { - Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx) - }), - } - } - - fn to_markdown(&self, cx: &App) -> String { - match self { - Self::Text { chunk } => chunk.read(cx).source().to_string(), - Self::Thought { chunk } => { - format!("\n{}\n", chunk.read(cx).source()) - } - } - } -} - -#[derive(Debug)] -pub enum AgentThreadEntry { - UserMessage(UserMessage), - AssistantMessage(AssistantMessage), - ToolCall(ToolCall), -} - -impl AgentThreadEntry { - fn to_markdown(&self, cx: &App) -> String { - match self { - Self::UserMessage(message) => message.to_markdown(cx), - Self::AssistantMessage(message) => message.to_markdown(cx), - Self::ToolCall(too_call) => too_call.to_markdown(cx), - } - } -} - -#[derive(Debug)] -pub struct ToolCall { - pub id: acp::ToolCallId, - pub label: Entity, - pub icon: IconName, - pub content: Option, - pub status: ToolCallStatus, -} - -impl ToolCall { - fn to_markdown(&self, cx: &App) -> String { - let mut markdown = format!( - "**Tool Call: {}**\nStatus: {}\n\n", - self.label.read(cx).source(), - self.status - ); - if let Some(content) = &self.content { - markdown.push_str(content.to_markdown(cx).as_str()); - markdown.push_str("\n\n"); - } - markdown - } -} - -#[derive(Debug)] -pub enum ToolCallStatus { - WaitingForConfirmation { - confirmation: ToolCallConfirmation, - respond_tx: oneshot::Sender, - }, - Allowed { - status: acp::ToolCallStatus, - }, - Rejected, - Canceled, -} - -impl Display for ToolCallStatus { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self { - ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation", - ToolCallStatus::Allowed { status } => match status { - acp::ToolCallStatus::Running => "Running", - acp::ToolCallStatus::Finished => "Finished", - acp::ToolCallStatus::Error => "Error", - }, - ToolCallStatus::Rejected => "Rejected", - ToolCallStatus::Canceled => "Canceled", - } - ) - } -} - -#[derive(Debug)] -pub enum ToolCallConfirmation { - Edit { - description: Option>, - }, - Execute { - command: String, - root_command: String, - description: Option>, - }, - Mcp { - server_name: String, - tool_name: String, - tool_display_name: String, - description: Option>, - }, - Fetch { - urls: Vec, - description: Option>, - }, - Other { - description: Entity, - }, -} - -impl ToolCallConfirmation { - pub fn from_acp( - confirmation: acp::ToolCallConfirmation, - language_registry: Arc, - cx: &mut App, - ) -> Self { - let to_md = |description: String, cx: &mut App| -> Entity { - cx.new(|cx| { - Markdown::new( - description.into(), - Some(language_registry.clone()), - None, - cx, - ) - }) - }; - - match confirmation { - acp::ToolCallConfirmation::Edit { description } => Self::Edit { - description: description.map(|description| to_md(description, cx)), - }, - acp::ToolCallConfirmation::Execute { - command, - root_command, - description, - } => Self::Execute { - command, - root_command, - description: description.map(|description| to_md(description, cx)), - }, - acp::ToolCallConfirmation::Mcp { - server_name, - tool_name, - tool_display_name, - description, - } => Self::Mcp { - server_name, - tool_name, - tool_display_name, - description: description.map(|description| to_md(description, cx)), - }, - acp::ToolCallConfirmation::Fetch { urls, description } => Self::Fetch { - urls: urls.iter().map(|url| url.into()).collect(), - description: description.map(|description| to_md(description, cx)), - }, - acp::ToolCallConfirmation::Other { description } => Self::Other { - description: to_md(description, cx), - }, - } - } -} - -#[derive(Debug)] -pub enum ToolCallContent { - Markdown { markdown: Entity }, - Diff { diff: Diff }, -} - -impl ToolCallContent { - pub fn from_acp( - content: acp::ToolCallContent, - language_registry: Arc, - cx: &mut App, - ) -> Self { - match content { - acp::ToolCallContent::Markdown { markdown } => Self::Markdown { - markdown: cx.new(|cx| Markdown::new_text(markdown.into(), cx)), - }, - acp::ToolCallContent::Diff { diff } => Self::Diff { - diff: Diff::from_acp(diff, language_registry, cx), - }, - } - } - - fn to_markdown(&self, cx: &App) -> String { - match self { - Self::Markdown { markdown } => markdown.read(cx).source().to_string(), - Self::Diff { diff } => diff.to_markdown(cx), - } - } -} - -#[derive(Debug)] -pub struct Diff { - pub multibuffer: Entity, - pub path: PathBuf, - _task: Task>, -} - -impl Diff { - pub fn from_acp( - diff: acp::Diff, - language_registry: Arc, - cx: &mut App, - ) -> Self { - let acp::Diff { - path, - old_text, - new_text, - } = diff; - - let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly)); - - let new_buffer = cx.new(|cx| Buffer::local(new_text, cx)); - let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx)); - let new_buffer_snapshot = new_buffer.read(cx).text_snapshot(); - let old_buffer_snapshot = old_buffer.read(cx).snapshot(); - let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx)); - let diff_task = buffer_diff.update(cx, |diff, cx| { - diff.set_base_text( - old_buffer_snapshot, - Some(language_registry.clone()), - new_buffer_snapshot, - cx, - ) - }); - - let task = cx.spawn({ - let multibuffer = multibuffer.clone(); - let path = path.clone(); - async move |cx| { - diff_task.await?; - - multibuffer - .update(cx, |multibuffer, cx| { - let hunk_ranges = { - let buffer = new_buffer.read(cx); - let diff = buffer_diff.read(cx); - diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx) - .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer)) - .collect::>() - }; - - multibuffer.set_excerpts_for_path( - PathKey::for_buffer(&new_buffer, cx), - new_buffer.clone(), - hunk_ranges, - editor::DEFAULT_MULTIBUFFER_CONTEXT, - cx, - ); - multibuffer.add_diff(buffer_diff.clone(), cx); - }) - .log_err(); - - if let Some(language) = language_registry - .language_for_file_path(&path) - .await - .log_err() - { - new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?; - } - - anyhow::Ok(()) - } - }); - - Self { - multibuffer, - path, - _task: task, - } - } - - fn to_markdown(&self, cx: &App) -> String { - let buffer_text = self - .multibuffer - .read(cx) - .all_buffers() - .iter() - .map(|buffer| buffer.read(cx).text()) - .join("\n"); - format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text) - } -} - -pub struct AcpThread { - entries: Vec, - title: SharedString, - project: Entity, - send_task: Option>, - connection: Arc, - child_status: Option>>, - _io_task: Task<()>, -} - -pub enum AcpThreadEvent { - NewEntry, - EntryUpdated(usize), -} - -impl EventEmitter for AcpThread {} - -#[derive(PartialEq, Eq)] -pub enum ThreadStatus { - Idle, - WaitingForToolConfirmation, - Generating, -} - -#[derive(Debug, Clone)] -pub enum LoadError { - Unsupported { current_version: SharedString }, - Exited(i32), - Other(SharedString), -} - -impl Display for LoadError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - LoadError::Unsupported { current_version } => { - write!( - f, - "Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).", - current_version - ) - } - LoadError::Exited(status) => write!(f, "Server exited with status {}", status), - LoadError::Other(msg) => write!(f, "{}", msg), - } - } -} - -impl Error for LoadError {} - -impl AcpThread { - pub async fn spawn( - server: impl AgentServer + 'static, - root_dir: &Path, - project: Entity, - cx: &mut AsyncApp, - ) -> Result> { - let command = match server.command(&project, cx).await { - Ok(command) => command, - Err(e) => return Err(anyhow!(LoadError::Other(format!("{e}").into()))), - }; - - let mut child = util::command::new_smol_command(&command.path) - .args(command.args.iter()) - .current_dir(root_dir) - .stdin(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::inherit()) - .kill_on_drop(true) - .spawn()?; - - let stdin = child.stdin.take().unwrap(); - let stdout = child.stdout.take().unwrap(); - - cx.new(|cx| { - let foreground_executor = cx.foreground_executor().clone(); - - let (connection, io_fut) = acp::AgentConnection::connect_to_agent( - AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()), - stdin, - stdout, - move |fut| foreground_executor.spawn(fut).detach(), - ); - - let io_task = cx.background_spawn(async move { - io_fut.await.log_err(); - }); - - let child_status = cx.background_spawn(async move { - match child.status().await { - Err(e) => Err(anyhow!(e)), - Ok(result) if result.success() => Ok(()), - Ok(result) => { - if let Some(version) = server.version(&command).await.log_err() - && !version.supported - { - Err(anyhow!(LoadError::Unsupported { - current_version: version.current_version - })) - } else { - Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127)))) - } - } - } - }); - - Self { - entries: Default::default(), - title: "ACP Thread".into(), - project, - send_task: None, - connection: Arc::new(connection), - child_status: Some(child_status), - _io_task: io_task, - } - }) - } - - #[cfg(test)] - pub fn fake( - stdin: async_pipe::PipeWriter, - stdout: async_pipe::PipeReader, - project: Entity, - cx: &mut Context, - ) -> Self { - let foreground_executor = cx.foreground_executor().clone(); - - let (connection, io_fut) = acp::AgentConnection::connect_to_agent( - AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()), - stdin, - stdout, - move |fut| { - foreground_executor.spawn(fut).detach(); - }, - ); - - let io_task = cx.background_spawn({ - async move { - io_fut.await.log_err(); - } - }); - - Self { - entries: Default::default(), - title: "ACP Thread".into(), - project, - send_task: None, - connection: Arc::new(connection), - child_status: None, - _io_task: io_task, - } - } - - pub fn title(&self) -> SharedString { - self.title.clone() - } - - pub fn entries(&self) -> &[AgentThreadEntry] { - &self.entries - } - - pub fn status(&self) -> ThreadStatus { - if self.send_task.is_some() { - if self.waiting_for_tool_confirmation() { - ThreadStatus::WaitingForToolConfirmation - } else { - ThreadStatus::Generating - } - } else { - ThreadStatus::Idle - } - } - - pub fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context) { - self.entries.push(entry); - cx.emit(AcpThreadEvent::NewEntry); - } - - pub fn push_assistant_chunk( - &mut self, - chunk: acp::AssistantMessageChunk, - cx: &mut Context, - ) { - let entries_len = self.entries.len(); - if let Some(last_entry) = self.entries.last_mut() - && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry - { - cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1)); - - match (chunks.last_mut(), &chunk) { - ( - Some(AssistantMessageChunk::Text { chunk: old_chunk }), - acp::AssistantMessageChunk::Text { text: new_chunk }, - ) - | ( - Some(AssistantMessageChunk::Thought { chunk: old_chunk }), - acp::AssistantMessageChunk::Thought { thought: new_chunk }, - ) => { - old_chunk.update(cx, |old_chunk, cx| { - old_chunk.append(&new_chunk, cx); - }); - } - _ => { - chunks.push(AssistantMessageChunk::from_acp( - chunk, - self.project.read(cx).languages().clone(), - cx, - )); - } - } - } else { - let chunk = AssistantMessageChunk::from_acp( - chunk, - self.project.read(cx).languages().clone(), - cx, - ); - - self.push_entry( - AgentThreadEntry::AssistantMessage(AssistantMessage { - chunks: vec![chunk], - }), - cx, - ); - } - } - - pub fn request_tool_call( - &mut self, - label: String, - icon: acp::Icon, - content: Option, - confirmation: acp::ToolCallConfirmation, - cx: &mut Context, - ) -> ToolCallRequest { - let (tx, rx) = oneshot::channel(); - - let status = ToolCallStatus::WaitingForConfirmation { - confirmation: ToolCallConfirmation::from_acp( - confirmation, - self.project.read(cx).languages().clone(), - cx, - ), - respond_tx: tx, - }; - - let id = self.insert_tool_call(label, status, icon, content, cx); - ToolCallRequest { id, outcome: rx } - } - - pub fn push_tool_call( - &mut self, - label: String, - icon: acp::Icon, - content: Option, - cx: &mut Context, - ) -> acp::ToolCallId { - let status = ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Running, - }; - - self.insert_tool_call(label, status, icon, content, cx) - } - - fn insert_tool_call( - &mut self, - label: String, - status: ToolCallStatus, - icon: acp::Icon, - content: Option, - cx: &mut Context, - ) -> acp::ToolCallId { - let language_registry = self.project.read(cx).languages().clone(); - let id = acp::ToolCallId(self.entries.len() as u64); - - self.push_entry( - AgentThreadEntry::ToolCall(ToolCall { - id, - label: cx.new(|cx| { - Markdown::new(label.into(), Some(language_registry.clone()), None, cx) - }), - icon: acp_icon_to_ui_icon(icon), - content: content - .map(|content| ToolCallContent::from_acp(content, language_registry, cx)), - status, - }), - cx, - ); - - id - } - - pub fn authorize_tool_call( - &mut self, - id: acp::ToolCallId, - outcome: acp::ToolCallConfirmationOutcome, - cx: &mut Context, - ) { - let Some((ix, call)) = self.tool_call_mut(id) else { - return; - }; - - let new_status = if outcome == acp::ToolCallConfirmationOutcome::Reject { - ToolCallStatus::Rejected - } else { - ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Running, - } - }; - - let curr_status = mem::replace(&mut call.status, new_status); - - if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status { - respond_tx.send(outcome).log_err(); - } else if cfg!(debug_assertions) { - panic!("tried to authorize an already authorized tool call"); - } - - cx.emit(AcpThreadEvent::EntryUpdated(ix)); - } - - pub fn update_tool_call( - &mut self, - id: acp::ToolCallId, - new_status: acp::ToolCallStatus, - new_content: Option, - cx: &mut Context, - ) -> Result<()> { - let language_registry = self.project.read(cx).languages().clone(); - let (ix, call) = self.tool_call_mut(id).context("Entry not found")?; - - call.content = new_content - .map(|new_content| ToolCallContent::from_acp(new_content, language_registry, cx)); - - match &mut call.status { - ToolCallStatus::Allowed { status } => { - *status = new_status; - } - ToolCallStatus::WaitingForConfirmation { .. } => { - anyhow::bail!("Tool call hasn't been authorized yet") - } - ToolCallStatus::Rejected => { - anyhow::bail!("Tool call was rejected and therefore can't be updated") - } - ToolCallStatus::Canceled => { - call.status = ToolCallStatus::Allowed { status: new_status }; - } - } - - cx.emit(AcpThreadEvent::EntryUpdated(ix)); - Ok(()) - } - - fn tool_call_mut(&mut self, id: acp::ToolCallId) -> Option<(usize, &mut ToolCall)> { - let entry = self.entries.get_mut(id.0 as usize); - debug_assert!( - entry.is_some(), - "We shouldn't give out ids to entries that don't exist" - ); - match entry { - Some(AgentThreadEntry::ToolCall(call)) if call.id == id => Some((id.0 as usize, call)), - _ => { - if cfg!(debug_assertions) { - panic!("entry is not a tool call"); - } - None - } - } - } - - /// Returns true if the last turn is awaiting tool authorization - pub fn waiting_for_tool_confirmation(&self) -> bool { - for entry in self.entries.iter().rev() { - match &entry { - AgentThreadEntry::ToolCall(call) => match call.status { - ToolCallStatus::WaitingForConfirmation { .. } => return true, - ToolCallStatus::Allowed { .. } - | ToolCallStatus::Rejected - | ToolCallStatus::Canceled => continue, - }, - AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => { - // Reached the beginning of the turn - return false; - } - } - } - false - } - - pub fn initialize(&self) -> impl use<> + Future> { - let connection = self.connection.clone(); - async move { Ok(connection.request(acp::InitializeParams).await?) } - } - - pub fn authenticate(&self) -> impl use<> + Future> { - let connection = self.connection.clone(); - async move { Ok(connection.request(acp::AuthenticateParams).await?) } - } - - #[cfg(test)] - pub fn send_raw( - &mut self, - message: &str, - cx: &mut Context, - ) -> BoxFuture<'static, Result<()>> { - self.send( - acp::SendUserMessageParams { - chunks: vec![acp::UserMessageChunk::Text { - text: message.to_string(), - }], - }, - cx, - ) - } - - pub fn send( - &mut self, - message: acp::SendUserMessageParams, - cx: &mut Context, - ) -> BoxFuture<'static, Result<()>> { - let agent = self.connection.clone(); - self.push_entry( - AgentThreadEntry::UserMessage(UserMessage::from_acp( - &message, - self.project.read(cx).languages().clone(), - cx, - )), - cx, - ); - - let (tx, rx) = oneshot::channel(); - let cancel = self.cancel(cx); - - self.send_task = Some(cx.spawn(async move |this, cx| { - cancel.await.log_err(); - - let result = agent.request(message).await; - tx.send(result).log_err(); - this.update(cx, |this, _cx| this.send_task.take()).log_err(); - })); - - async move { - match rx.await { - Ok(Err(e)) => Err(e)?, - _ => Ok(()), - } - } - .boxed() - } - - pub fn cancel(&mut self, cx: &mut Context) -> Task> { - let agent = self.connection.clone(); - - if self.send_task.take().is_some() { - cx.spawn(async move |this, cx| { - agent.request(acp::CancelSendMessageParams).await?; - - this.update(cx, |this, _cx| { - for entry in this.entries.iter_mut() { - if let AgentThreadEntry::ToolCall(call) = entry { - let cancel = matches!( - call.status, - ToolCallStatus::WaitingForConfirmation { .. } - | ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Running - } - ); - - if cancel { - let curr_status = - mem::replace(&mut call.status, ToolCallStatus::Canceled); - - if let ToolCallStatus::WaitingForConfirmation { - respond_tx, .. - } = curr_status - { - respond_tx - .send(acp::ToolCallConfirmationOutcome::Cancel) - .ok(); - } - } - } - } - }) - }) - } else { - Task::ready(Ok(())) - } - } - - pub fn child_status(&mut self) -> Option>> { - self.child_status.take() - } - - pub fn to_markdown(&self, cx: &App) -> String { - self.entries.iter().map(|e| e.to_markdown(cx)).collect() - } -} - -struct AcpClientDelegate { - thread: WeakEntity, - cx: AsyncApp, - // sent_buffer_versions: HashMap, HashMap>, -} - -impl AcpClientDelegate { - fn new(thread: WeakEntity, cx: AsyncApp) -> Self { - Self { thread, cx } - } -} - -impl acp::Client for AcpClientDelegate { - async fn stream_assistant_message_chunk( - &self, - params: acp::StreamAssistantMessageChunkParams, - ) -> Result<()> { - let cx = &mut self.cx.clone(); - - cx.update(|cx| { - self.thread - .update(cx, |thread, cx| { - thread.push_assistant_chunk(params.chunk, cx) - }) - .ok(); - })?; - - Ok(()) - } - - async fn request_tool_call_confirmation( - &self, - request: acp::RequestToolCallConfirmationParams, - ) -> Result { - let cx = &mut self.cx.clone(); - let ToolCallRequest { id, outcome } = cx - .update(|cx| { - self.thread.update(cx, |thread, cx| { - thread.request_tool_call( - request.label, - request.icon, - request.content, - request.confirmation, - cx, - ) - }) - })? - .context("Failed to update thread")?; - - Ok(acp::RequestToolCallConfirmationResponse { - id, - outcome: outcome.await?, - }) - } - - async fn push_tool_call( - &self, - request: acp::PushToolCallParams, - ) -> Result { - let cx = &mut self.cx.clone(); - let id = cx - .update(|cx| { - self.thread.update(cx, |thread, cx| { - thread.push_tool_call(request.label, request.icon, request.content, cx) - }) - })? - .context("Failed to update thread")?; - - Ok(acp::PushToolCallResponse { id }) - } - - async fn update_tool_call(&self, request: acp::UpdateToolCallParams) -> Result<()> { - let cx = &mut self.cx.clone(); - - cx.update(|cx| { - self.thread.update(cx, |thread, cx| { - thread.update_tool_call(request.tool_call_id, request.status, request.content, cx) - }) - })? - .context("Failed to update thread")??; - - Ok(()) - } -} - -fn acp_icon_to_ui_icon(icon: acp::Icon) -> IconName { - match icon { - acp::Icon::FileSearch => IconName::ToolSearch, - acp::Icon::Folder => IconName::ToolFolder, - acp::Icon::Globe => IconName::ToolWeb, - acp::Icon::Hammer => IconName::ToolHammer, - acp::Icon::LightBulb => IconName::ToolBulb, - acp::Icon::Pencil => IconName::ToolPencil, - acp::Icon::Regex => IconName::ToolRegex, - acp::Icon::Terminal => IconName::ToolTerminal, - } -} - -pub struct ToolCallRequest { - pub id: acp::ToolCallId, - pub outcome: oneshot::Receiver, -} - -#[cfg(test)] -mod tests { - use super::*; - use agent_servers::{AgentServerCommand, AgentServerVersion}; - use async_pipe::{PipeReader, PipeWriter}; - use futures::{channel::mpsc, future::LocalBoxFuture, select}; - use gpui::{AsyncApp, TestAppContext}; - use indoc::indoc; - use project::FakeFs; - use serde_json::json; - use settings::SettingsStore; - use smol::{future::BoxedLocal, stream::StreamExt as _}; - use std::{cell::RefCell, env, path::Path, rc::Rc, time::Duration}; - use util::path; - - fn init_test(cx: &mut TestAppContext) { - env_logger::try_init().ok(); - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - Project::init_settings(cx); - language::init(cx); - }); - } - - #[gpui::test] - async fn test_thinking_concatenation(cx: &mut TestAppContext) { - init_test(cx); - - let fs = FakeFs::new(cx.executor()); - let project = Project::test(fs, [], cx).await; - let (thread, fake_server) = fake_acp_thread(project, cx); - - fake_server.update(cx, |fake_server, _| { - fake_server.on_user_message(move |_, server, mut cx| async move { - server - .update(&mut cx, |server, _| { - server.send_to_zed(acp::StreamAssistantMessageChunkParams { - chunk: acp::AssistantMessageChunk::Thought { - thought: "Thinking ".into(), - }, - }) - })? - .await - .unwrap(); - server - .update(&mut cx, |server, _| { - server.send_to_zed(acp::StreamAssistantMessageChunkParams { - chunk: acp::AssistantMessageChunk::Thought { - thought: "hard!".into(), - }, - }) - })? - .await - .unwrap(); - - Ok(()) - }) - }); - - thread - .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx)) - .await - .unwrap(); - - let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx)); - assert_eq!( - output, - indoc! {r#" - ## User - - Hello from Zed! - - ## Assistant - - - Thinking hard! - - - "#} - ); - } - - #[gpui::test] - async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) { - init_test(cx); - - let fs = FakeFs::new(cx.executor()); - let project = Project::test(fs, [], cx).await; - let (thread, fake_server) = fake_acp_thread(project, cx); - - let (end_turn_tx, end_turn_rx) = oneshot::channel::<()>(); - - let tool_call_id = Rc::new(RefCell::new(None)); - let end_turn_rx = Rc::new(RefCell::new(Some(end_turn_rx))); - fake_server.update(cx, |fake_server, _| { - let tool_call_id = tool_call_id.clone(); - fake_server.on_user_message(move |_, server, mut cx| { - let end_turn_rx = end_turn_rx.clone(); - let tool_call_id = tool_call_id.clone(); - async move { - let tool_call_result = server - .update(&mut cx, |server, _| { - server.send_to_zed(acp::PushToolCallParams { - label: "Fetch".to_string(), - icon: acp::Icon::Globe, - content: None, - }) - })? - .await - .unwrap(); - *tool_call_id.clone().borrow_mut() = Some(tool_call_result.id); - end_turn_rx.take().unwrap().await.ok(); - - Ok(()) - } - }) - }); - - let request = thread.update(cx, |thread, cx| { - thread.send_raw("Fetch https://example.com", cx) - }); - - run_until_first_tool_call(&thread, cx).await; - - thread.read_with(cx, |thread, _| { - assert!(matches!( - thread.entries[1], - AgentThreadEntry::ToolCall(ToolCall { - status: ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Running, - .. - }, - .. - }) - )); - }); - - cx.run_until_parked(); - - thread - .update(cx, |thread, cx| thread.cancel(cx)) - .await - .unwrap(); - - thread.read_with(cx, |thread, _| { - assert!(matches!( - &thread.entries[1], - AgentThreadEntry::ToolCall(ToolCall { - status: ToolCallStatus::Canceled, - .. - }) - )); - }); - - fake_server - .update(cx, |fake_server, _| { - fake_server.send_to_zed(acp::UpdateToolCallParams { - tool_call_id: tool_call_id.borrow().unwrap(), - status: acp::ToolCallStatus::Finished, - content: None, - }) - }) - .await - .unwrap(); - - drop(end_turn_tx); - request.await.unwrap(); - - thread.read_with(cx, |thread, _| { - assert!(matches!( - thread.entries[1], - AgentThreadEntry::ToolCall(ToolCall { - status: ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Finished, - .. - }, - .. - }) - )); - }); - } - - #[gpui::test] - #[cfg_attr(not(feature = "gemini"), ignore)] - async fn test_gemini_basic(cx: &mut TestAppContext) { - init_test(cx); - - cx.executor().allow_parking(); - - let fs = FakeFs::new(cx.executor()); - let project = Project::test(fs, [], cx).await; - let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await; - thread - .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx)) - .await - .unwrap(); - - thread.read_with(cx, |thread, _| { - assert_eq!(thread.entries.len(), 2); - assert!(matches!( - thread.entries[0], - AgentThreadEntry::UserMessage(_) - )); - assert!(matches!( - thread.entries[1], - AgentThreadEntry::AssistantMessage(_) - )); - }); - } - - #[gpui::test] - #[cfg_attr(not(feature = "gemini"), ignore)] - async fn test_gemini_path_mentions(cx: &mut TestAppContext) { - init_test(cx); - - cx.executor().allow_parking(); - let tempdir = tempfile::tempdir().unwrap(); - std::fs::write( - tempdir.path().join("foo.rs"), - indoc! {" - fn main() { - println!(\"Hello, world!\"); - } - "}, - ) - .expect("failed to write file"); - let project = Project::example([tempdir.path()], &mut cx.to_async()).await; - let thread = gemini_acp_thread(project.clone(), tempdir.path(), cx).await; - thread - .update(cx, |thread, cx| { - thread.send( - acp::SendUserMessageParams { - chunks: vec![ - acp::UserMessageChunk::Text { - text: "Read the file ".into(), - }, - acp::UserMessageChunk::Path { - path: Path::new("foo.rs").into(), - }, - acp::UserMessageChunk::Text { - text: " and tell me what the content of the println! is".into(), - }, - ], - }, - cx, - ) - }) - .await - .unwrap(); - - thread.read_with(cx, |thread, cx| { - assert_eq!(thread.entries.len(), 3); - assert!(matches!( - thread.entries[0], - AgentThreadEntry::UserMessage(_) - )); - assert!(matches!(thread.entries[1], AgentThreadEntry::ToolCall(_))); - let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries[2] else { - panic!("Expected AssistantMessage") - }; - assert!( - assistant_message.to_markdown(cx).contains("Hello, world!"), - "unexpected assistant message: {:?}", - assistant_message.to_markdown(cx) - ); - }); - } - - #[gpui::test] - #[cfg_attr(not(feature = "gemini"), ignore)] - async fn test_gemini_tool_call(cx: &mut TestAppContext) { - init_test(cx); - - cx.executor().allow_parking(); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/private/tmp"), - json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}), - ) - .await; - let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; - let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await; - thread - .update(cx, |thread, cx| { - thread.send_raw( - "Read the '/private/tmp/foo' file and tell me what you see.", - cx, - ) - }) - .await - .unwrap(); - thread.read_with(cx, |thread, _cx| { - assert!(matches!( - &thread.entries()[2], - AgentThreadEntry::ToolCall(ToolCall { - status: ToolCallStatus::Allowed { .. }, - .. - }) - )); - - assert!(matches!( - thread.entries[3], - AgentThreadEntry::AssistantMessage(_) - )); - }); - } - - #[gpui::test] - #[cfg_attr(not(feature = "gemini"), ignore)] - async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) { - init_test(cx); - - cx.executor().allow_parking(); - - let fs = FakeFs::new(cx.executor()); - let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; - let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await; - let full_turn = thread.update(cx, |thread, cx| { - thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx) - }); - - run_until_first_tool_call(&thread, cx).await; - - let tool_call_id = thread.read_with(cx, |thread, _cx| { - let AgentThreadEntry::ToolCall(ToolCall { - id, - status: - ToolCallStatus::WaitingForConfirmation { - confirmation: ToolCallConfirmation::Execute { root_command, .. }, - .. - }, - .. - }) = &thread.entries()[2] - else { - panic!(); - }; - - assert_eq!(root_command, "echo"); - - *id - }); - - thread.update(cx, |thread, cx| { - thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx); - - assert!(matches!( - &thread.entries()[2], - AgentThreadEntry::ToolCall(ToolCall { - status: ToolCallStatus::Allowed { .. }, - .. - }) - )); - }); - - full_turn.await.unwrap(); - - thread.read_with(cx, |thread, cx| { - let AgentThreadEntry::ToolCall(ToolCall { - content: Some(ToolCallContent::Markdown { markdown }), - status: ToolCallStatus::Allowed { .. }, - .. - }) = &thread.entries()[2] - else { - panic!(); - }; - - markdown.read_with(cx, |md, _cx| { - assert!( - md.source().contains("Hello, world!"), - r#"Expected '{}' to contain "Hello, world!""#, - md.source() - ); - }); - }); - } - - #[gpui::test] - #[cfg_attr(not(feature = "gemini"), ignore)] - async fn test_gemini_cancel(cx: &mut TestAppContext) { - init_test(cx); - - cx.executor().allow_parking(); - - let fs = FakeFs::new(cx.executor()); - let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; - let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await; - let full_turn = thread.update(cx, |thread, cx| { - thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx) - }); - - let first_tool_call_ix = run_until_first_tool_call(&thread, cx).await; - - thread.read_with(cx, |thread, _cx| { - let AgentThreadEntry::ToolCall(ToolCall { - id, - status: - ToolCallStatus::WaitingForConfirmation { - confirmation: ToolCallConfirmation::Execute { root_command, .. }, - .. - }, - .. - }) = &thread.entries()[first_tool_call_ix] - else { - panic!("{:?}", thread.entries()[1]); - }; - - assert_eq!(root_command, "echo"); - - *id - }); - - thread - .update(cx, |thread, cx| thread.cancel(cx)) - .await - .unwrap(); - full_turn.await.unwrap(); - thread.read_with(cx, |thread, _| { - let AgentThreadEntry::ToolCall(ToolCall { - status: ToolCallStatus::Canceled, - .. - }) = &thread.entries()[first_tool_call_ix] - else { - panic!(); - }; - }); - - thread - .update(cx, |thread, cx| { - thread.send_raw(r#"Stop running and say goodbye to me."#, cx) - }) - .await - .unwrap(); - thread.read_with(cx, |thread, _| { - assert!(matches!( - &thread.entries().last().unwrap(), - AgentThreadEntry::AssistantMessage(..), - )) - }); - } - - async fn run_until_first_tool_call( - thread: &Entity, - cx: &mut TestAppContext, - ) -> usize { - let (mut tx, mut rx) = mpsc::channel::(1); - - let subscription = cx.update(|cx| { - cx.subscribe(thread, move |thread, _, cx| { - for (ix, entry) in thread.read(cx).entries.iter().enumerate() { - if matches!(entry, AgentThreadEntry::ToolCall(_)) { - return tx.try_send(ix).unwrap(); - } - } - }) - }); - - select! { - _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => { - panic!("Timeout waiting for tool call") - } - ix = rx.next().fuse() => { - drop(subscription); - ix.unwrap() - } - } - } - - pub async fn gemini_acp_thread( - project: Entity, - current_dir: impl AsRef, - cx: &mut TestAppContext, - ) -> Entity { - struct DevGemini; - - impl agent_servers::AgentServer for DevGemini { - async fn command( - &self, - _project: &Entity, - _cx: &mut AsyncApp, - ) -> Result { - let cli_path = Path::new(env!("CARGO_MANIFEST_DIR")) - .join("../../../gemini-cli/packages/cli") - .to_string_lossy() - .to_string(); - - Ok(AgentServerCommand { - path: "node".into(), - args: vec![cli_path, "--acp".into()], - env: None, - }) - } - - async fn version( - &self, - _command: &agent_servers::AgentServerCommand, - ) -> Result { - Ok(AgentServerVersion { - current_version: "0.1.0".into(), - supported: true, - }) - } - } - - let thread = AcpThread::spawn(DevGemini, current_dir.as_ref(), project, &mut cx.to_async()) - .await - .unwrap(); - - thread - .update(cx, |thread, _| thread.initialize()) - .await - .unwrap(); - thread - } - - pub fn fake_acp_thread( - project: Entity, - cx: &mut TestAppContext, - ) -> (Entity, Entity) { - let (stdin_tx, stdin_rx) = async_pipe::pipe(); - let (stdout_tx, stdout_rx) = async_pipe::pipe(); - let thread = cx.update(|cx| cx.new(|cx| AcpThread::fake(stdin_tx, stdout_rx, project, cx))); - let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx))); - (thread, agent) - } - - pub struct FakeAcpServer { - connection: acp::ClientConnection, - _io_task: Task<()>, - on_user_message: Option< - Rc< - dyn Fn( - acp::SendUserMessageParams, - Entity, - AsyncApp, - ) -> LocalBoxFuture<'static, Result<()>>, - >, - >, - } - - #[derive(Clone)] - struct FakeAgent { - server: Entity, - cx: AsyncApp, - } - - impl acp::Agent for FakeAgent { - async fn initialize(&self) -> Result { - Ok(acp::InitializeResponse { - is_authenticated: true, - }) - } - - async fn authenticate(&self) -> Result<()> { - Ok(()) - } - - async fn cancel_send_message(&self) -> Result<()> { - Ok(()) - } - - async fn send_user_message(&self, request: acp::SendUserMessageParams) -> Result<()> { - let mut cx = self.cx.clone(); - let handler = self - .server - .update(&mut cx, |server, _| server.on_user_message.clone()) - .ok() - .flatten(); - if let Some(handler) = handler { - handler(request, self.server.clone(), self.cx.clone()).await - } else { - anyhow::bail!("No handler for on_user_message") - } - } - } - - impl FakeAcpServer { - fn new(stdin: PipeReader, stdout: PipeWriter, cx: &Context) -> Self { - let agent = FakeAgent { - server: cx.entity(), - cx: cx.to_async(), - }; - let foreground_executor = cx.foreground_executor().clone(); - - let (connection, io_fut) = acp::ClientConnection::connect_to_client( - agent.clone(), - stdout, - stdin, - move |fut| { - foreground_executor.spawn(fut).detach(); - }, - ); - FakeAcpServer { - connection: connection, - on_user_message: None, - _io_task: cx.background_spawn(async move { - io_fut.await.log_err(); - }), - } - } - - fn on_user_message( - &mut self, - handler: impl for<'a> Fn(acp::SendUserMessageParams, Entity, AsyncApp) -> F - + 'static, - ) where - F: Future> + 'static, - { - self.on_user_message - .replace(Rc::new(move |request, server, cx| { - handler(request, server, cx).boxed_local() - })); - } - - fn send_to_zed( - &self, - message: T, - ) -> BoxedLocal> { - self.connection - .request(message) - .map(|f| f.map_err(|err| anyhow!(err))) - .boxed_local() - } - } -} diff --git a/crates/acp/Cargo.toml b/crates/acp_thread/Cargo.toml similarity index 79% rename from crates/acp/Cargo.toml rename to crates/acp_thread/Cargo.toml index dae6292e28fb6a345de4c06bb2a65da7f3ebad4c..1831c7e4733a58a889d082b8276acecc8dd186bb 100644 --- a/crates/acp/Cargo.toml +++ b/crates/acp_thread/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "acp" +name = "acp_thread" version = "0.1.0" edition.workspace = true publish.workspace = true @@ -9,25 +9,27 @@ license = "GPL-3.0-or-later" workspace = true [lib] -path = "src/acp.rs" +path = "src/acp_thread.rs" doctest = false [features] test-support = ["gpui/test-support", "project/test-support"] -gemini = [] [dependencies] -agent_servers.workspace = true -agentic-coding-protocol.workspace = true +agent-client-protocol.workspace = true anyhow.workspace = true +assistant_tool.workspace = true buffer_diff.workspace = true editor.workspace = true futures.workspace = true gpui.workspace = true itertools.workspace = true language.workspace = true +language_model.workspace = true markdown.workspace = true project.workspace = true +serde.workspace = true +serde_json.workspace = true settings.workspace = true smol.workspace = true ui.workspace = true @@ -35,12 +37,12 @@ util.workspace = true workspace-hack.workspace = true [dev-dependencies] -async-pipe.workspace = true env_logger.workspace = true gpui = { workspace = true, "features" = ["test-support"] } indoc.workspace = true +parking_lot.workspace = true project = { workspace = true, "features" = ["test-support"] } -serde_json.workspace = true +rand.workspace = true tempfile.workspace = true util.workspace = true settings.workspace = true diff --git a/crates/acp/LICENSE-GPL b/crates/acp_thread/LICENSE-GPL similarity index 100% rename from crates/acp/LICENSE-GPL rename to crates/acp_thread/LICENSE-GPL diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs new file mode 100644 index 0000000000000000000000000000000000000000..71827d69486d1cf1cca887d8050663e3d075a975 --- /dev/null +++ b/crates/acp_thread/src/acp_thread.rs @@ -0,0 +1,1833 @@ +mod connection; +pub use connection::*; + +use agent_client_protocol as acp; +use anyhow::{Context as _, Result}; +use assistant_tool::ActionLog; +use buffer_diff::BufferDiff; +use editor::{Bias, MultiBuffer, PathKey}; +use futures::future::{Fuse, FusedFuture}; +use futures::{FutureExt, channel::oneshot, future::BoxFuture}; +use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task}; +use itertools::Itertools; +use language::{ + Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point, + text_diff, +}; +use markdown::Markdown; +use project::{AgentLocation, Project}; +use std::collections::HashMap; +use std::error::Error; +use std::fmt::Formatter; +use std::process::ExitStatus; +use std::rc::Rc; +use std::{ + fmt::Display, + mem, + path::{Path, PathBuf}, + sync::Arc, +}; +use ui::App; +use util::ResultExt; + +#[derive(Debug)] +pub struct UserMessage { + pub content: ContentBlock, +} + +impl UserMessage { + pub fn from_acp( + message: impl IntoIterator, + language_registry: Arc, + cx: &mut App, + ) -> Self { + let mut content = ContentBlock::Empty; + for chunk in message { + content.append(chunk, &language_registry, cx) + } + Self { content: content } + } + + fn to_markdown(&self, cx: &App) -> String { + format!("## User\n\n{}\n\n", self.content.to_markdown(cx)) + } +} + +#[derive(Debug)] +pub struct MentionPath<'a>(&'a Path); + +impl<'a> MentionPath<'a> { + const PREFIX: &'static str = "@file:"; + + pub fn new(path: &'a Path) -> Self { + MentionPath(path) + } + + pub fn try_parse(url: &'a str) -> Option { + let path = url.strip_prefix(Self::PREFIX)?; + Some(MentionPath(Path::new(path))) + } + + pub fn path(&self) -> &Path { + self.0 + } +} + +impl Display for MentionPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "[@{}]({}{})", + self.0.file_name().unwrap_or_default().display(), + Self::PREFIX, + self.0.display() + ) + } +} + +#[derive(Debug, PartialEq)] +pub struct AssistantMessage { + pub chunks: Vec, +} + +impl AssistantMessage { + pub fn to_markdown(&self, cx: &App) -> String { + format!( + "## Assistant\n\n{}\n\n", + self.chunks + .iter() + .map(|chunk| chunk.to_markdown(cx)) + .join("\n\n") + ) + } +} + +#[derive(Debug, PartialEq)] +pub enum AssistantMessageChunk { + Message { block: ContentBlock }, + Thought { block: ContentBlock }, +} + +impl AssistantMessageChunk { + pub fn from_str(chunk: &str, language_registry: &Arc, cx: &mut App) -> Self { + Self::Message { + block: ContentBlock::new(chunk.into(), language_registry, cx), + } + } + + fn to_markdown(&self, cx: &App) -> String { + match self { + Self::Message { block } => block.to_markdown(cx).to_string(), + Self::Thought { block } => { + format!("\n{}\n", block.to_markdown(cx)) + } + } + } +} + +#[derive(Debug)] +pub enum AgentThreadEntry { + UserMessage(UserMessage), + AssistantMessage(AssistantMessage), + ToolCall(ToolCall), +} + +impl AgentThreadEntry { + fn to_markdown(&self, cx: &App) -> String { + match self { + Self::UserMessage(message) => message.to_markdown(cx), + Self::AssistantMessage(message) => message.to_markdown(cx), + Self::ToolCall(tool_call) => tool_call.to_markdown(cx), + } + } + + pub fn diffs(&self) -> impl Iterator { + if let AgentThreadEntry::ToolCall(call) = self { + itertools::Either::Left(call.diffs()) + } else { + itertools::Either::Right(std::iter::empty()) + } + } + + pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> { + if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self { + Some(locations) + } else { + None + } + } +} + +#[derive(Debug)] +pub struct ToolCall { + pub id: acp::ToolCallId, + pub label: Entity, + pub kind: acp::ToolKind, + pub content: Vec, + pub status: ToolCallStatus, + pub locations: Vec, + pub raw_input: Option, + pub raw_output: Option, +} + +impl ToolCall { + fn from_acp( + tool_call: acp::ToolCall, + status: ToolCallStatus, + language_registry: Arc, + cx: &mut App, + ) -> Self { + Self { + id: tool_call.id, + label: cx.new(|cx| { + Markdown::new( + tool_call.title.into(), + Some(language_registry.clone()), + None, + cx, + ) + }), + kind: tool_call.kind, + content: tool_call + .content + .into_iter() + .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx)) + .collect(), + locations: tool_call.locations, + status, + raw_input: tool_call.raw_input, + raw_output: tool_call.raw_output, + } + } + + fn update( + &mut self, + fields: acp::ToolCallUpdateFields, + language_registry: Arc, + cx: &mut App, + ) { + let acp::ToolCallUpdateFields { + kind, + status, + title, + content, + locations, + raw_input, + raw_output, + } = fields; + + if let Some(kind) = kind { + self.kind = kind; + } + + if let Some(status) = status { + self.status = ToolCallStatus::Allowed { status }; + } + + if let Some(title) = title { + self.label.update(cx, |label, cx| { + label.replace(title, cx); + }); + } + + if let Some(content) = content { + self.content = content + .into_iter() + .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx)) + .collect(); + } + + if let Some(locations) = locations { + self.locations = locations; + } + + if let Some(raw_input) = raw_input { + self.raw_input = Some(raw_input); + } + + if let Some(raw_output) = raw_output { + self.raw_output = Some(raw_output); + } + } + + pub fn diffs(&self) -> impl Iterator { + self.content.iter().filter_map(|content| match content { + ToolCallContent::ContentBlock { .. } => None, + ToolCallContent::Diff { diff } => Some(diff), + }) + } + + fn to_markdown(&self, cx: &App) -> String { + let mut markdown = format!( + "**Tool Call: {}**\nStatus: {}\n\n", + self.label.read(cx).source(), + self.status + ); + for content in &self.content { + markdown.push_str(content.to_markdown(cx).as_str()); + markdown.push_str("\n\n"); + } + markdown + } +} + +#[derive(Debug)] +pub enum ToolCallStatus { + WaitingForConfirmation { + options: Vec, + respond_tx: oneshot::Sender, + }, + Allowed { + status: acp::ToolCallStatus, + }, + Rejected, + Canceled, +} + +impl Display for ToolCallStatus { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation", + ToolCallStatus::Allowed { status } => match status { + acp::ToolCallStatus::Pending => "Pending", + acp::ToolCallStatus::InProgress => "In Progress", + acp::ToolCallStatus::Completed => "Completed", + acp::ToolCallStatus::Failed => "Failed", + }, + ToolCallStatus::Rejected => "Rejected", + ToolCallStatus::Canceled => "Canceled", + } + ) + } +} + +#[derive(Debug, PartialEq, Clone)] +pub enum ContentBlock { + Empty, + Markdown { markdown: Entity }, +} + +impl ContentBlock { + pub fn new( + block: acp::ContentBlock, + language_registry: &Arc, + cx: &mut App, + ) -> Self { + let mut this = Self::Empty; + this.append(block, language_registry, cx); + this + } + + pub fn new_combined( + blocks: impl IntoIterator, + language_registry: Arc, + cx: &mut App, + ) -> Self { + let mut this = Self::Empty; + for block in blocks { + this.append(block, &language_registry, cx); + } + this + } + + pub fn append( + &mut self, + block: acp::ContentBlock, + language_registry: &Arc, + cx: &mut App, + ) { + let new_content = match block { + acp::ContentBlock::Text(text_content) => text_content.text.clone(), + acp::ContentBlock::ResourceLink(resource_link) => { + if let Some(path) = resource_link.uri.strip_prefix("file://") { + format!("{}", MentionPath(path.as_ref())) + } else { + resource_link.uri.clone() + } + } + acp::ContentBlock::Image(_) + | acp::ContentBlock::Audio(_) + | acp::ContentBlock::Resource(_) => String::new(), + }; + + match self { + ContentBlock::Empty => { + *self = ContentBlock::Markdown { + markdown: cx.new(|cx| { + Markdown::new( + new_content.into(), + Some(language_registry.clone()), + None, + cx, + ) + }), + }; + } + ContentBlock::Markdown { markdown } => { + markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx)); + } + } + } + + fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str { + match self { + ContentBlock::Empty => "", + ContentBlock::Markdown { markdown } => markdown.read(cx).source(), + } + } + + pub fn markdown(&self) -> Option<&Entity> { + match self { + ContentBlock::Empty => None, + ContentBlock::Markdown { markdown } => Some(markdown), + } + } +} + +#[derive(Debug)] +pub enum ToolCallContent { + ContentBlock { content: ContentBlock }, + Diff { diff: Diff }, +} + +impl ToolCallContent { + pub fn from_acp( + content: acp::ToolCallContent, + language_registry: Arc, + cx: &mut App, + ) -> Self { + match content { + acp::ToolCallContent::Content { content } => Self::ContentBlock { + content: ContentBlock::new(content, &language_registry, cx), + }, + acp::ToolCallContent::Diff { diff } => Self::Diff { + diff: Diff::from_acp(diff, language_registry, cx), + }, + } + } + + pub fn to_markdown(&self, cx: &App) -> String { + match self { + Self::ContentBlock { content } => content.to_markdown(cx).to_string(), + Self::Diff { diff } => diff.to_markdown(cx), + } + } +} + +#[derive(Debug)] +pub struct Diff { + pub multibuffer: Entity, + pub path: PathBuf, + _task: Task>, +} + +impl Diff { + pub fn from_acp( + diff: acp::Diff, + language_registry: Arc, + cx: &mut App, + ) -> Self { + let acp::Diff { + path, + old_text, + new_text, + } = diff; + + let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly)); + + let new_buffer = cx.new(|cx| Buffer::local(new_text, cx)); + let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx)); + let new_buffer_snapshot = new_buffer.read(cx).text_snapshot(); + let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx)); + + let task = cx.spawn({ + let multibuffer = multibuffer.clone(); + let path = path.clone(); + async move |cx| { + let language = language_registry + .language_for_file_path(&path) + .await + .log_err(); + + new_buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?; + + let old_buffer_snapshot = old_buffer.update(cx, |buffer, cx| { + buffer.set_language(language, cx); + buffer.snapshot() + })?; + + buffer_diff + .update(cx, |diff, cx| { + diff.set_base_text( + old_buffer_snapshot, + Some(language_registry), + new_buffer_snapshot, + cx, + ) + })? + .await?; + + multibuffer + .update(cx, |multibuffer, cx| { + let hunk_ranges = { + let buffer = new_buffer.read(cx); + let diff = buffer_diff.read(cx); + diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx) + .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer)) + .collect::>() + }; + + multibuffer.set_excerpts_for_path( + PathKey::for_buffer(&new_buffer, cx), + new_buffer.clone(), + hunk_ranges, + editor::DEFAULT_MULTIBUFFER_CONTEXT, + cx, + ); + multibuffer.add_diff(buffer_diff, cx); + }) + .log_err(); + + anyhow::Ok(()) + } + }); + + Self { + multibuffer, + path, + _task: task, + } + } + + fn to_markdown(&self, cx: &App) -> String { + let buffer_text = self + .multibuffer + .read(cx) + .all_buffers() + .iter() + .map(|buffer| buffer.read(cx).text()) + .join("\n"); + format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text) + } +} + +#[derive(Debug, Default)] +pub struct Plan { + pub entries: Vec, +} + +#[derive(Debug)] +pub struct PlanStats<'a> { + pub in_progress_entry: Option<&'a PlanEntry>, + pub pending: u32, + pub completed: u32, +} + +impl Plan { + pub fn is_empty(&self) -> bool { + self.entries.is_empty() + } + + pub fn stats(&self) -> PlanStats<'_> { + let mut stats = PlanStats { + in_progress_entry: None, + pending: 0, + completed: 0, + }; + + for entry in &self.entries { + match &entry.status { + acp::PlanEntryStatus::Pending => { + stats.pending += 1; + } + acp::PlanEntryStatus::InProgress => { + stats.in_progress_entry = stats.in_progress_entry.or(Some(entry)); + } + acp::PlanEntryStatus::Completed => { + stats.completed += 1; + } + } + } + + stats + } +} + +#[derive(Debug)] +pub struct PlanEntry { + pub content: Entity, + pub priority: acp::PlanEntryPriority, + pub status: acp::PlanEntryStatus, +} + +impl PlanEntry { + pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self { + Self { + content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)), + priority: entry.priority, + status: entry.status, + } + } +} + +pub struct AcpThread { + title: SharedString, + entries: Vec, + plan: Plan, + project: Entity, + action_log: Entity, + shared_buffers: HashMap, BufferSnapshot>, + send_task: Option>>, + connection: Rc, + session_id: acp::SessionId, +} + +pub enum AcpThreadEvent { + NewEntry, + EntryUpdated(usize), + ToolAuthorizationRequired, + Stopped, + Error, + ServerExited(ExitStatus), +} + +impl EventEmitter for AcpThread {} + +#[derive(PartialEq, Eq)] +pub enum ThreadStatus { + Idle, + WaitingForToolConfirmation, + Generating, +} + +#[derive(Debug, Clone)] +pub enum LoadError { + Unsupported { + error_message: SharedString, + upgrade_message: SharedString, + upgrade_command: String, + }, + Exited(i32), + Other(SharedString), +} + +impl Display for LoadError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message), + LoadError::Exited(status) => write!(f, "Server exited with status {}", status), + LoadError::Other(msg) => write!(f, "{}", msg), + } + } +} + +impl Error for LoadError {} + +impl AcpThread { + pub fn new( + title: impl Into, + connection: Rc, + project: Entity, + session_id: acp::SessionId, + cx: &mut Context, + ) -> Self { + let action_log = cx.new(|_| ActionLog::new(project.clone())); + + Self { + action_log, + shared_buffers: Default::default(), + entries: Default::default(), + plan: Default::default(), + title: title.into(), + project, + send_task: None, + connection, + session_id, + } + } + + pub fn action_log(&self) -> &Entity { + &self.action_log + } + + pub fn project(&self) -> &Entity { + &self.project + } + + pub fn title(&self) -> SharedString { + self.title.clone() + } + + pub fn entries(&self) -> &[AgentThreadEntry] { + &self.entries + } + + pub fn session_id(&self) -> &acp::SessionId { + &self.session_id + } + + pub fn status(&self) -> ThreadStatus { + if self + .send_task + .as_ref() + .map_or(false, |t| !t.is_terminated()) + { + if self.waiting_for_tool_confirmation() { + ThreadStatus::WaitingForToolConfirmation + } else { + ThreadStatus::Generating + } + } else { + ThreadStatus::Idle + } + } + + pub fn has_pending_edit_tool_calls(&self) -> bool { + for entry in self.entries.iter().rev() { + match entry { + AgentThreadEntry::UserMessage(_) => return false, + AgentThreadEntry::ToolCall( + call @ ToolCall { + status: + ToolCallStatus::Allowed { + status: + acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending, + }, + .. + }, + ) if call.diffs().next().is_some() => { + return true; + } + AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {} + } + } + + false + } + + pub fn used_tools_since_last_user_message(&self) -> bool { + for entry in self.entries.iter().rev() { + match entry { + AgentThreadEntry::UserMessage(..) => return false, + AgentThreadEntry::AssistantMessage(..) => continue, + AgentThreadEntry::ToolCall(..) => return true, + } + } + + false + } + + pub fn handle_session_update( + &mut self, + update: acp::SessionUpdate, + cx: &mut Context, + ) -> Result<()> { + match update { + acp::SessionUpdate::UserMessageChunk { content } => { + self.push_user_content_block(content, cx); + } + acp::SessionUpdate::AgentMessageChunk { content } => { + self.push_assistant_content_block(content, false, cx); + } + acp::SessionUpdate::AgentThoughtChunk { content } => { + self.push_assistant_content_block(content, true, cx); + } + acp::SessionUpdate::ToolCall(tool_call) => { + self.upsert_tool_call(tool_call, cx); + } + acp::SessionUpdate::ToolCallUpdate(tool_call_update) => { + self.update_tool_call(tool_call_update, cx)?; + } + acp::SessionUpdate::Plan(plan) => { + self.update_plan(plan, cx); + } + } + Ok(()) + } + + pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context) { + let language_registry = self.project.read(cx).languages().clone(); + let entries_len = self.entries.len(); + + if let Some(last_entry) = self.entries.last_mut() + && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry + { + content.append(chunk, &language_registry, cx); + cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1)); + } else { + let content = ContentBlock::new(chunk, &language_registry, cx); + self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx); + } + } + + pub fn push_assistant_content_block( + &mut self, + chunk: acp::ContentBlock, + is_thought: bool, + cx: &mut Context, + ) { + let language_registry = self.project.read(cx).languages().clone(); + let entries_len = self.entries.len(); + if let Some(last_entry) = self.entries.last_mut() + && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry + { + cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1)); + match (chunks.last_mut(), is_thought) { + (Some(AssistantMessageChunk::Message { block }), false) + | (Some(AssistantMessageChunk::Thought { block }), true) => { + block.append(chunk, &language_registry, cx) + } + _ => { + let block = ContentBlock::new(chunk, &language_registry, cx); + if is_thought { + chunks.push(AssistantMessageChunk::Thought { block }) + } else { + chunks.push(AssistantMessageChunk::Message { block }) + } + } + } + } else { + let block = ContentBlock::new(chunk, &language_registry, cx); + let chunk = if is_thought { + AssistantMessageChunk::Thought { block } + } else { + AssistantMessageChunk::Message { block } + }; + + self.push_entry( + AgentThreadEntry::AssistantMessage(AssistantMessage { + chunks: vec![chunk], + }), + cx, + ); + } + } + + fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context) { + self.entries.push(entry); + cx.emit(AcpThreadEvent::NewEntry); + } + + pub fn update_tool_call( + &mut self, + update: acp::ToolCallUpdate, + cx: &mut Context, + ) -> Result<()> { + let languages = self.project.read(cx).languages().clone(); + + let (ix, current_call) = self + .tool_call_mut(&update.id) + .context("Tool call not found")?; + current_call.update(update.fields, languages, cx); + + cx.emit(AcpThreadEvent::EntryUpdated(ix)); + + Ok(()) + } + + /// Updates a tool call if id matches an existing entry, otherwise inserts a new one. + pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context) { + let status = ToolCallStatus::Allowed { + status: tool_call.status, + }; + self.upsert_tool_call_inner(tool_call, status, cx) + } + + pub fn upsert_tool_call_inner( + &mut self, + tool_call: acp::ToolCall, + status: ToolCallStatus, + cx: &mut Context, + ) { + let language_registry = self.project.read(cx).languages().clone(); + let call = ToolCall::from_acp(tool_call, status, language_registry, cx); + + let location = call.locations.last().cloned(); + + if let Some((ix, current_call)) = self.tool_call_mut(&call.id) { + *current_call = call; + + cx.emit(AcpThreadEvent::EntryUpdated(ix)); + } else { + self.push_entry(AgentThreadEntry::ToolCall(call), cx); + } + + if let Some(location) = location { + self.set_project_location(location, cx) + } + } + + fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> { + // The tool call we are looking for is typically the last one, or very close to the end. + // At the moment, it doesn't seem like a hashmap would be a good fit for this use case. + self.entries + .iter_mut() + .enumerate() + .rev() + .find_map(|(index, tool_call)| { + if let AgentThreadEntry::ToolCall(tool_call) = tool_call + && &tool_call.id == id + { + Some((index, tool_call)) + } else { + None + } + }) + } + + pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context) { + self.project.update(cx, |project, cx| { + let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else { + return; + }; + let buffer = project.open_buffer(path, cx); + cx.spawn(async move |project, cx| { + let buffer = buffer.await?; + + project.update(cx, |project, cx| { + let position = if let Some(line) = location.line { + let snapshot = buffer.read(cx).snapshot(); + let point = snapshot.clip_point(Point::new(line, 0), Bias::Left); + snapshot.anchor_before(point) + } else { + Anchor::MIN + }; + + project.set_agent_location( + Some(AgentLocation { + buffer: buffer.downgrade(), + position, + }), + cx, + ); + }) + }) + .detach_and_log_err(cx); + }); + } + + pub fn request_tool_call_authorization( + &mut self, + tool_call: acp::ToolCall, + options: Vec, + cx: &mut Context, + ) -> oneshot::Receiver { + let (tx, rx) = oneshot::channel(); + + let status = ToolCallStatus::WaitingForConfirmation { + options, + respond_tx: tx, + }; + + self.upsert_tool_call_inner(tool_call, status, cx); + cx.emit(AcpThreadEvent::ToolAuthorizationRequired); + rx + } + + pub fn authorize_tool_call( + &mut self, + id: acp::ToolCallId, + option_id: acp::PermissionOptionId, + option_kind: acp::PermissionOptionKind, + cx: &mut Context, + ) { + let Some((ix, call)) = self.tool_call_mut(&id) else { + return; + }; + + let new_status = match option_kind { + acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => { + ToolCallStatus::Rejected + } + acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => { + ToolCallStatus::Allowed { + status: acp::ToolCallStatus::InProgress, + } + } + }; + + let curr_status = mem::replace(&mut call.status, new_status); + + if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status { + respond_tx.send(option_id).log_err(); + } else if cfg!(debug_assertions) { + panic!("tried to authorize an already authorized tool call"); + } + + cx.emit(AcpThreadEvent::EntryUpdated(ix)); + } + + /// Returns true if the last turn is awaiting tool authorization + pub fn waiting_for_tool_confirmation(&self) -> bool { + for entry in self.entries.iter().rev() { + match &entry { + AgentThreadEntry::ToolCall(call) => match call.status { + ToolCallStatus::WaitingForConfirmation { .. } => return true, + ToolCallStatus::Allowed { .. } + | ToolCallStatus::Rejected + | ToolCallStatus::Canceled => continue, + }, + AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => { + // Reached the beginning of the turn + return false; + } + } + } + false + } + + pub fn plan(&self) -> &Plan { + &self.plan + } + + pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context) { + let new_entries_len = request.entries.len(); + let mut new_entries = request.entries.into_iter(); + + // Reuse existing markdown to prevent flickering + for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) { + let PlanEntry { + content, + priority, + status, + } = old; + content.update(cx, |old, cx| { + old.replace(new.content, cx); + }); + *priority = new.priority; + *status = new.status; + } + for new in new_entries { + self.plan.entries.push(PlanEntry::from_acp(new, cx)) + } + self.plan.entries.truncate(new_entries_len); + + cx.notify(); + } + + fn clear_completed_plan_entries(&mut self, cx: &mut Context) { + self.plan + .entries + .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed)); + cx.notify(); + } + + #[cfg(any(test, feature = "test-support"))] + pub fn send_raw( + &mut self, + message: &str, + cx: &mut Context, + ) -> BoxFuture<'static, Result<()>> { + self.send( + vec![acp::ContentBlock::Text(acp::TextContent { + text: message.to_string(), + annotations: None, + })], + cx, + ) + } + + pub fn send( + &mut self, + message: Vec, + cx: &mut Context, + ) -> BoxFuture<'static, Result<()>> { + let block = ContentBlock::new_combined( + message.clone(), + self.project.read(cx).languages().clone(), + cx, + ); + self.push_entry( + AgentThreadEntry::UserMessage(UserMessage { content: block }), + cx, + ); + self.clear_completed_plan_entries(cx); + + let (tx, rx) = oneshot::channel(); + let cancel_task = self.cancel(cx); + + self.send_task = Some( + cx.spawn(async move |this, cx| { + async { + cancel_task.await; + + let result = this + .update(cx, |this, cx| { + this.connection.prompt( + acp::PromptRequest { + prompt: message, + session_id: this.session_id.clone(), + }, + cx, + ) + })? + .await; + + tx.send(result).log_err(); + anyhow::Ok(()) + } + .await + .log_err(); + }) + .fuse(), + ); + + cx.spawn(async move |this, cx| match rx.await { + Ok(Err(e)) => { + this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error)) + .log_err(); + Err(e)? + } + _ => { + this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped)) + .log_err(); + Ok(()) + } + }) + .boxed() + } + + pub fn cancel(&mut self, cx: &mut Context) -> Task<()> { + let Some(send_task) = self.send_task.take() else { + return Task::ready(()); + }; + + for entry in self.entries.iter_mut() { + if let AgentThreadEntry::ToolCall(call) = entry { + let cancel = matches!( + call.status, + ToolCallStatus::WaitingForConfirmation { .. } + | ToolCallStatus::Allowed { + status: acp::ToolCallStatus::InProgress + } + ); + + if cancel { + call.status = ToolCallStatus::Canceled; + } + } + } + + self.connection.cancel(&self.session_id, cx); + + // Wait for the send task to complete + cx.foreground_executor().spawn(send_task) + } + + pub fn read_text_file( + &self, + path: PathBuf, + line: Option, + limit: Option, + reuse_shared_snapshot: bool, + cx: &mut Context, + ) -> Task> { + let project = self.project.clone(); + let action_log = self.action_log.clone(); + cx.spawn(async move |this, cx| { + let load = project.update(cx, |project, cx| { + let path = project + .project_path_for_absolute_path(&path, cx) + .context("invalid path")?; + anyhow::Ok(project.open_buffer(path, cx)) + }); + let buffer = load??.await?; + + let snapshot = if reuse_shared_snapshot { + this.read_with(cx, |this, _| { + this.shared_buffers.get(&buffer.clone()).cloned() + }) + .log_err() + .flatten() + } else { + None + }; + + let snapshot = if let Some(snapshot) = snapshot { + snapshot + } else { + action_log.update(cx, |action_log, cx| { + action_log.buffer_read(buffer.clone(), cx); + })?; + project.update(cx, |project, cx| { + let position = buffer + .read(cx) + .snapshot() + .anchor_before(Point::new(line.unwrap_or_default(), 0)); + project.set_agent_location( + Some(AgentLocation { + buffer: buffer.downgrade(), + position, + }), + cx, + ); + })?; + + buffer.update(cx, |buffer, _| buffer.snapshot())? + }; + + this.update(cx, |this, _| { + let text = snapshot.text(); + this.shared_buffers.insert(buffer.clone(), snapshot); + if line.is_none() && limit.is_none() { + return Ok(text); + } + let limit = limit.unwrap_or(u32::MAX) as usize; + let Some(line) = line else { + return Ok(text.lines().take(limit).collect::()); + }; + + let count = text.lines().count(); + if count < line as usize { + anyhow::bail!("There are only {} lines", count); + } + Ok(text + .lines() + .skip(line as usize + 1) + .take(limit) + .collect::()) + })? + }) + } + + pub fn write_text_file( + &self, + path: PathBuf, + content: String, + cx: &mut Context, + ) -> Task> { + let project = self.project.clone(); + let action_log = self.action_log.clone(); + cx.spawn(async move |this, cx| { + let load = project.update(cx, |project, cx| { + let path = project + .project_path_for_absolute_path(&path, cx) + .context("invalid path")?; + anyhow::Ok(project.open_buffer(path, cx)) + }); + let buffer = load??.await?; + let snapshot = this.update(cx, |this, cx| { + this.shared_buffers + .get(&buffer) + .cloned() + .unwrap_or_else(|| buffer.read(cx).snapshot()) + })?; + let edits = cx + .background_executor() + .spawn(async move { + let old_text = snapshot.text(); + text_diff(old_text.as_str(), &content) + .into_iter() + .map(|(range, replacement)| { + ( + snapshot.anchor_after(range.start) + ..snapshot.anchor_before(range.end), + replacement, + ) + }) + .collect::>() + }) + .await; + cx.update(|cx| { + project.update(cx, |project, cx| { + project.set_agent_location( + Some(AgentLocation { + buffer: buffer.downgrade(), + position: edits + .last() + .map(|(range, _)| range.end) + .unwrap_or(Anchor::MIN), + }), + cx, + ); + }); + + action_log.update(cx, |action_log, cx| { + action_log.buffer_read(buffer.clone(), cx); + }); + buffer.update(cx, |buffer, cx| { + buffer.edit(edits, None, cx); + }); + action_log.update(cx, |action_log, cx| { + action_log.buffer_edited(buffer.clone(), cx); + }); + })?; + project + .update(cx, |project, cx| project.save_buffer(buffer, cx))? + .await + }) + } + + pub fn to_markdown(&self, cx: &App) -> String { + self.entries.iter().map(|e| e.to_markdown(cx)).collect() + } + + pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context) { + cx.emit(AcpThreadEvent::ServerExited(status)); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use anyhow::anyhow; + use futures::{channel::mpsc, future::LocalBoxFuture, select}; + use gpui::{AsyncApp, TestAppContext, WeakEntity}; + use indoc::indoc; + use project::FakeFs; + use rand::Rng as _; + use serde_json::json; + use settings::SettingsStore; + use smol::stream::StreamExt as _; + use std::{cell::RefCell, rc::Rc, time::Duration}; + + use util::path; + + fn init_test(cx: &mut TestAppContext) { + env_logger::try_init().ok(); + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + Project::init_settings(cx); + language::init(cx); + }); + } + + #[gpui::test] + async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [], cx).await; + let connection = Rc::new(FakeAgentConnection::new()); + let thread = cx + .spawn(async move |mut cx| { + connection + .new_thread(project, Path::new(path!("/test")), &mut cx) + .await + }) + .await + .unwrap(); + + // Test creating a new user message + thread.update(cx, |thread, cx| { + thread.push_user_content_block( + acp::ContentBlock::Text(acp::TextContent { + annotations: None, + text: "Hello, ".to_string(), + }), + cx, + ); + }); + + thread.update(cx, |thread, cx| { + assert_eq!(thread.entries.len(), 1); + if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] { + assert_eq!(user_msg.content.to_markdown(cx), "Hello, "); + } else { + panic!("Expected UserMessage"); + } + }); + + // Test appending to existing user message + thread.update(cx, |thread, cx| { + thread.push_user_content_block( + acp::ContentBlock::Text(acp::TextContent { + annotations: None, + text: "world!".to_string(), + }), + cx, + ); + }); + + thread.update(cx, |thread, cx| { + assert_eq!(thread.entries.len(), 1); + if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] { + assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!"); + } else { + panic!("Expected UserMessage"); + } + }); + + // Test creating new user message after assistant message + thread.update(cx, |thread, cx| { + thread.push_assistant_content_block( + acp::ContentBlock::Text(acp::TextContent { + annotations: None, + text: "Assistant response".to_string(), + }), + false, + cx, + ); + }); + + thread.update(cx, |thread, cx| { + thread.push_user_content_block( + acp::ContentBlock::Text(acp::TextContent { + annotations: None, + text: "New user message".to_string(), + }), + cx, + ); + }); + + thread.update(cx, |thread, cx| { + assert_eq!(thread.entries.len(), 3); + if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] { + assert_eq!(user_msg.content.to_markdown(cx), "New user message"); + } else { + panic!("Expected UserMessage at index 2"); + } + }); + } + + #[gpui::test] + async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [], cx).await; + let connection = Rc::new(FakeAgentConnection::new().on_user_message( + |_, thread, mut cx| { + async move { + thread.update(&mut cx, |thread, cx| { + thread + .handle_session_update( + acp::SessionUpdate::AgentThoughtChunk { + content: "Thinking ".into(), + }, + cx, + ) + .unwrap(); + thread + .handle_session_update( + acp::SessionUpdate::AgentThoughtChunk { + content: "hard!".into(), + }, + cx, + ) + .unwrap(); + })?; + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) + } + .boxed_local() + }, + )); + + let thread = cx + .spawn(async move |mut cx| { + connection + .new_thread(project, Path::new(path!("/test")), &mut cx) + .await + }) + .await + .unwrap(); + + thread + .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx)) + .await + .unwrap(); + + let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx)); + assert_eq!( + output, + indoc! {r#" + ## User + + Hello from Zed! + + ## Assistant + + + Thinking hard! + + + "#} + ); + } + + #[gpui::test] + async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"})) + .await; + let project = Project::test(fs.clone(), [], cx).await; + let (read_file_tx, read_file_rx) = oneshot::channel::<()>(); + let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx))); + let connection = Rc::new(FakeAgentConnection::new().on_user_message( + move |_, thread, mut cx| { + let read_file_tx = read_file_tx.clone(); + async move { + let content = thread + .update(&mut cx, |thread, cx| { + thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx) + }) + .unwrap() + .await + .unwrap(); + assert_eq!(content, "one\ntwo\nthree\n"); + read_file_tx.take().unwrap().send(()).unwrap(); + thread + .update(&mut cx, |thread, cx| { + thread.write_text_file( + path!("/tmp/foo").into(), + "one\ntwo\nthree\nfour\nfive\n".to_string(), + cx, + ) + }) + .unwrap() + .await + .unwrap(); + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) + } + .boxed_local() + }, + )); + + let (worktree, pathbuf) = project + .update(cx, |project, cx| { + project.find_or_create_worktree(path!("/tmp/foo"), true, cx) + }) + .await + .unwrap(); + let buffer = project + .update(cx, |project, cx| { + project.open_buffer((worktree.read(cx).id(), pathbuf), cx) + }) + .await + .unwrap(); + + let thread = cx + .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx)) + .await + .unwrap(); + + let request = thread.update(cx, |thread, cx| { + thread.send_raw("Extend the count in /tmp/foo", cx) + }); + read_file_rx.await.ok(); + buffer.update(cx, |buffer, cx| { + buffer.edit([(0..0, "zero\n".to_string())], None, cx); + }); + cx.run_until_parked(); + assert_eq!( + buffer.read_with(cx, |buffer, _| buffer.text()), + "zero\none\ntwo\nthree\nfour\nfive\n" + ); + assert_eq!( + String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(), + "zero\none\ntwo\nthree\nfour\nfive\n" + ); + request.await.unwrap(); + } + + #[gpui::test] + async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [], cx).await; + let id = acp::ToolCallId("test".into()); + + let connection = Rc::new(FakeAgentConnection::new().on_user_message({ + let id = id.clone(); + move |_, thread, mut cx| { + let id = id.clone(); + async move { + thread + .update(&mut cx, |thread, cx| { + thread.handle_session_update( + acp::SessionUpdate::ToolCall(acp::ToolCall { + id: id.clone(), + title: "Label".into(), + kind: acp::ToolKind::Fetch, + status: acp::ToolCallStatus::InProgress, + content: vec![], + locations: vec![], + raw_input: None, + raw_output: None, + }), + cx, + ) + }) + .unwrap() + .unwrap(); + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) + } + .boxed_local() + } + })); + + let thread = cx + .spawn(async move |mut cx| { + connection + .new_thread(project, Path::new(path!("/test")), &mut cx) + .await + }) + .await + .unwrap(); + + let request = thread.update(cx, |thread, cx| { + thread.send_raw("Fetch https://example.com", cx) + }); + + run_until_first_tool_call(&thread, cx).await; + + thread.read_with(cx, |thread, _| { + assert!(matches!( + thread.entries[1], + AgentThreadEntry::ToolCall(ToolCall { + status: ToolCallStatus::Allowed { + status: acp::ToolCallStatus::InProgress, + .. + }, + .. + }) + )); + }); + + thread.update(cx, |thread, cx| thread.cancel(cx)).await; + + thread.read_with(cx, |thread, _| { + assert!(matches!( + &thread.entries[1], + AgentThreadEntry::ToolCall(ToolCall { + status: ToolCallStatus::Canceled, + .. + }) + )); + }); + + thread + .update(cx, |thread, cx| { + thread.handle_session_update( + acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate { + id, + fields: acp::ToolCallUpdateFields { + status: Some(acp::ToolCallStatus::Completed), + ..Default::default() + }, + }), + cx, + ) + }) + .unwrap(); + + request.await.unwrap(); + + thread.read_with(cx, |thread, _| { + assert!(matches!( + thread.entries[1], + AgentThreadEntry::ToolCall(ToolCall { + status: ToolCallStatus::Allowed { + status: acp::ToolCallStatus::Completed, + .. + }, + .. + }) + )); + }); + } + + #[gpui::test] + async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree(path!("/test"), json!({})).await; + let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + + let connection = Rc::new(FakeAgentConnection::new().on_user_message({ + move |_, thread, mut cx| { + async move { + thread + .update(&mut cx, |thread, cx| { + thread.handle_session_update( + acp::SessionUpdate::ToolCall(acp::ToolCall { + id: acp::ToolCallId("test".into()), + title: "Label".into(), + kind: acp::ToolKind::Edit, + status: acp::ToolCallStatus::Completed, + content: vec![acp::ToolCallContent::Diff { + diff: acp::Diff { + path: "/test/test.txt".into(), + old_text: None, + new_text: "foo".into(), + }, + }], + locations: vec![], + raw_input: None, + raw_output: None, + }), + cx, + ) + }) + .unwrap() + .unwrap(); + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) + } + .boxed_local() + } + })); + + let thread = connection + .new_thread(project, Path::new(path!("/test")), &mut cx.to_async()) + .await + .unwrap(); + cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx))) + .await + .unwrap(); + + assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls())); + } + + async fn run_until_first_tool_call( + thread: &Entity, + cx: &mut TestAppContext, + ) -> usize { + let (mut tx, mut rx) = mpsc::channel::(1); + + let subscription = cx.update(|cx| { + cx.subscribe(thread, move |thread, _, cx| { + for (ix, entry) in thread.read(cx).entries.iter().enumerate() { + if matches!(entry, AgentThreadEntry::ToolCall(_)) { + return tx.try_send(ix).unwrap(); + } + } + }) + }); + + select! { + _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => { + panic!("Timeout waiting for tool call") + } + ix = rx.next().fuse() => { + drop(subscription); + ix.unwrap() + } + } + } + + #[derive(Clone, Default)] + struct FakeAgentConnection { + auth_methods: Vec, + sessions: Arc>>>, + on_user_message: Option< + Rc< + dyn Fn( + acp::PromptRequest, + WeakEntity, + AsyncApp, + ) -> LocalBoxFuture<'static, Result> + + 'static, + >, + >, + } + + impl FakeAgentConnection { + fn new() -> Self { + Self { + auth_methods: Vec::new(), + on_user_message: None, + sessions: Arc::default(), + } + } + + #[expect(unused)] + fn with_auth_methods(mut self, auth_methods: Vec) -> Self { + self.auth_methods = auth_methods; + self + } + + fn on_user_message( + mut self, + handler: impl Fn( + acp::PromptRequest, + WeakEntity, + AsyncApp, + ) -> LocalBoxFuture<'static, Result> + + 'static, + ) -> Self { + self.on_user_message.replace(Rc::new(handler)); + self + } + } + + impl AgentConnection for FakeAgentConnection { + fn auth_methods(&self) -> &[acp::AuthMethod] { + &self.auth_methods + } + + fn new_thread( + self: Rc, + project: Entity, + _cwd: &Path, + cx: &mut gpui::AsyncApp, + ) -> Task>> { + let session_id = acp::SessionId( + rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(7) + .map(char::from) + .collect::() + .into(), + ); + let thread = cx + .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx)) + .unwrap(); + self.sessions.lock().insert(session_id, thread.downgrade()); + Task::ready(Ok(thread)) + } + + fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task> { + if self.auth_methods().iter().any(|m| m.id == method) { + Task::ready(Ok(())) + } else { + Task::ready(Err(anyhow!("Invalid Auth Method"))) + } + } + + fn prompt( + &self, + params: acp::PromptRequest, + cx: &mut App, + ) -> Task> { + let sessions = self.sessions.lock(); + let thread = sessions.get(¶ms.session_id).unwrap(); + if let Some(handler) = &self.on_user_message { + let handler = handler.clone(); + let thread = thread.clone(); + cx.spawn(async move |cx| handler(params, thread, cx.clone()).await) + } else { + Task::ready(Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + })) + } + } + + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { + let sessions = self.sessions.lock(); + let thread = sessions.get(&session_id).unwrap().clone(); + + cx.spawn(async move |cx| { + thread + .update(cx, |thread, cx| thread.cancel(cx)) + .unwrap() + .await + }) + .detach(); + } + } +} diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs new file mode 100644 index 0000000000000000000000000000000000000000..cf06563beee4aa6bdc6ecf1fdb85178116ad74dd --- /dev/null +++ b/crates/acp_thread/src/connection.rs @@ -0,0 +1,93 @@ +use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc}; + +use agent_client_protocol::{self as acp}; +use anyhow::Result; +use gpui::{AsyncApp, Entity, Task}; +use language_model::LanguageModel; +use project::Project; +use ui::App; + +use crate::AcpThread; + +/// Trait for agents that support listing, selecting, and querying language models. +/// +/// This is an optional capability; agents indicate support via [AgentConnection::model_selector]. +pub trait ModelSelector: 'static { + /// Lists all available language models for this agent. + /// + /// # Parameters + /// - `cx`: The GPUI app context for async operations and global access. + /// + /// # Returns + /// A task resolving to the list of models or an error (e.g., if no models are configured). + fn list_models(&self, cx: &mut AsyncApp) -> Task>>>; + + /// Selects a model for a specific session (thread). + /// + /// This sets the default model for future interactions in the session. + /// If the session doesn't exist or the model is invalid, it returns an error. + /// + /// # Parameters + /// - `session_id`: The ID of the session (thread) to apply the model to. + /// - `model`: The model to select (should be one from [list_models]). + /// - `cx`: The GPUI app context. + /// + /// # Returns + /// A task resolving to `Ok(())` on success or an error. + fn select_model( + &self, + session_id: acp::SessionId, + model: Arc, + cx: &mut AsyncApp, + ) -> Task>; + + /// Retrieves the currently selected model for a specific session (thread). + /// + /// # Parameters + /// - `session_id`: The ID of the session (thread) to query. + /// - `cx`: The GPUI app context. + /// + /// # Returns + /// A task resolving to the selected model (always set) or an error (e.g., session not found). + fn selected_model( + &self, + session_id: &acp::SessionId, + cx: &mut AsyncApp, + ) -> Task>>; +} + +pub trait AgentConnection { + fn new_thread( + self: Rc, + project: Entity, + cwd: &Path, + cx: &mut AsyncApp, + ) -> Task>>; + + fn auth_methods(&self) -> &[acp::AuthMethod]; + + fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task>; + + fn prompt(&self, params: acp::PromptRequest, cx: &mut App) + -> Task>; + + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App); + + /// Returns this agent as an [Rc] if the model selection capability is supported. + /// + /// If the agent does not support model selection, returns [None]. + /// This allows sharing the selector in UI components. + fn model_selector(&self) -> Option> { + None // Default impl for agents that don't support it + } +} + +#[derive(Debug)] +pub struct AuthRequired; + +impl Error for AuthRequired {} +impl fmt::Display for AuthRequired { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "AuthRequired") + } +} diff --git a/crates/activity_indicator/src/activity_indicator.rs b/crates/activity_indicator/src/activity_indicator.rs index b07c5418218c5045ae8018c9be9ec6fd07446544..f8ea7173d8afecb14a8e8ed9e1de6a87660ddc1a 100644 --- a/crates/activity_indicator/src/activity_indicator.rs +++ b/crates/activity_indicator/src/activity_indicator.rs @@ -231,7 +231,6 @@ impl ActivityIndicator { status, } => { let create_buffer = project.update(cx, |project, cx| project.create_buffer(cx)); - let project = project.clone(); let status = status.clone(); let server_name = server_name.clone(); cx.spawn_in(window, async move |workspace, cx| { @@ -247,8 +246,7 @@ impl ActivityIndicator { workspace.update_in(cx, |workspace, window, cx| { workspace.add_item_to_active_pane( Box::new(cx.new(|cx| { - let mut editor = - Editor::for_buffer(buffer, Some(project.clone()), window, cx); + let mut editor = Editor::for_buffer(buffer, None, window, cx); editor.set_read_only(true); editor })), @@ -448,7 +446,7 @@ impl ActivityIndicator { .into_any_element(), ), message: format!("Debug: {}", session.read(cx).adapter()), - tooltip_message: Some(session.read(cx).label().to_string()), + tooltip_message: session.read(cx).label().map(|label| label.to_string()), on_click: None, }); } diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index 135363ab6552a9b6737dfce0e0c95ced3237ae5c..7bc0e82cadd12a4a5e0926cfc869af194431955d 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -25,6 +25,7 @@ assistant_context.workspace = true assistant_tool.workspace = true chrono.workspace = true client.workspace = true +cloud_llm_client.workspace = true collections.workspace = true component.workspace = true context_server.workspace = true @@ -35,9 +36,9 @@ futures.workspace = true git.workspace = true gpui.workspace = true heed.workspace = true +http_client.workspace = true icons.workspace = true indoc.workspace = true -http_client.workspace = true itertools.workspace = true language.workspace = true language_model.workspace = true @@ -46,7 +47,6 @@ paths.workspace = true postage.workspace = true project.workspace = true prompt_store.workspace = true -proto.workspace = true ref-cast.workspace = true rope.workspace = true schemars.workspace = true @@ -63,7 +63,6 @@ time.workspace = true util.workspace = true uuid.workspace = true workspace-hack.workspace = true -zed_llm_client.workspace = true zstd.workspace = true [dev-dependencies] diff --git a/crates/agent/src/agent_profile.rs b/crates/agent/src/agent_profile.rs index a89857e71a6b8ed0f4e7a397be2bcd1bce4b1d7a..34ea1c8df7c4e2bbc58ac8a57b11655917c7aac2 100644 --- a/crates/agent/src/agent_profile.rs +++ b/crates/agent/src/agent_profile.rs @@ -308,7 +308,12 @@ mod tests { unimplemented!() } - fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool { + fn needs_confirmation( + &self, + _input: &serde_json::Value, + _project: &Entity, + _cx: &App, + ) -> bool { unimplemented!() } diff --git a/crates/agent/src/context.rs b/crates/agent/src/context.rs index ddd13de491ecb0e7d143ae7a6c3e602858fb9b85..cd366b8308abf8c9900d24633eea95d7ff92df6d 100644 --- a/crates/agent/src/context.rs +++ b/crates/agent/src/context.rs @@ -42,8 +42,8 @@ impl ContextKind { ContextKind::Symbol => IconName::Code, ContextKind::Selection => IconName::Context, ContextKind::FetchedUrl => IconName::Globe, - ContextKind::Thread => IconName::MessageBubbles, - ContextKind::TextThread => IconName::MessageBubbles, + ContextKind::Thread => IconName::Thread, + ContextKind::TextThread => IconName::TextThread, ContextKind::Rules => RULES_ICON, ContextKind::Image => IconName::Image, } diff --git a/crates/agent/src/context_server_tool.rs b/crates/agent/src/context_server_tool.rs index da7de1e312cea24c1be63568cc796a49ddfa178c..85e8ac7451405a292af1679406f4d184fe709eb9 100644 --- a/crates/agent/src/context_server_tool.rs +++ b/crates/agent/src/context_server_tool.rs @@ -38,7 +38,7 @@ impl Tool for ContextServerTool { } fn icon(&self) -> IconName { - IconName::Cog + IconName::ToolHammer } fn source(&self) -> ToolSource { @@ -47,7 +47,7 @@ impl Tool for ContextServerTool { } } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { true } diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 6a20ad8f83dd984c74a001fb86ccd564b110ce24..048aa4245d7ae527b094b0485a8bec048427ccec 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -8,11 +8,12 @@ use crate::{ }, tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState}, }; -use agent_settings::{AgentProfileId, AgentSettings, CompletionMode}; +use agent_settings::{AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_PROMPT}; use anyhow::{Result, anyhow}; use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet}; use chrono::{DateTime, Utc}; use client::{ModelRequestUsage, RequestUsage}; +use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit}; use collections::HashMap; use feature_flags::{self, FeatureFlagAppExt}; use futures::{FutureExt, StreamExt as _, future::Shared}; @@ -21,6 +22,7 @@ use gpui::{ AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity, Window, }; +use http_client::StatusCode; use language_model::{ ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt as _, LanguageModelId, LanguageModelRegistry, LanguageModelRequest, @@ -35,7 +37,6 @@ use project::{ git_store::{GitStore, GitStoreCheckpoint, RepositoryState}, }; use prompt_store::{ModelContext, PromptBuilder}; -use proto::Plan; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::Settings; @@ -46,12 +47,23 @@ use std::{ time::{Duration, Instant}, }; use thiserror::Error; -use util::{ResultExt as _, debug_panic, post_inc}; +use util::{ResultExt as _, post_inc}; use uuid::Uuid; -use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit}; -const MAX_RETRY_ATTEMPTS: u8 = 3; -const BASE_RETRY_DELAY_SECS: u64 = 5; +const MAX_RETRY_ATTEMPTS: u8 = 4; +const BASE_RETRY_DELAY: Duration = Duration::from_secs(5); + +#[derive(Debug, Clone)] +enum RetryStrategy { + ExponentialBackoff { + initial_delay: Duration, + max_attempts: u8, + }, + Fixed { + delay: Duration, + max_attempts: u8, + }, +} #[derive( Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema, @@ -383,6 +395,7 @@ pub struct Thread { remaining_turns: u32, configured_model: Option, profile: AgentProfile, + last_error_context: Option<(Arc, CompletionIntent)>, } #[derive(Clone, Debug)] @@ -476,10 +489,11 @@ impl Thread { retry_state: None, message_feedback: HashMap::default(), last_auto_capture_at: None, + last_error_context: None, last_received_chunk_at: None, request_callback: None, remaining_turns: u32::MAX, - configured_model, + configured_model: configured_model.clone(), profile: AgentProfile::new(profile_id, tools), } } @@ -600,6 +614,7 @@ impl Thread { feedback: None, message_feedback: HashMap::default(), last_auto_capture_at: None, + last_error_context: None, last_received_chunk_at: None, request_callback: None, remaining_turns: u32::MAX, @@ -926,7 +941,7 @@ impl Thread { } pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec { - self.tool_use.tool_uses_for_message(id, cx) + self.tool_use.tool_uses_for_message(id, &self.project, cx) } pub fn tool_results_for_message( @@ -1251,9 +1266,58 @@ impl Thread { self.flush_notifications(model.clone(), intent, cx); - let request = self.to_completion_request(model.clone(), intent, cx); + let _checkpoint = self.finalize_pending_checkpoint(cx); + self.stream_completion( + self.to_completion_request(model.clone(), intent, cx), + model, + intent, + window, + cx, + ); + } + + pub fn retry_last_completion( + &mut self, + window: Option, + cx: &mut Context, + ) { + // Clear any existing error state + self.retry_state = None; - self.stream_completion(request, model, intent, window, cx); + // Use the last error context if available, otherwise fall back to configured model + let (model, intent) = if let Some((model, intent)) = self.last_error_context.take() { + (model, intent) + } else if let Some(configured_model) = self.configured_model.as_ref() { + let model = configured_model.model.clone(); + let intent = if self.has_pending_tool_uses() { + CompletionIntent::ToolResults + } else { + CompletionIntent::UserPrompt + }; + (model, intent) + } else if let Some(configured_model) = self.get_or_init_configured_model(cx) { + let model = configured_model.model.clone(); + let intent = if self.has_pending_tool_uses() { + CompletionIntent::ToolResults + } else { + CompletionIntent::UserPrompt + }; + (model, intent) + } else { + return; + }; + + self.send_to_model(model, intent, window, cx); + } + + pub fn enable_burn_mode_and_retry( + &mut self, + window: Option, + cx: &mut Context, + ) { + self.completion_mode = CompletionMode::Burn; + cx.emit(ThreadEvent::ProfileChanged); + self.retry_last_completion(window, cx); } pub fn used_tools_since_last_user_message(&self) -> bool { @@ -1517,21 +1581,21 @@ impl Thread { model: Arc, cx: &mut App, ) -> Option { - let action_log = self.action_log.read(cx); - - action_log.unnotified_stale_buffers(cx).next()?; - // Represent notification as a simulated `project_notifications` tool call let tool_name = Arc::from("project_notifications"); - let Some(tool) = self.tools.read(cx).tool(&tool_name, cx) else { - debug_panic!("`project_notifications` tool not found"); - return None; - }; + let tool = self.tools.read(cx).tool(&tool_name, cx)?; if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) { return None; } + if self + .action_log + .update(cx, |log, cx| log.unnotified_user_edits(cx).is_none()) + { + return None; + } + let input = serde_json::json!({}); let request = Arc::new(LanguageModelRequest::default()); // unused let window = None; @@ -1616,7 +1680,7 @@ impl Thread { let completion_mode = request .mode - .unwrap_or(zed_llm_client::CompletionMode::Normal); + .unwrap_or(cloud_llm_client::CompletionMode::Normal); self.last_received_chunk_at = Some(Instant::now()); @@ -1933,18 +1997,6 @@ impl Thread { project.set_agent_location(None, cx); }); - fn emit_generic_error(error: &anyhow::Error, cx: &mut Context) { - let error_message = error - .chain() - .map(|err| err.to_string()) - .collect::>() - .join("\n"); - cx.emit(ThreadEvent::ShowError(ThreadError::Message { - header: "Error interacting with language model".into(), - message: SharedString::from(error_message.clone()), - })); - } - if error.is::() { cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired)); } else if let Some(error) = @@ -1956,9 +2008,10 @@ impl Thread { } else if let Some(completion_error) = error.downcast_ref::() { - use LanguageModelCompletionError::*; match &completion_error { - PromptTooLarge { tokens, .. } => { + LanguageModelCompletionError::PromptTooLarge { + tokens, .. + } => { let tokens = tokens.unwrap_or_else(|| { // We didn't get an exact token count from the API, so fall back on our estimate. thread @@ -1979,63 +2032,28 @@ impl Thread { }); cx.notify(); } - RateLimitExceeded { - retry_after: Some(retry_after), - .. - } - | ServerOverloaded { - retry_after: Some(retry_after), - .. - } => { - thread.handle_rate_limit_error( - &completion_error, - *retry_after, - model.clone(), - intent, - window, - cx, - ); - retry_scheduled = true; - } - RateLimitExceeded { .. } | ServerOverloaded { .. } => { - retry_scheduled = thread.handle_retryable_error( - &completion_error, - model.clone(), - intent, - window, - cx, - ); - if !retry_scheduled { - emit_generic_error(error, cx); - } - } - ApiInternalServerError { .. } - | ApiReadResponseError { .. } - | HttpSend { .. } => { - retry_scheduled = thread.handle_retryable_error( - &completion_error, - model.clone(), - intent, - window, - cx, - ); - if !retry_scheduled { - emit_generic_error(error, cx); + _ => { + if let Some(retry_strategy) = + Thread::get_retry_strategy(completion_error) + { + log::info!( + "Retrying with {:?} for language model completion error {:?}", + retry_strategy, + completion_error + ); + + retry_scheduled = thread + .handle_retryable_error_with_delay( + &completion_error, + Some(retry_strategy), + model.clone(), + intent, + window, + cx, + ); } } - NoApiKey { .. } - | HttpResponseError { .. } - | BadRequestFormat { .. } - | AuthenticationError { .. } - | PermissionError { .. } - | ApiEndpointNotFound { .. } - | SerializeRequest { .. } - | BuildRequestBody { .. } - | DeserializeResponse { .. } - | Other { .. } => emit_generic_error(error, cx), } - } else { - emit_generic_error(error, cx); } if !retry_scheduled { @@ -2094,12 +2112,10 @@ impl Thread { return; } - let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt"); - let request = self.to_summarize_request( &model.model, CompletionIntent::ThreadSummarization, - added_user_message.into(), + SUMMARIZE_THREAD_PROMPT.into(), cx, ); @@ -2162,73 +2178,141 @@ impl Thread { }); } - fn handle_rate_limit_error( - &mut self, - error: &LanguageModelCompletionError, - retry_after: Duration, - model: Arc, - intent: CompletionIntent, - window: Option, - cx: &mut Context, - ) { - // For rate limit errors, we only retry once with the specified duration - let retry_message = format!("{error}. Retrying in {} seconds…", retry_after.as_secs()); - log::warn!( - "Retrying completion request in {} seconds: {error:?}", - retry_after.as_secs(), - ); - - // Add a UI-only message instead of a regular message - let id = self.next_message_id.post_inc(); - self.messages.push(Message { - id, - role: Role::System, - segments: vec![MessageSegment::Text(retry_message)], - loaded_context: LoadedContext::default(), - creases: Vec::new(), - is_hidden: false, - ui_only: true, - }); - cx.emit(ThreadEvent::MessageAdded(id)); - // Schedule the retry - let thread_handle = cx.entity().downgrade(); - - cx.spawn(async move |_thread, cx| { - cx.background_executor().timer(retry_after).await; + fn get_retry_strategy(error: &LanguageModelCompletionError) -> Option { + use LanguageModelCompletionError::*; - thread_handle - .update(cx, |thread, cx| { - // Retry the completion - thread.send_to_model(model, intent, window, cx); + // General strategy here: + // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all. + // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff. + // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times. + match error { + HttpResponseError { + status_code: StatusCode::TOO_MANY_REQUESTS, + .. + } => Some(RetryStrategy::ExponentialBackoff { + initial_delay: BASE_RETRY_DELAY, + max_attempts: MAX_RETRY_ATTEMPTS, + }), + ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => { + Some(RetryStrategy::Fixed { + delay: retry_after.unwrap_or(BASE_RETRY_DELAY), + max_attempts: MAX_RETRY_ATTEMPTS, }) - .log_err(); - }) - .detach(); - } - - fn handle_retryable_error( - &mut self, - error: &LanguageModelCompletionError, - model: Arc, - intent: CompletionIntent, - window: Option, - cx: &mut Context, - ) -> bool { - self.handle_retryable_error_with_delay(error, None, model, intent, window, cx) + } + UpstreamProviderError { + status, + retry_after, + .. + } => match *status { + StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => { + Some(RetryStrategy::Fixed { + delay: retry_after.unwrap_or(BASE_RETRY_DELAY), + max_attempts: MAX_RETRY_ATTEMPTS, + }) + } + StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed { + delay: retry_after.unwrap_or(BASE_RETRY_DELAY), + // Internal Server Error could be anything, retry up to 3 times. + max_attempts: 3, + }), + status => { + // There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"), + // but we frequently get them in practice. See https://http.dev/529 + if status.as_u16() == 529 { + Some(RetryStrategy::Fixed { + delay: retry_after.unwrap_or(BASE_RETRY_DELAY), + max_attempts: MAX_RETRY_ATTEMPTS, + }) + } else { + Some(RetryStrategy::Fixed { + delay: retry_after.unwrap_or(BASE_RETRY_DELAY), + max_attempts: 2, + }) + } + } + }, + ApiInternalServerError { .. } => Some(RetryStrategy::Fixed { + delay: BASE_RETRY_DELAY, + max_attempts: 3, + }), + ApiReadResponseError { .. } + | HttpSend { .. } + | DeserializeResponse { .. } + | BadRequestFormat { .. } => Some(RetryStrategy::Fixed { + delay: BASE_RETRY_DELAY, + max_attempts: 3, + }), + // Retrying these errors definitely shouldn't help. + HttpResponseError { + status_code: + StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED, + .. + } + | AuthenticationError { .. } + | PermissionError { .. } + | NoApiKey { .. } + | ApiEndpointNotFound { .. } + | PromptTooLarge { .. } => None, + // These errors might be transient, so retry them + SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed { + delay: BASE_RETRY_DELAY, + max_attempts: 1, + }), + // Retry all other 4xx and 5xx errors once. + HttpResponseError { status_code, .. } + if status_code.is_client_error() || status_code.is_server_error() => + { + Some(RetryStrategy::Fixed { + delay: BASE_RETRY_DELAY, + max_attempts: 3, + }) + } + // Conservatively assume that any other errors are non-retryable + HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed { + delay: BASE_RETRY_DELAY, + max_attempts: 2, + }), + } } fn handle_retryable_error_with_delay( &mut self, error: &LanguageModelCompletionError, - custom_delay: Option, + strategy: Option, model: Arc, intent: CompletionIntent, window: Option, cx: &mut Context, ) -> bool { + // Store context for the Retry button + self.last_error_context = Some((model.clone(), intent)); + + // Only auto-retry if Burn Mode is enabled + if self.completion_mode != CompletionMode::Burn { + // Show error with retry options + cx.emit(ThreadEvent::ShowError(ThreadError::RetryableError { + message: format!( + "{}\n\nTo automatically retry when similar errors happen, enable Burn Mode.", + error + ) + .into(), + can_enable_burn_mode: true, + })); + return false; + } + + let Some(strategy) = strategy.or_else(|| Self::get_retry_strategy(error)) else { + return false; + }; + + let max_attempts = match &strategy { + RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts, + RetryStrategy::Fixed { max_attempts, .. } => *max_attempts, + }; + let retry_state = self.retry_state.get_or_insert(RetryState { attempt: 0, - max_attempts: MAX_RETRY_ATTEMPTS, + max_attempts, intent, }); @@ -2238,20 +2322,24 @@ impl Thread { let intent = retry_state.intent; if attempt <= max_attempts { - // Use custom delay if provided (e.g., from rate limit), otherwise exponential backoff - let delay = if let Some(custom_delay) = custom_delay { - custom_delay - } else { - let delay_secs = BASE_RETRY_DELAY_SECS * 2u64.pow((attempt - 1) as u32); - Duration::from_secs(delay_secs) + let delay = match &strategy { + RetryStrategy::ExponentialBackoff { initial_delay, .. } => { + let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32); + Duration::from_secs(delay_secs) + } + RetryStrategy::Fixed { delay, .. } => *delay, }; // Add a transient message to inform the user let delay_secs = delay.as_secs(); - let retry_message = format!( - "{error}. Retrying (attempt {attempt} of {max_attempts}) \ - in {delay_secs} seconds..." - ); + let retry_message = if max_attempts == 1 { + format!("{error}. Retrying in {delay_secs} seconds...") + } else { + format!( + "{error}. Retrying (attempt {attempt} of {max_attempts}) \ + in {delay_secs} seconds..." + ) + }; log::warn!( "Retrying completion request (attempt {attempt} of {max_attempts}) \ in {delay_secs} seconds: {error:?}", @@ -2290,18 +2378,15 @@ impl Thread { // Max retries exceeded self.retry_state = None; - let notification_text = if max_attempts == 1 { - "Failed after retrying.".into() - } else { - format!("Failed after retrying {} times.", max_attempts).into() - }; - // Stop generating since we're giving up on retrying. self.pending_completions.clear(); - cx.emit(ThreadEvent::RetriesFailed { - message: notification_text, - }); + // Show error alongside a Retry button, but no + // Enable Burn Mode button (since it's already enabled) + cx.emit(ThreadEvent::ShowError(ThreadError::RetryableError { + message: format!("Failed after retrying: {}", error).into(), + can_enable_burn_mode: false, + })); false } @@ -2469,7 +2554,7 @@ impl Thread { return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx); } - if tool.needs_confirmation(&tool_use.input, cx) + if tool.needs_confirmation(&tool_use.input, &self.project, cx) && !AgentSettings::get_global(cx).always_allow_tool_actions { self.tool_use.confirm_tool_use( @@ -3167,8 +3252,10 @@ impl Thread { } fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context) { - self.project.update(cx, |project, cx| { - project.user_store().update(cx, |user_store, cx| { + self.project + .read(cx) + .user_store() + .update(cx, |user_store, cx| { user_store.update_model_request_usage( ModelRequestUsage(RequestUsage { amount: amount as i32, @@ -3176,8 +3263,7 @@ impl Thread { }), cx, ) - }) - }); + }); } pub fn deny_tool_use( @@ -3213,6 +3299,11 @@ pub enum ThreadError { header: SharedString, message: SharedString, }, + #[error("Retryable error: {message}")] + RetryableError { + message: SharedString, + can_enable_burn_mode: bool, + }, } #[derive(Debug, Clone)] @@ -3258,9 +3349,6 @@ pub enum ThreadEvent { CancelEditing, CompletionCanceled, ProfileChanged, - RetriesFailed { - message: SharedString, - }, } impl EventEmitter for Thread {} @@ -3288,7 +3376,6 @@ mod tests { use futures::stream::BoxStream; use gpui::TestAppContext; use http_client; - use indoc::indoc; use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider}; use language_model::{ LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId, @@ -3617,6 +3704,7 @@ fn main() {{ } #[gpui::test] + #[ignore] // turn this test on when project_notifications tool is re-enabled async fn test_stale_buffer_notification(cx: &mut TestAppContext) { init_test_settings(cx); @@ -3649,6 +3737,7 @@ fn main() {{ cx, ); }); + cx.run_until_parked(); // We shouldn't have a stale buffer notification yet let notifications = thread.read_with(cx, |thread, _| { @@ -3678,11 +3767,13 @@ fn main() {{ cx, ) }); + cx.run_until_parked(); // Check for the stale buffer warning thread.update(cx, |thread, cx| { thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx) }); + cx.run_until_parked(); let notifications = thread.read_with(cx, |thread, _cx| { find_tool_uses(thread, "project_notifications") @@ -3696,12 +3787,8 @@ fn main() {{ panic!("`project_notifications` should return text"); }; - let expected_content = indoc! {"[The following is an auto-generated notification; do not reply] - - These files have changed since the last read: - - code.rs - "}; - assert_eq!(notification_content, expected_content); + assert!(notification_content.contains("These files have changed since the last read:")); + assert!(notification_content.contains("code.rs")); // Insert another user message and flush notifications again thread.update(cx, |thread, cx| { @@ -3717,6 +3804,7 @@ fn main() {{ thread.update(cx, |thread, cx| { thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx) }); + cx.run_until_parked(); // There should be no new notifications (we already flushed one) let notifications = thread.read_with(cx, |thread, _cx| { @@ -3957,8 +4045,8 @@ fn main() {{ }); cx.run_until_parked(); - fake_model.stream_last_completion_response("Brief"); - fake_model.stream_last_completion_response(" Introduction"); + fake_model.send_last_completion_stream_text_chunk("Brief"); + fake_model.send_last_completion_stream_text_chunk(" Introduction"); fake_model.end_last_completion_stream(); cx.run_until_parked(); @@ -4051,7 +4139,7 @@ fn main() {{ }); cx.run_until_parked(); - fake_model.stream_last_completion_response("A successful summary"); + fake_model.send_last_completion_stream_text_chunk("A successful summary"); fake_model.end_last_completion_stream(); cx.run_until_parked(); @@ -4171,6 +4259,11 @@ fn main() {{ let project = create_test_project(cx, json!({})).await; let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + // Enable Burn Mode to allow retries + thread.update(cx, |thread, _| { + thread.set_completion_mode(CompletionMode::Burn); + }); + // Create model that returns overloaded error let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); @@ -4192,7 +4285,7 @@ fn main() {{ assert_eq!(retry_state.attempt, 1, "Should be first retry attempt"); assert_eq!( retry_state.max_attempts, MAX_RETRY_ATTEMPTS, - "Should have default max attempts" + "Should retry MAX_RETRY_ATTEMPTS times for overloaded errors" ); }); @@ -4244,6 +4337,11 @@ fn main() {{ let project = create_test_project(cx, json!({})).await; let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + // Enable Burn Mode to allow retries + thread.update(cx, |thread, _| { + thread.set_completion_mode(CompletionMode::Burn); + }); + // Create model that returns internal server error let model = Arc::new(ErrorInjector::new(TestError::InternalServerError)); @@ -4265,7 +4363,7 @@ fn main() {{ let retry_state = thread.retry_state.as_ref().unwrap(); assert_eq!(retry_state.attempt, 1, "Should be first retry attempt"); assert_eq!( - retry_state.max_attempts, MAX_RETRY_ATTEMPTS, + retry_state.max_attempts, 3, "Should have correct max attempts" ); }); @@ -4281,8 +4379,9 @@ fn main() {{ if let MessageSegment::Text(text) = seg { text.contains("internal") && text.contains("Fake") - && text - .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)) + && text.contains("Retrying") + && text.contains("attempt 1 of 3") + && text.contains("seconds") } else { false } @@ -4320,8 +4419,13 @@ fn main() {{ let project = create_test_project(cx, json!({})).await; let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; - // Create model that returns overloaded error - let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); + // Enable Burn Mode to allow retries + thread.update(cx, |thread, _| { + thread.set_completion_mode(CompletionMode::Burn); + }); + + // Create model that returns internal server error + let model = Arc::new(ErrorInjector::new(TestError::InternalServerError)); // Insert a user message thread.update(cx, |thread, cx| { @@ -4371,50 +4475,25 @@ fn main() {{ assert!(thread.retry_state.is_some(), "Should have retry state"); let retry_state = thread.retry_state.as_ref().unwrap(); assert_eq!(retry_state.attempt, 1, "Should be first retry attempt"); + assert_eq!( + retry_state.max_attempts, 3, + "Internal server errors should retry up to 3 times" + ); }); // Advance clock for first retry - cx.executor() - .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS)); + cx.executor().advance_clock(BASE_RETRY_DELAY); cx.run_until_parked(); - // Should have scheduled second retry - count retry messages - let retry_count = thread.update(cx, |thread, _| { - thread - .messages - .iter() - .filter(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("Retrying") && text.contains("seconds") - } else { - false - } - }) - }) - .count() - }); - assert_eq!(retry_count, 2, "Should have scheduled second retry"); - - // Check retry state updated - thread.read_with(cx, |thread, _| { - assert!(thread.retry_state.is_some(), "Should have retry state"); - let retry_state = thread.retry_state.as_ref().unwrap(); - assert_eq!(retry_state.attempt, 2, "Should be second retry attempt"); - assert_eq!( - retry_state.max_attempts, MAX_RETRY_ATTEMPTS, - "Should have correct max attempts" - ); - }); + // Advance clock for second retry + cx.executor().advance_clock(BASE_RETRY_DELAY); + cx.run_until_parked(); - // Advance clock for second retry (exponential backoff) - cx.executor() - .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 2)); + // Advance clock for third retry + cx.executor().advance_clock(BASE_RETRY_DELAY); cx.run_until_parked(); - // Should have scheduled third retry - // Count all retry messages now + // Should have completed all retries - count retry messages let retry_count = thread.update(cx, |thread, _| { thread .messages @@ -4432,56 +4511,24 @@ fn main() {{ .count() }); assert_eq!( - retry_count, MAX_RETRY_ATTEMPTS as usize, - "Should have scheduled third retry" + retry_count, 3, + "Should have 3 retries for internal server errors" ); - // Check retry state updated + // For internal server errors, we retry 3 times and then give up + // Check that retry_state is cleared after all retries thread.read_with(cx, |thread, _| { - assert!(thread.retry_state.is_some(), "Should have retry state"); - let retry_state = thread.retry_state.as_ref().unwrap(); - assert_eq!( - retry_state.attempt, MAX_RETRY_ATTEMPTS, - "Should be at max retry attempt" - ); - assert_eq!( - retry_state.max_attempts, MAX_RETRY_ATTEMPTS, - "Should have correct max attempts" + assert!( + thread.retry_state.is_none(), + "Retry state should be cleared after all retries" ); }); - // Advance clock for third retry (exponential backoff) - cx.executor() - .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 4)); - cx.run_until_parked(); - - // No more retries should be scheduled after clock was advanced. - let retry_count = thread.update(cx, |thread, _| { - thread - .messages - .iter() - .filter(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("Retrying") && text.contains("seconds") - } else { - false - } - }) - }) - .count() - }); - assert_eq!( - retry_count, MAX_RETRY_ATTEMPTS as usize, - "Should not exceed max retries" - ); - - // Final completion count should be initial + max retries + // Verify total attempts (1 initial + 3 retries) assert_eq!( *completion_count.lock(), - (MAX_RETRY_ATTEMPTS + 1) as usize, - "Should have made initial + max retry attempts" + 4, + "Should have attempted once plus 3 retries" ); } @@ -4492,6 +4539,11 @@ fn main() {{ let project = create_test_project(cx, json!({})).await; let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + // Enable Burn Mode to allow retries + thread.update(cx, |thread, _| { + thread.set_completion_mode(CompletionMode::Burn); + }); + // Create model that returns overloaded error let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); @@ -4501,13 +4553,13 @@ fn main() {{ }); // Track events - let retries_failed = Arc::new(Mutex::new(false)); - let retries_failed_clone = retries_failed.clone(); + let stopped_with_error = Arc::new(Mutex::new(false)); + let stopped_with_error_clone = stopped_with_error.clone(); let _subscription = thread.update(cx, |_, cx| { cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| { - if let ThreadEvent::RetriesFailed { .. } = event { - *retries_failed_clone.lock() = true; + if let ThreadEvent::Stopped(Err(_)) = event { + *stopped_with_error_clone.lock() = true; } }) }); @@ -4519,23 +4571,11 @@ fn main() {{ cx.run_until_parked(); // Advance through all retries - for i in 0..MAX_RETRY_ATTEMPTS { - let delay = if i == 0 { - BASE_RETRY_DELAY_SECS - } else { - BASE_RETRY_DELAY_SECS * 2u64.pow(i as u32 - 1) - }; - cx.executor().advance_clock(Duration::from_secs(delay)); + for _ in 0..MAX_RETRY_ATTEMPTS { + cx.executor().advance_clock(BASE_RETRY_DELAY); cx.run_until_parked(); } - // After the 3rd retry is scheduled, we need to wait for it to execute and fail - // The 3rd retry has a delay of BASE_RETRY_DELAY_SECS * 4 (20 seconds) - let final_delay = BASE_RETRY_DELAY_SECS * 2u64.pow((MAX_RETRY_ATTEMPTS - 1) as u32); - cx.executor() - .advance_clock(Duration::from_secs(final_delay)); - cx.run_until_parked(); - let retry_count = thread.update(cx, |thread, _| { thread .messages @@ -4553,14 +4593,14 @@ fn main() {{ .count() }); - // After max retries, should emit RetriesFailed event + // After max retries, should emit Stopped(Err(...)) event assert_eq!( retry_count, MAX_RETRY_ATTEMPTS as usize, - "Should have attempted max retries" + "Should have attempted MAX_RETRY_ATTEMPTS retries for overloaded errors" ); assert!( - *retries_failed.lock(), - "Should emit RetriesFailed event after max retries exceeded" + *stopped_with_error.lock(), + "Should emit Stopped(Err(...)) event after max retries exceeded" ); // Retry state should be cleared @@ -4578,7 +4618,7 @@ fn main() {{ .count(); assert_eq!( retry_messages, MAX_RETRY_ATTEMPTS as usize, - "Should have one retry message per attempt" + "Should have MAX_RETRY_ATTEMPTS retry messages for overloaded errors" ); }); } @@ -4590,6 +4630,11 @@ fn main() {{ let project = create_test_project(cx, json!({})).await; let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + // Enable Burn Mode to allow retries + thread.update(cx, |thread, _| { + thread.set_completion_mode(CompletionMode::Burn); + }); + // We'll use a wrapper to switch behavior after first failure struct RetryTestModel { inner: Arc, @@ -4716,8 +4761,7 @@ fn main() {{ }); // Wait for retry - cx.executor() - .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS)); + cx.executor().advance_clock(BASE_RETRY_DELAY); cx.run_until_parked(); // Stream some successful content @@ -4728,7 +4772,7 @@ fn main() {{ !pending.is_empty(), "Should have a pending completion after retry" ); - fake_model.stream_completion_response(&pending[0], "Success!"); + fake_model.send_completion_stream_text_chunk(&pending[0], "Success!"); fake_model.end_completion_stream(&pending[0]); cx.run_until_parked(); @@ -4759,6 +4803,11 @@ fn main() {{ let project = create_test_project(cx, json!({})).await; let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + // Enable Burn Mode to allow retries + thread.update(cx, |thread, _| { + thread.set_completion_mode(CompletionMode::Burn); + }); + // Create a model that fails once then succeeds struct FailOnceModel { inner: Arc, @@ -4879,8 +4928,7 @@ fn main() {{ }); // Wait for retry delay - cx.executor() - .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS)); + cx.executor().advance_clock(BASE_RETRY_DELAY); cx.run_until_parked(); // The retry should now use our FailOnceModel which should succeed @@ -4892,7 +4940,7 @@ fn main() {{ // Check for pending completions and complete them if let Some(pending) = inner_fake.pending_completions().first() { - inner_fake.stream_completion_response(pending, "Success!"); + inner_fake.send_completion_stream_text_chunk(pending, "Success!"); inner_fake.end_completion_stream(pending); } cx.run_until_parked(); @@ -4921,6 +4969,11 @@ fn main() {{ let project = create_test_project(cx, json!({})).await; let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + // Enable Burn Mode to allow retries + thread.update(cx, |thread, _| { + thread.set_completion_mode(CompletionMode::Burn); + }); + // Create a model that returns rate limit error with retry_after struct RateLimitModel { inner: Arc, @@ -5039,9 +5092,15 @@ fn main() {{ thread.read_with(cx, |thread, _| { assert!( - thread.retry_state.is_none(), - "Rate limit errors should not set retry_state" + thread.retry_state.is_some(), + "Rate limit errors should set retry_state" ); + if let Some(retry_state) = &thread.retry_state { + assert_eq!( + retry_state.max_attempts, MAX_RETRY_ATTEMPTS, + "Rate limit errors should use MAX_RETRY_ATTEMPTS" + ); + } }); // Verify we have one retry message @@ -5074,18 +5133,15 @@ fn main() {{ .find(|msg| msg.role == Role::System && msg.ui_only) .expect("Should have a retry message"); - // Check that the message doesn't contain attempt count + // Check that the message contains attempt count since we use retry_state if let Some(MessageSegment::Text(text)) = retry_message.segments.first() { assert!( - !text.contains("attempt"), - "Rate limit retry message should not contain attempt count" + text.contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)), + "Rate limit retry message should contain attempt count with MAX_RETRY_ATTEMPTS" ); assert!( - text.contains(&format!( - "Retrying in {} seconds", - TEST_RATE_LIMIT_RETRY_SECS - )), - "Rate limit retry message should contain retry delay" + text.contains("Retrying"), + "Rate limit retry message should contain retry text" ); } }); @@ -5191,6 +5247,79 @@ fn main() {{ ); } + #[gpui::test] + async fn test_no_retry_without_burn_mode(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + + // Ensure we're in Normal mode (not Burn mode) + thread.update(cx, |thread, _| { + thread.set_completion_mode(CompletionMode::Normal); + }); + + // Track error events + let error_events = Arc::new(Mutex::new(Vec::new())); + let error_events_clone = error_events.clone(); + + let _subscription = thread.update(cx, |_, cx| { + cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| { + if let ThreadEvent::ShowError(error) = event { + error_events_clone.lock().push(error.clone()); + } + }) + }); + + // Create model that returns overloaded error + let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); + + // Insert a user message + thread.update(cx, |thread, cx| { + thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); + }); + + // Start completion + thread.update(cx, |thread, cx| { + thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); + }); + + cx.run_until_parked(); + + // Verify no retry state was created + thread.read_with(cx, |thread, _| { + assert!( + thread.retry_state.is_none(), + "Should not have retry state in Normal mode" + ); + }); + + // Check that a retryable error was reported + let errors = error_events.lock(); + assert!(!errors.is_empty(), "Should have received an error event"); + + if let ThreadError::RetryableError { + message: _, + can_enable_burn_mode, + } = &errors[0] + { + assert!( + *can_enable_burn_mode, + "Error should indicate burn mode can be enabled" + ); + } else { + panic!("Expected RetryableError, got {:?}", errors[0]); + } + + // Verify the thread is no longer generating + thread.read_with(cx, |thread, _| { + assert!( + !thread.is_generating(), + "Should not be generating after error without retry" + ); + }); + } + #[gpui::test] async fn test_retry_cancelled_on_stop(cx: &mut TestAppContext) { init_test_settings(cx); @@ -5198,6 +5327,11 @@ fn main() {{ let project = create_test_project(cx, json!({})).await; let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + // Enable Burn Mode to allow retries + thread.update(cx, |thread, _| { + thread.set_completion_mode(CompletionMode::Burn); + }); + // Create model that returns overloaded error let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); @@ -5291,7 +5425,7 @@ fn main() {{ fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) { cx.run_until_parked(); - fake_model.stream_last_completion_response("Assistant response"); + fake_model.send_last_completion_stream_text_chunk("Assistant response"); fake_model.end_last_completion_stream(); cx.run_until_parked(); } @@ -5359,7 +5493,7 @@ fn main() {{ let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None)); - let provider = Arc::new(FakeLanguageModelProvider); + let provider = Arc::new(FakeLanguageModelProvider::default()); let model = provider.test_model(); let model: Arc = Arc::new(model); diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 0347156cd4df0d8b5d953def949739cab1135025..cc7cb50c9195a6a9a5bd0e0e1a17bd76caf153ee 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -41,6 +41,9 @@ use std::{ }; use util::ResultExt as _; +pub static ZED_STATELESS: std::sync::LazyLock = + std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty())); + #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum DataType { #[serde(rename = "json")] @@ -874,7 +877,11 @@ impl ThreadsDatabase { let needs_migration_from_heed = mdb_path.exists(); - let connection = Connection::open_file(&sqlite_path.to_string_lossy()); + let connection = if *ZED_STATELESS { + Connection::open_memory(Some("THREAD_FALLBACK_DB")) + } else { + Connection::open_file(&sqlite_path.to_string_lossy()) + }; connection.exec(indoc! {" CREATE TABLE IF NOT EXISTS threads ( diff --git a/crates/agent/src/tool_use.rs b/crates/agent/src/tool_use.rs index 74c719b4e6cf4ad0743a833f8b1c9fcc9da8b929..7392c0878d17adf8038292b10a7a8c349d3ec4e8 100644 --- a/crates/agent/src/tool_use.rs +++ b/crates/agent/src/tool_use.rs @@ -165,7 +165,12 @@ impl ToolUseState { self.pending_tool_uses_by_id.values().collect() } - pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec { + pub fn tool_uses_for_message( + &self, + id: MessageId, + project: &Entity, + cx: &App, + ) -> Vec { let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else { return Vec::new(); }; @@ -211,7 +216,10 @@ impl ToolUseState { let (icon, needs_confirmation) = if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) { - (tool.icon(), tool.needs_confirmation(&tool_use.input, cx)) + ( + tool.icon(), + tool.needs_confirmation(&tool_use.input, project, cx), + ) } else { (IconName::Cog, false) }; diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..884378fbcc07d22895bf9d5eb54520adb8db4aa2 --- /dev/null +++ b/crates/agent2/Cargo.toml @@ -0,0 +1,57 @@ +[package] +name = "agent2" +version = "0.1.0" +edition = "2021" +license = "GPL-3.0-or-later" +publish = false + +[lib] +path = "src/agent2.rs" + +[lints] +workspace = true + +[dependencies] +acp_thread.workspace = true +agent-client-protocol.workspace = true +agent_servers.workspace = true +anyhow.workspace = true +assistant_tool.workspace = true +cloud_llm_client.workspace = true +collections.workspace = true +fs.workspace = true +futures.workspace = true +gpui.workspace = true +handlebars = { workspace = true, features = ["rust-embed"] } +indoc.workspace = true +language_model.workspace = true +language_models.workspace = true +log.workspace = true +project.workspace = true +prompt_store.workspace = true +rust-embed.workspace = true +schemars.workspace = true +serde.workspace = true +serde_json.workspace = true +settings.workspace = true +smol.workspace = true +ui.workspace = true +util.workspace = true +uuid.workspace = true +watch.workspace = true +workspace-hack.workspace = true + +[dev-dependencies] +ctor.workspace = true +client = { workspace = true, "features" = ["test-support"] } +clock = { workspace = true, "features" = ["test-support"] } +env_logger.workspace = true +fs = { workspace = true, "features" = ["test-support"] } +gpui = { workspace = true, "features" = ["test-support"] } +gpui_tokio.workspace = true +language = { workspace = true, "features" = ["test-support"] } +language_model = { workspace = true, "features" = ["test-support"] } +project = { workspace = true, "features" = ["test-support"] } +reqwest_client.workspace = true +settings = { workspace = true, "features" = ["test-support"] } +worktree = { workspace = true, "features" = ["test-support"] } diff --git a/crates/inline_completion/LICENSE-GPL b/crates/agent2/LICENSE-GPL similarity index 100% rename from crates/inline_completion/LICENSE-GPL rename to crates/agent2/LICENSE-GPL diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs new file mode 100644 index 0000000000000000000000000000000000000000..cb568f04c29269f187166fbad25ff69acd6b29dc --- /dev/null +++ b/crates/agent2/src/agent.rs @@ -0,0 +1,702 @@ +use crate::{templates::Templates, AgentResponseEvent, Thread}; +use crate::{FindPathTool, ThinkingTool, ToolCallAuthorization}; +use acp_thread::ModelSelector; +use agent_client_protocol as acp; +use anyhow::{anyhow, Context as _, Result}; +use futures::{future, StreamExt}; +use gpui::{ + App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity, +}; +use language_model::{LanguageModel, LanguageModelRegistry}; +use project::{Project, ProjectItem, ProjectPath, Worktree}; +use prompt_store::{ + ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext, +}; +use std::cell::RefCell; +use std::collections::HashMap; +use std::path::Path; +use std::rc::Rc; +use std::sync::Arc; +use util::ResultExt; + +const RULES_FILE_NAMES: [&'static str; 9] = [ + ".rules", + ".cursorrules", + ".windsurfrules", + ".clinerules", + ".github/copilot-instructions.md", + "CLAUDE.md", + "AGENT.md", + "AGENTS.md", + "GEMINI.md", +]; + +pub struct RulesLoadingError { + pub message: SharedString, +} + +/// Holds both the internal Thread and the AcpThread for a session +struct Session { + /// The internal thread that processes messages + thread: Entity, + /// The ACP thread that handles protocol communication + acp_thread: WeakEntity, + _subscription: Subscription, +} + +pub struct NativeAgent { + /// Session ID -> Session mapping + sessions: HashMap, + /// Shared project context for all threads + project_context: Rc>, + project_context_needs_refresh: watch::Sender<()>, + _maintain_project_context: Task>, + /// Shared templates for all threads + templates: Arc, + project: Entity, + prompt_store: Option>, + _subscriptions: Vec, +} + +impl NativeAgent { + pub async fn new( + project: Entity, + templates: Arc, + prompt_store: Option>, + cx: &mut AsyncApp, + ) -> Result> { + log::info!("Creating new NativeAgent"); + + let project_context = cx + .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))? + .await; + + cx.new(|cx| { + let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)]; + if let Some(prompt_store) = prompt_store.as_ref() { + subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event)) + } + + let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) = + watch::channel(()); + Self { + sessions: HashMap::new(), + project_context: Rc::new(RefCell::new(project_context)), + project_context_needs_refresh: project_context_needs_refresh_tx, + _maintain_project_context: cx.spawn(async move |this, cx| { + Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await + }), + templates, + project, + prompt_store, + _subscriptions: subscriptions, + } + }) + } + + async fn maintain_project_context( + this: WeakEntity, + mut needs_refresh: watch::Receiver<()>, + cx: &mut AsyncApp, + ) -> Result<()> { + while needs_refresh.changed().await.is_ok() { + let project_context = this + .update(cx, |this, cx| { + Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx) + })? + .await; + this.update(cx, |this, _| this.project_context.replace(project_context))?; + } + + Ok(()) + } + + fn build_project_context( + project: &Entity, + prompt_store: Option<&Entity>, + cx: &mut App, + ) -> Task { + let worktrees = project.read(cx).visible_worktrees(cx).collect::>(); + let worktree_tasks = worktrees + .into_iter() + .map(|worktree| { + Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx) + }) + .collect::>(); + let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() { + prompt_store.read_with(cx, |prompt_store, cx| { + let prompts = prompt_store.default_prompt_metadata(); + let load_tasks = prompts.into_iter().map(|prompt_metadata| { + let contents = prompt_store.load(prompt_metadata.id, cx); + async move { (contents.await, prompt_metadata) } + }); + cx.background_spawn(future::join_all(load_tasks)) + }) + } else { + Task::ready(vec![]) + }; + + cx.spawn(async move |_cx| { + let (worktrees, default_user_rules) = + future::join(future::join_all(worktree_tasks), default_user_rules_task).await; + + let worktrees = worktrees + .into_iter() + .map(|(worktree, _rules_error)| { + // TODO: show error message + // if let Some(rules_error) = rules_error { + // this.update(cx, |_, cx| cx.emit(rules_error)).ok(); + // } + worktree + }) + .collect::>(); + + let default_user_rules = default_user_rules + .into_iter() + .flat_map(|(contents, prompt_metadata)| match contents { + Ok(contents) => Some(UserRulesContext { + uuid: match prompt_metadata.id { + PromptId::User { uuid } => uuid, + PromptId::EditWorkflow => return None, + }, + title: prompt_metadata.title.map(|title| title.to_string()), + contents, + }), + Err(_err) => { + // TODO: show error message + // this.update(cx, |_, cx| { + // cx.emit(RulesLoadingError { + // message: format!("{err:?}").into(), + // }); + // }) + // .ok(); + None + } + }) + .collect::>(); + + ProjectContext::new(worktrees, default_user_rules) + }) + } + + fn load_worktree_info_for_system_prompt( + worktree: Entity, + project: Entity, + cx: &mut App, + ) -> Task<(WorktreeContext, Option)> { + let tree = worktree.read(cx); + let root_name = tree.root_name().into(); + let abs_path = tree.abs_path(); + + let mut context = WorktreeContext { + root_name, + abs_path, + rules_file: None, + }; + + let rules_task = Self::load_worktree_rules_file(worktree, project, cx); + let Some(rules_task) = rules_task else { + return Task::ready((context, None)); + }; + + cx.spawn(async move |_| { + let (rules_file, rules_file_error) = match rules_task.await { + Ok(rules_file) => (Some(rules_file), None), + Err(err) => ( + None, + Some(RulesLoadingError { + message: format!("{err}").into(), + }), + ), + }; + context.rules_file = rules_file; + (context, rules_file_error) + }) + } + + fn load_worktree_rules_file( + worktree: Entity, + project: Entity, + cx: &mut App, + ) -> Option>> { + let worktree = worktree.read(cx); + let worktree_id = worktree.id(); + let selected_rules_file = RULES_FILE_NAMES + .into_iter() + .filter_map(|name| { + worktree + .entry_for_path(name) + .filter(|entry| entry.is_file()) + .map(|entry| entry.path.clone()) + }) + .next(); + + // Note that Cline supports `.clinerules` being a directory, but that is not currently + // supported. This doesn't seem to occur often in GitHub repositories. + selected_rules_file.map(|path_in_worktree| { + let project_path = ProjectPath { + worktree_id, + path: path_in_worktree.clone(), + }; + let buffer_task = + project.update(cx, |project, cx| project.open_buffer(project_path, cx)); + let rope_task = cx.spawn(async move |cx| { + buffer_task.await?.read_with(cx, |buffer, cx| { + let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?; + anyhow::Ok((project_entry_id, buffer.as_rope().clone())) + })? + }); + // Build a string from the rope on a background thread. + cx.background_spawn(async move { + let (project_entry_id, rope) = rope_task.await?; + anyhow::Ok(RulesFileContext { + path_in_worktree, + text: rope.to_string().trim().to_string(), + project_entry_id: project_entry_id.to_usize(), + }) + }) + }) + } + + fn handle_project_event( + &mut self, + _project: Entity, + event: &project::Event, + _cx: &mut Context, + ) { + match event { + project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => { + self.project_context_needs_refresh.send(()).ok(); + } + project::Event::WorktreeUpdatedEntries(_, items) => { + if items.iter().any(|(path, _, _)| { + RULES_FILE_NAMES + .iter() + .any(|name| path.as_ref() == Path::new(name)) + }) { + self.project_context_needs_refresh.send(()).ok(); + } + } + _ => {} + } + } + + fn handle_prompts_updated_event( + &mut self, + _prompt_store: Entity, + _event: &prompt_store::PromptsUpdatedEvent, + _cx: &mut Context, + ) { + self.project_context_needs_refresh.send(()).ok(); + } +} + +/// Wrapper struct that implements the AgentConnection trait +#[derive(Clone)] +pub struct NativeAgentConnection(pub Entity); + +impl ModelSelector for NativeAgentConnection { + fn list_models(&self, cx: &mut AsyncApp) -> Task>>> { + log::debug!("NativeAgentConnection::list_models called"); + cx.spawn(async move |cx| { + cx.update(|cx| { + let registry = LanguageModelRegistry::read_global(cx); + let models = registry.available_models(cx).collect::>(); + log::info!("Found {} available models", models.len()); + if models.is_empty() { + Err(anyhow::anyhow!("No models available")) + } else { + Ok(models) + } + })? + }) + } + + fn select_model( + &self, + session_id: acp::SessionId, + model: Arc, + cx: &mut AsyncApp, + ) -> Task> { + log::info!( + "Setting model for session {}: {:?}", + session_id, + model.name() + ); + let agent = self.0.clone(); + + cx.spawn(async move |cx| { + agent.update(cx, |agent, cx| { + if let Some(session) = agent.sessions.get(&session_id) { + session.thread.update(cx, |thread, _cx| { + thread.selected_model = model; + }); + Ok(()) + } else { + Err(anyhow!("Session not found")) + } + })? + }) + } + + fn selected_model( + &self, + session_id: &acp::SessionId, + cx: &mut AsyncApp, + ) -> Task>> { + let agent = self.0.clone(); + let session_id = session_id.clone(); + cx.spawn(async move |cx| { + let thread = agent + .read_with(cx, |agent, _| { + agent + .sessions + .get(&session_id) + .map(|session| session.thread.clone()) + })? + .ok_or_else(|| anyhow::anyhow!("Session not found"))?; + let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?; + Ok(selected) + }) + } +} + +impl acp_thread::AgentConnection for NativeAgentConnection { + fn new_thread( + self: Rc, + project: Entity, + cwd: &Path, + cx: &mut AsyncApp, + ) -> Task>> { + let agent = self.0.clone(); + log::info!("Creating new thread for project at: {:?}", cwd); + + cx.spawn(async move |cx| { + log::debug!("Starting thread creation in async context"); + + // Generate session ID + let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into()); + log::info!("Created session with ID: {}", session_id); + + // Create AcpThread + let acp_thread = cx.update(|cx| { + cx.new(|cx| { + acp_thread::AcpThread::new("agent2", self.clone(), project.clone(), session_id.clone(), cx) + }) + })?; + let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?; + + // Create Thread + let thread = agent.update( + cx, + |agent, cx: &mut gpui::Context| -> Result<_> { + // Fetch default model from registry settings + let registry = LanguageModelRegistry::read_global(cx); + + // Log available models for debugging + let available_count = registry.available_models(cx).count(); + log::debug!("Total available models: {}", available_count); + + let default_model = registry + .default_model() + .map(|configured| { + log::info!( + "Using configured default model: {:?} from provider: {:?}", + configured.model.name(), + configured.provider.name() + ); + configured.model + }) + .ok_or_else(|| { + log::warn!("No default model configured in settings"); + anyhow!("No default model configured. Please configure a default model in settings.") + })?; + + let thread = cx.new(|_| { + let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log, agent.templates.clone(), default_model); + thread.add_tool(ThinkingTool); + thread.add_tool(FindPathTool::new(project.clone())); + thread + }); + + Ok(thread) + }, + )??; + + // Store the session + agent.update(cx, |agent, cx| { + agent.sessions.insert( + session_id, + Session { + thread, + acp_thread: acp_thread.downgrade(), + _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| { + this.sessions.remove(acp_thread.session_id()); + }) + }, + ); + })?; + + Ok(acp_thread) + }) + } + + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] // No auth for in-process + } + + fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task> { + Task::ready(Ok(())) + } + + fn model_selector(&self) -> Option> { + Some(Rc::new(self.clone()) as Rc) + } + + fn prompt( + &self, + params: acp::PromptRequest, + cx: &mut App, + ) -> Task> { + let session_id = params.session_id.clone(); + let agent = self.0.clone(); + log::info!("Received prompt request for session: {}", session_id); + log::debug!("Prompt blocks count: {}", params.prompt.len()); + + cx.spawn(async move |cx| { + // Get session + let (thread, acp_thread) = agent + .update(cx, |agent, _| { + agent + .sessions + .get_mut(&session_id) + .map(|s| (s.thread.clone(), s.acp_thread.clone())) + })? + .ok_or_else(|| { + log::error!("Session not found: {}", session_id); + anyhow::anyhow!("Session not found") + })?; + log::debug!("Found session for: {}", session_id); + + // Convert prompt to message + let message = convert_prompt_to_message(params.prompt); + log::info!("Converted prompt to message: {} chars", message.len()); + log::debug!("Message content: {}", message); + + // Get model using the ModelSelector capability (always available for agent2) + // Get the selected model from the thread directly + let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?; + + // Send to thread + log::info!("Sending message to thread with model: {:?}", model.name()); + let mut response_stream = + thread.update(cx, |thread, cx| thread.send(model, message, cx))?; + + // Handle response stream and forward to session.acp_thread + while let Some(result) = response_stream.next().await { + match result { + Ok(event) => { + log::trace!("Received completion event: {:?}", event); + + match event { + AgentResponseEvent::Text(text) => { + acp_thread.update(cx, |thread, cx| { + thread.handle_session_update( + acp::SessionUpdate::AgentMessageChunk { + content: acp::ContentBlock::Text(acp::TextContent { + text, + annotations: None, + }), + }, + cx, + ) + })??; + } + AgentResponseEvent::Thinking(text) => { + acp_thread.update(cx, |thread, cx| { + thread.handle_session_update( + acp::SessionUpdate::AgentThoughtChunk { + content: acp::ContentBlock::Text(acp::TextContent { + text, + annotations: None, + }), + }, + cx, + ) + })??; + } + AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization { + tool_call, + options, + response, + }) => { + let recv = acp_thread.update(cx, |thread, cx| { + thread.request_tool_call_authorization(tool_call, options, cx) + })?; + cx.background_spawn(async move { + if let Some(option) = recv + .await + .context("authorization sender was dropped") + .log_err() + { + response + .send(option) + .map(|_| anyhow!("authorization receiver was dropped")) + .log_err(); + } + }) + .detach(); + } + AgentResponseEvent::ToolCall(tool_call) => { + acp_thread.update(cx, |thread, cx| { + thread.handle_session_update( + acp::SessionUpdate::ToolCall(tool_call), + cx, + ) + })??; + } + AgentResponseEvent::ToolCallUpdate(tool_call_update) => { + acp_thread.update(cx, |thread, cx| { + thread.handle_session_update( + acp::SessionUpdate::ToolCallUpdate(tool_call_update), + cx, + ) + })??; + } + AgentResponseEvent::Stop(stop_reason) => { + log::debug!("Assistant message complete: {:?}", stop_reason); + return Ok(acp::PromptResponse { stop_reason }); + } + } + } + Err(e) => { + log::error!("Error in model response stream: {:?}", e); + // TODO: Consider sending an error message to the UI + break; + } + } + } + + log::info!("Response stream completed"); + anyhow::Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) + }) + } + + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { + log::info!("Cancelling on session: {}", session_id); + self.0.update(cx, |agent, cx| { + if let Some(agent) = agent.sessions.get(session_id) { + agent.thread.update(cx, |thread, _cx| thread.cancel()); + } + }); + } +} + +/// Convert ACP content blocks to a message string +fn convert_prompt_to_message(blocks: Vec) -> String { + log::debug!("Converting {} content blocks to message", blocks.len()); + let mut message = String::new(); + + for block in blocks { + match block { + acp::ContentBlock::Text(text) => { + log::trace!("Processing text block: {} chars", text.text.len()); + message.push_str(&text.text); + } + acp::ContentBlock::ResourceLink(link) => { + log::trace!("Processing resource link: {}", link.uri); + message.push_str(&format!(" @{} ", link.uri)); + } + acp::ContentBlock::Image(_) => { + log::trace!("Processing image block"); + message.push_str(" [image] "); + } + acp::ContentBlock::Audio(_) => { + log::trace!("Processing audio block"); + message.push_str(" [audio] "); + } + acp::ContentBlock::Resource(resource) => { + log::trace!("Processing resource block: {:?}", resource.resource); + message.push_str(&format!(" [resource: {:?}] ", resource.resource)); + } + } + } + + message +} + +#[cfg(test)] +mod tests { + use super::*; + use fs::FakeFs; + use gpui::TestAppContext; + use serde_json::json; + use settings::SettingsStore; + + #[gpui::test] + async fn test_maintaining_project_context(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/", + json!({ + "a": {} + }), + ) + .await; + let project = Project::test(fs.clone(), [], cx).await; + let agent = NativeAgent::new(project.clone(), Templates::new(), None, &mut cx.to_async()) + .await + .unwrap(); + agent.read_with(cx, |agent, _| { + assert_eq!(agent.project_context.borrow().worktrees, vec![]) + }); + + let worktree = project + .update(cx, |project, cx| project.create_worktree("/a", true, cx)) + .await + .unwrap(); + cx.run_until_parked(); + agent.read_with(cx, |agent, _| { + assert_eq!( + agent.project_context.borrow().worktrees, + vec![WorktreeContext { + root_name: "a".into(), + abs_path: Path::new("/a").into(), + rules_file: None + }] + ) + }); + + // Creating `/a/.rules` updates the project context. + fs.insert_file("/a/.rules", Vec::new()).await; + cx.run_until_parked(); + agent.read_with(cx, |agent, cx| { + let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap(); + assert_eq!( + agent.project_context.borrow().worktrees, + vec![WorktreeContext { + root_name: "a".into(), + abs_path: Path::new("/a").into(), + rules_file: Some(RulesFileContext { + path_in_worktree: Path::new(".rules").into(), + text: "".into(), + project_entry_id: rules_entry.id.to_usize() + }) + }] + ) + }); + } + + fn init_test(cx: &mut TestAppContext) { + env_logger::try_init().ok(); + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + Project::init_settings(cx); + language::init(cx); + }); + } +} diff --git a/crates/agent2/src/agent2.rs b/crates/agent2/src/agent2.rs new file mode 100644 index 0000000000000000000000000000000000000000..db743c84296f008d9b47317096525cd816460a87 --- /dev/null +++ b/crates/agent2/src/agent2.rs @@ -0,0 +1,13 @@ +mod agent; +mod native_agent_server; +mod templates; +mod thread; +mod tools; + +#[cfg(test)] +mod tests; + +pub use agent::*; +pub use native_agent_server::NativeAgentServer; +pub use thread::*; +pub use tools::*; diff --git a/crates/agent2/src/native_agent_server.rs b/crates/agent2/src/native_agent_server.rs new file mode 100644 index 0000000000000000000000000000000000000000..dd0188b54848903c9fdd5b56db5f44c3d76a84f4 --- /dev/null +++ b/crates/agent2/src/native_agent_server.rs @@ -0,0 +1,60 @@ +use std::path::Path; +use std::rc::Rc; + +use agent_servers::AgentServer; +use anyhow::Result; +use gpui::{App, Entity, Task}; +use project::Project; +use prompt_store::PromptStore; + +use crate::{templates::Templates, NativeAgent, NativeAgentConnection}; + +#[derive(Clone)] +pub struct NativeAgentServer; + +impl AgentServer for NativeAgentServer { + fn name(&self) -> &'static str { + "Native Agent" + } + + fn empty_state_headline(&self) -> &'static str { + "Native Agent" + } + + fn empty_state_message(&self) -> &'static str { + "How can I help you today?" + } + + fn logo(&self) -> ui::IconName { + // Using the ZedAssistant icon as it's the native built-in agent + ui::IconName::ZedAssistant + } + + fn connect( + &self, + _root_dir: &Path, + project: &Entity, + cx: &mut App, + ) -> Task>> { + log::info!( + "NativeAgentServer::connect called for path: {:?}", + _root_dir + ); + let project = project.clone(); + let prompt_store = PromptStore::global(cx); + cx.spawn(async move |cx| { + log::debug!("Creating templates for native agent"); + let templates = Templates::new(); + let prompt_store = prompt_store.await?; + + log::debug!("Creating native agent entity"); + let agent = NativeAgent::new(project, templates, Some(prompt_store), cx).await?; + + // Create the connection wrapper + let connection = NativeAgentConnection(agent); + log::info!("NativeAgentServer connection established successfully"); + + Ok(Rc::new(connection) as Rc) + }) + } +} diff --git a/crates/agent2/src/templates.rs b/crates/agent2/src/templates.rs new file mode 100644 index 0000000000000000000000000000000000000000..a63f0ad206308130712b9481cfd7231eb0fd2696 --- /dev/null +++ b/crates/agent2/src/templates.rs @@ -0,0 +1,87 @@ +use anyhow::Result; +use gpui::SharedString; +use handlebars::Handlebars; +use rust_embed::RustEmbed; +use serde::Serialize; +use std::sync::Arc; + +#[derive(RustEmbed)] +#[folder = "src/templates"] +#[include = "*.hbs"] +struct Assets; + +pub struct Templates(Handlebars<'static>); + +impl Templates { + pub fn new() -> Arc { + let mut handlebars = Handlebars::new(); + handlebars.set_strict_mode(true); + handlebars.register_helper("contains", Box::new(contains)); + handlebars.register_embed_templates::().unwrap(); + Arc::new(Self(handlebars)) + } +} + +pub trait Template: Sized { + const TEMPLATE_NAME: &'static str; + + fn render(&self, templates: &Templates) -> Result + where + Self: Serialize + Sized, + { + Ok(templates.0.render(Self::TEMPLATE_NAME, self)?) + } +} + +#[derive(Serialize)] +pub struct SystemPromptTemplate<'a> { + #[serde(flatten)] + pub project: &'a prompt_store::ProjectContext, + pub available_tools: Vec, +} + +impl Template for SystemPromptTemplate<'_> { + const TEMPLATE_NAME: &'static str = "system_prompt.hbs"; +} + +/// Handlebars helper for checking if an item is in a list +fn contains( + h: &handlebars::Helper, + _: &handlebars::Handlebars, + _: &handlebars::Context, + _: &mut handlebars::RenderContext, + out: &mut dyn handlebars::Output, +) -> handlebars::HelperResult { + let list = h + .param(0) + .and_then(|v| v.value().as_array()) + .ok_or_else(|| { + handlebars::RenderError::new("contains: missing or invalid list parameter") + })?; + let query = h.param(1).map(|v| v.value()).ok_or_else(|| { + handlebars::RenderError::new("contains: missing or invalid query parameter") + })?; + + if list.contains(&query) { + out.write("true")?; + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_system_prompt_template() { + let project = prompt_store::ProjectContext::default(); + let template = SystemPromptTemplate { + project: &project, + available_tools: vec!["echo".into()], + }; + let templates = Templates::new(); + let rendered = template.render(&templates).unwrap(); + assert!(rendered.contains("## Fixing Diagnostics")); + } +} diff --git a/crates/agent2/src/templates/system_prompt.hbs b/crates/agent2/src/templates/system_prompt.hbs new file mode 100644 index 0000000000000000000000000000000000000000..a9f67460d81e79f03d0a0a9b60cd4d6c32fc3b20 --- /dev/null +++ b/crates/agent2/src/templates/system_prompt.hbs @@ -0,0 +1,178 @@ +You are a highly skilled software engineer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices. + +## Communication + +1. Be conversational but professional. +2. Refer to the user in the second person and yourself in the first person. +3. Format your responses in markdown. Use backticks to format file, directory, function, and class names. +4. NEVER lie or make things up. +5. Refrain from apologizing all the time when results are unexpected. Instead, just try your best to proceed or explain the circumstances to the user without apologizing. + +{{#if (gt (len available_tools) 0)}} +## Tool Use + +1. Make sure to adhere to the tools schema. +2. Provide every required argument. +3. DO NOT use tools to access items that are already available in the context section. +4. Use only the tools that are currently available. +5. DO NOT use a tool that is not available just because it appears in the conversation. This means the user turned it off. +6. NEVER run commands that don't terminate on their own such as web servers (like `npm run start`, `npm run dev`, `python -m http.server`, etc) or file watchers. +7. Avoid HTML entity escaping - use plain characters instead. + +## Searching and Reading + +If you are unsure how to fulfill the user's request, gather more information with tool calls and/or clarifying questions. + +If appropriate, use tool calls to explore the current project, which contains the following root directories: + +{{#each worktrees}} +- `{{abs_path}}` +{{/each}} + +- Bias towards not asking the user for help if you can find the answer yourself. +- When providing paths to tools, the path should always start with the name of a project root directory listed above. +- Before you read or edit a file, you must first find the full path. DO NOT ever guess a file path! +{{# if (contains available_tools 'grep') }} +- When looking for symbols in the project, prefer the `grep` tool. +- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project. +- The user might specify a partial file path. If you don't know the full path, use `find_path` (not `grep`) before you read the file. +{{/if}} +{{else}} +You are being tasked with providing a response, but you have no ability to use tools or to read or write any aspect of the user's system (other than any context the user might have provided to you). + +As such, if you need the user to perform any actions for you, you must request them explicitly. Bias towards giving a response to the best of your ability, and then making requests for the user to take action (e.g. to give you more context) only optionally. + +The one exception to this is if the user references something you don't know about - for example, the name of a source code file, function, type, or other piece of code that you have no awareness of. In this case, you MUST NOT MAKE SOMETHING UP, or assume you know what that thing is or how it works. Instead, you must ask the user for clarification rather than giving a response. +{{/if}} + +## Code Block Formatting + +Whenever you mention a code block, you MUST use ONLY use the following format: +```path/to/Something.blah#L123-456 +(code goes here) +``` +The `#L123-456` means the line number range 123 through 456, and the path/to/Something.blah +is a path in the project. (If there is no valid path in the project, then you can use +/dev/null/path.extension for its path.) This is the ONLY valid way to format code blocks, because the Markdown parser +does not understand the more common ```language syntax, or bare ``` blocks. It only +understands this path-based syntax, and if the path is missing, then it will error and you will have to do it over again. +Just to be really clear about this, if you ever find yourself writing three backticks followed by a language name, STOP! +You have made a mistake. You can only ever put paths after triple backticks! + +Based on all the information I've gathered, here's a summary of how this system works: +1. The README file is loaded into the system. +2. The system finds the first two headers, including everything in between. In this case, that would be: +```path/to/README.md#L8-12 +# First Header +This is the info under the first header. +## Sub-header +``` +3. Then the system finds the last header in the README: +```path/to/README.md#L27-29 +## Last Header +This is the last header in the README. +``` +4. Finally, it passes this information on to the next process. + + +In Markdown, hash marks signify headings. For example: +```/dev/null/example.md#L1-3 +# Level 1 heading +## Level 2 heading +### Level 3 heading +``` + +Here are examples of ways you must never render code blocks: + +In Markdown, hash marks signify headings. For example: +``` +# Level 1 heading +## Level 2 heading +### Level 3 heading +``` + +This example is unacceptable because it does not include the path. + +In Markdown, hash marks signify headings. For example: +```markdown +# Level 1 heading +## Level 2 heading +### Level 3 heading +``` + +This example is unacceptable because it has the language instead of the path. + +In Markdown, hash marks signify headings. For example: + # Level 1 heading + ## Level 2 heading + ### Level 3 heading + +This example is unacceptable because it uses indentation to mark the code block +instead of backticks with a path. + +In Markdown, hash marks signify headings. For example: +```markdown +/dev/null/example.md#L1-3 +# Level 1 heading +## Level 2 heading +### Level 3 heading +``` + +This example is unacceptable because the path is in the wrong place. The path must be directly after the opening backticks. + +{{#if (gt (len available_tools) 0)}} +## Fixing Diagnostics + +1. Make 1-2 attempts at fixing diagnostics, then defer to the user. +2. Never simplify code you've written just to solve diagnostics. Complete, mostly correct code is more valuable than perfect code that doesn't solve the problem. + +## Debugging + +When debugging, only make code changes if you are certain that you can solve the problem. +Otherwise, follow debugging best practices: +1. Address the root cause instead of the symptoms. +2. Add descriptive logging statements and error messages to track variable and code state. +3. Add test functions and statements to isolate the problem. + +{{/if}} +## Calling External APIs + +1. Unless explicitly requested by the user, use the best suited external APIs and packages to solve the task. There is no need to ask the user for permission. +2. When selecting which version of an API or package to use, choose one that is compatible with the user's dependency management file(s). If no such file exists or if the package is not present, use the latest version that is in your training data. +3. If an external API requires an API Key, be sure to point this out to the user. Adhere to best security practices (e.g. DO NOT hardcode an API key in a place where it can be exposed) + +## System Information + +Operating System: {{os}} +Default Shell: {{shell}} + +{{#if (or has_rules has_user_rules)}} +## User's Custom Instructions + +The following additional instructions are provided by the user, and should be followed to the best of your ability{{#if (gt (len available_tools) 0)}} without interfering with the tool use guidelines{{/if}}. + +{{#if has_rules}} +There are project rules that apply to these root directories: +{{#each worktrees}} +{{#if rules_file}} +`{{root_name}}/{{rules_file.path_in_worktree}}`: +`````` +{{{rules_file.text}}} +`````` +{{/if}} +{{/each}} +{{/if}} + +{{#if has_user_rules}} +The user has specified the following rules that should be applied: +{{#each user_rules}} + +{{#if title}} +Rules title: {{title}} +{{/if}} +`````` +{{contents}}} +`````` +{{/each}} +{{/if}} +{{/if}} diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..7913f9a24cb988ffff507efbd6d6bdf38a008b66 --- /dev/null +++ b/crates/agent2/src/tests/mod.rs @@ -0,0 +1,817 @@ +use super::*; +use crate::templates::Templates; +use acp_thread::AgentConnection; +use agent_client_protocol::{self as acp}; +use anyhow::Result; +use assistant_tool::ActionLog; +use client::{Client, UserStore}; +use fs::FakeFs; +use futures::channel::mpsc::UnboundedReceiver; +use gpui::{http_client::FakeHttpClient, AppContext, Entity, Task, TestAppContext}; +use indoc::indoc; +use language_model::{ + fake_provider::FakeLanguageModel, LanguageModel, LanguageModelCompletionError, + LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, LanguageModelToolResult, + LanguageModelToolUse, MessageContent, Role, StopReason, +}; +use project::Project; +use prompt_store::ProjectContext; +use reqwest_client::ReqwestClient; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use smol::stream::StreamExt; +use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration}; +use util::path; + +mod test_tools; +use test_tools::*; + +#[gpui::test] +#[ignore = "can't run on CI yet"] +async fn test_echo(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await; + + let events = thread + .update(cx, |thread, cx| { + thread.send(model.clone(), "Testing: Reply with 'Hello'", cx) + }) + .collect() + .await; + thread.update(cx, |thread, _cx| { + assert_eq!( + thread.messages().last().unwrap().content, + vec![MessageContent::Text("Hello".to_string())] + ); + }); + assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]); +} + +#[gpui::test] +#[ignore = "can't run on CI yet"] +async fn test_thinking(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await; + + let events = thread + .update(cx, |thread, cx| { + thread.send( + model.clone(), + indoc! {" + Testing: + + Generate a thinking step where you just think the word 'Think', + and have your final answer be 'Hello' + "}, + cx, + ) + }) + .collect() + .await; + thread.update(cx, |thread, _cx| { + assert_eq!( + thread.messages().last().unwrap().to_markdown(), + indoc! {" + ## assistant + Think + Hello + "} + ) + }); + assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]); +} + +#[gpui::test] +async fn test_system_prompt(cx: &mut TestAppContext) { + let ThreadTest { + model, + thread, + project_context, + .. + } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + project_context.borrow_mut().shell = "test-shell".into(); + thread.update(cx, |thread, _| thread.add_tool(EchoTool)); + thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx)); + cx.run_until_parked(); + let mut pending_completions = fake_model.pending_completions(); + assert_eq!( + pending_completions.len(), + 1, + "unexpected pending completions: {:?}", + pending_completions + ); + + let pending_completion = pending_completions.pop().unwrap(); + assert_eq!(pending_completion.messages[0].role, Role::System); + + let system_message = &pending_completion.messages[0]; + let system_prompt = system_message.content[0].to_str().unwrap(); + assert!( + system_prompt.contains("test-shell"), + "unexpected system message: {:?}", + system_message + ); + assert!( + system_prompt.contains("## Fixing Diagnostics"), + "unexpected system message: {:?}", + system_message + ); +} + +#[gpui::test] +#[ignore = "can't run on CI yet"] +async fn test_basic_tool_calls(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await; + + // Test a tool call that's likely to complete *before* streaming stops. + let events = thread + .update(cx, |thread, cx| { + thread.add_tool(EchoTool); + thread.send( + model.clone(), + "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.", + cx, + ) + }) + .collect() + .await; + assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]); + + // Test a tool calls that's likely to complete *after* streaming stops. + let events = thread + .update(cx, |thread, cx| { + thread.remove_tool(&AgentTool::name(&EchoTool)); + thread.add_tool(DelayTool); + thread.send( + model.clone(), + "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.", + cx, + ) + }) + .collect() + .await; + assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]); + thread.update(cx, |thread, _cx| { + assert!(thread + .messages() + .last() + .unwrap() + .content + .iter() + .any(|content| { + if let MessageContent::Text(text) = content { + text.contains("Ding") + } else { + false + } + })); + }); +} + +#[gpui::test] +#[ignore = "can't run on CI yet"] +async fn test_streaming_tool_calls(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await; + + // Test a tool call that's likely to complete *before* streaming stops. + let mut events = thread.update(cx, |thread, cx| { + thread.add_tool(WordListTool); + thread.send(model.clone(), "Test the word_list tool.", cx) + }); + + let mut saw_partial_tool_use = false; + while let Some(event) = events.next().await { + if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event { + thread.update(cx, |thread, _cx| { + // Look for a tool use in the thread's last message + let last_content = thread.messages().last().unwrap().content.last().unwrap(); + if let MessageContent::ToolUse(last_tool_use) = last_content { + assert_eq!(last_tool_use.name.as_ref(), "word_list"); + if tool_call.status == acp::ToolCallStatus::Pending { + if !last_tool_use.is_input_complete + && last_tool_use.input.get("g").is_none() + { + saw_partial_tool_use = true; + } + } else { + last_tool_use + .input + .get("a") + .expect("'a' has streamed because input is now complete"); + last_tool_use + .input + .get("g") + .expect("'g' has streamed because input is now complete"); + } + } else { + panic!("last content should be a tool use"); + } + }); + } + } + + assert!( + saw_partial_tool_use, + "should see at least one partially streamed tool use in the history" + ); +} + +#[gpui::test] +async fn test_tool_authorization(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let mut events = thread.update(cx, |thread, cx| { + thread.add_tool(ToolRequiringPermission); + thread.send(model.clone(), "abc", cx) + }); + cx.run_until_parked(); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "tool_id_1".into(), + name: ToolRequiringPermission.name().into(), + raw_input: "{}".into(), + input: json!({}), + is_input_complete: true, + }, + )); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "tool_id_2".into(), + name: ToolRequiringPermission.name().into(), + raw_input: "{}".into(), + input: json!({}), + is_input_complete: true, + }, + )); + fake_model.end_last_completion_stream(); + let tool_call_auth_1 = next_tool_call_authorization(&mut events).await; + let tool_call_auth_2 = next_tool_call_authorization(&mut events).await; + + // Approve the first + tool_call_auth_1 + .response + .send(tool_call_auth_1.options[1].id.clone()) + .unwrap(); + cx.run_until_parked(); + + // Reject the second + tool_call_auth_2 + .response + .send(tool_call_auth_1.options[2].id.clone()) + .unwrap(); + cx.run_until_parked(); + + let completion = fake_model.pending_completions().pop().unwrap(); + let message = completion.messages.last().unwrap(); + assert_eq!( + message.content, + vec![ + MessageContent::ToolResult(LanguageModelToolResult { + tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(), + tool_name: ToolRequiringPermission.name().into(), + is_error: false, + content: "Allowed".into(), + output: None + }), + MessageContent::ToolResult(LanguageModelToolResult { + tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(), + tool_name: ToolRequiringPermission.name().into(), + is_error: true, + content: "Permission to run tool denied by user".into(), + output: None + }) + ] + ); +} + +#[gpui::test] +async fn test_tool_hallucination(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx)); + cx.run_until_parked(); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "tool_id_1".into(), + name: "nonexistent_tool".into(), + raw_input: "{}".into(), + input: json!({}), + is_input_complete: true, + }, + )); + fake_model.end_last_completion_stream(); + + let tool_call = expect_tool_call(&mut events).await; + assert_eq!(tool_call.title, "nonexistent_tool"); + assert_eq!(tool_call.status, acp::ToolCallStatus::Pending); + let update = expect_tool_call_update(&mut events).await; + assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed)); +} + +async fn expect_tool_call( + events: &mut UnboundedReceiver>, +) -> acp::ToolCall { + let event = events + .next() + .await + .expect("no tool call authorization event received") + .unwrap(); + match event { + AgentResponseEvent::ToolCall(tool_call) => return tool_call, + event => { + panic!("Unexpected event {event:?}"); + } + } +} + +async fn expect_tool_call_update( + events: &mut UnboundedReceiver>, +) -> acp::ToolCallUpdate { + let event = events + .next() + .await + .expect("no tool call authorization event received") + .unwrap(); + match event { + AgentResponseEvent::ToolCallUpdate(tool_call_update) => return tool_call_update, + event => { + panic!("Unexpected event {event:?}"); + } + } +} + +async fn next_tool_call_authorization( + events: &mut UnboundedReceiver>, +) -> ToolCallAuthorization { + loop { + let event = events + .next() + .await + .expect("no tool call authorization event received") + .unwrap(); + if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event { + let permission_kinds = tool_call_authorization + .options + .iter() + .map(|o| o.kind) + .collect::>(); + assert_eq!( + permission_kinds, + vec![ + acp::PermissionOptionKind::AllowAlways, + acp::PermissionOptionKind::AllowOnce, + acp::PermissionOptionKind::RejectOnce, + ] + ); + return tool_call_authorization; + } + } +} + +#[gpui::test] +#[ignore = "can't run on CI yet"] +async fn test_concurrent_tool_calls(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await; + + // Test concurrent tool calls with different delay times + let events = thread + .update(cx, |thread, cx| { + thread.add_tool(DelayTool); + thread.send( + model.clone(), + "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.", + cx, + ) + }) + .collect() + .await; + + let stop_reasons = stop_events(events); + assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]); + + thread.update(cx, |thread, _cx| { + let last_message = thread.messages().last().unwrap(); + let text = last_message + .content + .iter() + .filter_map(|content| { + if let MessageContent::Text(text) = content { + Some(text.as_str()) + } else { + None + } + }) + .collect::(); + + assert!(text.contains("Ding")); + }); +} + +#[gpui::test] +#[ignore = "can't run on CI yet"] +async fn test_cancellation(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await; + + let mut events = thread.update(cx, |thread, cx| { + thread.add_tool(InfiniteTool); + thread.add_tool(EchoTool); + thread.send( + model.clone(), + "Call the echo tool and then call the infinite tool, then explain their output", + cx, + ) + }); + + // Wait until both tools are called. + let mut expected_tool_calls = vec!["echo", "infinite"]; + let mut echo_id = None; + let mut echo_completed = false; + while let Some(event) = events.next().await { + match event.unwrap() { + AgentResponseEvent::ToolCall(tool_call) => { + assert_eq!(tool_call.title, expected_tool_calls.remove(0)); + if tool_call.title == "echo" { + echo_id = Some(tool_call.id); + } + } + AgentResponseEvent::ToolCallUpdate(acp::ToolCallUpdate { + id, + fields: + acp::ToolCallUpdateFields { + status: Some(acp::ToolCallStatus::Completed), + .. + }, + }) if Some(&id) == echo_id.as_ref() => { + echo_completed = true; + } + _ => {} + } + + if expected_tool_calls.is_empty() && echo_completed { + break; + } + } + + // Cancel the current send and ensure that the event stream is closed, even + // if one of the tools is still running. + thread.update(cx, |thread, _cx| thread.cancel()); + events.collect::>().await; + + // Ensure we can still send a new message after cancellation. + let events = thread + .update(cx, |thread, cx| { + thread.send(model.clone(), "Testing: reply with 'Hello' then stop.", cx) + }) + .collect::>() + .await; + thread.update(cx, |thread, _cx| { + assert_eq!( + thread.messages().last().unwrap().content, + vec![MessageContent::Text("Hello".to_string())] + ); + }); + assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]); +} + +#[gpui::test] +async fn test_refusal(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Hello", cx)); + cx.run_until_parked(); + thread.read_with(cx, |thread, _| { + assert_eq!( + thread.to_markdown(), + indoc! {" + ## user + Hello + "} + ); + }); + + fake_model.send_last_completion_stream_text_chunk("Hey!"); + cx.run_until_parked(); + thread.read_with(cx, |thread, _| { + assert_eq!( + thread.to_markdown(), + indoc! {" + ## user + Hello + ## assistant + Hey! + "} + ); + }); + + // If the model refuses to continue, the thread should remove all the messages after the last user message. + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal)); + let events = events.collect::>().await; + assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]); + thread.read_with(cx, |thread, _| { + assert_eq!(thread.to_markdown(), ""); + }); +} + +#[gpui::test] +async fn test_agent_connection(cx: &mut TestAppContext) { + cx.update(settings::init); + let templates = Templates::new(); + + // Initialize language model system with test provider + cx.update(|cx| { + gpui_tokio::init(cx); + client::init_settings(cx); + + let http_client = FakeHttpClient::with_404_response(); + let clock = Arc::new(clock::FakeSystemClock::new()); + let client = Client::new(clock, http_client, cx); + let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + language_model::init(client.clone(), cx); + language_models::init(user_store.clone(), client.clone(), cx); + Project::init_settings(cx); + LanguageModelRegistry::test(cx); + }); + cx.executor().forbid_parking(); + + // Create a project for new_thread + let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone())); + fake_fs.insert_tree(path!("/test"), json!({})).await; + let project = Project::test(fake_fs, [Path::new("/test")], cx).await; + let cwd = Path::new("/test"); + + // Create agent and connection + let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async()) + .await + .unwrap(); + let connection = NativeAgentConnection(agent.clone()); + + // Test model_selector returns Some + let selector_opt = connection.model_selector(); + assert!( + selector_opt.is_some(), + "agent2 should always support ModelSelector" + ); + let selector = selector_opt.unwrap(); + + // Test list_models + let listed_models = cx + .update(|cx| { + let mut async_cx = cx.to_async(); + selector.list_models(&mut async_cx) + }) + .await + .expect("list_models should succeed"); + assert!(!listed_models.is_empty(), "should have at least one model"); + assert_eq!(listed_models[0].id().0, "fake"); + + // Create a thread using new_thread + let connection_rc = Rc::new(connection.clone()); + let acp_thread = cx + .update(|cx| { + let mut async_cx = cx.to_async(); + connection_rc.new_thread(project, cwd, &mut async_cx) + }) + .await + .expect("new_thread should succeed"); + + // Get the session_id from the AcpThread + let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); + + // Test selected_model returns the default + let model = cx + .update(|cx| { + let mut async_cx = cx.to_async(); + selector.selected_model(&session_id, &mut async_cx) + }) + .await + .expect("selected_model should succeed"); + let model = model.as_fake(); + assert_eq!(model.id().0, "fake", "should return default model"); + + let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx)); + cx.run_until_parked(); + model.send_last_completion_stream_text_chunk("def"); + cx.run_until_parked(); + acp_thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + indoc! {" + ## User + + abc + + ## Assistant + + def + + "} + ) + }); + + // Test cancel + cx.update(|cx| connection.cancel(&session_id, cx)); + request.await.expect("prompt should fail gracefully"); + + // Ensure that dropping the ACP thread causes the native thread to be + // dropped as well. + cx.update(|_| drop(acp_thread)); + let result = cx + .update(|cx| { + connection.prompt( + acp::PromptRequest { + session_id: session_id.clone(), + prompt: vec!["ghi".into()], + }, + cx, + ) + }) + .await; + assert_eq!( + result.as_ref().unwrap_err().to_string(), + "Session not found", + "unexpected result: {:?}", + result + ); +} + +#[gpui::test] +async fn test_tool_updates_to_completion(cx: &mut TestAppContext) { + let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await; + thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool)); + let fake_model = model.as_fake(); + + let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Think", cx)); + cx.run_until_parked(); + + let input = json!({ "content": "Thinking hard!" }); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "1".into(), + name: ThinkingTool.name().into(), + raw_input: input.to_string(), + input, + is_input_complete: true, + }, + )); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + let tool_call = expect_tool_call(&mut events).await; + assert_eq!( + tool_call, + acp::ToolCall { + id: acp::ToolCallId("1".into()), + title: "Thinking".into(), + kind: acp::ToolKind::Think, + status: acp::ToolCallStatus::Pending, + content: vec![], + locations: vec![], + raw_input: Some(json!({ "content": "Thinking hard!" })), + raw_output: None, + } + ); + let update = expect_tool_call_update(&mut events).await; + assert_eq!( + update, + acp::ToolCallUpdate { + id: acp::ToolCallId("1".into()), + fields: acp::ToolCallUpdateFields { + status: Some(acp::ToolCallStatus::InProgress,), + ..Default::default() + }, + } + ); + let update = expect_tool_call_update(&mut events).await; + assert_eq!( + update, + acp::ToolCallUpdate { + id: acp::ToolCallId("1".into()), + fields: acp::ToolCallUpdateFields { + content: Some(vec!["Thinking hard!".into()]), + ..Default::default() + }, + } + ); + let update = expect_tool_call_update(&mut events).await; + assert_eq!( + update, + acp::ToolCallUpdate { + id: acp::ToolCallId("1".into()), + fields: acp::ToolCallUpdateFields { + status: Some(acp::ToolCallStatus::Completed), + ..Default::default() + }, + } + ); +} + +/// Filters out the stop events for asserting against in tests +fn stop_events( + result_events: Vec>, +) -> Vec { + result_events + .into_iter() + .filter_map(|event| match event.unwrap() { + AgentResponseEvent::Stop(stop_reason) => Some(stop_reason), + _ => None, + }) + .collect() +} + +struct ThreadTest { + model: Arc, + thread: Entity, + project_context: Rc>, +} + +enum TestModel { + Sonnet4, + Sonnet4Thinking, + Fake, +} + +impl TestModel { + fn id(&self) -> LanguageModelId { + match self { + TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()), + TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()), + TestModel::Fake => unreachable!(), + } + } +} + +async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { + cx.executor().allow_parking(); + cx.update(|cx| { + settings::init(cx); + Project::init_settings(cx); + }); + let templates = Templates::new(); + + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree(path!("/test"), json!({})).await; + let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + + let model = cx + .update(|cx| { + gpui_tokio::init(cx); + let http_client = ReqwestClient::user_agent("agent tests").unwrap(); + cx.set_http_client(Arc::new(http_client)); + + client::init_settings(cx); + let client = Client::production(cx); + let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + language_model::init(client.clone(), cx); + language_models::init(user_store.clone(), client.clone(), cx); + + if let TestModel::Fake = model { + Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>) + } else { + let model_id = model.id(); + let models = LanguageModelRegistry::read_global(cx); + let model = models + .available_models(cx) + .find(|model| model.id() == model_id) + .unwrap(); + + let provider = models.provider(&model.provider_id()).unwrap(); + let authenticated = provider.authenticate(cx); + + cx.spawn(async move |_cx| { + authenticated.await.unwrap(); + model + }) + } + }) + .await; + + let project_context = Rc::new(RefCell::new(ProjectContext::default())); + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let thread = cx.new(|_| { + Thread::new( + project, + project_context.clone(), + action_log, + templates, + model.clone(), + ) + }); + ThreadTest { + model, + thread, + project_context, + } +} + +#[cfg(test)] +#[ctor::ctor] +fn init_logger() { + if std::env::var("RUST_LOG").is_ok() { + env_logger::init(); + } +} diff --git a/crates/agent2/src/tests/test_tools.rs b/crates/agent2/src/tests/test_tools.rs new file mode 100644 index 0000000000000000000000000000000000000000..fd6e7e941fdb0e36e3c5a5c4ad56ebaed0a9acdf --- /dev/null +++ b/crates/agent2/src/tests/test_tools.rs @@ -0,0 +1,195 @@ +use super::*; +use anyhow::Result; +use gpui::{App, SharedString, Task}; +use std::future; + +/// A tool that echoes its input +#[derive(JsonSchema, Serialize, Deserialize)] +pub struct EchoToolInput { + /// The text to echo. + text: String, +} + +pub struct EchoTool; + +impl AgentTool for EchoTool { + type Input = EchoToolInput; + + fn name(&self) -> SharedString { + "echo".into() + } + + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Other + } + + fn initial_title(&self, _: Self::Input) -> SharedString { + "Echo".into() + } + + fn run( + self: Arc, + input: Self::Input, + _event_stream: ToolCallEventStream, + _cx: &mut App, + ) -> Task> { + Task::ready(Ok(input.text)) + } +} + +/// A tool that waits for a specified delay +#[derive(JsonSchema, Serialize, Deserialize)] +pub struct DelayToolInput { + /// The delay in milliseconds. + ms: u64, +} + +pub struct DelayTool; + +impl AgentTool for DelayTool { + type Input = DelayToolInput; + + fn name(&self) -> SharedString { + "delay".into() + } + + fn initial_title(&self, input: Self::Input) -> SharedString { + format!("Delay {}ms", input.ms).into() + } + + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Other + } + + fn run( + self: Arc, + input: Self::Input, + _event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> + where + Self: Sized, + { + cx.foreground_executor().spawn(async move { + smol::Timer::after(Duration::from_millis(input.ms)).await; + Ok("Ding".to_string()) + }) + } +} + +#[derive(JsonSchema, Serialize, Deserialize)] +pub struct ToolRequiringPermissionInput {} + +pub struct ToolRequiringPermission; + +impl AgentTool for ToolRequiringPermission { + type Input = ToolRequiringPermissionInput; + + fn name(&self) -> SharedString { + "tool_requiring_permission".into() + } + + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Other + } + + fn initial_title(&self, _input: Self::Input) -> SharedString { + "This tool requires permission".into() + } + + fn run( + self: Arc, + input: Self::Input, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> + where + Self: Sized, + { + let auth_check = self.authorize(input, event_stream); + cx.foreground_executor().spawn(async move { + auth_check.await?; + Ok("Allowed".to_string()) + }) + } +} + +#[derive(JsonSchema, Serialize, Deserialize)] +pub struct InfiniteToolInput {} + +pub struct InfiniteTool; + +impl AgentTool for InfiniteTool { + type Input = InfiniteToolInput; + + fn name(&self) -> SharedString { + "infinite".into() + } + + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Other + } + + fn initial_title(&self, _input: Self::Input) -> SharedString { + "This is the tool that never ends... it just goes on and on my friends!".into() + } + + fn run( + self: Arc, + _input: Self::Input, + _event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + cx.foreground_executor().spawn(async move { + future::pending::<()>().await; + unreachable!() + }) + } +} + +/// A tool that takes an object with map from letters to random words starting with that letter. +/// All fiealds are required! Pass a word for every letter! +#[derive(JsonSchema, Serialize, Deserialize)] +pub struct WordListInput { + /// Provide a random word that starts with A. + a: Option, + /// Provide a random word that starts with B. + b: Option, + /// Provide a random word that starts with C. + c: Option, + /// Provide a random word that starts with D. + d: Option, + /// Provide a random word that starts with E. + e: Option, + /// Provide a random word that starts with F. + f: Option, + /// Provide a random word that starts with G. + g: Option, +} + +pub struct WordListTool; + +impl AgentTool for WordListTool { + type Input = WordListInput; + + fn name(&self) -> SharedString { + "word_list".into() + } + + fn initial_title(&self, _input: Self::Input) -> SharedString { + "List of random words".into() + } + + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Other + } + + fn run( + self: Arc, + _input: Self::Input, + _event_stream: ToolCallEventStream, + _cx: &mut App, + ) -> Task> { + Task::ready(Ok("ok".to_string())) + } +} diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs new file mode 100644 index 0000000000000000000000000000000000000000..805ffff1c07fc1303a0852f49135a7226a75eb8d --- /dev/null +++ b/crates/agent2/src/thread.rs @@ -0,0 +1,926 @@ +use crate::templates::{SystemPromptTemplate, Template, Templates}; +use agent_client_protocol as acp; +use anyhow::{anyhow, Context as _, Result}; +use assistant_tool::{adapt_schema_to_format, ActionLog}; +use cloud_llm_client::{CompletionIntent, CompletionMode}; +use collections::HashMap; +use futures::{ + channel::{mpsc, oneshot}, + stream::FuturesUnordered, +}; +use gpui::{App, Context, Entity, SharedString, Task}; +use language_model::{ + LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, + LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, + LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, + LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason, +}; +use log; +use project::Project; +use prompt_store::ProjectContext; +use schemars::{JsonSchema, Schema}; +use serde::{Deserialize, Serialize}; +use smol::stream::StreamExt; +use std::{cell::RefCell, collections::BTreeMap, fmt::Write, future::Future, rc::Rc, sync::Arc}; +use util::{markdown::MarkdownCodeBlock, ResultExt}; + +#[derive(Debug, Clone)] +pub struct AgentMessage { + pub role: Role, + pub content: Vec, +} + +impl AgentMessage { + pub fn to_markdown(&self) -> String { + let mut markdown = format!("## {}\n", self.role); + + for content in &self.content { + match content { + MessageContent::Text(text) => { + markdown.push_str(text); + markdown.push('\n'); + } + MessageContent::Thinking { text, .. } => { + markdown.push_str(""); + markdown.push_str(text); + markdown.push_str("\n"); + } + MessageContent::RedactedThinking(_) => markdown.push_str("\n"), + MessageContent::Image(_) => { + markdown.push_str("\n"); + } + MessageContent::ToolUse(tool_use) => { + markdown.push_str(&format!( + "**Tool Use**: {} (ID: {})\n", + tool_use.name, tool_use.id + )); + markdown.push_str(&format!( + "{}\n", + MarkdownCodeBlock { + tag: "json", + text: &format!("{:#}", tool_use.input) + } + )); + } + MessageContent::ToolResult(tool_result) => { + markdown.push_str(&format!( + "**Tool Result**: {} (ID: {})\n\n", + tool_result.tool_name, tool_result.tool_use_id + )); + if tool_result.is_error { + markdown.push_str("**ERROR:**\n"); + } + + match &tool_result.content { + LanguageModelToolResultContent::Text(text) => { + writeln!(markdown, "{text}\n").ok(); + } + LanguageModelToolResultContent::Image(_) => { + writeln!(markdown, "\n").ok(); + } + } + + if let Some(output) = tool_result.output.as_ref() { + writeln!( + markdown, + "**Debug Output**:\n\n```json\n{}\n```\n", + serde_json::to_string_pretty(output).unwrap() + ) + .unwrap(); + } + } + } + } + + markdown + } +} + +#[derive(Debug)] +pub enum AgentResponseEvent { + Text(String), + Thinking(String), + ToolCall(acp::ToolCall), + ToolCallUpdate(acp::ToolCallUpdate), + ToolCallAuthorization(ToolCallAuthorization), + Stop(acp::StopReason), +} + +#[derive(Debug)] +pub struct ToolCallAuthorization { + pub tool_call: acp::ToolCall, + pub options: Vec, + pub response: oneshot::Sender, +} + +pub struct Thread { + messages: Vec, + completion_mode: CompletionMode, + /// Holds the task that handles agent interaction until the end of the turn. + /// Survives across multiple requests as the model performs tool calls and + /// we run tools, report their results. + running_turn: Option>, + pending_tool_uses: HashMap, + tools: BTreeMap>, + project_context: Rc>, + templates: Arc, + pub selected_model: Arc, + _action_log: Entity, +} + +impl Thread { + pub fn new( + _project: Entity, + project_context: Rc>, + action_log: Entity, + templates: Arc, + default_model: Arc, + ) -> Self { + Self { + messages: Vec::new(), + completion_mode: CompletionMode::Normal, + running_turn: None, + pending_tool_uses: HashMap::default(), + tools: BTreeMap::default(), + project_context, + templates, + selected_model: default_model, + _action_log: action_log, + } + } + + pub fn set_mode(&mut self, mode: CompletionMode) { + self.completion_mode = mode; + } + + pub fn messages(&self) -> &[AgentMessage] { + &self.messages + } + + pub fn add_tool(&mut self, tool: impl AgentTool) { + self.tools.insert(tool.name(), tool.erase()); + } + + pub fn remove_tool(&mut self, name: &str) -> bool { + self.tools.remove(name).is_some() + } + + pub fn cancel(&mut self) { + self.running_turn.take(); + + let tool_results = self + .pending_tool_uses + .drain() + .map(|(tool_use_id, tool_use)| { + MessageContent::ToolResult(LanguageModelToolResult { + tool_use_id, + tool_name: tool_use.name.clone(), + is_error: true, + content: LanguageModelToolResultContent::Text("Tool canceled by user".into()), + output: None, + }) + }) + .collect::>(); + self.last_user_message().content.extend(tool_results); + } + + /// Sending a message results in the model streaming a response, which could include tool calls. + /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent. + /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn. + pub fn send( + &mut self, + model: Arc, + content: impl Into, + cx: &mut Context, + ) -> mpsc::UnboundedReceiver> { + let content = content.into(); + log::info!("Thread::send called with model: {:?}", model.name()); + log::debug!("Thread::send content: {:?}", content); + + cx.notify(); + let (events_tx, events_rx) = + mpsc::unbounded::>(); + let event_stream = AgentResponseEventStream(events_tx); + + let user_message_ix = self.messages.len(); + self.messages.push(AgentMessage { + role: Role::User, + content: vec![content], + }); + log::info!("Total messages in thread: {}", self.messages.len()); + self.running_turn = Some(cx.spawn(async move |thread, cx| { + log::info!("Starting agent turn execution"); + let turn_result = async { + // Perform one request, then keep looping if the model makes tool calls. + let mut completion_intent = CompletionIntent::UserPrompt; + 'outer: loop { + log::debug!( + "Building completion request with intent: {:?}", + completion_intent + ); + let request = thread.update(cx, |thread, cx| { + thread.build_completion_request(completion_intent, cx) + })?; + + // println!( + // "request: {}", + // serde_json::to_string_pretty(&request).unwrap() + // ); + + // Stream events, appending to messages and collecting up tool uses. + log::info!("Calling model.stream_completion"); + let mut events = model.stream_completion(request, cx).await?; + log::debug!("Stream completion started successfully"); + let mut tool_uses = FuturesUnordered::new(); + while let Some(event) = events.next().await { + match event { + Ok(LanguageModelCompletionEvent::Stop(reason)) => { + event_stream.send_stop(reason); + if reason == StopReason::Refusal { + thread.update(cx, |thread, _cx| { + thread.messages.truncate(user_message_ix); + })?; + break 'outer; + } + } + Ok(event) => { + log::trace!("Received completion event: {:?}", event); + thread + .update(cx, |thread, cx| { + tool_uses.extend(thread.handle_streamed_completion_event( + event, + &event_stream, + cx, + )); + }) + .ok(); + } + Err(error) => { + log::error!("Error in completion stream: {:?}", error); + event_stream.send_error(error); + break; + } + } + } + + // If there are no tool uses, the turn is done. + if tool_uses.is_empty() { + log::info!("No tool uses found, completing turn"); + break; + } + log::info!("Found {} tool uses to execute", tool_uses.len()); + + // As tool results trickle in, insert them in the last user + // message so that they can be sent on the next tick of the + // agentic loop. + while let Some(tool_result) = tool_uses.next().await { + log::info!("Tool finished {:?}", tool_result); + + event_stream.send_tool_call_update( + &tool_result.tool_use_id, + acp::ToolCallUpdateFields { + status: Some(if tool_result.is_error { + acp::ToolCallStatus::Failed + } else { + acp::ToolCallStatus::Completed + }), + ..Default::default() + }, + ); + thread + .update(cx, |thread, _cx| { + thread.pending_tool_uses.remove(&tool_result.tool_use_id); + thread + .last_user_message() + .content + .push(MessageContent::ToolResult(tool_result)); + }) + .ok(); + } + + completion_intent = CompletionIntent::ToolResults; + } + + Ok(()) + } + .await; + + if let Err(error) = turn_result { + log::error!("Turn execution failed: {:?}", error); + event_stream.send_error(error); + } else { + log::info!("Turn execution completed successfully"); + } + })); + events_rx + } + + pub fn build_system_message(&self) -> AgentMessage { + log::debug!("Building system message"); + let prompt = SystemPromptTemplate { + project: &self.project_context.borrow(), + available_tools: self.tools.keys().cloned().collect(), + } + .render(&self.templates) + .context("failed to build system prompt") + .expect("Invalid template"); + log::debug!("System message built"); + AgentMessage { + role: Role::System, + content: vec![prompt.into()], + } + } + + /// A helper method that's called on every streamed completion event. + /// Returns an optional tool result task, which the main agentic loop in + /// send will send back to the model when it resolves. + fn handle_streamed_completion_event( + &mut self, + event: LanguageModelCompletionEvent, + event_stream: &AgentResponseEventStream, + cx: &mut Context, + ) -> Option> { + log::trace!("Handling streamed completion event: {:?}", event); + use LanguageModelCompletionEvent::*; + + match event { + StartMessage { .. } => { + self.messages.push(AgentMessage { + role: Role::Assistant, + content: Vec::new(), + }); + } + Text(new_text) => self.handle_text_event(new_text, event_stream, cx), + Thinking { text, signature } => { + self.handle_thinking_event(text, signature, event_stream, cx) + } + RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx), + ToolUse(tool_use) => { + return self.handle_tool_use_event(tool_use, event_stream, cx); + } + ToolUseJsonParseError { + id, + tool_name, + raw_input, + json_parse_error, + } => { + return Some(Task::ready(self.handle_tool_use_json_parse_error_event( + id, + tool_name, + raw_input, + json_parse_error, + ))); + } + UsageUpdate(_) | StatusUpdate(_) => {} + Stop(_) => unreachable!(), + } + + None + } + + fn handle_text_event( + &mut self, + new_text: String, + events_stream: &AgentResponseEventStream, + cx: &mut Context, + ) { + events_stream.send_text(&new_text); + + let last_message = self.last_assistant_message(); + if let Some(MessageContent::Text(text)) = last_message.content.last_mut() { + text.push_str(&new_text); + } else { + last_message.content.push(MessageContent::Text(new_text)); + } + + cx.notify(); + } + + fn handle_thinking_event( + &mut self, + new_text: String, + new_signature: Option, + event_stream: &AgentResponseEventStream, + cx: &mut Context, + ) { + event_stream.send_thinking(&new_text); + + let last_message = self.last_assistant_message(); + if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut() + { + text.push_str(&new_text); + *signature = new_signature.or(signature.take()); + } else { + last_message.content.push(MessageContent::Thinking { + text: new_text, + signature: new_signature, + }); + } + + cx.notify(); + } + + fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context) { + let last_message = self.last_assistant_message(); + last_message + .content + .push(MessageContent::RedactedThinking(data)); + cx.notify(); + } + + fn handle_tool_use_event( + &mut self, + tool_use: LanguageModelToolUse, + event_stream: &AgentResponseEventStream, + cx: &mut Context, + ) -> Option> { + cx.notify(); + + let tool = self.tools.get(tool_use.name.as_ref()).cloned(); + + self.pending_tool_uses + .insert(tool_use.id.clone(), tool_use.clone()); + let last_message = self.last_assistant_message(); + + // Ensure the last message ends in the current tool use + let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| { + if let MessageContent::ToolUse(last_tool_use) = content { + if last_tool_use.id == tool_use.id { + *last_tool_use = tool_use.clone(); + false + } else { + true + } + } else { + true + } + }); + + if push_new_tool_use { + event_stream.send_tool_call(tool.as_ref(), &tool_use); + last_message + .content + .push(MessageContent::ToolUse(tool_use.clone())); + } else { + event_stream.send_tool_call_update( + &tool_use.id, + acp::ToolCallUpdateFields { + raw_input: Some(tool_use.input.clone()), + ..Default::default() + }, + ); + } + + if !tool_use.is_input_complete { + return None; + } + + let Some(tool) = tool else { + let content = format!("No tool named {} exists", tool_use.name); + return Some(Task::ready(LanguageModelToolResult { + content: LanguageModelToolResultContent::Text(Arc::from(content)), + tool_use_id: tool_use.id, + tool_name: tool_use.name, + is_error: true, + output: None, + })); + }; + + let tool_result = self.run_tool(tool, tool_use.clone(), event_stream.clone(), cx); + Some(cx.foreground_executor().spawn(async move { + match tool_result.await { + Ok(tool_output) => LanguageModelToolResult { + tool_use_id: tool_use.id, + tool_name: tool_use.name, + is_error: false, + content: LanguageModelToolResultContent::Text(Arc::from(tool_output)), + output: None, + }, + Err(error) => LanguageModelToolResult { + tool_use_id: tool_use.id, + tool_name: tool_use.name, + is_error: true, + content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())), + output: None, + }, + } + })) + } + + fn run_tool( + &self, + tool: Arc, + tool_use: LanguageModelToolUse, + event_stream: AgentResponseEventStream, + cx: &mut Context, + ) -> Task> { + cx.spawn(async move |_this, cx| { + let tool_event_stream = ToolCallEventStream::new(tool_use.id, event_stream); + tool_event_stream.send_update(acp::ToolCallUpdateFields { + status: Some(acp::ToolCallStatus::InProgress), + ..Default::default() + }); + cx.update(|cx| tool.run(tool_use.input, tool_event_stream, cx))? + .await + }) + } + + fn handle_tool_use_json_parse_error_event( + &mut self, + tool_use_id: LanguageModelToolUseId, + tool_name: Arc, + raw_input: Arc, + json_parse_error: String, + ) -> LanguageModelToolResult { + let tool_output = format!("Error parsing input JSON: {json_parse_error}"); + LanguageModelToolResult { + tool_use_id, + tool_name, + is_error: true, + content: LanguageModelToolResultContent::Text(tool_output.into()), + output: Some(serde_json::Value::String(raw_input.to_string())), + } + } + + /// Guarantees the last message is from the assistant and returns a mutable reference. + fn last_assistant_message(&mut self) -> &mut AgentMessage { + if self + .messages + .last() + .map_or(true, |m| m.role != Role::Assistant) + { + self.messages.push(AgentMessage { + role: Role::Assistant, + content: Vec::new(), + }); + } + self.messages.last_mut().unwrap() + } + + /// Guarantees the last message is from the user and returns a mutable reference. + fn last_user_message(&mut self) -> &mut AgentMessage { + if self.messages.last().map_or(true, |m| m.role != Role::User) { + self.messages.push(AgentMessage { + role: Role::User, + content: Vec::new(), + }); + } + self.messages.last_mut().unwrap() + } + + fn build_completion_request( + &self, + completion_intent: CompletionIntent, + cx: &mut App, + ) -> LanguageModelRequest { + log::debug!("Building completion request"); + log::debug!("Completion intent: {:?}", completion_intent); + log::debug!("Completion mode: {:?}", self.completion_mode); + + let messages = self.build_request_messages(); + log::info!("Request will include {} messages", messages.len()); + + let tools: Vec = self + .tools + .values() + .filter_map(|tool| { + let tool_name = tool.name().to_string(); + log::trace!("Including tool: {}", tool_name); + Some(LanguageModelRequestTool { + name: tool_name, + description: tool.description(cx).to_string(), + input_schema: tool + .input_schema(self.selected_model.tool_input_format()) + .log_err()?, + }) + }) + .collect(); + + log::info!("Request includes {} tools", tools.len()); + + let request = LanguageModelRequest { + thread_id: None, + prompt_id: None, + intent: Some(completion_intent), + mode: Some(self.completion_mode), + messages, + tools, + tool_choice: None, + stop: Vec::new(), + temperature: None, + thinking_allowed: true, + }; + + log::debug!("Completion request built successfully"); + request + } + + fn build_request_messages(&self) -> Vec { + log::trace!( + "Building request messages from {} thread messages", + self.messages.len() + ); + + let messages = Some(self.build_system_message()) + .iter() + .chain(self.messages.iter()) + .map(|message| { + log::trace!( + " - {} message with {} content items", + match message.role { + Role::System => "System", + Role::User => "User", + Role::Assistant => "Assistant", + }, + message.content.len() + ); + LanguageModelRequestMessage { + role: message.role, + content: message.content.clone(), + cache: false, + } + }) + .collect(); + messages + } + + pub fn to_markdown(&self) -> String { + let mut markdown = String::new(); + for message in &self.messages { + markdown.push_str(&message.to_markdown()); + } + markdown + } +} + +pub trait AgentTool +where + Self: 'static + Sized, +{ + type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema; + + fn name(&self) -> SharedString; + + fn description(&self, _cx: &mut App) -> SharedString { + let schema = schemars::schema_for!(Self::Input); + SharedString::new( + schema + .get("description") + .and_then(|description| description.as_str()) + .unwrap_or_default(), + ) + } + + fn kind(&self) -> acp::ToolKind; + + /// The initial tool title to display. Can be updated during the tool run. + fn initial_title(&self, input: Self::Input) -> SharedString; + + /// Returns the JSON schema that describes the tool's input. + fn input_schema(&self) -> Schema { + schemars::schema_for!(Self::Input) + } + + /// Allows the tool to authorize a given tool call with the user if necessary + fn authorize( + &self, + input: Self::Input, + event_stream: ToolCallEventStream, + ) -> impl use + Future> { + let json_input = serde_json::json!(&input); + event_stream.authorize(self.initial_title(input).into(), self.kind(), json_input) + } + + /// Runs the tool with the provided input. + fn run( + self: Arc, + input: Self::Input, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task>; + + fn erase(self) -> Arc { + Arc::new(Erased(Arc::new(self))) + } +} + +pub struct Erased(T); + +pub trait AnyAgentTool { + fn name(&self) -> SharedString; + fn description(&self, cx: &mut App) -> SharedString; + fn kind(&self) -> acp::ToolKind; + fn initial_title(&self, input: serde_json::Value) -> Result; + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result; + fn run( + self: Arc, + input: serde_json::Value, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task>; +} + +impl AnyAgentTool for Erased> +where + T: AgentTool, +{ + fn name(&self) -> SharedString { + self.0.name() + } + + fn description(&self, cx: &mut App) -> SharedString { + self.0.description(cx) + } + + fn kind(&self) -> agent_client_protocol::ToolKind { + self.0.kind() + } + + fn initial_title(&self, input: serde_json::Value) -> Result { + let parsed_input = serde_json::from_value(input)?; + Ok(self.0.initial_title(parsed_input)) + } + + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { + let mut json = serde_json::to_value(self.0.input_schema())?; + adapt_schema_to_format(&mut json, format)?; + Ok(json) + } + + fn run( + self: Arc, + input: serde_json::Value, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + let parsed_input: Result = serde_json::from_value(input).map_err(Into::into); + match parsed_input { + Ok(input) => self.0.clone().run(input, event_stream, cx), + Err(error) => Task::ready(Err(anyhow!(error))), + } + } +} + +#[derive(Clone)] +struct AgentResponseEventStream( + mpsc::UnboundedSender>, +); + +impl AgentResponseEventStream { + fn send_text(&self, text: &str) { + self.0 + .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string()))) + .ok(); + } + + fn send_thinking(&self, text: &str) { + self.0 + .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string()))) + .ok(); + } + + fn authorize_tool_call( + &self, + id: &LanguageModelToolUseId, + title: String, + kind: acp::ToolKind, + input: serde_json::Value, + ) -> impl use<> + Future> { + let (response_tx, response_rx) = oneshot::channel(); + self.0 + .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization( + ToolCallAuthorization { + tool_call: Self::initial_tool_call(id, title, kind, input), + options: vec![ + acp::PermissionOption { + id: acp::PermissionOptionId("always_allow".into()), + name: "Always Allow".into(), + kind: acp::PermissionOptionKind::AllowAlways, + }, + acp::PermissionOption { + id: acp::PermissionOptionId("allow".into()), + name: "Allow".into(), + kind: acp::PermissionOptionKind::AllowOnce, + }, + acp::PermissionOption { + id: acp::PermissionOptionId("deny".into()), + name: "Deny".into(), + kind: acp::PermissionOptionKind::RejectOnce, + }, + ], + response: response_tx, + }, + ))) + .ok(); + async move { + match response_rx.await?.0.as_ref() { + "allow" | "always_allow" => Ok(()), + _ => Err(anyhow!("Permission to run tool denied by user")), + } + } + } + + fn send_tool_call( + &self, + tool: Option<&Arc>, + tool_use: &LanguageModelToolUse, + ) { + self.0 + .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call( + &tool_use.id, + tool.and_then(|t| t.initial_title(tool_use.input.clone()).ok()) + .map(|i| i.into()) + .unwrap_or_else(|| tool_use.name.to_string()), + tool.map(|t| t.kind()).unwrap_or(acp::ToolKind::Other), + tool_use.input.clone(), + )))) + .ok(); + } + + fn initial_tool_call( + id: &LanguageModelToolUseId, + title: String, + kind: acp::ToolKind, + input: serde_json::Value, + ) -> acp::ToolCall { + acp::ToolCall { + id: acp::ToolCallId(id.to_string().into()), + title, + kind, + status: acp::ToolCallStatus::Pending, + content: vec![], + locations: vec![], + raw_input: Some(input), + raw_output: None, + } + } + + fn send_tool_call_update( + &self, + tool_use_id: &LanguageModelToolUseId, + fields: acp::ToolCallUpdateFields, + ) { + self.0 + .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( + acp::ToolCallUpdate { + id: acp::ToolCallId(tool_use_id.to_string().into()), + fields, + }, + ))) + .ok(); + } + + fn send_stop(&self, reason: StopReason) { + match reason { + StopReason::EndTurn => { + self.0 + .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn))) + .ok(); + } + StopReason::MaxTokens => { + self.0 + .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens))) + .ok(); + } + StopReason::Refusal => { + self.0 + .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal))) + .ok(); + } + StopReason::ToolUse => {} + } + } + + fn send_error(&self, error: LanguageModelCompletionError) { + self.0.unbounded_send(Err(error)).ok(); + } +} + +#[derive(Clone)] +pub struct ToolCallEventStream { + tool_use_id: LanguageModelToolUseId, + stream: AgentResponseEventStream, +} + +impl ToolCallEventStream { + fn new(tool_use_id: LanguageModelToolUseId, stream: AgentResponseEventStream) -> Self { + Self { + tool_use_id, + stream, + } + } + + pub fn send_update(&self, fields: acp::ToolCallUpdateFields) { + self.stream.send_tool_call_update(&self.tool_use_id, fields); + } + + pub fn authorize( + &self, + title: String, + kind: acp::ToolKind, + input: serde_json::Value, + ) -> impl use<> + Future> { + self.stream + .authorize_tool_call(&self.tool_use_id, title, kind, input) + } +} diff --git a/crates/agent2/src/tools.rs b/crates/agent2/src/tools.rs new file mode 100644 index 0000000000000000000000000000000000000000..848fe552ed1e913699e26ae320f93ccc379b254b --- /dev/null +++ b/crates/agent2/src/tools.rs @@ -0,0 +1,5 @@ +mod find_path_tool; +mod thinking_tool; + +pub use find_path_tool::*; +pub use thinking_tool::*; diff --git a/crates/agent2/src/tools/find_path_tool.rs b/crates/agent2/src/tools/find_path_tool.rs new file mode 100644 index 0000000000000000000000000000000000000000..e840fec78c899e764db720fb8cbaa55dbe6fa466 --- /dev/null +++ b/crates/agent2/src/tools/find_path_tool.rs @@ -0,0 +1,231 @@ +use agent_client_protocol as acp; +use anyhow::{anyhow, Result}; +use gpui::{App, AppContext, Entity, SharedString, Task}; +use project::Project; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::fmt::Write; +use std::{cmp, path::PathBuf, sync::Arc}; +use util::paths::PathMatcher; + +use crate::{AgentTool, ToolCallEventStream}; + +/// Fast file path pattern matching tool that works with any codebase size +/// +/// - Supports glob patterns like "**/*.js" or "src/**/*.ts" +/// - Returns matching file paths sorted alphabetically +/// - Prefer the `grep` tool to this tool when searching for symbols unless you have specific information about paths. +/// - Use this tool when you need to find files by name patterns +/// - Results are paginated with 50 matches per page. Use the optional 'offset' parameter to request subsequent pages. +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct FindPathToolInput { + /// The glob to match against every path in the project. + /// + /// + /// If the project has the following root directories: + /// + /// - directory1/a/something.txt + /// - directory2/a/things.txt + /// - directory3/a/other.txt + /// + /// You can get back the first two paths by providing a glob of "*thing*.txt" + /// + pub glob: String, + + /// Optional starting position for paginated results (0-based). + /// When not provided, starts from the beginning. + #[serde(default)] + pub offset: usize, +} + +#[derive(Debug, Serialize, Deserialize)] +struct FindPathToolOutput { + paths: Vec, +} + +const RESULTS_PER_PAGE: usize = 50; + +pub struct FindPathTool { + project: Entity, +} + +impl FindPathTool { + pub fn new(project: Entity) -> Self { + Self { project } + } +} + +impl AgentTool for FindPathTool { + type Input = FindPathToolInput; + + fn name(&self) -> SharedString { + "find_path".into() + } + + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Search + } + + fn initial_title(&self, input: Self::Input) -> SharedString { + format!("Find paths matching “`{}`”", input.glob).into() + } + + fn run( + self: Arc, + input: Self::Input, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + let search_paths_task = search_paths(&input.glob, self.project.clone(), cx); + + cx.background_spawn(async move { + let matches = search_paths_task.await?; + let paginated_matches: &[PathBuf] = &matches[cmp::min(input.offset, matches.len()) + ..cmp::min(input.offset + RESULTS_PER_PAGE, matches.len())]; + + event_stream.send_update(acp::ToolCallUpdateFields { + title: Some(if paginated_matches.len() == 0 { + "No matches".into() + } else if paginated_matches.len() == 1 { + "1 match".into() + } else { + format!("{} matches", paginated_matches.len()) + }), + content: Some( + paginated_matches + .iter() + .map(|path| acp::ToolCallContent::Content { + content: acp::ContentBlock::ResourceLink(acp::ResourceLink { + uri: format!("file://{}", path.display()), + name: path.to_string_lossy().into(), + annotations: None, + description: None, + mime_type: None, + size: None, + title: None, + }), + }) + .collect(), + ), + raw_output: Some(serde_json::json!({ + "paths": &matches, + })), + ..Default::default() + }); + + if matches.is_empty() { + Ok("No matches found".into()) + } else { + let mut message = format!("Found {} total matches.", matches.len()); + if matches.len() > RESULTS_PER_PAGE { + write!( + &mut message, + "\nShowing results {}-{} (provide 'offset' parameter for more results):", + input.offset + 1, + input.offset + paginated_matches.len() + ) + .unwrap(); + } + + for mat in matches.iter().skip(input.offset).take(RESULTS_PER_PAGE) { + write!(&mut message, "\n{}", mat.display()).unwrap(); + } + + Ok(message) + } + }) + } +} + +fn search_paths(glob: &str, project: Entity, cx: &mut App) -> Task>> { + let path_matcher = match PathMatcher::new([ + // Sometimes models try to search for "". In this case, return all paths in the project. + if glob.is_empty() { "*" } else { glob }, + ]) { + Ok(matcher) => matcher, + Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))), + }; + let snapshots: Vec<_> = project + .read(cx) + .worktrees(cx) + .map(|worktree| worktree.read(cx).snapshot()) + .collect(); + + cx.background_spawn(async move { + Ok(snapshots + .iter() + .flat_map(|snapshot| { + let root_name = PathBuf::from(snapshot.root_name()); + snapshot + .entries(false, 0) + .map(move |entry| root_name.join(&entry.path)) + .filter(|path| path_matcher.is_match(&path)) + }) + .collect()) + }) +} + +#[cfg(test)] +mod test { + use super::*; + use gpui::TestAppContext; + use project::{FakeFs, Project}; + use settings::SettingsStore; + use util::path; + + #[gpui::test] + async fn test_find_path_tool(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + serde_json::json!({ + "apple": { + "banana": { + "carrot": "1", + }, + "bandana": { + "carbonara": "2", + }, + "endive": "3" + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + + let matches = cx + .update(|cx| search_paths("root/**/car*", project.clone(), cx)) + .await + .unwrap(); + assert_eq!( + matches, + &[ + PathBuf::from("root/apple/banana/carrot"), + PathBuf::from("root/apple/bandana/carbonara") + ] + ); + + let matches = cx + .update(|cx| search_paths("**/car*", project.clone(), cx)) + .await + .unwrap(); + assert_eq!( + matches, + &[ + PathBuf::from("root/apple/banana/carrot"), + PathBuf::from("root/apple/bandana/carbonara") + ] + ); + } + + fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + }); + } +} diff --git a/crates/agent2/src/tools/thinking_tool.rs b/crates/agent2/src/tools/thinking_tool.rs new file mode 100644 index 0000000000000000000000000000000000000000..bb85d8ecebd7870436a7b97c96af241c3d656665 --- /dev/null +++ b/crates/agent2/src/tools/thinking_tool.rs @@ -0,0 +1,48 @@ +use agent_client_protocol as acp; +use anyhow::Result; +use gpui::{App, SharedString, Task}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; + +use crate::{AgentTool, ToolCallEventStream}; + +/// A tool for thinking through problems, brainstorming ideas, or planning without executing any actions. +/// Use this tool when you need to work through complex problems, develop strategies, or outline approaches before taking action. +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct ThinkingToolInput { + /// Content to think about. This should be a description of what to think about or + /// a problem to solve. + content: String, +} + +pub struct ThinkingTool; + +impl AgentTool for ThinkingTool { + type Input = ThinkingToolInput; + + fn name(&self) -> SharedString { + "thinking".into() + } + + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Think + } + + fn initial_title(&self, _input: Self::Input) -> SharedString { + "Thinking".into() + } + + fn run( + self: Arc, + input: Self::Input, + event_stream: ToolCallEventStream, + _cx: &mut App, + ) -> Task> { + event_stream.send_update(acp::ToolCallUpdateFields { + content: Some(vec![input.content.into()]), + ..Default::default() + }); + Task::ready(Ok("Finished thinking.".to_string())) + } +} diff --git a/crates/agent_servers/Cargo.toml b/crates/agent_servers/Cargo.toml index 549162c5dd16feeb1959ece447d79faa7b7073e4..81c97c8aa6cc4fa64d017b97ade5ddd535487b81 100644 --- a/crates/agent_servers/Cargo.toml +++ b/crates/agent_servers/Cargo.toml @@ -5,6 +5,10 @@ edition.workspace = true publish.workspace = true license = "GPL-3.0-or-later" +[features] +test-support = ["acp_thread/test-support", "gpui/test-support", "project/test-support"] +e2e = [] + [lints] workspace = true @@ -13,15 +17,42 @@ path = "src/agent_servers.rs" doctest = false [dependencies] +acp_thread.workspace = true +agent-client-protocol.workspace = true +agentic-coding-protocol.workspace = true anyhow.workspace = true collections.workspace = true +context_server.workspace = true futures.workspace = true gpui.workspace = true +indoc.workspace = true +itertools.workspace = true +log.workspace = true paths.workspace = true project.workspace = true +rand.workspace = true schemars.workspace = true serde.workspace = true +serde_json.workspace = true settings.workspace = true +smol.workspace = true +strum.workspace = true +tempfile.workspace = true +thiserror.workspace = true +ui.workspace = true util.workspace = true +uuid.workspace = true +watch.workspace = true which.workspace = true workspace-hack.workspace = true + +[target.'cfg(unix)'.dependencies] +libc.workspace = true +nix.workspace = true + +[dev-dependencies] +env_logger.workspace = true +language.workspace = true +indoc.workspace = true +acp_thread = { workspace = true, features = ["test-support"] } +gpui = { workspace = true, features = ["test-support"] } diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs new file mode 100644 index 0000000000000000000000000000000000000000..00e3e3df5093c6f1acef32665ab0d3d8846fc39f --- /dev/null +++ b/crates/agent_servers/src/acp.rs @@ -0,0 +1,34 @@ +use std::{path::Path, rc::Rc}; + +use crate::AgentServerCommand; +use acp_thread::AgentConnection; +use anyhow::Result; +use gpui::AsyncApp; +use thiserror::Error; + +mod v0; +mod v1; + +#[derive(Debug, Error)] +#[error("Unsupported version")] +pub struct UnsupportedVersion; + +pub async fn connect( + server_name: &'static str, + command: AgentServerCommand, + root_dir: &Path, + cx: &mut AsyncApp, +) -> Result> { + let conn = v1::AcpConnection::stdio(server_name, command.clone(), &root_dir, cx).await; + + match conn { + Ok(conn) => Ok(Rc::new(conn) as _), + Err(err) if err.is::() => { + // Consider re-using initialize response and subprocess when adding another version here + let conn: Rc = + Rc::new(v0::AcpConnection::stdio(server_name, command, &root_dir, cx).await?); + Ok(conn) + } + Err(err) => Err(err), + } +} diff --git a/crates/agent_servers/src/acp/v0.rs b/crates/agent_servers/src/acp/v0.rs new file mode 100644 index 0000000000000000000000000000000000000000..8d85435f92779e478334fb5b920a3af610287ffe --- /dev/null +++ b/crates/agent_servers/src/acp/v0.rs @@ -0,0 +1,509 @@ +// Translates old acp agents into the new schema +use agent_client_protocol as acp; +use agentic_coding_protocol::{self as acp_old, AgentRequest as _}; +use anyhow::{Context as _, Result, anyhow}; +use futures::channel::oneshot; +use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity}; +use project::Project; +use std::{cell::RefCell, path::Path, rc::Rc}; +use ui::App; +use util::ResultExt as _; + +use crate::AgentServerCommand; +use acp_thread::{AcpThread, AgentConnection, AuthRequired}; + +#[derive(Clone)] +struct OldAcpClientDelegate { + thread: Rc>>, + cx: AsyncApp, + next_tool_call_id: Rc>, + // sent_buffer_versions: HashMap, HashMap>, +} + +impl OldAcpClientDelegate { + fn new(thread: Rc>>, cx: AsyncApp) -> Self { + Self { + thread, + cx, + next_tool_call_id: Rc::new(RefCell::new(0)), + } + } +} + +impl acp_old::Client for OldAcpClientDelegate { + async fn stream_assistant_message_chunk( + &self, + params: acp_old::StreamAssistantMessageChunkParams, + ) -> Result<(), acp_old::Error> { + let cx = &mut self.cx.clone(); + + cx.update(|cx| { + self.thread + .borrow() + .update(cx, |thread, cx| match params.chunk { + acp_old::AssistantMessageChunk::Text { text } => { + thread.push_assistant_content_block(text.into(), false, cx) + } + acp_old::AssistantMessageChunk::Thought { thought } => { + thread.push_assistant_content_block(thought.into(), true, cx) + } + }) + .log_err(); + })?; + + Ok(()) + } + + async fn request_tool_call_confirmation( + &self, + request: acp_old::RequestToolCallConfirmationParams, + ) -> Result { + let cx = &mut self.cx.clone(); + + let old_acp_id = *self.next_tool_call_id.borrow() + 1; + self.next_tool_call_id.replace(old_acp_id); + + let tool_call = into_new_tool_call( + acp::ToolCallId(old_acp_id.to_string().into()), + request.tool_call, + ); + + let mut options = match request.confirmation { + acp_old::ToolCallConfirmation::Edit { .. } => vec![( + acp_old::ToolCallConfirmationOutcome::AlwaysAllow, + acp::PermissionOptionKind::AllowAlways, + "Always Allow Edits".to_string(), + )], + acp_old::ToolCallConfirmation::Execute { root_command, .. } => vec![( + acp_old::ToolCallConfirmationOutcome::AlwaysAllow, + acp::PermissionOptionKind::AllowAlways, + format!("Always Allow {}", root_command), + )], + acp_old::ToolCallConfirmation::Mcp { + server_name, + tool_name, + .. + } => vec![ + ( + acp_old::ToolCallConfirmationOutcome::AlwaysAllowMcpServer, + acp::PermissionOptionKind::AllowAlways, + format!("Always Allow {}", server_name), + ), + ( + acp_old::ToolCallConfirmationOutcome::AlwaysAllowTool, + acp::PermissionOptionKind::AllowAlways, + format!("Always Allow {}", tool_name), + ), + ], + acp_old::ToolCallConfirmation::Fetch { .. } => vec![( + acp_old::ToolCallConfirmationOutcome::AlwaysAllow, + acp::PermissionOptionKind::AllowAlways, + "Always Allow".to_string(), + )], + acp_old::ToolCallConfirmation::Other { .. } => vec![( + acp_old::ToolCallConfirmationOutcome::AlwaysAllow, + acp::PermissionOptionKind::AllowAlways, + "Always Allow".to_string(), + )], + }; + + options.extend([ + ( + acp_old::ToolCallConfirmationOutcome::Allow, + acp::PermissionOptionKind::AllowOnce, + "Allow".to_string(), + ), + ( + acp_old::ToolCallConfirmationOutcome::Reject, + acp::PermissionOptionKind::RejectOnce, + "Reject".to_string(), + ), + ]); + + let mut outcomes = Vec::with_capacity(options.len()); + let mut acp_options = Vec::with_capacity(options.len()); + + for (index, (outcome, kind, label)) in options.into_iter().enumerate() { + outcomes.push(outcome); + acp_options.push(acp::PermissionOption { + id: acp::PermissionOptionId(index.to_string().into()), + name: label, + kind, + }) + } + + let response = cx + .update(|cx| { + self.thread.borrow().update(cx, |thread, cx| { + thread.request_tool_call_authorization(tool_call, acp_options, cx) + }) + })? + .context("Failed to update thread")? + .await; + + let outcome = match response { + Ok(option_id) => outcomes[option_id.0.parse::().unwrap_or(0)], + Err(oneshot::Canceled) => acp_old::ToolCallConfirmationOutcome::Cancel, + }; + + Ok(acp_old::RequestToolCallConfirmationResponse { + id: acp_old::ToolCallId(old_acp_id), + outcome: outcome, + }) + } + + async fn push_tool_call( + &self, + request: acp_old::PushToolCallParams, + ) -> Result { + let cx = &mut self.cx.clone(); + + let old_acp_id = *self.next_tool_call_id.borrow() + 1; + self.next_tool_call_id.replace(old_acp_id); + + cx.update(|cx| { + self.thread.borrow().update(cx, |thread, cx| { + thread.upsert_tool_call( + into_new_tool_call(acp::ToolCallId(old_acp_id.to_string().into()), request), + cx, + ) + }) + })? + .context("Failed to update thread")?; + + Ok(acp_old::PushToolCallResponse { + id: acp_old::ToolCallId(old_acp_id), + }) + } + + async fn update_tool_call( + &self, + request: acp_old::UpdateToolCallParams, + ) -> Result<(), acp_old::Error> { + let cx = &mut self.cx.clone(); + + cx.update(|cx| { + self.thread.borrow().update(cx, |thread, cx| { + thread.update_tool_call( + acp::ToolCallUpdate { + id: acp::ToolCallId(request.tool_call_id.0.to_string().into()), + fields: acp::ToolCallUpdateFields { + status: Some(into_new_tool_call_status(request.status)), + content: Some( + request + .content + .into_iter() + .map(into_new_tool_call_content) + .collect::>(), + ), + ..Default::default() + }, + }, + cx, + ) + }) + })? + .context("Failed to update thread")??; + + Ok(()) + } + + async fn update_plan(&self, request: acp_old::UpdatePlanParams) -> Result<(), acp_old::Error> { + let cx = &mut self.cx.clone(); + + cx.update(|cx| { + self.thread.borrow().update(cx, |thread, cx| { + thread.update_plan( + acp::Plan { + entries: request + .entries + .into_iter() + .map(into_new_plan_entry) + .collect(), + }, + cx, + ) + }) + })? + .context("Failed to update thread")?; + + Ok(()) + } + + async fn read_text_file( + &self, + acp_old::ReadTextFileParams { path, line, limit }: acp_old::ReadTextFileParams, + ) -> Result { + let content = self + .cx + .update(|cx| { + self.thread.borrow().update(cx, |thread, cx| { + thread.read_text_file(path, line, limit, false, cx) + }) + })? + .context("Failed to update thread")? + .await?; + Ok(acp_old::ReadTextFileResponse { content }) + } + + async fn write_text_file( + &self, + acp_old::WriteTextFileParams { path, content }: acp_old::WriteTextFileParams, + ) -> Result<(), acp_old::Error> { + self.cx + .update(|cx| { + self.thread + .borrow() + .update(cx, |thread, cx| thread.write_text_file(path, content, cx)) + })? + .context("Failed to update thread")? + .await?; + + Ok(()) + } +} + +fn into_new_tool_call(id: acp::ToolCallId, request: acp_old::PushToolCallParams) -> acp::ToolCall { + acp::ToolCall { + id: id, + title: request.label, + kind: acp_kind_from_old_icon(request.icon), + status: acp::ToolCallStatus::InProgress, + content: request + .content + .into_iter() + .map(into_new_tool_call_content) + .collect(), + locations: request + .locations + .into_iter() + .map(into_new_tool_call_location) + .collect(), + raw_input: None, + raw_output: None, + } +} + +fn acp_kind_from_old_icon(icon: acp_old::Icon) -> acp::ToolKind { + match icon { + acp_old::Icon::FileSearch => acp::ToolKind::Search, + acp_old::Icon::Folder => acp::ToolKind::Search, + acp_old::Icon::Globe => acp::ToolKind::Search, + acp_old::Icon::Hammer => acp::ToolKind::Other, + acp_old::Icon::LightBulb => acp::ToolKind::Think, + acp_old::Icon::Pencil => acp::ToolKind::Edit, + acp_old::Icon::Regex => acp::ToolKind::Search, + acp_old::Icon::Terminal => acp::ToolKind::Execute, + } +} + +fn into_new_tool_call_status(status: acp_old::ToolCallStatus) -> acp::ToolCallStatus { + match status { + acp_old::ToolCallStatus::Running => acp::ToolCallStatus::InProgress, + acp_old::ToolCallStatus::Finished => acp::ToolCallStatus::Completed, + acp_old::ToolCallStatus::Error => acp::ToolCallStatus::Failed, + } +} + +fn into_new_tool_call_content(content: acp_old::ToolCallContent) -> acp::ToolCallContent { + match content { + acp_old::ToolCallContent::Markdown { markdown } => markdown.into(), + acp_old::ToolCallContent::Diff { diff } => acp::ToolCallContent::Diff { + diff: into_new_diff(diff), + }, + } +} + +fn into_new_diff(diff: acp_old::Diff) -> acp::Diff { + acp::Diff { + path: diff.path, + old_text: diff.old_text, + new_text: diff.new_text, + } +} + +fn into_new_tool_call_location(location: acp_old::ToolCallLocation) -> acp::ToolCallLocation { + acp::ToolCallLocation { + path: location.path, + line: location.line, + } +} + +fn into_new_plan_entry(entry: acp_old::PlanEntry) -> acp::PlanEntry { + acp::PlanEntry { + content: entry.content, + priority: into_new_plan_priority(entry.priority), + status: into_new_plan_status(entry.status), + } +} + +fn into_new_plan_priority(priority: acp_old::PlanEntryPriority) -> acp::PlanEntryPriority { + match priority { + acp_old::PlanEntryPriority::Low => acp::PlanEntryPriority::Low, + acp_old::PlanEntryPriority::Medium => acp::PlanEntryPriority::Medium, + acp_old::PlanEntryPriority::High => acp::PlanEntryPriority::High, + } +} + +fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatus { + match status { + acp_old::PlanEntryStatus::Pending => acp::PlanEntryStatus::Pending, + acp_old::PlanEntryStatus::InProgress => acp::PlanEntryStatus::InProgress, + acp_old::PlanEntryStatus::Completed => acp::PlanEntryStatus::Completed, + } +} + +pub struct AcpConnection { + pub name: &'static str, + pub connection: acp_old::AgentConnection, + pub _child_status: Task>, + pub current_thread: Rc>>, +} + +impl AcpConnection { + pub fn stdio( + name: &'static str, + command: AgentServerCommand, + root_dir: &Path, + cx: &mut AsyncApp, + ) -> Task> { + let root_dir = root_dir.to_path_buf(); + + cx.spawn(async move |cx| { + let mut child = util::command::new_smol_command(&command.path) + .args(command.args.iter()) + .current_dir(root_dir) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::inherit()) + .kill_on_drop(true) + .spawn()?; + + let stdin = child.stdin.take().unwrap(); + let stdout = child.stdout.take().unwrap(); + log::trace!("Spawned (pid: {})", child.id()); + + let foreground_executor = cx.foreground_executor().clone(); + + let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid())); + + let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent( + OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()), + stdin, + stdout, + move |fut| foreground_executor.spawn(fut).detach(), + ); + + let io_task = cx.background_spawn(async move { + io_fut.await.log_err(); + }); + + let child_status = cx.background_spawn(async move { + let result = match child.status().await { + Err(e) => Err(anyhow!(e)), + Ok(result) if result.success() => Ok(()), + Ok(result) => Err(anyhow!(result)), + }; + drop(io_task); + result + }); + + Ok(Self { + name, + connection, + _child_status: child_status, + current_thread: thread_rc, + }) + }) + } +} + +impl AgentConnection for AcpConnection { + fn new_thread( + self: Rc, + project: Entity, + _cwd: &Path, + cx: &mut AsyncApp, + ) -> Task>> { + let task = self.connection.request_any( + acp_old::InitializeParams { + protocol_version: acp_old::ProtocolVersion::latest(), + } + .into_any(), + ); + let current_thread = self.current_thread.clone(); + cx.spawn(async move |cx| { + let result = task.await?; + let result = acp_old::InitializeParams::response_from_any(result)?; + + if !result.is_authenticated { + anyhow::bail!(AuthRequired) + } + + cx.update(|cx| { + let thread = cx.new(|cx| { + let session_id = acp::SessionId("acp-old-no-id".into()); + AcpThread::new(self.name, self.clone(), project, session_id, cx) + }); + current_thread.replace(thread.downgrade()); + thread + }) + }) + } + + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] + } + + fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task> { + let task = self + .connection + .request_any(acp_old::AuthenticateParams.into_any()); + cx.foreground_executor().spawn(async move { + task.await?; + Ok(()) + }) + } + + fn prompt( + &self, + params: acp::PromptRequest, + cx: &mut App, + ) -> Task> { + let chunks = params + .prompt + .into_iter() + .filter_map(|block| match block { + acp::ContentBlock::Text(text) => { + Some(acp_old::UserMessageChunk::Text { text: text.text }) + } + acp::ContentBlock::ResourceLink(link) => Some(acp_old::UserMessageChunk::Path { + path: link.uri.into(), + }), + _ => None, + }) + .collect(); + + let task = self + .connection + .request_any(acp_old::SendUserMessageParams { chunks }.into_any()); + cx.foreground_executor().spawn(async move { + task.await?; + anyhow::Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) + }) + } + + fn cancel(&self, _session_id: &acp::SessionId, cx: &mut App) { + let task = self + .connection + .request_any(acp_old::CancelSendMessageParams.into_any()); + cx.foreground_executor() + .spawn(async move { + task.await?; + anyhow::Ok(()) + }) + .detach_and_log_err(cx) + } +} diff --git a/crates/agent_servers/src/acp/v1.rs b/crates/agent_servers/src/acp/v1.rs new file mode 100644 index 0000000000000000000000000000000000000000..ff71783b4890be905ce5d3a68202c6ea0371c12f --- /dev/null +++ b/crates/agent_servers/src/acp/v1.rs @@ -0,0 +1,282 @@ +use agent_client_protocol::{self as acp, Agent as _}; +use anyhow::anyhow; +use collections::HashMap; +use futures::channel::oneshot; +use project::Project; +use std::cell::RefCell; +use std::path::Path; +use std::rc::Rc; + +use anyhow::{Context as _, Result}; +use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity}; + +use crate::{AgentServerCommand, acp::UnsupportedVersion}; +use acp_thread::{AcpThread, AgentConnection, AuthRequired}; + +pub struct AcpConnection { + server_name: &'static str, + connection: Rc, + sessions: Rc>>, + auth_methods: Vec, + _io_task: Task>, +} + +pub struct AcpSession { + thread: WeakEntity, +} + +const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1; + +impl AcpConnection { + pub async fn stdio( + server_name: &'static str, + command: AgentServerCommand, + root_dir: &Path, + cx: &mut AsyncApp, + ) -> Result { + let mut child = util::command::new_smol_command(&command.path) + .args(command.args.iter().map(|arg| arg.as_str())) + .envs(command.env.iter().flatten()) + .current_dir(root_dir) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::inherit()) + .kill_on_drop(true) + .spawn()?; + + let stdout = child.stdout.take().expect("Failed to take stdout"); + let stdin = child.stdin.take().expect("Failed to take stdin"); + log::trace!("Spawned (pid: {})", child.id()); + + let sessions = Rc::new(RefCell::new(HashMap::default())); + + let client = ClientDelegate { + sessions: sessions.clone(), + cx: cx.clone(), + }; + let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, { + let foreground_executor = cx.foreground_executor().clone(); + move |fut| { + foreground_executor.spawn(fut).detach(); + } + }); + + let io_task = cx.background_spawn(io_task); + + cx.spawn({ + let sessions = sessions.clone(); + async move |cx| { + let status = child.status().await?; + + for session in sessions.borrow().values() { + session + .thread + .update(cx, |thread, cx| thread.emit_server_exited(status, cx)) + .ok(); + } + + anyhow::Ok(()) + } + }) + .detach(); + + let response = connection + .initialize(acp::InitializeRequest { + protocol_version: acp::VERSION, + client_capabilities: acp::ClientCapabilities { + fs: acp::FileSystemCapability { + read_text_file: true, + write_text_file: true, + }, + }, + }) + .await?; + + if response.protocol_version < MINIMUM_SUPPORTED_VERSION { + return Err(UnsupportedVersion.into()); + } + + Ok(Self { + auth_methods: response.auth_methods, + connection: connection.into(), + server_name, + sessions, + _io_task: io_task, + }) + } +} + +impl AgentConnection for AcpConnection { + fn new_thread( + self: Rc, + project: Entity, + cwd: &Path, + cx: &mut AsyncApp, + ) -> Task>> { + let conn = self.connection.clone(); + let sessions = self.sessions.clone(); + let cwd = cwd.to_path_buf(); + cx.spawn(async move |cx| { + let response = conn + .new_session(acp::NewSessionRequest { + mcp_servers: vec![], + cwd, + }) + .await + .map_err(|err| { + if err.code == acp::ErrorCode::AUTH_REQUIRED.code { + anyhow!(AuthRequired) + } else { + anyhow!(err) + } + })?; + + let session_id = response.session_id; + + let thread = cx.new(|cx| { + AcpThread::new( + self.server_name, + self.clone(), + project, + session_id.clone(), + cx, + ) + })?; + + let session = AcpSession { + thread: thread.downgrade(), + }; + sessions.borrow_mut().insert(session_id, session); + + Ok(thread) + }) + } + + fn auth_methods(&self) -> &[acp::AuthMethod] { + &self.auth_methods + } + + fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task> { + let conn = self.connection.clone(); + cx.foreground_executor().spawn(async move { + let result = conn + .authenticate(acp::AuthenticateRequest { + method_id: method_id.clone(), + }) + .await?; + + Ok(result) + }) + } + + fn prompt( + &self, + params: acp::PromptRequest, + cx: &mut App, + ) -> Task> { + let conn = self.connection.clone(); + cx.foreground_executor().spawn(async move { + let response = conn.prompt(params).await?; + Ok(response) + }) + } + + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { + let conn = self.connection.clone(); + let params = acp::CancelNotification { + session_id: session_id.clone(), + }; + cx.foreground_executor() + .spawn(async move { conn.cancel(params).await }) + .detach(); + } +} + +struct ClientDelegate { + sessions: Rc>>, + cx: AsyncApp, +} + +impl acp::Client for ClientDelegate { + async fn request_permission( + &self, + arguments: acp::RequestPermissionRequest, + ) -> Result { + let cx = &mut self.cx.clone(); + let rx = self + .sessions + .borrow() + .get(&arguments.session_id) + .context("Failed to get session")? + .thread + .update(cx, |thread, cx| { + thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx) + })?; + + let result = rx.await; + + let outcome = match result { + Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option }, + Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled, + }; + + Ok(acp::RequestPermissionResponse { outcome }) + } + + async fn write_text_file( + &self, + arguments: acp::WriteTextFileRequest, + ) -> Result<(), acp::Error> { + let cx = &mut self.cx.clone(); + let task = self + .sessions + .borrow() + .get(&arguments.session_id) + .context("Failed to get session")? + .thread + .update(cx, |thread, cx| { + thread.write_text_file(arguments.path, arguments.content, cx) + })?; + + task.await?; + + Ok(()) + } + + async fn read_text_file( + &self, + arguments: acp::ReadTextFileRequest, + ) -> Result { + let cx = &mut self.cx.clone(); + let task = self + .sessions + .borrow() + .get(&arguments.session_id) + .context("Failed to get session")? + .thread + .update(cx, |thread, cx| { + thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx) + })?; + + let content = task.await?; + + Ok(acp::ReadTextFileResponse { content }) + } + + async fn session_notification( + &self, + notification: acp::SessionNotification, + ) -> Result<(), acp::Error> { + let cx = &mut self.cx.clone(); + let sessions = self.sessions.borrow(); + let session = sessions + .get(¬ification.session_id) + .context("Failed to get session")?; + + session.thread.update(cx, |thread, cx| { + thread.handle_session_update(notification.update, cx) + })??; + + Ok(()) + } +} diff --git a/crates/agent_servers/src/agent_servers.rs b/crates/agent_servers/src/agent_servers.rs index 5d588cd4aea0f863203201de82b0614cc210e615..b3b8a3317049927986a6a578bc50c4e5506b7650 100644 --- a/crates/agent_servers/src/agent_servers.rs +++ b/crates/agent_servers/src/agent_servers.rs @@ -1,30 +1,79 @@ -use std::{ - path::{Path, PathBuf}, - sync::Arc, -}; +mod acp; +mod claude; +mod gemini; +mod settings; -use anyhow::{Context as _, Result}; +#[cfg(test)] +mod e2e_tests; + +pub use claude::*; +pub use gemini::*; +pub use settings::*; + +use acp_thread::AgentConnection; +use anyhow::Result; use collections::HashMap; -use gpui::{App, AsyncApp, Entity, SharedString}; +use gpui::{App, AsyncApp, Entity, SharedString, Task}; use project::Project; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use settings::{Settings, SettingsSources, SettingsStore}; -use util::{ResultExt, paths}; +use std::{ + path::{Path, PathBuf}, + rc::Rc, + sync::Arc, +}; +use util::ResultExt as _; pub fn init(cx: &mut App) { - AllAgentServersSettings::register(cx); + settings::init(cx); } -#[derive(Default, Deserialize, Serialize, Clone, JsonSchema, Debug)] -pub struct AllAgentServersSettings { - gemini: Option, +pub trait AgentServer: Send { + fn logo(&self) -> ui::IconName; + fn name(&self) -> &'static str; + fn empty_state_headline(&self) -> &'static str; + fn empty_state_message(&self) -> &'static str; + + fn connect( + &self, + root_dir: &Path, + project: &Entity, + cx: &mut App, + ) -> Task>>; } -#[derive(Deserialize, Serialize, Clone, JsonSchema, Debug)] -pub struct AgentServerSettings { - #[serde(flatten)] - command: AgentServerCommand, +impl std::fmt::Debug for AgentServerCommand { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let filtered_env = self.env.as_ref().map(|env| { + env.iter() + .map(|(k, v)| { + ( + k, + if util::redact::should_redact(k) { + "[REDACTED]" + } else { + v + }, + ) + }) + .collect::>() + }); + + f.debug_struct("AgentServerCommand") + .field("path", &self.path) + .field("args", &self.args) + .field("env", &filtered_env) + .finish() + } +} + +pub enum AgentServerVersion { + Supported, + Unsupported { + error_message: SharedString, + upgrade_message: SharedString, + upgrade_command: String, + }, } #[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema)] @@ -36,105 +85,46 @@ pub struct AgentServerCommand { pub env: Option>, } -pub struct Gemini; - -pub struct AgentServerVersion { - pub current_version: SharedString, - pub supported: bool, -} - -pub trait AgentServer: Send { - fn command( - &self, +impl AgentServerCommand { + pub(crate) async fn resolve( + path_bin_name: &'static str, + extra_args: &[&'static str], + fallback_path: Option<&Path>, + settings: Option, project: &Entity, cx: &mut AsyncApp, - ) -> impl Future>; - - fn version( - &self, - command: &AgentServerCommand, - ) -> impl Future> + Send; -} - -const GEMINI_ACP_ARG: &str = "--acp"; - -impl AgentServer for Gemini { - async fn command( - &self, - project: &Entity, - cx: &mut AsyncApp, - ) -> Result { - let custom_command = cx.read_global(|settings: &SettingsStore, _| { - let settings = settings.get::(None); - settings - .gemini - .as_ref() - .map(|gemini_settings| AgentServerCommand { - path: gemini_settings.command.path.clone(), - args: gemini_settings - .command - .args - .iter() - .cloned() - .chain(std::iter::once(GEMINI_ACP_ARG.into())) - .collect(), - env: gemini_settings.command.env.clone(), - }) - })?; - - if let Some(custom_command) = custom_command { - return Ok(custom_command); - } - - if let Some(path) = find_bin_in_path("gemini", project, cx).await { - return Ok(AgentServerCommand { - path, - args: vec![GEMINI_ACP_ARG.into()], - env: None, + ) -> Option { + if let Some(agent_settings) = settings { + return Some(Self { + path: agent_settings.command.path, + args: agent_settings + .command + .args + .into_iter() + .chain(extra_args.iter().map(|arg| arg.to_string())) + .collect(), + env: agent_settings.command.env, }); + } else { + match find_bin_in_path(path_bin_name, project, cx).await { + Some(path) => Some(Self { + path, + args: extra_args.iter().map(|arg| arg.to_string()).collect(), + env: None, + }), + None => fallback_path.and_then(|path| { + if path.exists() { + Some(Self { + path: path.to_path_buf(), + args: extra_args.iter().map(|arg| arg.to_string()).collect(), + env: None, + }) + } else { + None + } + }), + } } - - let (fs, node_runtime) = project.update(cx, |project, _| { - (project.fs().clone(), project.node_runtime().cloned()) - })?; - let node_runtime = node_runtime.context("gemini not found on path")?; - - let directory = ::paths::agent_servers_dir().join("gemini"); - fs.create_dir(&directory).await?; - node_runtime - .npm_install_packages(&directory, &[("@google/gemini-cli", "latest")]) - .await?; - let path = directory.join("node_modules/.bin/gemini"); - - Ok(AgentServerCommand { - path, - args: vec![GEMINI_ACP_ARG.into()], - env: None, - }) - } - - async fn version(&self, command: &AgentServerCommand) -> Result { - let version_fut = util::command::new_smol_command(&command.path) - .args(command.args.iter()) - .arg("--version") - .kill_on_drop(true) - .output(); - - let help_fut = util::command::new_smol_command(&command.path) - .args(command.args.iter()) - .arg("--help") - .kill_on_drop(true) - .output(); - - let (version_output, help_output) = futures::future::join(version_fut, help_fut).await; - - let current_version = String::from_utf8(version_output?.stdout)?.into(); - let supported = String::from_utf8(help_output?.stdout)?.contains(GEMINI_ACP_ARG); - - Ok(AgentServerVersion { - current_version, - supported, - }) } } @@ -184,48 +174,3 @@ async fn find_bin_in_path( }) .await } - -impl std::fmt::Debug for AgentServerCommand { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let filtered_env = self.env.as_ref().map(|env| { - env.iter() - .map(|(k, v)| { - ( - k, - if util::redact::should_redact(k) { - "[REDACTED]" - } else { - v - }, - ) - }) - .collect::>() - }); - - f.debug_struct("AgentServerCommand") - .field("path", &self.path) - .field("args", &self.args) - .field("env", &filtered_env) - .finish() - } -} - -impl settings::Settings for AllAgentServersSettings { - const KEY: Option<&'static str> = Some("agent_servers"); - - type FileContent = Self; - - fn load(sources: SettingsSources, _: &mut App) -> Result { - let mut settings = AllAgentServersSettings::default(); - - for value in sources.defaults_and_customizations() { - if value.gemini.is_some() { - settings.gemini = value.gemini.clone(); - } - } - - Ok(settings) - } - - fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {} -} diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs new file mode 100644 index 0000000000000000000000000000000000000000..c65508f1520b9b03d866df606c2272ccf43d8dbd --- /dev/null +++ b/crates/agent_servers/src/claude.rs @@ -0,0 +1,1065 @@ +mod mcp_server; +pub mod tools; + +use collections::HashMap; +use context_server::listener::McpServerTool; +use project::Project; +use settings::SettingsStore; +use smol::process::Child; +use std::cell::RefCell; +use std::fmt::Display; +use std::path::Path; +use std::rc::Rc; +use uuid::Uuid; + +use agent_client_protocol as acp; +use anyhow::{Result, anyhow}; +use futures::channel::oneshot; +use futures::{AsyncBufReadExt, AsyncWriteExt}; +use futures::{ + AsyncRead, AsyncWrite, FutureExt, StreamExt, + channel::mpsc::{self, UnboundedReceiver, UnboundedSender}, + io::BufReader, + select_biased, +}; +use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity}; +use serde::{Deserialize, Serialize}; +use util::{ResultExt, debug_panic}; + +use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig}; +use crate::claude::tools::ClaudeTool; +use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings}; +use acp_thread::{AcpThread, AgentConnection}; + +#[derive(Clone)] +pub struct ClaudeCode; + +impl AgentServer for ClaudeCode { + fn name(&self) -> &'static str { + "Claude Code" + } + + fn empty_state_headline(&self) -> &'static str { + self.name() + } + + fn empty_state_message(&self) -> &'static str { + "How can I help you today?" + } + + fn logo(&self) -> ui::IconName { + ui::IconName::AiClaude + } + + fn connect( + &self, + _root_dir: &Path, + _project: &Entity, + _cx: &mut App, + ) -> Task>> { + let connection = ClaudeAgentConnection { + sessions: Default::default(), + }; + + Task::ready(Ok(Rc::new(connection) as _)) + } +} + +struct ClaudeAgentConnection { + sessions: Rc>>, +} + +impl AgentConnection for ClaudeAgentConnection { + fn new_thread( + self: Rc, + project: Entity, + cwd: &Path, + cx: &mut AsyncApp, + ) -> Task>> { + let cwd = cwd.to_owned(); + cx.spawn(async move |cx| { + let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid()); + let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), cx).await?; + + let mut mcp_servers = HashMap::default(); + mcp_servers.insert( + mcp_server::SERVER_NAME.to_string(), + permission_mcp_server.server_config()?, + ); + let mcp_config = McpConfig { mcp_servers }; + + let mcp_config_file = tempfile::NamedTempFile::new()?; + let (mcp_config_file, mcp_config_path) = mcp_config_file.into_parts(); + + let mut mcp_config_file = smol::fs::File::from(mcp_config_file); + mcp_config_file + .write_all(serde_json::to_string(&mcp_config)?.as_bytes()) + .await?; + mcp_config_file.flush().await?; + + let settings = cx.read_global(|settings: &SettingsStore, _| { + settings.get::(None).claude.clone() + })?; + + let Some(command) = AgentServerCommand::resolve( + "claude", + &[], + Some(&util::paths::home_dir().join(".claude/local/claude")), + settings, + &project, + cx, + ) + .await + else { + anyhow::bail!("Failed to find claude binary"); + }; + + let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded(); + let (outgoing_tx, outgoing_rx) = mpsc::unbounded(); + + let session_id = acp::SessionId(Uuid::new_v4().to_string().into()); + + log::trace!("Starting session with id: {}", session_id); + + let mut child = spawn_claude( + &command, + ClaudeSessionMode::Start, + session_id.clone(), + &mcp_config_path, + &cwd, + )?; + + let stdin = child.stdin.take().unwrap(); + let stdout = child.stdout.take().unwrap(); + + let pid = child.id(); + log::trace!("Spawned (pid: {})", pid); + + cx.background_spawn(async move { + let mut outgoing_rx = Some(outgoing_rx); + + ClaudeAgentSession::handle_io( + outgoing_rx.take().unwrap(), + incoming_message_tx.clone(), + stdin, + stdout, + ) + .await?; + + log::trace!("Stopped (pid: {})", pid); + + drop(mcp_config_path); + anyhow::Ok(()) + }) + .detach(); + + let turn_state = Rc::new(RefCell::new(TurnState::None)); + + let handler_task = cx.spawn({ + let turn_state = turn_state.clone(); + let mut thread_rx = thread_rx.clone(); + async move |cx| { + while let Some(message) = incoming_message_rx.next().await { + ClaudeAgentSession::handle_message( + thread_rx.clone(), + message, + turn_state.clone(), + cx, + ) + .await + } + + if let Some(status) = child.status().await.log_err() { + if let Some(thread) = thread_rx.recv().await.ok() { + thread + .update(cx, |thread, cx| { + thread.emit_server_exited(status, cx); + }) + .ok(); + } + } + } + }); + + let thread = cx.new(|cx| { + AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx) + })?; + + thread_tx.send(thread.downgrade())?; + + let session = ClaudeAgentSession { + outgoing_tx, + turn_state, + _handler_task: handler_task, + _mcp_server: Some(permission_mcp_server), + }; + + self.sessions.borrow_mut().insert(session_id, session); + + Ok(thread) + }) + } + + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] + } + + fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task> { + Task::ready(Err(anyhow!("Authentication not supported"))) + } + + fn prompt( + &self, + params: acp::PromptRequest, + cx: &mut App, + ) -> Task> { + let sessions = self.sessions.borrow(); + let Some(session) = sessions.get(¶ms.session_id) else { + return Task::ready(Err(anyhow!( + "Attempted to send message to nonexistent session {}", + params.session_id + ))); + }; + + let (end_tx, end_rx) = oneshot::channel(); + session.turn_state.replace(TurnState::InProgress { end_tx }); + + let mut content = String::new(); + for chunk in params.prompt { + match chunk { + acp::ContentBlock::Text(text_content) => { + content.push_str(&text_content.text); + } + acp::ContentBlock::ResourceLink(resource_link) => { + content.push_str(&format!("@{}", resource_link.uri)); + } + acp::ContentBlock::Audio(_) + | acp::ContentBlock::Image(_) + | acp::ContentBlock::Resource(_) => { + // TODO + } + } + } + + if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User { + message: Message { + role: Role::User, + content: Content::UntaggedText(content), + id: None, + model: None, + stop_reason: None, + stop_sequence: None, + usage: None, + }, + session_id: Some(params.session_id.to_string()), + }) { + return Task::ready(Err(anyhow!(err))); + } + + cx.foreground_executor().spawn(async move { end_rx.await? }) + } + + fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) { + let sessions = self.sessions.borrow(); + let Some(session) = sessions.get(&session_id) else { + log::warn!("Attempted to cancel nonexistent session {}", session_id); + return; + }; + + let request_id = new_request_id(); + + let turn_state = session.turn_state.take(); + let TurnState::InProgress { end_tx } = turn_state else { + // Already cancelled or idle, put it back + session.turn_state.replace(turn_state); + return; + }; + + session.turn_state.replace(TurnState::CancelRequested { + end_tx, + request_id: request_id.clone(), + }); + + session + .outgoing_tx + .unbounded_send(SdkMessage::ControlRequest { + request_id, + request: ControlRequest::Interrupt, + }) + .log_err(); + } +} + +#[derive(Clone, Copy)] +enum ClaudeSessionMode { + Start, + #[expect(dead_code)] + Resume, +} + +fn spawn_claude( + command: &AgentServerCommand, + mode: ClaudeSessionMode, + session_id: acp::SessionId, + mcp_config_path: &Path, + root_dir: &Path, +) -> Result { + let child = util::command::new_smol_command(&command.path) + .args([ + "--input-format", + "stream-json", + "--output-format", + "stream-json", + "--print", + "--verbose", + "--mcp-config", + mcp_config_path.to_string_lossy().as_ref(), + "--permission-prompt-tool", + &format!( + "mcp__{}__{}", + mcp_server::SERVER_NAME, + mcp_server::PermissionTool::NAME, + ), + "--allowedTools", + &format!( + "mcp__{}__{},mcp__{}__{}", + mcp_server::SERVER_NAME, + mcp_server::EditTool::NAME, + mcp_server::SERVER_NAME, + mcp_server::ReadTool::NAME + ), + "--disallowedTools", + "Read,Edit", + ]) + .args(match mode { + ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()], + ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()], + }) + .args(command.args.iter().map(|arg| arg.as_str())) + .current_dir(root_dir) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::inherit()) + .kill_on_drop(true) + .spawn()?; + + Ok(child) +} + +struct ClaudeAgentSession { + outgoing_tx: UnboundedSender, + turn_state: Rc>, + _mcp_server: Option, + _handler_task: Task<()>, +} + +#[derive(Debug, Default)] +enum TurnState { + #[default] + None, + InProgress { + end_tx: oneshot::Sender>, + }, + CancelRequested { + end_tx: oneshot::Sender>, + request_id: String, + }, + CancelConfirmed { + end_tx: oneshot::Sender>, + }, +} + +impl TurnState { + fn is_cancelled(&self) -> bool { + matches!(self, TurnState::CancelConfirmed { .. }) + } + + fn end_tx(self) -> Option>> { + match self { + TurnState::None => None, + TurnState::InProgress { end_tx, .. } => Some(end_tx), + TurnState::CancelRequested { end_tx, .. } => Some(end_tx), + TurnState::CancelConfirmed { end_tx } => Some(end_tx), + } + } + + fn confirm_cancellation(self, id: &str) -> Self { + match self { + TurnState::CancelRequested { request_id, end_tx } if request_id == id => { + TurnState::CancelConfirmed { end_tx } + } + _ => self, + } + } +} + +impl ClaudeAgentSession { + async fn handle_message( + mut thread_rx: watch::Receiver>, + message: SdkMessage, + turn_state: Rc>, + cx: &mut AsyncApp, + ) { + match message { + // we should only be sending these out, they don't need to be in the thread + SdkMessage::ControlRequest { .. } => {} + SdkMessage::User { + message, + session_id: _, + } => { + let Some(thread) = thread_rx + .recv() + .await + .log_err() + .and_then(|entity| entity.upgrade()) + else { + log::error!("Received an SDK message but thread is gone"); + return; + }; + + for chunk in message.content.chunks() { + match chunk { + ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => { + if !turn_state.borrow().is_cancelled() { + thread + .update(cx, |thread, cx| { + thread.push_user_content_block(text.into(), cx) + }) + .log_err(); + } + } + ContentChunk::ToolResult { + content, + tool_use_id, + } => { + let content = content.to_string(); + thread + .update(cx, |thread, cx| { + thread.update_tool_call( + acp::ToolCallUpdate { + id: acp::ToolCallId(tool_use_id.into()), + fields: acp::ToolCallUpdateFields { + status: if turn_state.borrow().is_cancelled() { + // Do not set to completed if turn was cancelled + None + } else { + Some(acp::ToolCallStatus::Completed) + }, + content: (!content.is_empty()) + .then(|| vec![content.into()]), + ..Default::default() + }, + }, + cx, + ) + }) + .log_err(); + } + ContentChunk::Thinking { .. } + | ContentChunk::RedactedThinking + | ContentChunk::ToolUse { .. } => { + debug_panic!( + "Should not get {:?} with role: assistant. should we handle this?", + chunk + ); + } + + ContentChunk::Image + | ContentChunk::Document + | ContentChunk::WebSearchToolResult => { + thread + .update(cx, |thread, cx| { + thread.push_assistant_content_block( + format!("Unsupported content: {:?}", chunk).into(), + false, + cx, + ) + }) + .log_err(); + } + } + } + } + SdkMessage::Assistant { + message, + session_id: _, + } => { + let Some(thread) = thread_rx + .recv() + .await + .log_err() + .and_then(|entity| entity.upgrade()) + else { + log::error!("Received an SDK message but thread is gone"); + return; + }; + + for chunk in message.content.chunks() { + match chunk { + ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => { + thread + .update(cx, |thread, cx| { + thread.push_assistant_content_block(text.into(), false, cx) + }) + .log_err(); + } + ContentChunk::Thinking { thinking } => { + thread + .update(cx, |thread, cx| { + thread.push_assistant_content_block(thinking.into(), true, cx) + }) + .log_err(); + } + ContentChunk::RedactedThinking => { + thread + .update(cx, |thread, cx| { + thread.push_assistant_content_block( + "[REDACTED]".into(), + true, + cx, + ) + }) + .log_err(); + } + ContentChunk::ToolUse { id, name, input } => { + let claude_tool = ClaudeTool::infer(&name, input); + + thread + .update(cx, |thread, cx| { + if let ClaudeTool::TodoWrite(Some(params)) = claude_tool { + thread.update_plan( + acp::Plan { + entries: params + .todos + .into_iter() + .map(Into::into) + .collect(), + }, + cx, + ) + } else { + thread.upsert_tool_call( + claude_tool.as_acp(acp::ToolCallId(id.into())), + cx, + ); + } + }) + .log_err(); + } + ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => { + debug_panic!( + "Should not get tool results with role: assistant. should we handle this?" + ); + } + ContentChunk::Image | ContentChunk::Document => { + thread + .update(cx, |thread, cx| { + thread.push_assistant_content_block( + format!("Unsupported content: {:?}", chunk).into(), + false, + cx, + ) + }) + .log_err(); + } + } + } + } + SdkMessage::Result { + is_error, + subtype, + result, + .. + } => { + let turn_state = turn_state.take(); + let was_cancelled = turn_state.is_cancelled(); + let Some(end_turn_tx) = turn_state.end_tx() else { + debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn"); + return; + }; + + if is_error || (!was_cancelled && subtype == ResultErrorType::ErrorDuringExecution) + { + end_turn_tx + .send(Err(anyhow!( + "Error: {}", + result.unwrap_or_else(|| subtype.to_string()) + ))) + .ok(); + } else { + let stop_reason = match subtype { + ResultErrorType::Success => acp::StopReason::EndTurn, + ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests, + ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled, + }; + end_turn_tx + .send(Ok(acp::PromptResponse { stop_reason })) + .ok(); + } + } + SdkMessage::ControlResponse { response } => { + if matches!(response.subtype, ResultErrorType::Success) { + let new_state = turn_state.take().confirm_cancellation(&response.request_id); + turn_state.replace(new_state); + } else { + log::error!("Control response error: {:?}", response); + } + } + SdkMessage::System { .. } => {} + } + } + + async fn handle_io( + mut outgoing_rx: UnboundedReceiver, + incoming_tx: UnboundedSender, + mut outgoing_bytes: impl Unpin + AsyncWrite, + incoming_bytes: impl Unpin + AsyncRead, + ) -> Result> { + let mut output_reader = BufReader::new(incoming_bytes); + let mut outgoing_line = Vec::new(); + let mut incoming_line = String::new(); + loop { + select_biased! { + message = outgoing_rx.next() => { + if let Some(message) = message { + outgoing_line.clear(); + serde_json::to_writer(&mut outgoing_line, &message)?; + log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line)); + outgoing_line.push(b'\n'); + outgoing_bytes.write_all(&outgoing_line).await.ok(); + } else { + break; + } + } + bytes_read = output_reader.read_line(&mut incoming_line).fuse() => { + if bytes_read? == 0 { + break + } + log::trace!("recv: {}", &incoming_line); + match serde_json::from_str::(&incoming_line) { + Ok(message) => { + incoming_tx.unbounded_send(message).log_err(); + } + Err(error) => { + log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}"); + } + } + incoming_line.clear(); + } + } + } + + Ok(outgoing_rx) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct Message { + role: Role, + content: Content, + #[serde(skip_serializing_if = "Option::is_none")] + id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + stop_reason: Option, + #[serde(skip_serializing_if = "Option::is_none")] + stop_sequence: Option, + #[serde(skip_serializing_if = "Option::is_none")] + usage: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +enum Content { + UntaggedText(String), + Chunks(Vec), +} + +impl Content { + pub fn chunks(self) -> impl Iterator { + match self { + Self::Chunks(chunks) => chunks.into_iter(), + Self::UntaggedText(text) => vec![ContentChunk::Text { text: text.clone() }].into_iter(), + } + } +} + +impl Display for Content { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Content::UntaggedText(txt) => write!(f, "{}", txt), + Content::Chunks(chunks) => { + for chunk in chunks { + write!(f, "{}", chunk)?; + } + Ok(()) + } + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +enum ContentChunk { + Text { + text: String, + }, + ToolUse { + id: String, + name: String, + input: serde_json::Value, + }, + ToolResult { + content: Content, + tool_use_id: String, + }, + Thinking { + thinking: String, + }, + RedactedThinking, + // TODO + Image, + Document, + WebSearchToolResult, + #[serde(untagged)] + UntaggedText(String), +} + +impl Display for ContentChunk { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ContentChunk::Text { text } => write!(f, "{}", text), + ContentChunk::Thinking { thinking } => write!(f, "Thinking: {}", thinking), + ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"), + ContentChunk::UntaggedText(text) => write!(f, "{}", text), + ContentChunk::ToolResult { content, .. } => write!(f, "{}", content), + ContentChunk::Image + | ContentChunk::Document + | ContentChunk::ToolUse { .. } + | ContentChunk::WebSearchToolResult => { + write!(f, "\n{:?}\n", &self) + } + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct Usage { + input_tokens: u32, + cache_creation_input_tokens: u32, + cache_read_input_tokens: u32, + output_tokens: u32, + service_tier: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +enum Role { + System, + Assistant, + User, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct MessageParam { + role: Role, + content: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +enum SdkMessage { + // An assistant message + Assistant { + message: Message, // from Anthropic SDK + #[serde(skip_serializing_if = "Option::is_none")] + session_id: Option, + }, + // A user message + User { + message: Message, // from Anthropic SDK + #[serde(skip_serializing_if = "Option::is_none")] + session_id: Option, + }, + // Emitted as the last message in a conversation + Result { + subtype: ResultErrorType, + duration_ms: f64, + duration_api_ms: f64, + is_error: bool, + num_turns: i32, + #[serde(skip_serializing_if = "Option::is_none")] + result: Option, + session_id: String, + total_cost_usd: f64, + }, + // Emitted as the first message at the start of a conversation + System { + cwd: String, + session_id: String, + tools: Vec, + model: String, + mcp_servers: Vec, + #[serde(rename = "apiKeySource")] + api_key_source: String, + #[serde(rename = "permissionMode")] + permission_mode: PermissionMode, + }, + /// Messages used to control the conversation, outside of chat messages to the model + ControlRequest { + request_id: String, + request: ControlRequest, + }, + /// Response to a control request + ControlResponse { response: ControlResponse }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "subtype", rename_all = "snake_case")] +enum ControlRequest { + /// Cancel the current conversation + Interrupt, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct ControlResponse { + request_id: String, + subtype: ResultErrorType, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] +#[serde(rename_all = "snake_case")] +enum ResultErrorType { + Success, + ErrorMaxTurns, + ErrorDuringExecution, +} + +impl Display for ResultErrorType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ResultErrorType::Success => write!(f, "success"), + ResultErrorType::ErrorMaxTurns => write!(f, "error_max_turns"), + ResultErrorType::ErrorDuringExecution => write!(f, "error_during_execution"), + } + } +} + +fn new_request_id() -> String { + use rand::Rng; + // In the Claude Code TS SDK they just generate a random 12 character string, + // `Math.random().toString(36).substring(2, 15)` + rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(12) + .map(char::from) + .collect() +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct McpServer { + name: String, + status: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +enum PermissionMode { + Default, + AcceptEdits, + BypassPermissions, + Plan, +} + +#[cfg(test)] +pub(crate) mod tests { + use super::*; + use crate::e2e_tests; + use gpui::TestAppContext; + use serde_json::json; + + crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow"); + + pub fn local_command() -> AgentServerCommand { + AgentServerCommand { + path: "claude".into(), + args: vec![], + env: None, + } + } + + #[gpui::test] + #[cfg_attr(not(feature = "e2e"), ignore)] + async fn test_todo_plan(cx: &mut TestAppContext) { + let fs = e2e_tests::init_test(cx).await; + let project = Project::test(fs, [], cx).await; + let thread = + e2e_tests::new_test_thread(ClaudeCode, project.clone(), "/private/tmp", cx).await; + + thread + .update(cx, |thread, cx| { + thread.send_raw( + "Create a todo plan for initializing a new React app. I'll follow it myself, do not execute on it.", + cx, + ) + }) + .await + .unwrap(); + + let mut entries_len = 0; + + thread.read_with(cx, |thread, _| { + entries_len = thread.plan().entries.len(); + assert!(thread.plan().entries.len() > 0, "Empty plan"); + }); + + thread + .update(cx, |thread, cx| { + thread.send_raw( + "Mark the first entry status as in progress without acting on it.", + cx, + ) + }) + .await + .unwrap(); + + thread.read_with(cx, |thread, _| { + assert!(matches!( + thread.plan().entries[0].status, + acp::PlanEntryStatus::InProgress + )); + assert_eq!(thread.plan().entries.len(), entries_len); + }); + + thread + .update(cx, |thread, cx| { + thread.send_raw( + "Now mark the first entry as completed without acting on it.", + cx, + ) + }) + .await + .unwrap(); + + thread.read_with(cx, |thread, _| { + assert!(matches!( + thread.plan().entries[0].status, + acp::PlanEntryStatus::Completed + )); + assert_eq!(thread.plan().entries.len(), entries_len); + }); + } + + #[test] + fn test_deserialize_content_untagged_text() { + let json = json!("Hello, world!"); + let content: Content = serde_json::from_value(json).unwrap(); + match content { + Content::UntaggedText(text) => assert_eq!(text, "Hello, world!"), + _ => panic!("Expected UntaggedText variant"), + } + } + + #[test] + fn test_deserialize_content_chunks() { + let json = json!([ + { + "type": "text", + "text": "Hello" + }, + { + "type": "tool_use", + "id": "tool_123", + "name": "calculator", + "input": {"operation": "add", "a": 1, "b": 2} + } + ]); + let content: Content = serde_json::from_value(json).unwrap(); + match content { + Content::Chunks(chunks) => { + assert_eq!(chunks.len(), 2); + match &chunks[0] { + ContentChunk::Text { text } => assert_eq!(text, "Hello"), + _ => panic!("Expected Text chunk"), + } + match &chunks[1] { + ContentChunk::ToolUse { id, name, input } => { + assert_eq!(id, "tool_123"); + assert_eq!(name, "calculator"); + assert_eq!(input["operation"], "add"); + assert_eq!(input["a"], 1); + assert_eq!(input["b"], 2); + } + _ => panic!("Expected ToolUse chunk"), + } + } + _ => panic!("Expected Chunks variant"), + } + } + + #[test] + fn test_deserialize_tool_result_untagged_text() { + let json = json!({ + "type": "tool_result", + "content": "Result content", + "tool_use_id": "tool_456" + }); + let chunk: ContentChunk = serde_json::from_value(json).unwrap(); + match chunk { + ContentChunk::ToolResult { + content, + tool_use_id, + } => { + match content { + Content::UntaggedText(text) => assert_eq!(text, "Result content"), + _ => panic!("Expected UntaggedText content"), + } + assert_eq!(tool_use_id, "tool_456"); + } + _ => panic!("Expected ToolResult variant"), + } + } + + #[test] + fn test_deserialize_tool_result_chunks() { + let json = json!({ + "type": "tool_result", + "content": [ + { + "type": "text", + "text": "Processing complete" + }, + { + "type": "text", + "text": "Result: 42" + } + ], + "tool_use_id": "tool_789" + }); + let chunk: ContentChunk = serde_json::from_value(json).unwrap(); + match chunk { + ContentChunk::ToolResult { + content, + tool_use_id, + } => { + match content { + Content::Chunks(chunks) => { + assert_eq!(chunks.len(), 2); + match &chunks[0] { + ContentChunk::Text { text } => assert_eq!(text, "Processing complete"), + _ => panic!("Expected Text chunk"), + } + match &chunks[1] { + ContentChunk::Text { text } => assert_eq!(text, "Result: 42"), + _ => panic!("Expected Text chunk"), + } + } + _ => panic!("Expected Chunks content"), + } + assert_eq!(tool_use_id, "tool_789"); + } + _ => panic!("Expected ToolResult variant"), + } + } +} diff --git a/crates/agent_servers/src/claude/mcp_server.rs b/crates/agent_servers/src/claude/mcp_server.rs new file mode 100644 index 0000000000000000000000000000000000000000..53a8556e74545bc339936d0b2f9f78444190af0c --- /dev/null +++ b/crates/agent_servers/src/claude/mcp_server.rs @@ -0,0 +1,302 @@ +use std::path::PathBuf; + +use crate::claude::tools::{ClaudeTool, EditToolParams, ReadToolParams}; +use acp_thread::AcpThread; +use agent_client_protocol as acp; +use anyhow::{Context, Result}; +use collections::HashMap; +use context_server::listener::{McpServerTool, ToolResponse}; +use context_server::types::{ + Implementation, InitializeParams, InitializeResponse, ProtocolVersion, ServerCapabilities, + ToolAnnotations, ToolResponseContent, ToolsCapabilities, requests, +}; +use gpui::{App, AsyncApp, Task, WeakEntity}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +pub struct ClaudeZedMcpServer { + server: context_server::listener::McpServer, +} + +pub const SERVER_NAME: &str = "zed"; + +impl ClaudeZedMcpServer { + pub async fn new( + thread_rx: watch::Receiver>, + cx: &AsyncApp, + ) -> Result { + let mut mcp_server = context_server::listener::McpServer::new(cx).await?; + mcp_server.handle_request::(Self::handle_initialize); + + mcp_server.add_tool(PermissionTool { + thread_rx: thread_rx.clone(), + }); + mcp_server.add_tool(ReadTool { + thread_rx: thread_rx.clone(), + }); + mcp_server.add_tool(EditTool { + thread_rx: thread_rx.clone(), + }); + + Ok(Self { server: mcp_server }) + } + + pub fn server_config(&self) -> Result { + #[cfg(not(test))] + let zed_path = std::env::current_exe() + .context("finding current executable path for use in mcp_server")?; + + #[cfg(test)] + let zed_path = crate::e2e_tests::get_zed_path(); + + Ok(McpServerConfig { + command: zed_path, + args: vec![ + "--nc".into(), + self.server.socket_path().display().to_string(), + ], + env: None, + }) + } + + fn handle_initialize(_: InitializeParams, cx: &App) -> Task> { + cx.foreground_executor().spawn(async move { + Ok(InitializeResponse { + protocol_version: ProtocolVersion("2025-06-18".into()), + capabilities: ServerCapabilities { + experimental: None, + logging: None, + completions: None, + prompts: None, + resources: None, + tools: Some(ToolsCapabilities { + list_changed: Some(false), + }), + }, + server_info: Implementation { + name: SERVER_NAME.into(), + version: "0.1.0".into(), + }, + meta: None, + }) + }) + } +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +pub struct McpConfig { + pub mcp_servers: HashMap, +} + +#[derive(Serialize, Clone)] +#[serde(rename_all = "camelCase")] +pub struct McpServerConfig { + pub command: PathBuf, + pub args: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub env: Option>, +} + +// Tools + +#[derive(Clone)] +pub struct PermissionTool { + thread_rx: watch::Receiver>, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct PermissionToolParams { + tool_name: String, + input: serde_json::Value, + tool_use_id: Option, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionToolResponse { + behavior: PermissionToolBehavior, + updated_input: serde_json::Value, +} + +#[derive(Serialize)] +#[serde(rename_all = "snake_case")] +enum PermissionToolBehavior { + Allow, + Deny, +} + +impl McpServerTool for PermissionTool { + type Input = PermissionToolParams; + type Output = (); + + const NAME: &'static str = "Confirmation"; + + fn description(&self) -> &'static str { + "Request permission for tool calls" + } + + async fn run( + &self, + input: Self::Input, + cx: &mut AsyncApp, + ) -> Result> { + let mut thread_rx = self.thread_rx.clone(); + let Some(thread) = thread_rx.recv().await?.upgrade() else { + anyhow::bail!("Thread closed"); + }; + + let claude_tool = ClaudeTool::infer(&input.tool_name, input.input.clone()); + let tool_call_id = acp::ToolCallId(input.tool_use_id.context("Tool ID required")?.into()); + let allow_option_id = acp::PermissionOptionId("allow".into()); + let reject_option_id = acp::PermissionOptionId("reject".into()); + + let chosen_option = thread + .update(cx, |thread, cx| { + thread.request_tool_call_authorization( + claude_tool.as_acp(tool_call_id), + vec![ + acp::PermissionOption { + id: allow_option_id.clone(), + name: "Allow".into(), + kind: acp::PermissionOptionKind::AllowOnce, + }, + acp::PermissionOption { + id: reject_option_id.clone(), + name: "Reject".into(), + kind: acp::PermissionOptionKind::RejectOnce, + }, + ], + cx, + ) + })? + .await?; + + let response = if chosen_option == allow_option_id { + PermissionToolResponse { + behavior: PermissionToolBehavior::Allow, + updated_input: input.input, + } + } else { + debug_assert_eq!(chosen_option, reject_option_id); + PermissionToolResponse { + behavior: PermissionToolBehavior::Deny, + updated_input: input.input, + } + }; + + Ok(ToolResponse { + content: vec![ToolResponseContent::Text { + text: serde_json::to_string(&response)?, + }], + structured_content: (), + }) + } +} + +#[derive(Clone)] +pub struct ReadTool { + thread_rx: watch::Receiver>, +} + +impl McpServerTool for ReadTool { + type Input = ReadToolParams; + type Output = (); + + const NAME: &'static str = "Read"; + + fn description(&self) -> &'static str { + "Read the contents of a file. In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents." + } + + fn annotations(&self) -> ToolAnnotations { + ToolAnnotations { + title: Some("Read file".to_string()), + read_only_hint: Some(true), + destructive_hint: Some(false), + open_world_hint: Some(false), + idempotent_hint: None, + } + } + + async fn run( + &self, + input: Self::Input, + cx: &mut AsyncApp, + ) -> Result> { + let mut thread_rx = self.thread_rx.clone(); + let Some(thread) = thread_rx.recv().await?.upgrade() else { + anyhow::bail!("Thread closed"); + }; + + let content = thread + .update(cx, |thread, cx| { + thread.read_text_file(input.abs_path, input.offset, input.limit, false, cx) + })? + .await?; + + Ok(ToolResponse { + content: vec![ToolResponseContent::Text { text: content }], + structured_content: (), + }) + } +} + +#[derive(Clone)] +pub struct EditTool { + thread_rx: watch::Receiver>, +} + +impl McpServerTool for EditTool { + type Input = EditToolParams; + type Output = (); + + const NAME: &'static str = "Edit"; + + fn description(&self) -> &'static str { + "Edits a file. In sessions with mcp__zed__Edit always use it instead of Edit as it will show the diff to the user better." + } + + fn annotations(&self) -> ToolAnnotations { + ToolAnnotations { + title: Some("Edit file".to_string()), + read_only_hint: Some(false), + destructive_hint: Some(false), + open_world_hint: Some(false), + idempotent_hint: Some(false), + } + } + + async fn run( + &self, + input: Self::Input, + cx: &mut AsyncApp, + ) -> Result> { + let mut thread_rx = self.thread_rx.clone(); + let Some(thread) = thread_rx.recv().await?.upgrade() else { + anyhow::bail!("Thread closed"); + }; + + let content = thread + .update(cx, |thread, cx| { + thread.read_text_file(input.abs_path.clone(), None, None, true, cx) + })? + .await?; + + let new_content = content.replace(&input.old_text, &input.new_text); + if new_content == content { + return Err(anyhow::anyhow!("The old_text was not found in the content")); + } + + thread + .update(cx, |thread, cx| { + thread.write_text_file(input.abs_path, new_content, cx) + })? + .await?; + + Ok(ToolResponse { + content: vec![], + structured_content: (), + }) + } +} diff --git a/crates/agent_servers/src/claude/tools.rs b/crates/agent_servers/src/claude/tools.rs new file mode 100644 index 0000000000000000000000000000000000000000..7ca150c0bd0b30b958a4791db9d01684d16460d6 --- /dev/null +++ b/crates/agent_servers/src/claude/tools.rs @@ -0,0 +1,661 @@ +use std::path::PathBuf; + +use agent_client_protocol as acp; +use itertools::Itertools; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use util::ResultExt; + +pub enum ClaudeTool { + Task(Option), + NotebookRead(Option), + NotebookEdit(Option), + Edit(Option), + MultiEdit(Option), + ReadFile(Option), + Write(Option), + Ls(Option), + Glob(Option), + Grep(Option), + Terminal(Option), + WebFetch(Option), + WebSearch(Option), + TodoWrite(Option), + ExitPlanMode(Option), + Other { + name: String, + input: serde_json::Value, + }, +} + +impl ClaudeTool { + pub fn infer(tool_name: &str, input: serde_json::Value) -> Self { + match tool_name { + // Known tools + "mcp__zed__Read" => Self::ReadFile(serde_json::from_value(input).log_err()), + "mcp__zed__Edit" => Self::Edit(serde_json::from_value(input).log_err()), + "MultiEdit" => Self::MultiEdit(serde_json::from_value(input).log_err()), + "Write" => Self::Write(serde_json::from_value(input).log_err()), + "LS" => Self::Ls(serde_json::from_value(input).log_err()), + "Glob" => Self::Glob(serde_json::from_value(input).log_err()), + "Grep" => Self::Grep(serde_json::from_value(input).log_err()), + "Bash" => Self::Terminal(serde_json::from_value(input).log_err()), + "WebFetch" => Self::WebFetch(serde_json::from_value(input).log_err()), + "WebSearch" => Self::WebSearch(serde_json::from_value(input).log_err()), + "TodoWrite" => Self::TodoWrite(serde_json::from_value(input).log_err()), + "exit_plan_mode" => Self::ExitPlanMode(serde_json::from_value(input).log_err()), + "Task" => Self::Task(serde_json::from_value(input).log_err()), + "NotebookRead" => Self::NotebookRead(serde_json::from_value(input).log_err()), + "NotebookEdit" => Self::NotebookEdit(serde_json::from_value(input).log_err()), + // Inferred from name + _ => { + let tool_name = tool_name.to_lowercase(); + + if tool_name.contains("edit") || tool_name.contains("write") { + Self::Edit(None) + } else if tool_name.contains("terminal") { + Self::Terminal(None) + } else { + Self::Other { + name: tool_name.to_string(), + input, + } + } + } + } + } + + pub fn label(&self) -> String { + match &self { + Self::Task(Some(params)) => params.description.clone(), + Self::Task(None) => "Task".into(), + Self::NotebookRead(Some(params)) => { + format!("Read Notebook {}", params.notebook_path.display()) + } + Self::NotebookRead(None) => "Read Notebook".into(), + Self::NotebookEdit(Some(params)) => { + format!("Edit Notebook {}", params.notebook_path.display()) + } + Self::NotebookEdit(None) => "Edit Notebook".into(), + Self::Terminal(Some(params)) => format!("`{}`", params.command), + Self::Terminal(None) => "Terminal".into(), + Self::ReadFile(_) => "Read File".into(), + Self::Ls(Some(params)) => { + format!("List Directory {}", params.path.display()) + } + Self::Ls(None) => "List Directory".into(), + Self::Edit(Some(params)) => { + format!("Edit {}", params.abs_path.display()) + } + Self::Edit(None) => "Edit".into(), + Self::MultiEdit(Some(params)) => { + format!("Multi Edit {}", params.file_path.display()) + } + Self::MultiEdit(None) => "Multi Edit".into(), + Self::Write(Some(params)) => { + format!("Write {}", params.file_path.display()) + } + Self::Write(None) => "Write".into(), + Self::Glob(Some(params)) => { + format!("Glob `{params}`") + } + Self::Glob(None) => "Glob".into(), + Self::Grep(Some(params)) => format!("`{params}`"), + Self::Grep(None) => "Grep".into(), + Self::WebFetch(Some(params)) => format!("Fetch {}", params.url), + Self::WebFetch(None) => "Fetch".into(), + Self::WebSearch(Some(params)) => format!("Web Search: {}", params), + Self::WebSearch(None) => "Web Search".into(), + Self::TodoWrite(Some(params)) => format!( + "Update TODOs: {}", + params.todos.iter().map(|todo| &todo.content).join(", ") + ), + Self::TodoWrite(None) => "Update TODOs".into(), + Self::ExitPlanMode(_) => "Exit Plan Mode".into(), + Self::Other { name, .. } => name.clone(), + } + } + pub fn content(&self) -> Vec { + match &self { + Self::Other { input, .. } => vec![ + format!( + "```json\n{}```", + serde_json::to_string_pretty(&input).unwrap_or("{}".to_string()) + ) + .into(), + ], + Self::Task(Some(params)) => vec![params.prompt.clone().into()], + Self::NotebookRead(Some(params)) => { + vec![params.notebook_path.display().to_string().into()] + } + Self::NotebookEdit(Some(params)) => vec![params.new_source.clone().into()], + Self::Terminal(Some(params)) => vec![ + format!( + "`{}`\n\n{}", + params.command, + params.description.as_deref().unwrap_or_default() + ) + .into(), + ], + Self::ReadFile(Some(params)) => vec![params.abs_path.display().to_string().into()], + Self::Ls(Some(params)) => vec![params.path.display().to_string().into()], + Self::Glob(Some(params)) => vec![params.to_string().into()], + Self::Grep(Some(params)) => vec![format!("`{params}`").into()], + Self::WebFetch(Some(params)) => vec![params.prompt.clone().into()], + Self::WebSearch(Some(params)) => vec![params.to_string().into()], + Self::ExitPlanMode(Some(params)) => vec![params.plan.clone().into()], + Self::Edit(Some(params)) => vec![acp::ToolCallContent::Diff { + diff: acp::Diff { + path: params.abs_path.clone(), + old_text: Some(params.old_text.clone()), + new_text: params.new_text.clone(), + }, + }], + Self::Write(Some(params)) => vec![acp::ToolCallContent::Diff { + diff: acp::Diff { + path: params.file_path.clone(), + old_text: None, + new_text: params.content.clone(), + }, + }], + Self::MultiEdit(Some(params)) => { + // todo: show multiple edits in a multibuffer? + params + .edits + .first() + .map(|edit| { + vec![acp::ToolCallContent::Diff { + diff: acp::Diff { + path: params.file_path.clone(), + old_text: Some(edit.old_string.clone()), + new_text: edit.new_string.clone(), + }, + }] + }) + .unwrap_or_default() + } + Self::TodoWrite(Some(_)) => { + // These are mapped to plan updates later + vec![] + } + Self::Task(None) + | Self::NotebookRead(None) + | Self::NotebookEdit(None) + | Self::Terminal(None) + | Self::ReadFile(None) + | Self::Ls(None) + | Self::Glob(None) + | Self::Grep(None) + | Self::WebFetch(None) + | Self::WebSearch(None) + | Self::TodoWrite(None) + | Self::ExitPlanMode(None) + | Self::Edit(None) + | Self::Write(None) + | Self::MultiEdit(None) => vec![], + } + } + + pub fn kind(&self) -> acp::ToolKind { + match self { + Self::Task(_) => acp::ToolKind::Think, + Self::NotebookRead(_) => acp::ToolKind::Read, + Self::NotebookEdit(_) => acp::ToolKind::Edit, + Self::Edit(_) => acp::ToolKind::Edit, + Self::MultiEdit(_) => acp::ToolKind::Edit, + Self::Write(_) => acp::ToolKind::Edit, + Self::ReadFile(_) => acp::ToolKind::Read, + Self::Ls(_) => acp::ToolKind::Search, + Self::Glob(_) => acp::ToolKind::Search, + Self::Grep(_) => acp::ToolKind::Search, + Self::Terminal(_) => acp::ToolKind::Execute, + Self::WebSearch(_) => acp::ToolKind::Search, + Self::WebFetch(_) => acp::ToolKind::Fetch, + Self::TodoWrite(_) => acp::ToolKind::Think, + Self::ExitPlanMode(_) => acp::ToolKind::Think, + Self::Other { .. } => acp::ToolKind::Other, + } + } + + pub fn locations(&self) -> Vec { + match &self { + Self::Edit(Some(EditToolParams { abs_path, .. })) => vec![acp::ToolCallLocation { + path: abs_path.clone(), + line: None, + }], + Self::MultiEdit(Some(MultiEditToolParams { file_path, .. })) => { + vec![acp::ToolCallLocation { + path: file_path.clone(), + line: None, + }] + } + Self::Write(Some(WriteToolParams { file_path, .. })) => { + vec![acp::ToolCallLocation { + path: file_path.clone(), + line: None, + }] + } + Self::ReadFile(Some(ReadToolParams { + abs_path, offset, .. + })) => vec![acp::ToolCallLocation { + path: abs_path.clone(), + line: *offset, + }], + Self::NotebookRead(Some(NotebookReadToolParams { notebook_path, .. })) => { + vec![acp::ToolCallLocation { + path: notebook_path.clone(), + line: None, + }] + } + Self::NotebookEdit(Some(NotebookEditToolParams { notebook_path, .. })) => { + vec![acp::ToolCallLocation { + path: notebook_path.clone(), + line: None, + }] + } + Self::Glob(Some(GlobToolParams { + path: Some(path), .. + })) => vec![acp::ToolCallLocation { + path: path.clone(), + line: None, + }], + Self::Ls(Some(LsToolParams { path, .. })) => vec![acp::ToolCallLocation { + path: path.clone(), + line: None, + }], + Self::Grep(Some(GrepToolParams { + path: Some(path), .. + })) => vec![acp::ToolCallLocation { + path: PathBuf::from(path), + line: None, + }], + Self::Task(_) + | Self::NotebookRead(None) + | Self::NotebookEdit(None) + | Self::Edit(None) + | Self::MultiEdit(None) + | Self::Write(None) + | Self::ReadFile(None) + | Self::Ls(None) + | Self::Glob(_) + | Self::Grep(_) + | Self::Terminal(_) + | Self::WebFetch(_) + | Self::WebSearch(_) + | Self::TodoWrite(_) + | Self::ExitPlanMode(_) + | Self::Other { .. } => vec![], + } + } + + pub fn as_acp(&self, id: acp::ToolCallId) -> acp::ToolCall { + acp::ToolCall { + id, + kind: self.kind(), + status: acp::ToolCallStatus::InProgress, + title: self.label(), + content: self.content(), + locations: self.locations(), + raw_input: None, + raw_output: None, + } + } +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct EditToolParams { + /// The absolute path to the file to read. + pub abs_path: PathBuf, + /// The old text to replace (must be unique in the file) + pub old_text: String, + /// The new text. + pub new_text: String, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct ReadToolParams { + /// The absolute path to the file to read. + pub abs_path: PathBuf, + /// Which line to start reading from. Omit to start from the beginning. + #[serde(skip_serializing_if = "Option::is_none")] + pub offset: Option, + /// How many lines to read. Omit for the whole file. + #[serde(skip_serializing_if = "Option::is_none")] + pub limit: Option, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct WriteToolParams { + /// Absolute path for new file + pub file_path: PathBuf, + /// File content + pub content: String, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct BashToolParams { + /// Shell command to execute + pub command: String, + /// 5-10 word description of what command does + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// Timeout in ms (max 600000ms/10min, default 120000ms) + #[serde(skip_serializing_if = "Option::is_none")] + pub timeout: Option, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct GlobToolParams { + /// Glob pattern like **/*.js or src/**/*.ts + pub pattern: String, + /// Directory to search in (omit for current directory) + #[serde(skip_serializing_if = "Option::is_none")] + pub path: Option, +} + +impl std::fmt::Display for GlobToolParams { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(path) = &self.path { + write!(f, "{}", path.display())?; + } + write!(f, "{}", self.pattern) + } +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct LsToolParams { + /// Absolute path to directory + pub path: PathBuf, + /// Array of glob patterns to ignore + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub ignore: Vec, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct GrepToolParams { + /// Regex pattern to search for + pub pattern: String, + /// File/directory to search (defaults to current directory) + #[serde(skip_serializing_if = "Option::is_none")] + pub path: Option, + /// "content" (shows lines), "files_with_matches" (default), "count" + #[serde(skip_serializing_if = "Option::is_none")] + pub output_mode: Option, + /// Filter files with glob pattern like "*.js" + #[serde(skip_serializing_if = "Option::is_none")] + pub glob: Option, + /// File type filter like "js", "py", "rust" + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + pub file_type: Option, + /// Case insensitive search + #[serde(rename = "-i", default, skip_serializing_if = "is_false")] + pub case_insensitive: bool, + /// Show line numbers (content mode only) + #[serde(rename = "-n", default, skip_serializing_if = "is_false")] + pub line_numbers: bool, + /// Lines after match (content mode only) + #[serde(rename = "-A", skip_serializing_if = "Option::is_none")] + pub after_context: Option, + /// Lines before match (content mode only) + #[serde(rename = "-B", skip_serializing_if = "Option::is_none")] + pub before_context: Option, + /// Lines before and after match (content mode only) + #[serde(rename = "-C", skip_serializing_if = "Option::is_none")] + pub context: Option, + /// Enable multiline/cross-line matching + #[serde(default, skip_serializing_if = "is_false")] + pub multiline: bool, + /// Limit output to first N results + #[serde(skip_serializing_if = "Option::is_none")] + pub head_limit: Option, +} + +impl std::fmt::Display for GrepToolParams { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "grep")?; + + // Boolean flags + if self.case_insensitive { + write!(f, " -i")?; + } + if self.line_numbers { + write!(f, " -n")?; + } + + // Context options + if let Some(after) = self.after_context { + write!(f, " -A {}", after)?; + } + if let Some(before) = self.before_context { + write!(f, " -B {}", before)?; + } + if let Some(context) = self.context { + write!(f, " -C {}", context)?; + } + + // Output mode + if let Some(mode) = &self.output_mode { + match mode { + GrepOutputMode::FilesWithMatches => write!(f, " -l")?, + GrepOutputMode::Count => write!(f, " -c")?, + GrepOutputMode::Content => {} // Default mode + } + } + + // Head limit + if let Some(limit) = self.head_limit { + write!(f, " | head -{}", limit)?; + } + + // Glob pattern + if let Some(glob) = &self.glob { + write!(f, " --include=\"{}\"", glob)?; + } + + // File type + if let Some(file_type) = &self.file_type { + write!(f, " --type={}", file_type)?; + } + + // Multiline + if self.multiline { + write!(f, " -P")?; // Perl-compatible regex for multiline + } + + // Pattern (escaped if contains special characters) + write!(f, " \"{}\"", self.pattern)?; + + // Path + if let Some(path) = &self.path { + write!(f, " {}", path)?; + } + + Ok(()) + } +} + +#[derive(Default, Deserialize, Serialize, JsonSchema, strum::Display, Debug)] +#[serde(rename_all = "snake_case")] +pub enum TodoPriority { + High, + #[default] + Medium, + Low, +} + +impl Into for TodoPriority { + fn into(self) -> acp::PlanEntryPriority { + match self { + TodoPriority::High => acp::PlanEntryPriority::High, + TodoPriority::Medium => acp::PlanEntryPriority::Medium, + TodoPriority::Low => acp::PlanEntryPriority::Low, + } + } +} + +#[derive(Deserialize, Serialize, JsonSchema, Debug)] +#[serde(rename_all = "snake_case")] +pub enum TodoStatus { + Pending, + InProgress, + Completed, +} + +impl Into for TodoStatus { + fn into(self) -> acp::PlanEntryStatus { + match self { + TodoStatus::Pending => acp::PlanEntryStatus::Pending, + TodoStatus::InProgress => acp::PlanEntryStatus::InProgress, + TodoStatus::Completed => acp::PlanEntryStatus::Completed, + } + } +} + +#[derive(Deserialize, Serialize, JsonSchema, Debug)] +pub struct Todo { + /// Task description + pub content: String, + /// Current status of the todo + pub status: TodoStatus, + /// Priority level of the todo + #[serde(default)] + pub priority: TodoPriority, +} + +impl Into for Todo { + fn into(self) -> acp::PlanEntry { + acp::PlanEntry { + content: self.content, + priority: self.priority.into(), + status: self.status.into(), + } + } +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct TodoWriteToolParams { + pub todos: Vec, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct ExitPlanModeToolParams { + /// Implementation plan in markdown format + pub plan: String, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct TaskToolParams { + /// Short 3-5 word description of task + pub description: String, + /// Detailed task for agent to perform + pub prompt: String, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct NotebookReadToolParams { + /// Absolute path to .ipynb file + pub notebook_path: PathBuf, + /// Specific cell ID to read + #[serde(skip_serializing_if = "Option::is_none")] + pub cell_id: Option, +} + +#[derive(Deserialize, Serialize, JsonSchema, Debug)] +#[serde(rename_all = "snake_case")] +pub enum CellType { + Code, + Markdown, +} + +#[derive(Deserialize, Serialize, JsonSchema, Debug)] +#[serde(rename_all = "snake_case")] +pub enum EditMode { + Replace, + Insert, + Delete, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct NotebookEditToolParams { + /// Absolute path to .ipynb file + pub notebook_path: PathBuf, + /// New cell content + pub new_source: String, + /// Cell ID to edit + #[serde(skip_serializing_if = "Option::is_none")] + pub cell_id: Option, + /// Type of cell (code or markdown) + #[serde(skip_serializing_if = "Option::is_none")] + pub cell_type: Option, + /// Edit operation mode + #[serde(skip_serializing_if = "Option::is_none")] + pub edit_mode: Option, +} + +#[derive(Deserialize, Serialize, JsonSchema, Debug)] +pub struct MultiEditItem { + /// The text to search for and replace + pub old_string: String, + /// The replacement text + pub new_string: String, + /// Whether to replace all occurrences or just the first + #[serde(default, skip_serializing_if = "is_false")] + pub replace_all: bool, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct MultiEditToolParams { + /// Absolute path to file + pub file_path: PathBuf, + /// List of edits to apply + pub edits: Vec, +} + +fn is_false(v: &bool) -> bool { + !*v +} + +#[derive(Deserialize, JsonSchema, Debug)] +#[serde(rename_all = "snake_case")] +pub enum GrepOutputMode { + Content, + FilesWithMatches, + Count, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct WebFetchToolParams { + /// Valid URL to fetch + #[serde(rename = "url")] + pub url: String, + /// What to extract from content + pub prompt: String, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct WebSearchToolParams { + /// Search query (min 2 chars) + pub query: String, + /// Only include these domains + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub allowed_domains: Vec, + /// Exclude these domains + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub blocked_domains: Vec, +} + +impl std::fmt::Display for WebSearchToolParams { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "\"{}\"", self.query)?; + + if !self.allowed_domains.is_empty() { + write!(f, " (allowed: {})", self.allowed_domains.join(", "))?; + } + + if !self.blocked_domains.is_empty() { + write!(f, " (blocked: {})", self.blocked_domains.join(", "))?; + } + + Ok(()) + } +} diff --git a/crates/agent_servers/src/e2e_tests.rs b/crates/agent_servers/src/e2e_tests.rs new file mode 100644 index 0000000000000000000000000000000000000000..ec6ca29b9dd1a902708a8786ddc6853955da5532 --- /dev/null +++ b/crates/agent_servers/src/e2e_tests.rs @@ -0,0 +1,482 @@ +use std::{ + path::{Path, PathBuf}, + sync::Arc, + time::Duration, +}; + +use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings}; +use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus}; +use agent_client_protocol as acp; + +use futures::{FutureExt, StreamExt, channel::mpsc, select}; +use gpui::{Entity, TestAppContext}; +use indoc::indoc; +use project::{FakeFs, Project}; +use settings::{Settings, SettingsStore}; +use util::path; + +pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppContext) { + let fs = init_test(cx).await; + let project = Project::test(fs, [], cx).await; + let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; + + thread + .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx)) + .await + .unwrap(); + + thread.read_with(cx, |thread, _| { + assert!( + thread.entries().len() >= 2, + "Expected at least 2 entries. Got: {:?}", + thread.entries() + ); + assert!(matches!( + thread.entries()[0], + AgentThreadEntry::UserMessage(_) + )); + assert!(matches!( + thread.entries()[1], + AgentThreadEntry::AssistantMessage(_) + )); + }); +} + +pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut TestAppContext) { + let _fs = init_test(cx).await; + + let tempdir = tempfile::tempdir().unwrap(); + std::fs::write( + tempdir.path().join("foo.rs"), + indoc! {" + fn main() { + println!(\"Hello, world!\"); + } + "}, + ) + .expect("failed to write file"); + let project = Project::example([tempdir.path()], &mut cx.to_async()).await; + let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await; + thread + .update(cx, |thread, cx| { + thread.send( + vec![ + acp::ContentBlock::Text(acp::TextContent { + text: "Read the file ".into(), + annotations: None, + }), + acp::ContentBlock::ResourceLink(acp::ResourceLink { + uri: "foo.rs".into(), + name: "foo.rs".into(), + annotations: None, + description: None, + mime_type: None, + size: None, + title: None, + }), + acp::ContentBlock::Text(acp::TextContent { + text: " and tell me what the content of the println! is".into(), + annotations: None, + }), + ], + cx, + ) + }) + .await + .unwrap(); + + thread.read_with(cx, |thread, cx| { + assert!(matches!( + thread.entries()[0], + AgentThreadEntry::UserMessage(_) + )); + let assistant_message = &thread + .entries() + .iter() + .rev() + .find_map(|entry| match entry { + AgentThreadEntry::AssistantMessage(msg) => Some(msg), + _ => None, + }) + .unwrap(); + + assert!( + assistant_message.to_markdown(cx).contains("Hello, world!"), + "unexpected assistant message: {:?}", + assistant_message.to_markdown(cx) + ); + }); + + drop(tempdir); +} + +pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) { + let _fs = init_test(cx).await; + + let tempdir = tempfile::tempdir().unwrap(); + let foo_path = tempdir.path().join("foo"); + std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file"); + + let project = Project::example([tempdir.path()], &mut cx.to_async()).await; + let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; + + thread + .update(cx, |thread, cx| { + thread.send_raw( + &format!("Read {} and tell me what you see.", foo_path.display()), + cx, + ) + }) + .await + .unwrap(); + thread.read_with(cx, |thread, _cx| { + assert!(thread.entries().iter().any(|entry| { + matches!( + entry, + AgentThreadEntry::ToolCall(ToolCall { + status: ToolCallStatus::Allowed { .. }, + .. + }) + ) + })); + assert!( + thread + .entries() + .iter() + .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) }) + ); + }); + + drop(tempdir); +} + +pub async fn test_tool_call_with_permission( + server: impl AgentServer + 'static, + allow_option_id: acp::PermissionOptionId, + cx: &mut TestAppContext, +) { + let fs = init_test(cx).await; + let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; + let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; + let full_turn = thread.update(cx, |thread, cx| { + thread.send_raw( + r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#, + cx, + ) + }); + + run_until_first_tool_call( + &thread, + |entry| { + matches!( + entry, + AgentThreadEntry::ToolCall(ToolCall { + status: ToolCallStatus::WaitingForConfirmation { .. }, + .. + }) + ) + }, + cx, + ) + .await; + + let tool_call_id = thread.read_with(cx, |thread, cx| { + let AgentThreadEntry::ToolCall(ToolCall { + id, + label, + status: ToolCallStatus::WaitingForConfirmation { .. }, + .. + }) = &thread + .entries() + .iter() + .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_))) + .unwrap() + else { + panic!(); + }; + + let label = label.read(cx).source(); + assert!(label.contains("touch"), "Got: {}", label); + + id.clone() + }); + + thread.update(cx, |thread, cx| { + thread.authorize_tool_call( + tool_call_id, + allow_option_id, + acp::PermissionOptionKind::AllowOnce, + cx, + ); + + assert!(thread.entries().iter().any(|entry| matches!( + entry, + AgentThreadEntry::ToolCall(ToolCall { + status: ToolCallStatus::Allowed { .. }, + .. + }) + ))); + }); + + full_turn.await.unwrap(); + + thread.read_with(cx, |thread, cx| { + let AgentThreadEntry::ToolCall(ToolCall { + content, + status: ToolCallStatus::Allowed { .. }, + .. + }) = thread + .entries() + .iter() + .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_))) + .unwrap() + else { + panic!(); + }; + + assert!( + content.iter().any(|c| c.to_markdown(cx).contains("Hello")), + "Expected content to contain 'Hello'" + ); + }); +} + +pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppContext) { + let fs = init_test(cx).await; + + let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; + let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; + let _ = thread.update(cx, |thread, cx| { + thread.send_raw( + r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#, + cx, + ) + }); + + let first_tool_call_ix = run_until_first_tool_call( + &thread, + |entry| { + matches!( + entry, + AgentThreadEntry::ToolCall(ToolCall { + status: ToolCallStatus::WaitingForConfirmation { .. }, + .. + }) + ) + }, + cx, + ) + .await; + + thread.read_with(cx, |thread, cx| { + let AgentThreadEntry::ToolCall(ToolCall { + id, + label, + status: ToolCallStatus::WaitingForConfirmation { .. }, + .. + }) = &thread.entries()[first_tool_call_ix] + else { + panic!("{:?}", thread.entries()[1]); + }; + + let label = label.read(cx).source(); + assert!(label.contains("touch"), "Got: {}", label); + + id.clone() + }); + + thread.update(cx, |thread, cx| thread.cancel(cx)).await; + thread.read_with(cx, |thread, _cx| { + let AgentThreadEntry::ToolCall(ToolCall { + status: ToolCallStatus::Canceled, + .. + }) = &thread.entries()[first_tool_call_ix] + else { + panic!(); + }; + }); + + thread + .update(cx, |thread, cx| { + thread.send_raw(r#"Stop running and say goodbye to me."#, cx) + }) + .await + .unwrap(); + thread.read_with(cx, |thread, _| { + assert!(matches!( + &thread.entries().last().unwrap(), + AgentThreadEntry::AssistantMessage(..), + )) + }); +} + +pub async fn test_thread_drop(server: impl AgentServer + 'static, cx: &mut TestAppContext) { + let fs = init_test(cx).await; + let project = Project::test(fs, [], cx).await; + let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; + + thread + .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx)) + .await + .unwrap(); + + thread.read_with(cx, |thread, _| { + assert!(thread.entries().len() >= 2, "Expected at least 2 entries"); + }); + + let weak_thread = thread.downgrade(); + drop(thread); + + cx.executor().run_until_parked(); + assert!(!weak_thread.is_upgradable()); +} + +#[macro_export] +macro_rules! common_e2e_tests { + ($server:expr, allow_option_id = $allow_option_id:expr) => { + mod common_e2e { + use super::*; + + #[::gpui::test] + #[cfg_attr(not(feature = "e2e"), ignore)] + async fn basic(cx: &mut ::gpui::TestAppContext) { + $crate::e2e_tests::test_basic($server, cx).await; + } + + #[::gpui::test] + #[cfg_attr(not(feature = "e2e"), ignore)] + async fn path_mentions(cx: &mut ::gpui::TestAppContext) { + $crate::e2e_tests::test_path_mentions($server, cx).await; + } + + #[::gpui::test] + #[cfg_attr(not(feature = "e2e"), ignore)] + async fn tool_call(cx: &mut ::gpui::TestAppContext) { + $crate::e2e_tests::test_tool_call($server, cx).await; + } + + #[::gpui::test] + #[cfg_attr(not(feature = "e2e"), ignore)] + async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) { + $crate::e2e_tests::test_tool_call_with_permission( + $server, + ::agent_client_protocol::PermissionOptionId($allow_option_id.into()), + cx, + ) + .await; + } + + #[::gpui::test] + #[cfg_attr(not(feature = "e2e"), ignore)] + async fn cancel(cx: &mut ::gpui::TestAppContext) { + $crate::e2e_tests::test_cancel($server, cx).await; + } + + #[::gpui::test] + #[cfg_attr(not(feature = "e2e"), ignore)] + async fn thread_drop(cx: &mut ::gpui::TestAppContext) { + $crate::e2e_tests::test_thread_drop($server, cx).await; + } + } + }; +} + +// Helpers + +pub async fn init_test(cx: &mut TestAppContext) -> Arc { + env_logger::try_init().ok(); + + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + Project::init_settings(cx); + language::init(cx); + crate::settings::init(cx); + + crate::AllAgentServersSettings::override_global( + AllAgentServersSettings { + claude: Some(AgentServerSettings { + command: crate::claude::tests::local_command(), + }), + gemini: Some(AgentServerSettings { + command: crate::gemini::tests::local_command(), + }), + }, + cx, + ); + }); + + cx.executor().allow_parking(); + + FakeFs::new(cx.executor()) +} + +pub async fn new_test_thread( + server: impl AgentServer + 'static, + project: Entity, + current_dir: impl AsRef, + cx: &mut TestAppContext, +) -> Entity { + let connection = cx + .update(|cx| server.connect(current_dir.as_ref(), &project, cx)) + .await + .unwrap(); + + let thread = connection + .new_thread(project.clone(), current_dir.as_ref(), &mut cx.to_async()) + .await + .unwrap(); + + thread +} + +pub async fn run_until_first_tool_call( + thread: &Entity, + wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static, + cx: &mut TestAppContext, +) -> usize { + let (mut tx, mut rx) = mpsc::channel::(1); + + let subscription = cx.update(|cx| { + cx.subscribe(thread, move |thread, _, cx| { + for (ix, entry) in thread.read(cx).entries().iter().enumerate() { + if wait_until(entry) { + return tx.try_send(ix).unwrap(); + } + } + }) + }); + + select! { + // We have to use a smol timer here because + // cx.background_executor().timer isn't real in the test context + _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => { + panic!("Timeout waiting for tool call") + } + ix = rx.next().fuse() => { + drop(subscription); + ix.unwrap() + } + } +} + +pub fn get_zed_path() -> PathBuf { + let mut zed_path = std::env::current_exe().unwrap(); + + while zed_path + .file_name() + .map_or(true, |name| name.to_string_lossy() != "debug") + { + if !zed_path.pop() { + panic!("Could not find target directory"); + } + } + + zed_path.push("zed"); + + if !zed_path.exists() { + panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n"); + } + + zed_path +} diff --git a/crates/agent_servers/src/gemini.rs b/crates/agent_servers/src/gemini.rs new file mode 100644 index 0000000000000000000000000000000000000000..ad883f6da8bd344044e1db0051ca6f24120d5057 --- /dev/null +++ b/crates/agent_servers/src/gemini.rs @@ -0,0 +1,111 @@ +use std::path::Path; +use std::rc::Rc; + +use crate::{AgentServer, AgentServerCommand}; +use acp_thread::{AgentConnection, LoadError}; +use anyhow::Result; +use gpui::{Entity, Task}; +use project::Project; +use settings::SettingsStore; +use ui::App; + +use crate::AllAgentServersSettings; + +#[derive(Clone)] +pub struct Gemini; + +const ACP_ARG: &str = "--experimental-acp"; + +impl AgentServer for Gemini { + fn name(&self) -> &'static str { + "Gemini" + } + + fn empty_state_headline(&self) -> &'static str { + "Welcome to Gemini" + } + + fn empty_state_message(&self) -> &'static str { + "Ask questions, edit files, run commands.\nBe specific for the best results." + } + + fn logo(&self) -> ui::IconName { + ui::IconName::AiGemini + } + + fn connect( + &self, + root_dir: &Path, + project: &Entity, + cx: &mut App, + ) -> Task>> { + let project = project.clone(); + let root_dir = root_dir.to_path_buf(); + let server_name = self.name(); + cx.spawn(async move |cx| { + let settings = cx.read_global(|settings: &SettingsStore, _| { + settings.get::(None).gemini.clone() + })?; + + let Some(command) = + AgentServerCommand::resolve("gemini", &[ACP_ARG], None, settings, &project, cx).await + else { + anyhow::bail!("Failed to find gemini binary"); + }; + + let result = crate::acp::connect(server_name, command.clone(), &root_dir, cx).await; + if result.is_err() { + let version_fut = util::command::new_smol_command(&command.path) + .args(command.args.iter()) + .arg("--version") + .kill_on_drop(true) + .output(); + + let help_fut = util::command::new_smol_command(&command.path) + .args(command.args.iter()) + .arg("--help") + .kill_on_drop(true) + .output(); + + let (version_output, help_output) = futures::future::join(version_fut, help_fut).await; + + let current_version = String::from_utf8(version_output?.stdout)?; + let supported = String::from_utf8(help_output?.stdout)?.contains(ACP_ARG); + + if !supported { + return Err(LoadError::Unsupported { + error_message: format!( + "Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).", + current_version + ).into(), + upgrade_message: "Upgrade Gemini to Latest".into(), + upgrade_command: "npm install -g @google/gemini-cli@latest".into(), + }.into()) + } + } + result + }) + } +} + +#[cfg(test)] +pub(crate) mod tests { + use super::*; + use crate::AgentServerCommand; + use std::path::Path; + + crate::common_e2e_tests!(Gemini, allow_option_id = "proceed_once"); + + pub fn local_command() -> AgentServerCommand { + let cli_path = Path::new(env!("CARGO_MANIFEST_DIR")) + .join("../../../gemini-cli/packages/cli") + .to_string_lossy() + .to_string(); + + AgentServerCommand { + path: "node".into(), + args: vec![cli_path], + env: None, + } + } +} diff --git a/crates/agent_servers/src/settings.rs b/crates/agent_servers/src/settings.rs new file mode 100644 index 0000000000000000000000000000000000000000..645674b5f15087250c2364fb9a8a846e163ad54c --- /dev/null +++ b/crates/agent_servers/src/settings.rs @@ -0,0 +1,45 @@ +use crate::AgentServerCommand; +use anyhow::Result; +use gpui::App; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings::{Settings, SettingsSources}; + +pub fn init(cx: &mut App) { + AllAgentServersSettings::register(cx); +} + +#[derive(Default, Deserialize, Serialize, Clone, JsonSchema, Debug)] +pub struct AllAgentServersSettings { + pub gemini: Option, + pub claude: Option, +} + +#[derive(Deserialize, Serialize, Clone, JsonSchema, Debug)] +pub struct AgentServerSettings { + #[serde(flatten)] + pub command: AgentServerCommand, +} + +impl settings::Settings for AllAgentServersSettings { + const KEY: Option<&'static str> = Some("agent_servers"); + + type FileContent = Self; + + fn load(sources: SettingsSources, _: &mut App) -> Result { + let mut settings = AllAgentServersSettings::default(); + + for AllAgentServersSettings { gemini, claude } in sources.defaults_and_customizations() { + if gemini.is_some() { + settings.gemini = gemini.clone(); + } + if claude.is_some() { + settings.claude = claude.clone(); + } + } + + Ok(settings) + } + + fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {} +} diff --git a/crates/agent_settings/Cargo.toml b/crates/agent_settings/Cargo.toml index 3afe5ae54757953a43a6bdd465c095dc70c27288..d34396a5d35dd8919e519e804a93b50dfe046133 100644 --- a/crates/agent_settings/Cargo.toml +++ b/crates/agent_settings/Cargo.toml @@ -13,6 +13,7 @@ path = "src/agent_settings.rs" [dependencies] anyhow.workspace = true +cloud_llm_client.workspace = true collections.workspace = true gpui.workspace = true language_model.workspace = true @@ -20,7 +21,6 @@ schemars.workspace = true serde.workspace = true settings.workspace = true workspace-hack.workspace = true -zed_llm_client.workspace = true [dev-dependencies] fs.workspace = true diff --git a/crates/agent_settings/src/agent_settings.rs b/crates/agent_settings/src/agent_settings.rs index 131cd2dc3f3e4e8967c03cbf1e808ebdeee306cf..e6a79963d670d2bc2067826855202104bcf1621c 100644 --- a/crates/agent_settings/src/agent_settings.rs +++ b/crates/agent_settings/src/agent_settings.rs @@ -13,6 +13,9 @@ use std::borrow::Cow; pub use crate::agent_profile::*; +pub const SUMMARIZE_THREAD_PROMPT: &str = + include_str!("../../agent/src/prompts/summarize_thread_prompt.txt"); + pub fn init(cx: &mut App) { AgentSettings::register(cx); } @@ -69,6 +72,7 @@ pub struct AgentSettings { pub enable_feedback: bool, pub expand_edit_card: bool, pub expand_terminal_card: bool, + pub use_modifier_to_send: bool, } impl AgentSettings { @@ -174,6 +178,10 @@ impl AgentSettingsContent { self.single_file_review = Some(allow); } + pub fn set_use_modifier_to_send(&mut self, always_use: bool) { + self.use_modifier_to_send = Some(always_use); + } + pub fn set_profile(&mut self, profile_id: AgentProfileId) { self.default_profile = Some(profile_id); } @@ -301,6 +309,10 @@ pub struct AgentSettingsContent { /// /// Default: true expand_terminal_card: Option, + /// Whether to always use cmd-enter (or ctrl-enter on Linux) to send messages in the agent panel. + /// + /// Default: false + use_modifier_to_send: Option, } #[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Default)] @@ -312,11 +324,11 @@ pub enum CompletionMode { Burn, } -impl From for zed_llm_client::CompletionMode { +impl From for cloud_llm_client::CompletionMode { fn from(value: CompletionMode) -> Self { match value { - CompletionMode::Normal => zed_llm_client::CompletionMode::Normal, - CompletionMode::Burn => zed_llm_client::CompletionMode::Max, + CompletionMode::Normal => cloud_llm_client::CompletionMode::Normal, + CompletionMode::Burn => cloud_llm_client::CompletionMode::Max, } } } @@ -456,6 +468,10 @@ impl Settings for AgentSettings { &mut settings.expand_terminal_card, value.expand_terminal_card, ); + merge( + &mut settings.use_modifier_to_send, + value.use_modifier_to_send, + ); settings .model_parameters diff --git a/crates/agent_ui/Cargo.toml b/crates/agent_ui/Cargo.toml index 72466fe8e7a05f52ac69d79e64a1af3452df089f..c145df0eaecd5c504830e1985a3d38911b000a5e 100644 --- a/crates/agent_ui/Cargo.toml +++ b/crates/agent_ui/Cargo.toml @@ -16,11 +16,13 @@ doctest = false test-support = ["gpui/test-support", "language/test-support"] [dependencies] -acp.workspace = true +acp_thread.workspace = true +agent-client-protocol.workspace = true agent.workspace = true -agentic-coding-protocol.workspace = true -agent_settings.workspace = true +agent2.workspace = true agent_servers.workspace = true +agent_settings.workspace = true +ai_onboarding.workspace = true anyhow.workspace = true assistant_context.workspace = true assistant_slash_command.workspace = true @@ -30,7 +32,9 @@ audio.workspace = true buffer_diff.workspace = true chrono.workspace = true client.workspace = true +cloud_llm_client.workspace = true collections.workspace = true +command_palette_hooks.workspace = true component.workspace = true context_server.workspace = true db.workspace = true @@ -44,14 +48,15 @@ futures.workspace = true fuzzy.workspace = true gpui.workspace = true html_to_markdown.workspace = true -indoc.workspace = true http_client.workspace = true indexed_docs.workspace = true +indoc.workspace = true inventory.workspace = true itertools.workspace = true jsonschema.workspace = true language.workspace = true language_model.workspace = true +language_models.workspace = true log.workspace = true lsp.workspace = true markdown.workspace = true @@ -86,6 +91,7 @@ theme.workspace = true time.workspace = true time_format.workspace = true ui.workspace = true +ui_input.workspace = true urlencoding.workspace = true util.workspace = true uuid.workspace = true @@ -93,7 +99,6 @@ watch.workspace = true workspace-hack.workspace = true workspace.workspace = true zed_actions.workspace = true -zed_llm_client.workspace = true [dev-dependencies] assistant_tools.workspace = true diff --git a/crates/agent_ui/src/acp.rs b/crates/agent_ui/src/acp.rs index 23ada8d77a536391de6968ba9f9c0ee351967266..cc476b1a862b11d964f731cbb0d5ac8e9f100e59 100644 --- a/crates/agent_ui/src/acp.rs +++ b/crates/agent_ui/src/acp.rs @@ -2,4 +2,5 @@ mod completion_provider; mod message_history; mod thread_view; +pub use message_history::MessageHistory; pub use thread_view::AcpThreadView; diff --git a/crates/agent_ui/src/acp/message_history.rs b/crates/agent_ui/src/acp/message_history.rs index 6d9626627af3c4cc7a13d4347069ff7349725afb..c6106c7578230e60576cdeb48318012b52e76e46 100644 --- a/crates/agent_ui/src/acp/message_history.rs +++ b/crates/agent_ui/src/acp/message_history.rs @@ -3,19 +3,25 @@ pub struct MessageHistory { current: Option, } -impl MessageHistory { - pub fn new() -> Self { +impl Default for MessageHistory { + fn default() -> Self { MessageHistory { items: Vec::new(), current: None, } } +} +impl MessageHistory { pub fn push(&mut self, message: T) { self.current.take(); self.items.push(message); } + pub fn reset_position(&mut self) { + self.current.take(); + } + pub fn prev(&mut self) -> Option<&T> { if self.items.is_empty() { return None; @@ -39,6 +45,11 @@ impl MessageHistory { None }) } + + #[cfg(test)] + pub fn items(&self) -> &[T] { + &self.items + } } #[cfg(test)] mod tests { @@ -46,7 +57,7 @@ mod tests { #[test] fn test_prev_next() { - let mut history = MessageHistory::new(); + let mut history = MessageHistory::default(); // Test empty history assert_eq!(history.prev(), None); diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 2e3bf54837ca2edd479687437fad90283456f36c..3d1fbba45d763fcf293844a67bd9fc3ab9fbe26b 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -1,59 +1,86 @@ +use acp_thread::{AgentConnection, Plan}; +use agent_servers::AgentServer; +use agent_settings::{AgentSettings, NotifyWhenAgentWaiting}; +use audio::{Audio, Sound}; +use std::cell::RefCell; +use std::collections::BTreeMap; use std::path::Path; +use std::process::ExitStatus; use std::rc::Rc; use std::sync::Arc; use std::time::Duration; -use agentic_coding_protocol::{self as acp}; +use agent_client_protocol as acp; +use assistant_tool::ActionLog; +use buffer_diff::BufferDiff; use collections::{HashMap, HashSet}; use editor::{ AnchorRangeExt, ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorMode, - EditorStyle, MinimapVisibility, MultiBuffer, + EditorStyle, MinimapVisibility, MultiBuffer, PathKey, }; use file_icons::FileIcons; -use futures::channel::oneshot; use gpui::{ - Animation, AnimationExt, App, BorderStyle, EdgesRefinement, Empty, Entity, EntityId, Focusable, - Hsla, Length, ListOffset, ListState, SharedString, StyleRefinement, Subscription, TextStyle, - TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, Window, div, list, percentage, - prelude::*, pulsating_between, + Action, Animation, AnimationExt, App, BorderStyle, EdgesRefinement, Empty, Entity, EntityId, + FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton, PlatformDisplay, + SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement, + Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div, linear_color_stop, + linear_gradient, list, percentage, point, prelude::*, pulsating_between, }; -use gpui::{FocusHandle, Task}; use language::language_settings::SoftWrap; use language::{Buffer, Language}; use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle}; use parking_lot::Mutex; use project::Project; -use settings::Settings as _; +use settings::{Settings as _, SettingsStore}; +use text::{Anchor, BufferSnapshot}; use theme::ThemeSettings; -use ui::{Disclosure, Tooltip, prelude::*}; +use ui::{ + Disclosure, Divider, DividerColor, KeyBinding, Scrollbar, ScrollbarState, Tooltip, prelude::*, +}; use util::ResultExt; -use workspace::Workspace; +use workspace::{CollaboratorId, Workspace}; use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage}; -use ::acp::{ +use ::acp_thread::{ AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, Diff, - LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallConfirmation, ToolCallContent, - ToolCallId, ToolCallStatus, + LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus, }; use crate::acp::completion_provider::{ContextPickerCompletionProvider, MentionSet}; use crate::acp::message_history::MessageHistory; +use crate::agent_diff::AgentDiff; +use crate::message_editor::{MAX_EDITOR_LINES, MIN_EDITOR_LINES}; +use crate::ui::{AgentNotification, AgentNotificationEvent}; +use crate::{ + AgentDiffPane, AgentPanel, ExpandMessageEditor, Follow, KeepAll, OpenAgentDiff, RejectAll, +}; const RESPONSE_PADDING_X: Pixels = px(19.); pub struct AcpThreadView { + agent: Rc, workspace: WeakEntity, project: Entity, thread_state: ThreadState, diff_editors: HashMap>, message_editor: Entity, + message_set_from_history: Option, + _message_editor_subscription: Subscription, mention_set: Arc>, + notifications: Vec>, + notification_subscriptions: HashMap, Vec>, last_error: Option>, list_state: ListState, + scrollbar_state: ScrollbarState, auth_task: Option>, - expanded_tool_calls: HashSet, + expanded_tool_calls: HashSet, expanded_thinking_blocks: HashSet<(usize, usize)>, - message_history: MessageHistory, + edits_expanded: bool, + plan_expanded: bool, + editor_expanded: bool, + message_history: Rc>>>, + _cancel_task: Option>, + _subscriptions: [Subscription; 1], } enum ThreadState { @@ -62,18 +89,25 @@ enum ThreadState { }, Ready { thread: Entity, - _subscription: Subscription, + _subscription: [Subscription; 2], }, LoadError(LoadError), Unauthenticated { - thread: Entity, + connection: Rc, + }, + ServerExited { + status: ExitStatus, }, } impl AcpThreadView { pub fn new( + agent: Rc, workspace: WeakEntity, project: Entity, + message_history: Rc>>>, + min_lines: usize, + max_lines: Option, window: &mut Window, cx: &mut Context, ) -> Self { @@ -93,8 +127,8 @@ impl AcpThreadView { let mut editor = Editor::new( editor::EditorMode::AutoHeight { - min_lines: 4, - max_lines: None, + min_lines, + max_lines: max_lines, }, buffer, None, @@ -118,40 +152,65 @@ impl AcpThreadView { editor }); - let list_state = ListState::new( - 0, - gpui::ListAlignment::Bottom, - px(2048.0), - cx.processor({ - move |this: &mut Self, index: usize, window, cx| { - let Some((entry, len)) = this.thread().and_then(|thread| { - let entries = &thread.read(cx).entries(); - Some((entries.get(index)?, entries.len())) - }) else { - return Empty.into_any(); - }; - this.render_entry(index, len, entry, window, cx) + let message_editor_subscription = + cx.subscribe(&message_editor, |this, editor, event, cx| { + if let editor::EditorEvent::BufferEdited = &event { + let buffer = editor + .read(cx) + .buffer() + .read(cx) + .as_singleton() + .unwrap() + .read(cx) + .snapshot(); + if let Some(message) = this.message_set_from_history.clone() + && message.version() != buffer.version() + { + this.message_set_from_history = None; + } + + if this.message_set_from_history.is_none() { + this.message_history.borrow_mut().reset_position(); + } } - }), - ); + }); + + let mention_set = mention_set.clone(); + + let list_state = ListState::new(0, gpui::ListAlignment::Bottom, px(2048.0)); + + let subscription = cx.observe_global_in::(window, Self::settings_changed); Self { - workspace, + agent: agent.clone(), + workspace: workspace.clone(), project: project.clone(), - thread_state: Self::initial_state(project, window, cx), + thread_state: Self::initial_state(agent, workspace, project, window, cx), message_editor, + message_set_from_history: None, + _message_editor_subscription: message_editor_subscription, mention_set, + notifications: Vec::new(), + notification_subscriptions: HashMap::default(), diff_editors: Default::default(), - list_state: list_state, + list_state: list_state.clone(), + scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()), last_error: None, auth_task: None, expanded_tool_calls: HashSet::default(), expanded_thinking_blocks: HashSet::default(), - message_history: MessageHistory::new(), + edits_expanded: false, + plan_expanded: false, + editor_expanded: false, + message_history, + _subscriptions: [subscription], + _cancel_task: None, } } fn initial_state( + agent: Rc, + workspace: WeakEntity, project: Entity, window: &mut Window, cx: &mut Context, @@ -163,10 +222,10 @@ impl AcpThreadView { .map(|worktree| worktree.read(cx).abs_path()) .unwrap_or_else(|| paths::home_dir().as_path().into()); + let connect_task = agent.connect(&root_dir, &project, cx); let load_task = cx.spawn_in(window, async move |this, cx| { - let thread = match AcpThread::spawn(agent_servers::Gemini, &root_dir, project, cx).await - { - Ok(thread) => thread, + let connection = match connect_task.await { + Ok(connection) => connection, Err(err) => { this.update(cx, |this, cx| { this.handle_load_error(err, cx); @@ -177,57 +236,61 @@ impl AcpThreadView { } }; - let init_response = async { - let resp = thread - .read_with(cx, |thread, _cx| thread.initialize())? - .await?; - anyhow::Ok(resp) - }; - - let result = match init_response.await { + // this.update_in(cx, |_this, _window, cx| { + // let status = connection.exit_status(cx); + // cx.spawn(async move |this, cx| { + // let status = status.await.ok(); + // this.update(cx, |this, cx| { + // this.thread_state = ThreadState::ServerExited { status }; + // cx.notify(); + // }) + // .ok(); + // }) + // .detach(); + // }) + // .ok(); + + let result = match connection + .clone() + .new_thread(project.clone(), &root_dir, cx) + .await + { Err(e) => { let mut cx = cx.clone(); - if e.downcast_ref::().is_some() { - let child_status = thread - .update(&mut cx, |thread, _| thread.child_status()) - .ok() - .flatten(); - if let Some(child_status) = child_status { - match child_status.await { - Ok(_) => Err(e), - Err(e) => Err(e), - } - } else { - Err(e) - } - } else { - Err(e) - } - } - Ok(response) => { - if !response.is_authenticated { - this.update(cx, |this, _| { - this.thread_state = ThreadState::Unauthenticated { thread }; + if e.is::() { + this.update(&mut cx, |this, cx| { + this.thread_state = ThreadState::Unauthenticated { connection }; + cx.notify(); }) .ok(); return; - }; - Ok(()) + } else { + Err(e) + } } + Ok(session_id) => Ok(session_id), }; this.update_in(cx, |this, window, cx| { match result { - Ok(()) => { - let subscription = + Ok(thread) => { + let thread_subscription = cx.subscribe_in(&thread, window, Self::handle_thread_event); + + let action_log = thread.read(cx).action_log().clone(); + let action_log_subscription = + cx.observe(&action_log, |_, _, cx| cx.notify()); + this.list_state .splice(0..0, thread.read(cx).entries().len()); + AgentDiff::set_active_thread(&workspace, thread.clone(), window, cx); + this.thread_state = ThreadState::Ready { thread, - _subscription: subscription, + _subscription: [thread_subscription, action_log_subscription], }; + cx.notify(); } Err(err) => { @@ -250,12 +313,13 @@ impl AcpThreadView { cx.notify(); } - fn thread(&self) -> Option<&Entity> { + pub fn thread(&self) -> Option<&Entity> { match &self.thread_state { - ThreadState::Ready { thread, .. } | ThreadState::Unauthenticated { thread } => { - Some(thread) - } - ThreadState::Loading { .. } | ThreadState::LoadError(..) => None, + ThreadState::Ready { thread, .. } => Some(thread), + ThreadState::Unauthenticated { .. } + | ThreadState::Loading { .. } + | ThreadState::LoadError(..) + | ThreadState::ServerExited { .. } => None, } } @@ -265,6 +329,7 @@ impl AcpThreadView { ThreadState::Loading { .. } => "Loading…".into(), ThreadState::LoadError(_) => "Failed to load".into(), ThreadState::Unauthenticated { .. } => "Not authenticated".into(), + ThreadState::ServerExited { .. } => "Server exited unexpectedly".into(), } } @@ -272,33 +337,73 @@ impl AcpThreadView { self.last_error.take(); if let Some(thread) = self.thread() { - thread.update(cx, |thread, cx| thread.cancel(cx)).detach(); + self._cancel_task = Some(thread.update(cx, |thread, cx| thread.cancel(cx))); } } + pub fn expand_message_editor( + &mut self, + _: &ExpandMessageEditor, + _window: &mut Window, + cx: &mut Context, + ) { + self.set_editor_is_expanded(!self.editor_expanded, cx); + cx.notify(); + } + + fn set_editor_is_expanded(&mut self, is_expanded: bool, cx: &mut Context) { + self.editor_expanded = is_expanded; + self.message_editor.update(cx, |editor, _| { + if self.editor_expanded { + editor.set_mode(EditorMode::Full { + scale_ui_elements_with_buffer_font_size: false, + show_active_line_background: false, + sized_by_content: false, + }) + } else { + editor.set_mode(EditorMode::AutoHeight { + min_lines: MIN_EDITOR_LINES, + max_lines: Some(MAX_EDITOR_LINES), + }) + } + }); + cx.notify(); + } + fn chat(&mut self, _: &Chat, window: &mut Window, cx: &mut Context) { self.last_error.take(); let mut ix = 0; - let mut chunks: Vec = Vec::new(); - + let mut chunks: Vec = Vec::new(); let project = self.project.clone(); self.message_editor.update(cx, |editor, cx| { let text = editor.text(cx); editor.display_map.update(cx, |map, cx| { let snapshot = map.snapshot(cx); for (crease_id, crease) in snapshot.crease_snapshot.creases() { + // Skip creases that have been edited out of the message buffer. + if !crease.range().start.is_valid(&snapshot.buffer_snapshot) { + continue; + } + if let Some(project_path) = self.mention_set.lock().path_for_crease_id(crease_id) { let crease_range = crease.range().to_offset(&snapshot.buffer_snapshot); if crease_range.start > ix { - chunks.push(acp::UserMessageChunk::Text { - text: text[ix..crease_range.start].to_string(), - }); + chunks.push(text[ix..crease_range.start].into()); } if let Some(abs_path) = project.read(cx).absolute_path(&project_path, cx) { - chunks.push(acp::UserMessageChunk::Path { path: abs_path }); + let path_str = abs_path.display().to_string(); + chunks.push(acp::ContentBlock::ResourceLink(acp::ResourceLink { + uri: path_str.clone(), + name: path_str, + annotations: None, + description: None, + mime_type: None, + size: None, + title: None, + })); } ix = crease_range.end; } @@ -307,9 +412,7 @@ impl AcpThreadView { if ix < text.len() { let last_chunk = text[ix..].trim(); if !last_chunk.is_empty() { - chunks.push(acp::UserMessageChunk::Text { - text: last_chunk.into(), - }); + chunks.push(last_chunk.into()); } } }) @@ -319,9 +422,10 @@ impl AcpThreadView { return; } - let Some(thread) = self.thread() else { return }; - let message = acp::SendUserMessageParams { chunks }; - let task = thread.update(cx, |thread, cx| thread.send(message.clone(), cx)); + let Some(thread) = self.thread() else { + return; + }; + let task = thread.update(cx, |thread, cx| thread.send(chunks.clone(), cx)); cx.spawn(async move |this, cx| { let result = task.await; @@ -337,12 +441,16 @@ impl AcpThreadView { let mention_set = self.mention_set.clone(); + self.set_editor_is_expanded(false, cx); + self.message_editor.update(cx, |editor, cx| { editor.clear(window, cx); editor.remove_creases(mention_set.lock().drain(), cx) }); - self.message_history.push(message); + self.scroll_to_bottom(cx); + + self.message_history.borrow_mut().push(chunks); } fn previous_history_message( @@ -351,11 +459,21 @@ impl AcpThreadView { window: &mut Window, cx: &mut Context, ) { - Self::set_draft_message( + if self.message_set_from_history.is_none() && !self.message_editor.read(cx).is_empty(cx) { + self.message_editor.update(cx, |editor, cx| { + editor.move_up(&Default::default(), window, cx); + }); + return; + } + + self.message_set_from_history = Self::set_draft_message( self.message_editor.clone(), self.mention_set.clone(), self.project.clone(), - self.message_history.prev(), + self.message_history + .borrow_mut() + .prev() + .map(|blocks| blocks.as_slice()), window, cx, ); @@ -367,49 +485,92 @@ impl AcpThreadView { window: &mut Window, cx: &mut Context, ) { - Self::set_draft_message( + if self.message_set_from_history.is_none() { + self.message_editor.update(cx, |editor, cx| { + editor.move_down(&Default::default(), window, cx); + }); + return; + } + + let mut message_history = self.message_history.borrow_mut(); + let next_history = message_history.next(); + + let set_draft_message = Self::set_draft_message( self.message_editor.clone(), self.mention_set.clone(), self.project.clone(), - self.message_history.next(), + Some( + next_history + .map(|blocks| blocks.as_slice()) + .unwrap_or_else(|| &[]), + ), window, cx, ); + // If we reset the text to an empty string because we ran out of history, + // we don't want to mark it as coming from the history + self.message_set_from_history = if next_history.is_some() { + set_draft_message + } else { + None + }; + } + + fn open_agent_diff(&mut self, _: &OpenAgentDiff, window: &mut Window, cx: &mut Context) { + if let Some(thread) = self.thread() { + AgentDiffPane::deploy(thread.clone(), self.workspace.clone(), window, cx).log_err(); + } + } + + fn open_edited_buffer( + &mut self, + buffer: &Entity, + window: &mut Window, + cx: &mut Context, + ) { + let Some(thread) = self.thread() else { + return; + }; + + let Some(diff) = + AgentDiffPane::deploy(thread.clone(), self.workspace.clone(), window, cx).log_err() + else { + return; + }; + + diff.update(cx, |diff, cx| { + diff.move_to_path(PathKey::for_buffer(&buffer, cx), window, cx) + }) } fn set_draft_message( message_editor: Entity, mention_set: Arc>, project: Entity, - message: Option<&acp::SendUserMessageParams>, + message: Option<&[acp::ContentBlock]>, window: &mut Window, cx: &mut Context, - ) { + ) -> Option { cx.notify(); - let Some(message) = message else { - message_editor.update(cx, |editor, cx| { - editor.clear(window, cx); - editor.remove_creases(mention_set.lock().drain(), cx) - }); - return; - }; + let message = message?; let mut text = String::new(); let mut mentions = Vec::new(); - for chunk in &message.chunks { + for chunk in message { match chunk { - acp::UserMessageChunk::Text { text: chunk } => { - text.push_str(&chunk); + acp::ContentBlock::Text(text_content) => { + text.push_str(&text_content.text); } - acp::UserMessageChunk::Path { path } => { + acp::ContentBlock::ResourceLink(resource_link) => { + let path = Path::new(&resource_link.uri); let start = text.len(); - let content = MentionPath::new(path).to_string(); + let content = MentionPath::new(&path).to_string(); text.push_str(&content); let end = text.len(); if let Some(project_path) = - project.read(cx).project_path_for_absolute_path(path, cx) + project.read(cx).project_path_for_absolute_path(&path, cx) { let filename: SharedString = path .file_name() @@ -420,6 +581,9 @@ impl AcpThreadView { mentions.push((start..end, project_path, filename)); } } + acp::ContentBlock::Image(_) + | acp::ContentBlock::Audio(_) + | acp::ContentBlock::Resource(_) => {} } } @@ -452,6 +616,9 @@ impl AcpThreadView { mention_set.lock().insert(crease_id, project_path); } } + + let snapshot = snapshot.as_singleton().unwrap().2.clone(); + Some(snapshot.text) } fn handle_thread_event( @@ -464,7 +631,8 @@ impl AcpThreadView { let count = self.list_state.item_count(); match event { AcpThreadEvent::NewEntry => { - self.sync_thread_entry_view(thread.read(cx).entries().len() - 1, window, cx); + let index = thread.read(cx).entries().len() - 1; + self.sync_thread_entry_view(index, window, cx); self.list_state.splice(count..count, 1); } AcpThreadEvent::EntryUpdated(index) => { @@ -472,6 +640,33 @@ impl AcpThreadView { self.sync_thread_entry_view(index, window, cx); self.list_state.splice(index..index + 1, 1); } + AcpThreadEvent::ToolAuthorizationRequired => { + self.notify_with_sound("Waiting for tool confirmation", IconName::Info, window, cx); + } + AcpThreadEvent::Stopped => { + let used_tools = thread.read(cx).used_tools_since_last_user_message(); + self.notify_with_sound( + if used_tools { + "Finished running tools" + } else { + "New message" + }, + IconName::ZedAssistant, + window, + cx, + ); + } + AcpThreadEvent::Error => { + self.notify_with_sound( + "Agent stopped due to an error", + IconName::Warning, + window, + cx, + ); + } + AcpThreadEvent::ServerExited(status) => { + self.thread_state = ThreadState::ServerExited { status: *status }; + } } cx.notify(); } @@ -482,81 +677,79 @@ impl AcpThreadView { window: &mut Window, cx: &mut Context, ) { - let Some(multibuffer) = self.entry_diff_multibuffer(entry_ix, cx) else { + let Some(multibuffers) = self.entry_diff_multibuffers(entry_ix, cx) else { return; }; - if self.diff_editors.contains_key(&multibuffer.entity_id()) { - return; - } + let multibuffers = multibuffers.collect::>(); - let editor = cx.new(|cx| { - let mut editor = Editor::new( - EditorMode::Full { - scale_ui_elements_with_buffer_font_size: false, - show_active_line_background: false, - sized_by_content: true, - }, - multibuffer.clone(), - None, - window, - cx, - ); - editor.set_show_gutter(false, cx); - editor.disable_inline_diagnostics(); - editor.disable_expand_excerpt_buttons(cx); - editor.set_show_vertical_scrollbar(false, cx); - editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx); - editor.set_soft_wrap_mode(SoftWrap::None, cx); - editor.scroll_manager.set_forbid_vertical_scroll(true); - editor.set_show_indent_guides(false, cx); - editor.set_read_only(true); - editor.set_show_breakpoints(false, cx); - editor.set_show_code_actions(false, cx); - editor.set_show_git_diff_gutter(false, cx); - editor.set_expand_all_diff_hunks(cx); - editor.set_text_style_refinement(TextStyleRefinement { - font_size: Some( - TextSize::Small - .rems(cx) - .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx)) - .into(), - ), - ..Default::default() + for multibuffer in multibuffers { + if self.diff_editors.contains_key(&multibuffer.entity_id()) { + return; + } + + let editor = cx.new(|cx| { + let mut editor = Editor::new( + EditorMode::Full { + scale_ui_elements_with_buffer_font_size: false, + show_active_line_background: false, + sized_by_content: true, + }, + multibuffer.clone(), + None, + window, + cx, + ); + editor.set_show_gutter(false, cx); + editor.disable_inline_diagnostics(); + editor.disable_expand_excerpt_buttons(cx); + editor.set_show_vertical_scrollbar(false, cx); + editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx); + editor.set_soft_wrap_mode(SoftWrap::None, cx); + editor.scroll_manager.set_forbid_vertical_scroll(true); + editor.set_show_indent_guides(false, cx); + editor.set_read_only(true); + editor.set_show_breakpoints(false, cx); + editor.set_show_code_actions(false, cx); + editor.set_show_git_diff_gutter(false, cx); + editor.set_expand_all_diff_hunks(cx); + editor.set_text_style_refinement(diff_editor_text_style_refinement(cx)); + editor }); - editor - }); - let entity_id = multibuffer.entity_id(); - cx.observe_release(&multibuffer, move |this, _, _| { - this.diff_editors.remove(&entity_id); - }) - .detach(); + let entity_id = multibuffer.entity_id(); + cx.observe_release(&multibuffer, move |this, _, _| { + this.diff_editors.remove(&entity_id); + }) + .detach(); - self.diff_editors.insert(entity_id, editor); + self.diff_editors.insert(entity_id, editor); + } } - fn entry_diff_multibuffer(&self, entry_ix: usize, cx: &App) -> Option> { + fn entry_diff_multibuffers( + &self, + entry_ix: usize, + cx: &App, + ) -> Option>> { let entry = self.thread()?.read(cx).entries().get(entry_ix)?; - if let AgentThreadEntry::ToolCall(ToolCall { - content: Some(ToolCallContent::Diff { diff }), - .. - }) = &entry - { - Some(diff.multibuffer.clone()) - } else { - None - } + Some(entry.diffs().map(|diff| diff.multibuffer.clone())) } - fn authenticate(&mut self, window: &mut Window, cx: &mut Context) { - let Some(thread) = self.thread().cloned() else { + fn authenticate( + &mut self, + method: acp::AuthMethodId, + window: &mut Window, + cx: &mut Context, + ) { + let ThreadState::Unauthenticated { ref connection } = self.thread_state else { return; }; self.last_error.take(); - let authenticate = thread.read(cx).authenticate(); + let authenticate = connection.authenticate(method, cx); self.auth_task = Some(cx.spawn_in(window, { let project = self.project.clone(); + let agent = self.agent.clone(); async move |this, cx| { let result = authenticate.await; @@ -566,7 +759,13 @@ impl AcpThreadView { Markdown::new(format!("Error: {err}").into(), None, None, cx) })) } else { - this.thread_state = Self::initial_state(project.clone(), window, cx) + this.thread_state = Self::initial_state( + agent, + this.workspace.clone(), + project.clone(), + window, + cx, + ) } this.auth_task.take() }) @@ -577,15 +776,16 @@ impl AcpThreadView { fn authorize_tool_call( &mut self, - id: ToolCallId, - outcome: acp::ToolCallConfirmationOutcome, + tool_call_id: acp::ToolCallId, + option_id: acp::PermissionOptionId, + option_kind: acp::PermissionOptionKind, cx: &mut Context, ) { let Some(thread) = self.thread() else { return; }; thread.update(cx, |thread, cx| { - thread.authorize_tool_call(id, outcome, cx); + thread.authorize_tool_call(tool_call_id, option_id, option_kind, cx); }); cx.notify(); } @@ -598,7 +798,7 @@ impl AcpThreadView { window: &mut Window, cx: &Context, ) -> AnyElement { - match &entry { + let primary = match &entry { AgentThreadEntry::UserMessage(message) => div() .py_4() .px_2() @@ -612,10 +812,12 @@ impl AcpThreadView { .border_1() .border_color(cx.theme().colors().border) .text_xs() - .child(self.render_markdown( - message.content.clone(), - user_message_markdown_style(window, cx), - )), + .children(message.content.markdown().map(|md| { + self.render_markdown( + md.clone(), + user_message_markdown_style(window, cx), + ) + })), ) .into_any(), AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) => { @@ -623,20 +825,28 @@ impl AcpThreadView { let message_body = v_flex() .w_full() .gap_2p5() - .children(chunks.iter().enumerate().map(|(chunk_ix, chunk)| { - match chunk { - AssistantMessageChunk::Text { chunk } => self - .render_markdown(chunk.clone(), style.clone()) - .into_any_element(), - AssistantMessageChunk::Thought { chunk } => self.render_thinking_block( - index, - chunk_ix, - chunk.clone(), - window, - cx, - ), - } - })) + .children(chunks.iter().enumerate().filter_map( + |(chunk_ix, chunk)| match chunk { + AssistantMessageChunk::Message { block } => { + block.markdown().map(|md| { + self.render_markdown(md.clone(), style.clone()) + .into_any_element() + }) + } + AssistantMessageChunk::Thought { block } => { + block.markdown().map(|md| { + self.render_thinking_block( + index, + chunk_ix, + md.clone(), + window, + cx, + ) + .into_any_element() + }) + } + }, + )) .into_any(); v_flex() @@ -649,10 +859,25 @@ impl AcpThreadView { .into_any() } AgentThreadEntry::ToolCall(tool_call) => div() + .w_full() .py_1p5() .px_5() .child(self.render_tool_call(index, tool_call, window, cx)) .into_any(), + }; + + let Some(thread) = self.thread() else { + return primary; + }; + let is_generating = matches!(thread.read(cx).status(), ThreadStatus::Generating); + if index == total_entries - 1 && !is_generating { + v_flex() + .w_full() + .child(primary) + .child(self.render_thread_controls(cx)) + .into_any_element() + } else { + primary } } @@ -680,6 +905,7 @@ impl AcpThreadView { cx: &Context, ) -> AnyElement { let header_id = SharedString::from(format!("thinking-block-header-{}", entry_ix)); + let card_header_id = SharedString::from("inner-card-header"); let key = (entry_ix, chunk_ix); let is_open = self.expanded_thinking_blocks.contains(&key); @@ -687,41 +913,53 @@ impl AcpThreadView { .child( h_flex() .id(header_id) - .group("disclosure-header") + .group(&card_header_id) + .relative() .w_full() - .justify_between() + .gap_1p5() .opacity(0.8) .hover(|style| style.opacity(1.)) .child( h_flex() - .gap_1p5() + .size_4() + .justify_center() .child( - Icon::new(IconName::ToolBulb) - .size(IconSize::Small) - .color(Color::Muted), + div() + .group_hover(&card_header_id, |s| s.invisible().w_0()) + .child( + Icon::new(IconName::ToolThink) + .size(IconSize::Small) + .color(Color::Muted), + ), ) .child( - div() - .text_size(self.tool_name_font_size()) - .child("Thinking"), + h_flex() + .absolute() + .inset_0() + .invisible() + .justify_center() + .group_hover(&card_header_id, |s| s.visible()) + .child( + Disclosure::new(("expand", entry_ix), is_open) + .opened_icon(IconName::ChevronUp) + .closed_icon(IconName::ChevronRight) + .on_click(cx.listener({ + move |this, _event, _window, cx| { + if is_open { + this.expanded_thinking_blocks.remove(&key); + } else { + this.expanded_thinking_blocks.insert(key); + } + cx.notify(); + } + })), + ), ), ) .child( - div().visible_on_hover("disclosure-header").child( - Disclosure::new("thinking-disclosure", is_open) - .opened_icon(IconName::ChevronUp) - .closed_icon(IconName::ChevronDown) - .on_click(cx.listener({ - move |this, _event, _window, cx| { - if is_open { - this.expanded_thinking_blocks.remove(&key); - } else { - this.expanded_thinking_blocks.insert(key); - } - cx.notify(); - } - })), - ), + div() + .text_size(self.tool_name_font_size()) + .child("Thinking"), ) .on_click(cx.listener({ move |this, _event, _window, cx| { @@ -752,6 +990,67 @@ impl AcpThreadView { .into_any_element() } + fn render_tool_call_icon( + &self, + group_name: SharedString, + entry_ix: usize, + is_collapsible: bool, + is_open: bool, + tool_call: &ToolCall, + cx: &Context, + ) -> Div { + let tool_icon = Icon::new(match tool_call.kind { + acp::ToolKind::Read => IconName::ToolRead, + acp::ToolKind::Edit => IconName::ToolPencil, + acp::ToolKind::Delete => IconName::ToolDeleteFile, + acp::ToolKind::Move => IconName::ArrowRightLeft, + acp::ToolKind::Search => IconName::ToolSearch, + acp::ToolKind::Execute => IconName::ToolTerminal, + acp::ToolKind::Think => IconName::ToolThink, + acp::ToolKind::Fetch => IconName::ToolWeb, + acp::ToolKind::Other => IconName::ToolHammer, + }) + .size(IconSize::Small) + .color(Color::Muted); + + if is_collapsible { + h_flex() + .size_4() + .justify_center() + .child( + div() + .group_hover(&group_name, |s| s.invisible().w_0()) + .child(tool_icon), + ) + .child( + h_flex() + .absolute() + .inset_0() + .invisible() + .justify_center() + .group_hover(&group_name, |s| s.visible()) + .child( + Disclosure::new(("expand", entry_ix), is_open) + .opened_icon(IconName::ChevronUp) + .closed_icon(IconName::ChevronRight) + .on_click(cx.listener({ + let id = tool_call.id.clone(); + move |this: &mut Self, _, _, cx: &mut Context| { + if is_open { + this.expanded_tool_calls.remove(&id); + } else { + this.expanded_tool_calls.insert(id.clone()); + } + cx.notify(); + } + })), + ), + ) + } else { + div().child(tool_icon) + } + } + fn render_tool_call( &self, entry_ix: usize, @@ -759,12 +1058,16 @@ impl AcpThreadView { window: &Window, cx: &Context, ) -> Div { - let header_id = SharedString::from(format!("tool-call-header-{}", entry_ix)); + let header_id = SharedString::from(format!("outer-tool-call-header-{}", entry_ix)); + let card_header_id = SharedString::from("inner-tool-call-header"); let status_icon = match &tool_call.status { - ToolCallStatus::WaitingForConfirmation { .. } => None, ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Running, + status: acp::ToolCallStatus::Pending, + } + | ToolCallStatus::WaitingForConfirmation { .. } => None, + ToolCallStatus::Allowed { + status: acp::ToolCallStatus::InProgress, .. } => Some( Icon::new(IconName::ArrowCircle) @@ -778,13 +1081,13 @@ impl AcpThreadView { .into_any(), ), ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Finished, + status: acp::ToolCallStatus::Completed, .. } => None, ToolCallStatus::Rejected | ToolCallStatus::Canceled | ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Error, + status: acp::ToolCallStatus::Failed, .. } => Some( Icon::new(IconName::X) @@ -802,32 +1105,22 @@ impl AcpThreadView { .any(|content| matches!(content, ToolCallContent::Diff { .. })), }; - let is_collapsible = tool_call.content.is_some() && !needs_confirmation; + let is_collapsible = !tool_call.content.is_empty() && !needs_confirmation; let is_open = !is_collapsible || self.expanded_tool_calls.contains(&tool_call.id); - let content = if is_open { - match &tool_call.status { - ToolCallStatus::WaitingForConfirmation { confirmation, .. } => { - Some(self.render_tool_call_confirmation( - tool_call.id, - confirmation, - tool_call.content.as_ref(), - window, - cx, - )) - } - ToolCallStatus::Allowed { .. } | ToolCallStatus::Canceled => { - tool_call.content.as_ref().map(|content| { - div() - .py_1p5() - .child(self.render_tool_call_content(content, window, cx)) - .into_any_element() - }) - } - ToolCallStatus::Rejected => None, - } - } else { - None + let gradient_color = cx.theme().colors().panel_background; + let gradient_overlay = { + div() + .absolute() + .top_0() + .right_0() + .w_12() + .h_full() + .bg(linear_gradient( + 90., + linear_color_stop(gradient_color, 1.), + linear_color_stop(gradient_color.opacity(0.2), 0.), + )) }; v_flex() @@ -846,76 +1139,108 @@ impl AcpThreadView { .justify_between() .map(|this| { if needs_confirmation { - this.px_2() + this.pl_2() + .pr_1() .py_1() .rounded_t_md() - .bg(self.tool_card_header_bg(cx)) .border_b_1() .border_color(self.tool_card_border_color(cx)) + .bg(self.tool_card_header_bg(cx)) } else { this.opacity(0.8).hover(|style| style.opacity(1.)) } }) .child( h_flex() - .id("tool-call-header") - .overflow_x_scroll() + .group(&card_header_id) + .relative() + .w_full() .map(|this| { - if needs_confirmation { - this.text_xs() + if tool_call.locations.len() == 1 { + this.gap_0() } else { - this.text_size(self.tool_name_font_size()) + this.gap_1p5() } }) - .gap_1p5() - .child( - Icon::new(tool_call.icon) - .size(IconSize::Small) - .color(Color::Muted), - ) - .child(self.render_markdown( - tool_call.label.clone(), - default_markdown_style(needs_confirmation, window, cx), - )), - ) - .child( - h_flex() - .gap_0p5() - .when(is_collapsible, |this| { - this.child( - Disclosure::new(("expand", tool_call.id.0), is_open) - .opened_icon(IconName::ChevronUp) - .closed_icon(IconName::ChevronDown) - .on_click(cx.listener({ - let id = tool_call.id; - move |this: &mut Self, _, _, cx: &mut Context| { - if is_open { - this.expanded_tool_calls.remove(&id); - } else { - this.expanded_tool_calls.insert(id); - } - cx.notify(); + .text_size(self.tool_name_font_size()) + .child(self.render_tool_call_icon( + card_header_id, + entry_ix, + is_collapsible, + is_open, + tool_call, + cx, + )) + .child(if tool_call.locations.len() == 1 { + let name = tool_call.locations[0] + .path + .file_name() + .unwrap_or_default() + .display() + .to_string(); + + h_flex() + .id(("open-tool-call-location", entry_ix)) + .w_full() + .max_w_full() + .px_1p5() + .rounded_sm() + .overflow_x_scroll() + .opacity(0.8) + .hover(|label| { + label.opacity(1.).bg(cx + .theme() + .colors() + .element_hover + .opacity(0.5)) + }) + .child(name) + .tooltip(Tooltip::text("Jump to File")) + .on_click(cx.listener(move |this, _, window, cx| { + this.open_tool_call_location(entry_ix, 0, window, cx); + })) + .into_any_element() + } else { + h_flex() + .id("non-card-label-container") + .w_full() + .relative() + .overflow_hidden() + .child( + h_flex() + .id("non-card-label") + .pr_8() + .w_full() + .overflow_x_scroll() + .child(self.render_markdown( + tool_call.label.clone(), + default_markdown_style( + needs_confirmation, + window, + cx, + ), + )), + ) + .child(gradient_overlay) + .on_click(cx.listener({ + let id = tool_call.id.clone(); + move |this: &mut Self, _, _, cx: &mut Context| { + if is_open { + this.expanded_tool_calls.remove(&id); + } else { + this.expanded_tool_calls.insert(id.clone()); } - })), - ) - }) - .children(status_icon), + cx.notify(); + } + })) + .into_any() + }), ) - .on_click(cx.listener({ - let id = tool_call.id; - move |this: &mut Self, _, _, cx: &mut Context| { - if is_open { - this.expanded_tool_calls.remove(&id); - } else { - this.expanded_tool_calls.insert(id); - } - cx.notify(); - } - })), + .children(status_icon), ) .when(is_open, |this| { this.child( - div() + v_flex() .text_xs() .when(is_collapsible, |this| { this.mt_1() @@ -924,7 +1249,45 @@ impl AcpThreadView { .bg(cx.theme().colors().editor_background) .rounded_lg() }) - .children(content), + .map(|this| { + if is_open { + match &tool_call.status { + ToolCallStatus::WaitingForConfirmation { options, .. } => this + .children(tool_call.content.iter().map(|content| { + div() + .py_1p5() + .child( + self.render_tool_call_content( + content, window, cx, + ), + ) + .into_any_element() + })) + .child(self.render_permission_buttons( + options, + entry_ix, + tool_call.id.clone(), + tool_call.content.is_empty(), + cx, + )), + ToolCallStatus::Allowed { .. } | ToolCallStatus::Canceled => { + this.children(tool_call.content.iter().map(|content| { + div() + .py_1p5() + .child( + self.render_tool_call_content( + content, window, cx, + ), + ) + .into_any_element() + })) + } + ToolCallStatus::Rejected => this, + } + } else { + this + } + }), ) }) } @@ -936,438 +1299,83 @@ impl AcpThreadView { cx: &Context, ) -> AnyElement { match content { - ToolCallContent::Markdown { markdown } => self - .render_markdown(markdown.clone(), default_markdown_style(false, window, cx)) - .into_any_element(), - ToolCallContent::Diff { - diff: Diff { - path, multibuffer, .. - }, - .. - } => self.render_diff_editor(multibuffer, path), - } - } - - fn render_tool_call_confirmation( - &self, - tool_call_id: ToolCallId, - confirmation: &ToolCallConfirmation, - content: Option<&ToolCallContent>, - window: &Window, - cx: &Context, - ) -> AnyElement { - let confirmation_container = v_flex().mt_1().py_1p5(); - - let button_container = h_flex() - .pt_1p5() - .px_1p5() - .gap_1() - .justify_end() - .border_t_1() - .border_color(self.tool_card_border_color(cx)); - - match confirmation { - ToolCallConfirmation::Edit { description } => confirmation_container - .child( + ToolCallContent::ContentBlock { content } => { + if let Some(md) = content.markdown() { div() - .px_2() - .children(description.clone().map(|description| { - self.render_markdown( - description, - default_markdown_style(false, window, cx), - ) - })), - ) - .children(content.map(|content| self.render_tool_call_content(content, window, cx))) - .child( - button_container - .child( - Button::new(("always_allow", tool_call_id.0), "Always Allow Edits") - .icon(IconName::CheckDouble) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Success) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::AlwaysAllow, - cx, - ); - } - })), - ) - .child( - Button::new(("allow", tool_call_id.0), "Allow") - .icon(IconName::Check) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Success) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::Allow, - cx, - ); - } - })), - ) - .child( - Button::new(("reject", tool_call_id.0), "Reject") - .icon(IconName::X) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Error) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::Reject, - cx, - ); - } - })), - ), - ) - .into_any(), - ToolCallConfirmation::Execute { - command, - root_command, - description, - } => confirmation_container - .child(v_flex().px_2().pb_1p5().child(command.clone()).children( - description.clone().map(|description| { - self.render_markdown(description, default_markdown_style(false, window, cx)) - .on_url_click({ - let workspace = self.workspace.clone(); - move |text, window, cx| { - Self::open_link(text, &workspace, window, cx); - } - }) - }), - )) - .children(content.map(|content| self.render_tool_call_content(content, window, cx))) - .child( - button_container - .child( - Button::new( - ("always_allow", tool_call_id.0), - format!("Always Allow {root_command}"), - ) - .icon(IconName::CheckDouble) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Success) - .label_size(LabelSize::Small) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::AlwaysAllow, - cx, - ); - } - })), - ) - .child( - Button::new(("allow", tool_call_id.0), "Allow") - .icon(IconName::Check) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Success) - .label_size(LabelSize::Small) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::Allow, - cx, - ); - } - })), - ) - .child( - Button::new(("reject", tool_call_id.0), "Reject") - .icon(IconName::X) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Error) - .label_size(LabelSize::Small) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::Reject, - cx, - ); - } - })), - ), - ) - .into_any(), - ToolCallConfirmation::Mcp { - server_name, - tool_name: _, - tool_display_name, - description, - } => confirmation_container - .child( - v_flex() - .px_2() - .pb_1p5() - .child(format!("{server_name} - {tool_display_name}")) - .children(description.clone().map(|description| { - self.render_markdown( - description, - default_markdown_style(false, window, cx), - ) - })), - ) - .children(content.map(|content| self.render_tool_call_content(content, window, cx))) - .child( - button_container - .child( - Button::new( - ("always_allow_server", tool_call_id.0), - format!("Always Allow {server_name}"), - ) - .icon(IconName::CheckDouble) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Success) - .label_size(LabelSize::Small) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer, - cx, - ); - } - })), - ) - .child( - Button::new( - ("always_allow_tool", tool_call_id.0), - format!("Always Allow {tool_display_name}"), - ) - .icon(IconName::CheckDouble) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Success) - .label_size(LabelSize::Small) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::AlwaysAllowTool, - cx, - ); - } - })), - ) - .child( - Button::new(("allow", tool_call_id.0), "Allow") - .icon(IconName::Check) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Success) - .label_size(LabelSize::Small) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::Allow, - cx, - ); - } - })), - ) + .p_2() .child( - Button::new(("reject", tool_call_id.0), "Reject") - .icon(IconName::X) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Error) - .label_size(LabelSize::Small) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::Reject, - cx, - ); - } - })), - ), - ) - .into_any(), - ToolCallConfirmation::Fetch { description, urls } => confirmation_container - .child( - v_flex() - .px_2() - .pb_1p5() - .gap_1() - .children(urls.iter().map(|url| { - h_flex().child( - Button::new(url.clone(), url) - .icon(IconName::ArrowUpRight) - .icon_color(Color::Muted) - .icon_size(IconSize::XSmall) - .on_click({ - let url = url.clone(); - move |_, _, cx| cx.open_url(&url) - }), - ) - })) - .children(description.clone().map(|description| { self.render_markdown( - description, + md.clone(), default_markdown_style(false, window, cx), - ) - })), - ) - .children(content.map(|content| self.render_tool_call_content(content, window, cx))) - .child( - button_container - .child( - Button::new(("always_allow", tool_call_id.0), "Always Allow") - .icon(IconName::CheckDouble) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Success) - .label_size(LabelSize::Small) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::AlwaysAllow, - cx, - ); - } - })), - ) - .child( - Button::new(("allow", tool_call_id.0), "Allow") - .icon(IconName::Check) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Success) - .label_size(LabelSize::Small) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::Allow, - cx, - ); - } - })), - ) - .child( - Button::new(("reject", tool_call_id.0), "Reject") - .icon(IconName::X) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Error) - .label_size(LabelSize::Small) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::Reject, - cx, - ); - } - })), - ), - ) - .into_any(), - ToolCallConfirmation::Other { description } => confirmation_container - .child(v_flex().px_2().pb_1p5().child(self.render_markdown( - description.clone(), - default_markdown_style(false, window, cx), - ))) - .children(content.map(|content| self.render_tool_call_content(content, window, cx))) - .child( - button_container - .child( - Button::new(("always_allow", tool_call_id.0), "Always Allow") - .icon(IconName::CheckDouble) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Success) - .label_size(LabelSize::Small) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::AlwaysAllow, - cx, - ); - } - })), - ) - .child( - Button::new(("allow", tool_call_id.0), "Allow") - .icon(IconName::Check) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Success) - .label_size(LabelSize::Small) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::Allow, - cx, - ); - } - })), + ), ) - .child( - Button::new(("reject", tool_call_id.0), "Reject") - .icon(IconName::X) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Error) - .label_size(LabelSize::Small) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::Reject, - cx, - ); - } - })), - ), - ) - .into_any(), + .into_any_element() + } else { + Empty.into_any_element() + } + } + ToolCallContent::Diff { + diff: Diff { multibuffer, .. }, + .. + } => self.render_diff_editor(multibuffer), } } - fn render_diff_editor(&self, multibuffer: &Entity, path: &Path) -> AnyElement { + fn render_permission_buttons( + &self, + options: &[acp::PermissionOption], + entry_ix: usize, + tool_call_id: acp::ToolCallId, + empty_content: bool, + cx: &Context, + ) -> Div { + h_flex() + .p_1p5() + .gap_1() + .justify_end() + .when(!empty_content, |this| { + this.border_t_1() + .border_color(self.tool_card_border_color(cx)) + }) + .children(options.iter().map(|option| { + let option_id = SharedString::from(option.id.0.clone()); + Button::new((option_id, entry_ix), option.name.clone()) + .map(|this| match option.kind { + acp::PermissionOptionKind::AllowOnce => { + this.icon(IconName::Check).icon_color(Color::Success) + } + acp::PermissionOptionKind::AllowAlways => { + this.icon(IconName::CheckDouble).icon_color(Color::Success) + } + acp::PermissionOptionKind::RejectOnce => { + this.icon(IconName::X).icon_color(Color::Error) + } + acp::PermissionOptionKind::RejectAlways => { + this.icon(IconName::X).icon_color(Color::Error) + } + }) + .icon_position(IconPosition::Start) + .icon_size(IconSize::XSmall) + .label_size(LabelSize::Small) + .on_click(cx.listener({ + let tool_call_id = tool_call_id.clone(); + let option_id = option.id.clone(); + let option_kind = option.kind; + move |this, _, _, cx| { + this.authorize_tool_call( + tool_call_id.clone(), + option_id.clone(), + option_kind, + cx, + ); + } + })) + })) + } + + fn render_diff_editor(&self, multibuffer: &Entity) -> AnyElement { v_flex() .h_full() - .child(path.to_string_lossy().to_string()) .child( if let Some(editor) = self.diff_editors.get(&multibuffer.entity_id()) { editor.clone().into_any_element() @@ -1378,15 +1386,15 @@ impl AcpThreadView { .into_any() } - fn render_gemini_logo(&self) -> AnyElement { - Icon::new(IconName::AiGemini) + fn render_agent_logo(&self) -> AnyElement { + Icon::new(self.agent.logo()) .color(Color::Muted) .size(IconSize::XLarge) .into_any_element() } - fn render_error_gemini_logo(&self) -> AnyElement { - let logo = Icon::new(IconName::AiGemini) + fn render_error_agent_logo(&self) -> AnyElement { + let logo = Icon::new(self.agent.logo()) .color(Color::Muted) .size(IconSize::XLarge) .into_any_element(); @@ -1405,49 +1413,50 @@ impl AcpThreadView { .into_any_element() } - fn render_empty_state(&self, loading: bool, cx: &App) -> AnyElement { + fn render_empty_state(&self, cx: &App) -> AnyElement { + let loading = matches!(&self.thread_state, ThreadState::Loading { .. }); + v_flex() .size_full() .items_center() .justify_center() - .child( - if loading { - h_flex() - .justify_center() - .child(self.render_gemini_logo()) - .with_animation( - "pulsating_icon", - Animation::new(Duration::from_secs(2)) - .repeat() - .with_easing(pulsating_between(0.4, 1.0)), - |icon, delta| icon.opacity(delta), - ).into_any() - } else { - self.render_gemini_logo().into_any_element() - } - ) - .child( + .child(if loading { h_flex() - .mt_4() - .mb_1() .justify_center() - .child(Headline::new(if loading { - "Connecting to Gemini…" - } else { - "Welcome to Gemini" - }).size(HeadlineSize::Medium)), - ) + .child(self.render_agent_logo()) + .with_animation( + "pulsating_icon", + Animation::new(Duration::from_secs(2)) + .repeat() + .with_easing(pulsating_between(0.4, 1.0)), + |icon, delta| icon.opacity(delta), + ) + .into_any() + } else { + self.render_agent_logo().into_any_element() + }) + .child(h_flex().mt_4().mb_1().justify_center().child(if loading { + div() + .child(LoadingLabel::new("").size(LabelSize::Large)) + .into_any_element() + } else { + Headline::new(self.agent.empty_state_headline()) + .size(HeadlineSize::Medium) + .into_any_element() + })) .child( div() .max_w_1_2() .text_sm() .text_center() - .map(|this| if loading { - this.invisible() - } else { - this.text_color(cx.theme().colors().text_muted) + .map(|this| { + if loading { + this.invisible() + } else { + this.text_color(cx.theme().colors().text_muted) + } }) - .child("Ask questions, edit files, run commands.\nBe specific for the best results.") + .child(self.agent.empty_state_message()), ) .into_any() } @@ -1456,7 +1465,7 @@ impl AcpThreadView { v_flex() .items_center() .justify_center() - .child(self.render_error_gemini_logo()) + .child(self.render_error_agent_logo()) .child( h_flex() .mt_4() @@ -1467,11 +1476,33 @@ impl AcpThreadView { .into_any() } - fn render_error_state(&self, e: &LoadError, cx: &Context) -> AnyElement { + fn render_server_exited(&self, status: ExitStatus, _cx: &Context) -> AnyElement { + v_flex() + .items_center() + .justify_center() + .child(self.render_error_agent_logo()) + .child( + v_flex() + .mt_4() + .mb_2() + .gap_0p5() + .text_center() + .items_center() + .child(Headline::new("Server exited unexpectedly").size(HeadlineSize::Medium)) + .child( + Label::new(format!("Exit status: {}", status.code().unwrap_or(-127))) + .size(LabelSize::Small) + .color(Color::Muted), + ), + ) + .into_any_element() + } + + fn render_load_error(&self, e: &LoadError, cx: &Context) -> AnyElement { let mut container = v_flex() .items_center() .justify_center() - .child(self.render_error_gemini_logo()) + .child(self.render_error_agent_logo()) .child( v_flex() .mt_4() @@ -1487,76 +1518,701 @@ impl AcpThreadView { ), ); - if matches!(e, LoadError::Unsupported { .. }) { - container = - container.child(Button::new("upgrade", "Upgrade Gemini to Latest").on_click( - cx.listener(|this, _, window, cx| { - this.workspace - .update(cx, |workspace, cx| { - let project = workspace.project().read(cx); - let cwd = project.first_project_directory(cx); - let shell = project.terminal_settings(&cwd, cx).shell.clone(); - let command = - "npm install -g @google/gemini-cli@latest".to_string(); - let spawn_in_terminal = task::SpawnInTerminal { - id: task::TaskId("install".to_string()), - full_label: command.clone(), - label: command.clone(), - command: Some(command.clone()), - args: Vec::new(), - command_label: command.clone(), - cwd, - env: Default::default(), - use_new_terminal: true, - allow_concurrent_runs: true, - reveal: Default::default(), - reveal_target: Default::default(), - hide: Default::default(), - shell, - show_summary: true, - show_command: true, - show_rerun: false, - }; - workspace - .spawn_in_terminal(spawn_in_terminal, window, cx) - .detach(); - }) - .ok(); - }), - )); + if let LoadError::Unsupported { + upgrade_message, + upgrade_command, + .. + } = &e + { + let upgrade_message = upgrade_message.clone(); + let upgrade_command = upgrade_command.clone(); + container = container.child(Button::new("upgrade", upgrade_message).on_click( + cx.listener(move |this, _, window, cx| { + this.workspace + .update(cx, |workspace, cx| { + let project = workspace.project().read(cx); + let cwd = project.first_project_directory(cx); + let shell = project.terminal_settings(&cwd, cx).shell.clone(); + let spawn_in_terminal = task::SpawnInTerminal { + id: task::TaskId("install".to_string()), + full_label: upgrade_command.clone(), + label: upgrade_command.clone(), + command: Some(upgrade_command.clone()), + args: Vec::new(), + command_label: upgrade_command.clone(), + cwd, + env: Default::default(), + use_new_terminal: true, + allow_concurrent_runs: true, + reveal: Default::default(), + reveal_target: Default::default(), + hide: Default::default(), + shell, + show_summary: true, + show_command: true, + show_rerun: false, + }; + workspace + .spawn_in_terminal(spawn_in_terminal, window, cx) + .detach(); + }) + .ok(); + }), + )); } container.into_any() } - fn render_message_editor(&mut self, cx: &mut Context) -> AnyElement { - let settings = ThemeSettings::get_global(cx); - let font_size = TextSize::Small - .rems(cx) - .to_pixels(settings.agent_font_size(cx)); - let line_height = settings.buffer_line_height.value() * font_size; - - let text_style = TextStyle { - color: cx.theme().colors().text, - font_family: settings.buffer_font.family.clone(), - font_fallbacks: settings.buffer_font.fallbacks.clone(), - font_features: settings.buffer_font.features.clone(), - font_size: font_size.into(), - line_height: line_height.into(), - ..Default::default() + fn render_activity_bar( + &self, + thread_entity: &Entity, + window: &mut Window, + cx: &Context, + ) -> Option { + let thread = thread_entity.read(cx); + let action_log = thread.action_log(); + let changed_buffers = action_log.read(cx).changed_buffers(cx); + let plan = thread.plan(); + + if changed_buffers.is_empty() && plan.is_empty() { + return None; + } + + let editor_bg_color = cx.theme().colors().editor_background; + let active_color = cx.theme().colors().element_selected; + let bg_edit_files_disclosure = editor_bg_color.blend(active_color.opacity(0.3)); + + let pending_edits = thread.has_pending_edit_tool_calls(); + + v_flex() + .mt_1() + .mx_2() + .bg(bg_edit_files_disclosure) + .border_1() + .border_b_0() + .border_color(cx.theme().colors().border) + .rounded_t_md() + .shadow(vec![gpui::BoxShadow { + color: gpui::black().opacity(0.15), + offset: point(px(1.), px(-1.)), + blur_radius: px(3.), + spread_radius: px(0.), + }]) + .when(!plan.is_empty(), |this| { + this.child(self.render_plan_summary(plan, window, cx)) + .when(self.plan_expanded, |parent| { + parent.child(self.render_plan_entries(plan, window, cx)) + }) + }) + .when(!changed_buffers.is_empty(), |this| { + this.child(Divider::horizontal().color(DividerColor::Border)) + .child(self.render_edits_summary( + action_log, + &changed_buffers, + self.edits_expanded, + pending_edits, + window, + cx, + )) + .when(self.edits_expanded, |parent| { + parent.child(self.render_edited_files( + action_log, + &changed_buffers, + pending_edits, + cx, + )) + }) + }) + .into_any() + .into() + } + + fn render_plan_summary(&self, plan: &Plan, window: &mut Window, cx: &Context) -> Div { + let stats = plan.stats(); + + let title = if let Some(entry) = stats.in_progress_entry + && !self.plan_expanded + { + h_flex() + .w_full() + .cursor_default() + .gap_1() + .text_xs() + .text_color(cx.theme().colors().text_muted) + .justify_between() + .child( + h_flex() + .gap_1() + .child( + Label::new("Current:") + .size(LabelSize::Small) + .color(Color::Muted), + ) + .child(MarkdownElement::new( + entry.content.clone(), + plan_label_markdown_style(&entry.status, window, cx), + )), + ) + .when(stats.pending > 0, |this| { + this.child( + Label::new(format!("{} left", stats.pending)) + .size(LabelSize::Small) + .color(Color::Muted) + .mr_1(), + ) + }) + } else { + let status_label = if stats.pending == 0 { + "All Done".to_string() + } else if stats.completed == 0 { + format!("{} Tasks", plan.entries.len()) + } else { + format!("{}/{}", stats.completed, plan.entries.len()) + }; + + h_flex() + .w_full() + .gap_1() + .justify_between() + .child( + Label::new("Plan") + .size(LabelSize::Small) + .color(Color::Muted), + ) + .child( + Label::new(status_label) + .size(LabelSize::Small) + .color(Color::Muted) + .mr_1(), + ) }; - EditorElement::new( - &self.message_editor, - EditorStyle { - background: cx.theme().colors().editor_background, - local_player: cx.theme().players().local(), - text: text_style, - syntax: cx.theme().syntax().clone(), - ..Default::default() + h_flex() + .p_1() + .justify_between() + .when(self.plan_expanded, |this| { + this.border_b_1().border_color(cx.theme().colors().border) + }) + .child( + h_flex() + .id("plan_summary") + .w_full() + .gap_1() + .child(Disclosure::new("plan_disclosure", self.plan_expanded)) + .child(title) + .on_click(cx.listener(|this, _, _, cx| { + this.plan_expanded = !this.plan_expanded; + cx.notify(); + })), + ) + } + + fn render_plan_entries(&self, plan: &Plan, window: &mut Window, cx: &Context) -> Div { + v_flex().children(plan.entries.iter().enumerate().flat_map(|(index, entry)| { + let element = h_flex() + .py_1() + .px_2() + .gap_2() + .justify_between() + .bg(cx.theme().colors().editor_background) + .when(index < plan.entries.len() - 1, |parent| { + parent.border_color(cx.theme().colors().border).border_b_1() + }) + .child( + h_flex() + .id(("plan_entry", index)) + .gap_1p5() + .max_w_full() + .overflow_x_scroll() + .text_xs() + .text_color(cx.theme().colors().text_muted) + .child(match entry.status { + acp::PlanEntryStatus::Pending => Icon::new(IconName::TodoPending) + .size(IconSize::Small) + .color(Color::Muted) + .into_any_element(), + acp::PlanEntryStatus::InProgress => Icon::new(IconName::TodoProgress) + .size(IconSize::Small) + .color(Color::Accent) + .with_animation( + "running", + Animation::new(Duration::from_secs(2)).repeat(), + |icon, delta| { + icon.transform(Transformation::rotate(percentage(delta))) + }, + ) + .into_any_element(), + acp::PlanEntryStatus::Completed => Icon::new(IconName::TodoComplete) + .size(IconSize::Small) + .color(Color::Success) + .into_any_element(), + }) + .child(MarkdownElement::new( + entry.content.clone(), + plan_label_markdown_style(&entry.status, window, cx), + )), + ); + + Some(element) + })) + } + + fn render_edits_summary( + &self, + action_log: &Entity, + changed_buffers: &BTreeMap, Entity>, + expanded: bool, + pending_edits: bool, + window: &mut Window, + cx: &Context, + ) -> Div { + const EDIT_NOT_READY_TOOLTIP_LABEL: &str = "Wait until file edits are complete."; + + let focus_handle = self.focus_handle(cx); + + h_flex() + .p_1() + .justify_between() + .when(expanded, |this| { + this.border_b_1().border_color(cx.theme().colors().border) + }) + .child( + h_flex() + .id("edits-container") + .w_full() + .gap_1() + .child(Disclosure::new("edits-disclosure", expanded)) + .map(|this| { + if pending_edits { + this.child( + Label::new(format!( + "Editing {} {}…", + changed_buffers.len(), + if changed_buffers.len() == 1 { + "file" + } else { + "files" + } + )) + .color(Color::Muted) + .size(LabelSize::Small) + .with_animation( + "edit-label", + Animation::new(Duration::from_secs(2)) + .repeat() + .with_easing(pulsating_between(0.3, 0.7)), + |label, delta| label.alpha(delta), + ), + ) + } else { + this.child( + Label::new("Edits") + .size(LabelSize::Small) + .color(Color::Muted), + ) + .child(Label::new("•").size(LabelSize::XSmall).color(Color::Muted)) + .child( + Label::new(format!( + "{} {}", + changed_buffers.len(), + if changed_buffers.len() == 1 { + "file" + } else { + "files" + } + )) + .size(LabelSize::Small) + .color(Color::Muted), + ) + } + }) + .on_click(cx.listener(|this, _, _, cx| { + this.edits_expanded = !this.edits_expanded; + cx.notify(); + })), + ) + .child( + h_flex() + .gap_1() + .child( + IconButton::new("review-changes", IconName::ListTodo) + .icon_size(IconSize::Small) + .tooltip({ + let focus_handle = focus_handle.clone(); + move |window, cx| { + Tooltip::for_action_in( + "Review Changes", + &OpenAgentDiff, + &focus_handle, + window, + cx, + ) + } + }) + .on_click(cx.listener(|_, _, window, cx| { + window.dispatch_action(OpenAgentDiff.boxed_clone(), cx); + })), + ) + .child(Divider::vertical().color(DividerColor::Border)) + .child( + Button::new("reject-all-changes", "Reject All") + .label_size(LabelSize::Small) + .disabled(pending_edits) + .when(pending_edits, |this| { + this.tooltip(Tooltip::text(EDIT_NOT_READY_TOOLTIP_LABEL)) + }) + .key_binding( + KeyBinding::for_action_in( + &RejectAll, + &focus_handle.clone(), + window, + cx, + ) + .map(|kb| kb.size(rems_from_px(10.))), + ) + .on_click({ + let action_log = action_log.clone(); + cx.listener(move |_, _, _, cx| { + action_log.update(cx, |action_log, cx| { + action_log.reject_all_edits(cx).detach(); + }) + }) + }), + ) + .child( + Button::new("keep-all-changes", "Keep All") + .label_size(LabelSize::Small) + .disabled(pending_edits) + .when(pending_edits, |this| { + this.tooltip(Tooltip::text(EDIT_NOT_READY_TOOLTIP_LABEL)) + }) + .key_binding( + KeyBinding::for_action_in(&KeepAll, &focus_handle, window, cx) + .map(|kb| kb.size(rems_from_px(10.))), + ) + .on_click({ + let action_log = action_log.clone(); + cx.listener(move |_, _, _, cx| { + action_log.update(cx, |action_log, cx| { + action_log.keep_all_edits(cx); + }) + }) + }), + ), + ) + } + + fn render_edited_files( + &self, + action_log: &Entity, + changed_buffers: &BTreeMap, Entity>, + pending_edits: bool, + cx: &Context, + ) -> Div { + let editor_bg_color = cx.theme().colors().editor_background; + + v_flex().children(changed_buffers.into_iter().enumerate().flat_map( + |(index, (buffer, _diff))| { + let file = buffer.read(cx).file()?; + let path = file.path(); + + let file_path = path.parent().and_then(|parent| { + let parent_str = parent.to_string_lossy(); + + if parent_str.is_empty() { + None + } else { + Some( + Label::new(format!("/{}{}", parent_str, std::path::MAIN_SEPARATOR_STR)) + .color(Color::Muted) + .size(LabelSize::XSmall) + .buffer_font(cx), + ) + } + }); + + let file_name = path.file_name().map(|name| { + Label::new(name.to_string_lossy().to_string()) + .size(LabelSize::XSmall) + .buffer_font(cx) + }); + + let file_icon = FileIcons::get_icon(&path, cx) + .map(Icon::from_path) + .map(|icon| icon.color(Color::Muted).size(IconSize::Small)) + .unwrap_or_else(|| { + Icon::new(IconName::File) + .color(Color::Muted) + .size(IconSize::Small) + }); + + let overlay_gradient = linear_gradient( + 90., + linear_color_stop(editor_bg_color, 1.), + linear_color_stop(editor_bg_color.opacity(0.2), 0.), + ); + + let element = h_flex() + .group("edited-code") + .id(("file-container", index)) + .relative() + .py_1() + .pl_2() + .pr_1() + .gap_2() + .justify_between() + .bg(editor_bg_color) + .when(index < changed_buffers.len() - 1, |parent| { + parent.border_color(cx.theme().colors().border).border_b_1() + }) + .child( + h_flex() + .id(("file-name", index)) + .pr_8() + .gap_1p5() + .max_w_full() + .overflow_x_scroll() + .child(file_icon) + .child(h_flex().gap_0p5().children(file_name).children(file_path)) + .on_click({ + let buffer = buffer.clone(); + cx.listener(move |this, _, window, cx| { + this.open_edited_buffer(&buffer, window, cx); + }) + }), + ) + .child( + h_flex() + .gap_1() + .visible_on_hover("edited-code") + .child( + Button::new("review", "Review") + .label_size(LabelSize::Small) + .on_click({ + let buffer = buffer.clone(); + cx.listener(move |this, _, window, cx| { + this.open_edited_buffer(&buffer, window, cx); + }) + }), + ) + .child(Divider::vertical().color(DividerColor::BorderVariant)) + .child( + Button::new("reject-file", "Reject") + .label_size(LabelSize::Small) + .disabled(pending_edits) + .on_click({ + let buffer = buffer.clone(); + let action_log = action_log.clone(); + move |_, _, cx| { + action_log.update(cx, |action_log, cx| { + action_log + .reject_edits_in_ranges( + buffer.clone(), + vec![Anchor::MIN..Anchor::MAX], + cx, + ) + .detach_and_log_err(cx); + }) + } + }), + ) + .child( + Button::new("keep-file", "Keep") + .label_size(LabelSize::Small) + .disabled(pending_edits) + .on_click({ + let buffer = buffer.clone(); + let action_log = action_log.clone(); + move |_, _, cx| { + action_log.update(cx, |action_log, cx| { + action_log.keep_edits_in_range( + buffer.clone(), + Anchor::MIN..Anchor::MAX, + cx, + ); + }) + } + }), + ), + ) + .child( + div() + .id("gradient-overlay") + .absolute() + .h_full() + .w_12() + .top_0() + .bottom_0() + .right(px(152.)) + .bg(overlay_gradient), + ); + + Some(element) }, - ) - .into_any() + )) + } + + fn render_message_editor(&mut self, window: &mut Window, cx: &mut Context) -> AnyElement { + let focus_handle = self.message_editor.focus_handle(cx); + let editor_bg_color = cx.theme().colors().editor_background; + let (expand_icon, expand_tooltip) = if self.editor_expanded { + (IconName::Minimize, "Minimize Message Editor") + } else { + (IconName::Maximize, "Expand Message Editor") + }; + + v_flex() + .on_action(cx.listener(Self::expand_message_editor)) + .p_2() + .gap_2() + .border_t_1() + .border_color(cx.theme().colors().border) + .bg(editor_bg_color) + .when(self.editor_expanded, |this| { + this.h(vh(0.8, window)).size_full().justify_between() + }) + .child( + v_flex() + .relative() + .size_full() + .pt_1() + .pr_2p5() + .child(div().flex_1().child({ + let settings = ThemeSettings::get_global(cx); + let font_size = TextSize::Small + .rems(cx) + .to_pixels(settings.agent_font_size(cx)); + let line_height = settings.buffer_line_height.value() * font_size; + + let text_style = TextStyle { + color: cx.theme().colors().text, + font_family: settings.buffer_font.family.clone(), + font_fallbacks: settings.buffer_font.fallbacks.clone(), + font_features: settings.buffer_font.features.clone(), + font_size: font_size.into(), + line_height: line_height.into(), + ..Default::default() + }; + + EditorElement::new( + &self.message_editor, + EditorStyle { + background: editor_bg_color, + local_player: cx.theme().players().local(), + text: text_style, + syntax: cx.theme().syntax().clone(), + ..Default::default() + }, + ) + })) + .child( + h_flex() + .absolute() + .top_0() + .right_0() + .opacity(0.5) + .hover(|this| this.opacity(1.0)) + .child( + IconButton::new("toggle-height", expand_icon) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .tooltip({ + let focus_handle = focus_handle.clone(); + move |window, cx| { + Tooltip::for_action_in( + expand_tooltip, + &ExpandMessageEditor, + &focus_handle, + window, + cx, + ) + } + }) + .on_click(cx.listener(|_, _, window, cx| { + window.dispatch_action(Box::new(ExpandMessageEditor), cx); + })), + ), + ), + ) + .child( + h_flex() + .flex_none() + .justify_between() + .child(self.render_follow_toggle(cx)) + .child(self.render_send_button(cx)), + ) + .into_any() + } + + fn render_send_button(&self, cx: &mut Context) -> AnyElement { + if self.thread().map_or(true, |thread| { + thread.read(cx).status() == ThreadStatus::Idle + }) { + let is_editor_empty = self.message_editor.read(cx).is_empty(cx); + IconButton::new("send-message", IconName::Send) + .icon_color(Color::Accent) + .style(ButtonStyle::Filled) + .disabled(self.thread().is_none() || is_editor_empty) + .when(!is_editor_empty, |button| { + button.tooltip(move |window, cx| Tooltip::for_action("Send", &Chat, window, cx)) + }) + .when(is_editor_empty, |button| { + button.tooltip(Tooltip::text("Type a message to submit")) + }) + .on_click(cx.listener(|this, _, window, cx| { + this.chat(&Chat, window, cx); + })) + .into_any_element() + } else { + IconButton::new("stop-generation", IconName::StopFilled) + .icon_color(Color::Error) + .style(ButtonStyle::Tinted(ui::TintColor::Error)) + .tooltip(move |window, cx| { + Tooltip::for_action("Stop Generation", &editor::actions::Cancel, window, cx) + }) + .on_click(cx.listener(|this, _event, _, cx| this.cancel(cx))) + .into_any_element() + } + } + + fn render_follow_toggle(&self, cx: &mut Context) -> impl IntoElement { + let following = self + .workspace + .read_with(cx, |workspace, _| { + workspace.is_being_followed(CollaboratorId::Agent) + }) + .unwrap_or(false); + + IconButton::new("follow-agent", IconName::Crosshair) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .toggle_state(following) + .selected_icon_color(Some(Color::Custom(cx.theme().players().agent().cursor))) + .tooltip(move |window, cx| { + if following { + Tooltip::for_action("Stop Following Agent", &Follow, window, cx) + } else { + Tooltip::with_meta( + "Follow Agent", + Some(&Follow), + "Track the agent's location as it reads and edits files.", + window, + cx, + ) + } + }) + .on_click(cx.listener(move |this, _, window, cx| { + this.workspace + .update(cx, |workspace, cx| { + if following { + workspace.unfollow(CollaboratorId::Agent, window, cx); + } else { + workspace.follow(CollaboratorId::Agent, window, cx); + } + }) + .ok(); + })) } fn render_markdown(&self, markdown: Entity, style: MarkdownStyle) -> MarkdownElement { @@ -1603,6 +2259,64 @@ impl AcpThreadView { } } + fn open_tool_call_location( + &self, + entry_ix: usize, + location_ix: usize, + window: &mut Window, + cx: &mut Context, + ) -> Option<()> { + let location = self + .thread()? + .read(cx) + .entries() + .get(entry_ix)? + .locations()? + .get(location_ix)?; + + let project_path = self + .project + .read(cx) + .find_project_path(&location.path, cx)?; + + let open_task = self + .workspace + .update(cx, |worskpace, cx| { + worskpace.open_path(project_path, None, true, window, cx) + }) + .log_err()?; + + window + .spawn(cx, async move |cx| { + let item = open_task.await?; + + let Some(active_editor) = item.downcast::() else { + return anyhow::Ok(()); + }; + + active_editor.update_in(cx, |editor, window, cx| { + let snapshot = editor.buffer().read(cx).snapshot(cx); + let first_hunk = editor + .diff_hunks_in_ranges( + &[editor::Anchor::min()..editor::Anchor::max()], + &snapshot, + ) + .next(); + if let Some(first_hunk) = first_hunk { + let first_hunk_start = first_hunk.multi_buffer_range().start; + editor.change_selections(Default::default(), window, cx, |selections| { + selections.select_anchor_ranges([first_hunk_start..first_hunk_start]); + }) + } + })?; + + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + + None + } + pub fn open_thread_as_markdown( &self, workspace: Entity, @@ -1615,12 +2329,11 @@ impl AcpThreadView { .languages .language_for_name("Markdown"); - let (thread_summary, markdown) = match &self.thread_state { - ThreadState::Ready { thread, .. } | ThreadState::Unauthenticated { thread } => { - let thread = thread.read(cx); - (thread.title().to_string(), thread.to_markdown(cx)) - } - ThreadState::Loading { .. } | ThreadState::LoadError(..) => return Task::ready(Ok(())), + let (thread_summary, markdown) = if let Some(thread) = self.thread() { + let thread = thread.read(cx); + (thread.title().to_string(), thread.to_markdown(cx)) + } else { + return Task::ready(Ok(())); }; window.spawn(cx, async move |cx| { @@ -1653,31 +2366,175 @@ impl AcpThreadView { cx, ); - anyhow::Ok(()) - })??; - anyhow::Ok(()) - }) - } + anyhow::Ok(()) + })??; + anyhow::Ok(()) + }) + } + + fn scroll_to_top(&mut self, cx: &mut Context) { + self.list_state.scroll_to(ListOffset::default()); + cx.notify(); + } + + pub fn scroll_to_bottom(&mut self, cx: &mut Context) { + if let Some(thread) = self.thread() { + let entry_count = thread.read(cx).entries().len(); + self.list_state.reset(entry_count); + cx.notify(); + } + } + + fn notify_with_sound( + &mut self, + caption: impl Into, + icon: IconName, + window: &mut Window, + cx: &mut Context, + ) { + self.play_notification_sound(window, cx); + self.show_notification(caption, icon, window, cx); + } + + fn play_notification_sound(&self, window: &Window, cx: &mut App) { + let settings = AgentSettings::get_global(cx); + if settings.play_sound_when_agent_done && !window.is_window_active() { + Audio::play_sound(Sound::AgentDone, cx); + } + } + + fn show_notification( + &mut self, + caption: impl Into, + icon: IconName, + window: &mut Window, + cx: &mut Context, + ) { + if window.is_window_active() || !self.notifications.is_empty() { + return; + } + + let title = self.title(cx); + + match AgentSettings::get_global(cx).notify_when_agent_waiting { + NotifyWhenAgentWaiting::PrimaryScreen => { + if let Some(primary) = cx.primary_display() { + self.pop_up(icon, caption.into(), title, window, primary, cx); + } + } + NotifyWhenAgentWaiting::AllScreens => { + let caption = caption.into(); + for screen in cx.displays() { + self.pop_up(icon, caption.clone(), title.clone(), window, screen, cx); + } + } + NotifyWhenAgentWaiting::Never => { + // Don't show anything + } + } + } + + fn pop_up( + &mut self, + icon: IconName, + caption: SharedString, + title: SharedString, + window: &mut Window, + screen: Rc, + cx: &mut Context, + ) { + let options = AgentNotification::window_options(screen, cx); + + let project_name = self.workspace.upgrade().and_then(|workspace| { + workspace + .read(cx) + .project() + .read(cx) + .visible_worktrees(cx) + .next() + .map(|worktree| worktree.read(cx).root_name().to_string()) + }); + + if let Some(screen_window) = cx + .open_window(options, |_, cx| { + cx.new(|_| { + AgentNotification::new(title.clone(), caption.clone(), icon, project_name) + }) + }) + .log_err() + { + if let Some(pop_up) = screen_window.entity(cx).log_err() { + self.notification_subscriptions + .entry(screen_window) + .or_insert_with(Vec::new) + .push(cx.subscribe_in(&pop_up, window, { + |this, _, event, window, cx| match event { + AgentNotificationEvent::Accepted => { + let handle = window.window_handle(); + cx.activate(true); + + let workspace_handle = this.workspace.clone(); + + // If there are multiple Zed windows, activate the correct one. + cx.defer(move |cx| { + handle + .update(cx, |_view, window, _cx| { + window.activate_window(); + + if let Some(workspace) = workspace_handle.upgrade() { + workspace.update(_cx, |workspace, cx| { + workspace.focus_panel::(window, cx); + }); + } + }) + .log_err(); + }); - fn scroll_to_top(&mut self, cx: &mut Context) { - self.list_state.scroll_to(ListOffset::default()); - cx.notify(); + this.dismiss_notifications(cx); + } + AgentNotificationEvent::Dismissed => { + this.dismiss_notifications(cx); + } + } + })); + + self.notifications.push(screen_window); + + // If the user manually refocuses the original window, dismiss the popup. + self.notification_subscriptions + .entry(screen_window) + .or_insert_with(Vec::new) + .push({ + let pop_up_weak = pop_up.downgrade(); + + cx.observe_window_activation(window, move |_, window, cx| { + if window.is_window_active() { + if let Some(pop_up) = pop_up_weak.upgrade() { + pop_up.update(cx, |_, cx| { + cx.emit(AgentNotificationEvent::Dismissed); + }); + } + } + }) + }); + } + } } -} -impl Focusable for AcpThreadView { - fn focus_handle(&self, cx: &App) -> FocusHandle { - self.message_editor.focus_handle(cx) - } -} + fn dismiss_notifications(&mut self, cx: &mut Context) { + for window in self.notifications.drain(..) { + window + .update(cx, |_, window, _| { + window.remove_window(); + }) + .ok(); -impl Render for AcpThreadView { - fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { - let text = self.message_editor.read(cx).text(cx); - let is_editor_empty = text.is_empty(); - let focus_handle = self.message_editor.focus_handle(cx); + self.notification_subscriptions.remove(&window); + } + } - let open_as_markdown = IconButton::new("open-as-markdown", IconName::DocumentText) + fn render_thread_controls(&self, cx: &Context) -> impl IntoElement { + let open_as_markdown = IconButton::new("open-as-markdown", IconName::FileText) .icon_size(IconSize::XSmall) .icon_color(Color::Ignored) .tooltip(Tooltip::text("Open Thread as Markdown")) @@ -1696,69 +2553,151 @@ impl Render for AcpThreadView { this.scroll_to_top(cx); })); + h_flex() + .w_full() + .mr_1() + .pb_2() + .px(RESPONSE_PADDING_X) + .opacity(0.4) + .hover(|style| style.opacity(1.)) + .flex_wrap() + .justify_end() + .child(open_as_markdown) + .child(scroll_to_top) + } + + fn render_vertical_scrollbar(&self, cx: &mut Context) -> Stateful
{ + div() + .id("acp-thread-scrollbar") + .occlude() + .on_mouse_move(cx.listener(|_, _, _, cx| { + cx.notify(); + cx.stop_propagation() + })) + .on_hover(|_, _, cx| { + cx.stop_propagation(); + }) + .on_any_mouse_down(|_, _, cx| { + cx.stop_propagation(); + }) + .on_mouse_up( + MouseButton::Left, + cx.listener(|_, _, _, cx| { + cx.stop_propagation(); + }), + ) + .on_scroll_wheel(cx.listener(|_, _, _, cx| { + cx.notify(); + })) + .h_full() + .absolute() + .right_1() + .top_1() + .bottom_0() + .w(px(12.)) + .cursor_default() + .children(Scrollbar::vertical(self.scrollbar_state.clone()).map(|s| s.auto_hide(cx))) + } + + fn settings_changed(&mut self, _window: &mut Window, cx: &mut Context) { + for diff_editor in self.diff_editors.values() { + diff_editor.update(cx, |diff_editor, cx| { + diff_editor.set_text_style_refinement(diff_editor_text_style_refinement(cx)); + cx.notify(); + }) + } + } +} + +impl Focusable for AcpThreadView { + fn focus_handle(&self, cx: &App) -> FocusHandle { + self.message_editor.focus_handle(cx) + } +} + +impl Render for AcpThreadView { + fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { v_flex() .size_full() .key_context("AcpThread") .on_action(cx.listener(Self::chat)) .on_action(cx.listener(Self::previous_history_message)) .on_action(cx.listener(Self::next_history_message)) + .on_action(cx.listener(Self::open_agent_diff)) + .bg(cx.theme().colors().panel_background) .child(match &self.thread_state { - ThreadState::Unauthenticated { .. } => v_flex() + ThreadState::Unauthenticated { connection } => v_flex() .p_2() .flex_1() .items_center() .justify_center() .child(self.render_pending_auth_state()) - .child(h_flex().mt_1p5().justify_center().child( - Button::new("sign-in", "Sign in to Gemini").on_click( - cx.listener(|this, _, window, cx| this.authenticate(window, cx)), - ), + .child(h_flex().mt_1p5().justify_center().children( + connection.auth_methods().into_iter().map(|method| { + Button::new( + SharedString::from(method.id.0.clone()), + method.name.clone(), + ) + .on_click({ + let method_id = method.id.clone(); + cx.listener(move |this, _, window, cx| { + this.authenticate(method_id.clone(), window, cx) + }) + }) + }), )), - ThreadState::Loading { .. } => { - v_flex().flex_1().child(self.render_empty_state(true, cx)) - } + ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)), ThreadState::LoadError(e) => v_flex() .p_2() .flex_1() .items_center() .justify_center() - .child(self.render_error_state(e, cx)), - ThreadState::Ready { thread, .. } => v_flex().flex_1().map(|this| { - if self.list_state.item_count() > 0 { - this.child( - list(self.list_state.clone()) + .child(self.render_load_error(e, cx)), + ThreadState::ServerExited { status } => v_flex() + .p_2() + .flex_1() + .items_center() + .justify_center() + .child(self.render_server_exited(*status, cx)), + ThreadState::Ready { thread, .. } => { + let thread_clone = thread.clone(); + + v_flex().flex_1().map(|this| { + if self.list_state.item_count() > 0 { + this.child( + list( + self.list_state.clone(), + cx.processor(|this, index: usize, window, cx| { + let Some((entry, len)) = this.thread().and_then(|thread| { + let entries = &thread.read(cx).entries(); + Some((entries.get(index)?, entries.len())) + }) else { + return Empty.into_any(); + }; + this.render_entry(index, len, entry, window, cx) + }), + ) .with_sizing_behavior(gpui::ListSizingBehavior::Auto) .flex_grow() .into_any(), - ) - .child( - h_flex() - .group("controls") - .mt_1() - .mr_1() - .py_2() - .px(RESPONSE_PADDING_X) - .opacity(0.4) - .hover(|style| style.opacity(1.)) - .gap_1() - .flex_wrap() - .justify_end() - .child(open_as_markdown) - .child(scroll_to_top) - .into_any_element(), - ) - .children(match thread.read(cx).status() { - ThreadStatus::Idle | ThreadStatus::WaitingForToolConfirmation => None, - ThreadStatus::Generating => div() - .px_5() - .py_2() - .child(LoadingLabel::new("").size(LabelSize::Small)) - .into(), - }) - } else { - this.child(self.render_empty_state(false, cx)) - } - }), + ) + .child(self.render_vertical_scrollbar(cx)) + .children(match thread_clone.read(cx).status() { + ThreadStatus::Idle | ThreadStatus::WaitingForToolConfirmation => { + None + } + ThreadStatus::Generating => div() + .px_5() + .py_2() + .child(LoadingLabel::new("").size(LabelSize::Small)) + .into(), + }) + .children(self.render_activity_bar(&thread_clone, window, cx)) + } else { + this.child(self.render_empty_state(cx)) + } + }) + } }) .when_some(self.last_error.clone(), |el, error| { el.child( @@ -1773,57 +2712,7 @@ impl Render for AcpThreadView { ), ) }) - .child( - v_flex() - .p_2() - .pt_3() - .gap_1() - .bg(cx.theme().colors().editor_background) - .border_t_1() - .border_color(cx.theme().colors().border) - .child(self.render_message_editor(cx)) - .child({ - let thread = self.thread(); - - h_flex().justify_end().child( - if thread.map_or(true, |thread| { - thread.read(cx).status() == ThreadStatus::Idle - }) { - IconButton::new("send-message", IconName::Send) - .icon_color(Color::Accent) - .style(ButtonStyle::Filled) - .disabled(thread.is_none() || is_editor_empty) - .on_click({ - let focus_handle = focus_handle.clone(); - move |_event, window, cx| { - focus_handle.dispatch_action(&Chat, window, cx); - } - }) - .when(!is_editor_empty, |button| { - button.tooltip(move |window, cx| { - Tooltip::for_action("Send", &Chat, window, cx) - }) - }) - .when(is_editor_empty, |button| { - button.tooltip(Tooltip::text("Type a message to submit")) - }) - } else { - IconButton::new("stop-generation", IconName::StopFilled) - .icon_color(Color::Error) - .style(ButtonStyle::Tinted(ui::TintColor::Error)) - .tooltip(move |window, cx| { - Tooltip::for_action( - "Stop Generation", - &editor::actions::Cancel, - window, - cx, - ) - }) - .on_click(cx.listener(|this, _event, _, cx| this.cancel(cx))) - }, - ) - }), - ) + .child(self.render_message_editor(window, cx)) } } @@ -1970,3 +2859,511 @@ fn default_markdown_style(buffer_font: bool, window: &Window, cx: &App) -> Markd ..Default::default() } } + +fn plan_label_markdown_style( + status: &acp::PlanEntryStatus, + window: &Window, + cx: &App, +) -> MarkdownStyle { + let default_md_style = default_markdown_style(false, window, cx); + + MarkdownStyle { + base_text_style: TextStyle { + color: cx.theme().colors().text_muted, + strikethrough: if matches!(status, acp::PlanEntryStatus::Completed) { + Some(gpui::StrikethroughStyle { + thickness: px(1.), + color: Some(cx.theme().colors().text_muted.opacity(0.8)), + }) + } else { + None + }, + ..default_md_style.base_text_style + }, + ..default_md_style + } +} + +fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement { + TextStyleRefinement { + font_size: Some( + TextSize::Small + .rems(cx) + .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx)) + .into(), + ), + ..Default::default() + } +} + +#[cfg(test)] +mod tests { + use agent_client_protocol::SessionId; + use editor::EditorSettings; + use fs::FakeFs; + use futures::future::try_join_all; + use gpui::{SemanticVersion, TestAppContext, VisualTestContext}; + use lsp::{CompletionContext, CompletionTriggerKind}; + use project::CompletionIntent; + use rand::Rng; + use serde_json::json; + use settings::SettingsStore; + use util::path; + + use super::*; + + #[gpui::test] + async fn test_drop(cx: &mut TestAppContext) { + init_test(cx); + + let (thread_view, _cx) = setup_thread_view(StubAgentServer::default(), cx).await; + let weak_view = thread_view.downgrade(); + drop(thread_view); + assert!(!weak_view.is_upgradable()); + } + + #[gpui::test] + async fn test_notification_for_stop_event(cx: &mut TestAppContext) { + init_test(cx); + + let (thread_view, cx) = setup_thread_view(StubAgentServer::default(), cx).await; + + let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Hello", window, cx); + }); + + cx.deactivate_window(); + + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.chat(&Chat, window, cx); + }); + + cx.run_until_parked(); + + assert!( + cx.windows() + .iter() + .any(|window| window.downcast::().is_some()) + ); + } + + #[gpui::test] + async fn test_notification_for_error(cx: &mut TestAppContext) { + init_test(cx); + + let (thread_view, cx) = + setup_thread_view(StubAgentServer::new(SaboteurAgentConnection), cx).await; + + let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Hello", window, cx); + }); + + cx.deactivate_window(); + + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.chat(&Chat, window, cx); + }); + + cx.run_until_parked(); + + assert!( + cx.windows() + .iter() + .any(|window| window.downcast::().is_some()) + ); + } + + #[gpui::test] + async fn test_notification_for_tool_authorization(cx: &mut TestAppContext) { + init_test(cx); + + let tool_call_id = acp::ToolCallId("1".into()); + let tool_call = acp::ToolCall { + id: tool_call_id.clone(), + title: "Label".into(), + kind: acp::ToolKind::Edit, + status: acp::ToolCallStatus::Pending, + content: vec!["hi".into()], + locations: vec![], + raw_input: None, + raw_output: None, + }; + let connection = StubAgentConnection::new(vec![acp::SessionUpdate::ToolCall(tool_call)]) + .with_permission_requests(HashMap::from_iter([( + tool_call_id, + vec![acp::PermissionOption { + id: acp::PermissionOptionId("1".into()), + name: "Allow".into(), + kind: acp::PermissionOptionKind::AllowOnce, + }], + )])); + let (thread_view, cx) = setup_thread_view(StubAgentServer::new(connection), cx).await; + + let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Hello", window, cx); + }); + + cx.deactivate_window(); + + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.chat(&Chat, window, cx); + }); + + cx.run_until_parked(); + + assert!( + cx.windows() + .iter() + .any(|window| window.downcast::().is_some()) + ); + } + + #[gpui::test] + async fn test_crease_removal(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree("/project", json!({"file": ""})).await; + let project = Project::test(fs, [Path::new(path!("/project"))], cx).await; + let agent = StubAgentServer::default(); + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let thread_view = cx.update(|window, cx| { + cx.new(|cx| { + AcpThreadView::new( + Rc::new(agent), + workspace.downgrade(), + project, + Rc::new(RefCell::new(MessageHistory::default())), + 1, + None, + window, + cx, + ) + }) + }); + + cx.run_until_parked(); + + let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); + let excerpt_id = message_editor.update(cx, |editor, cx| { + editor + .buffer() + .read(cx) + .excerpt_ids() + .into_iter() + .next() + .unwrap() + }); + let completions = message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Hello @", window, cx); + let buffer = editor.buffer().read(cx).as_singleton().unwrap(); + let completion_provider = editor.completion_provider().unwrap(); + completion_provider.completions( + excerpt_id, + &buffer, + Anchor::MAX, + CompletionContext { + trigger_kind: CompletionTriggerKind::TRIGGER_CHARACTER, + trigger_character: Some("@".into()), + }, + window, + cx, + ) + }); + let [_, completion]: [_; 2] = completions + .await + .unwrap() + .into_iter() + .flat_map(|response| response.completions) + .collect::>() + .try_into() + .unwrap(); + + message_editor.update_in(cx, |editor, window, cx| { + let snapshot = editor.buffer().read(cx).snapshot(cx); + let start = snapshot + .anchor_in_excerpt(excerpt_id, completion.replace_range.start) + .unwrap(); + let end = snapshot + .anchor_in_excerpt(excerpt_id, completion.replace_range.end) + .unwrap(); + editor.edit([(start..end, completion.new_text)], cx); + (completion.confirm.unwrap())(CompletionIntent::Complete, window, cx); + }); + + cx.run_until_parked(); + + // Backspace over the inserted crease (and the following space). + message_editor.update_in(cx, |editor, window, cx| { + editor.backspace(&Default::default(), window, cx); + editor.backspace(&Default::default(), window, cx); + }); + + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.chat(&Chat, window, cx); + }); + + cx.run_until_parked(); + + let content = thread_view.update_in(cx, |thread_view, _window, _cx| { + thread_view + .message_history + .borrow() + .items() + .iter() + .flatten() + .cloned() + .collect::>() + }); + + // We don't send a resource link for the deleted crease. + pretty_assertions::assert_matches!(content.as_slice(), [acp::ContentBlock::Text { .. }]); + } + + async fn setup_thread_view( + agent: impl AgentServer + 'static, + cx: &mut TestAppContext, + ) -> (Entity, &mut VisualTestContext) { + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [], cx).await; + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + + let thread_view = cx.update(|window, cx| { + cx.new(|cx| { + AcpThreadView::new( + Rc::new(agent), + workspace.downgrade(), + project, + Rc::new(RefCell::new(MessageHistory::default())), + 1, + None, + window, + cx, + ) + }) + }); + cx.run_until_parked(); + (thread_view, cx) + } + + struct StubAgentServer { + connection: C, + } + + impl StubAgentServer { + fn new(connection: C) -> Self { + Self { connection } + } + } + + impl StubAgentServer { + fn default() -> Self { + Self::new(StubAgentConnection::default()) + } + } + + impl AgentServer for StubAgentServer + where + C: 'static + AgentConnection + Send + Clone, + { + fn logo(&self) -> ui::IconName { + unimplemented!() + } + + fn name(&self) -> &'static str { + unimplemented!() + } + + fn empty_state_headline(&self) -> &'static str { + unimplemented!() + } + + fn empty_state_message(&self) -> &'static str { + unimplemented!() + } + + fn connect( + &self, + _root_dir: &Path, + _project: &Entity, + _cx: &mut App, + ) -> Task>> { + Task::ready(Ok(Rc::new(self.connection.clone()))) + } + } + + #[derive(Clone, Default)] + struct StubAgentConnection { + sessions: Arc>>>, + permission_requests: HashMap>, + updates: Vec, + } + + impl StubAgentConnection { + fn new(updates: Vec) -> Self { + Self { + updates, + permission_requests: HashMap::default(), + sessions: Arc::default(), + } + } + + fn with_permission_requests( + mut self, + permission_requests: HashMap>, + ) -> Self { + self.permission_requests = permission_requests; + self + } + } + + impl AgentConnection for StubAgentConnection { + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] + } + + fn new_thread( + self: Rc, + project: Entity, + _cwd: &Path, + cx: &mut gpui::AsyncApp, + ) -> Task>> { + let session_id = SessionId( + rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(7) + .map(char::from) + .collect::() + .into(), + ); + let thread = cx + .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx)) + .unwrap(); + self.sessions.lock().insert(session_id, thread.downgrade()); + Task::ready(Ok(thread)) + } + + fn authenticate( + &self, + _method_id: acp::AuthMethodId, + _cx: &mut App, + ) -> Task> { + unimplemented!() + } + + fn prompt( + &self, + params: acp::PromptRequest, + cx: &mut App, + ) -> Task> { + let sessions = self.sessions.lock(); + let thread = sessions.get(¶ms.session_id).unwrap(); + let mut tasks = vec![]; + for update in &self.updates { + let thread = thread.clone(); + let update = update.clone(); + let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update + && let Some(options) = self.permission_requests.get(&tool_call.id) + { + Some((tool_call.clone(), options.clone())) + } else { + None + }; + let task = cx.spawn(async move |cx| { + if let Some((tool_call, options)) = permission_request { + let permission = thread.update(cx, |thread, cx| { + thread.request_tool_call_authorization( + tool_call.clone(), + options.clone(), + cx, + ) + })?; + permission.await?; + } + thread.update(cx, |thread, cx| { + thread.handle_session_update(update.clone(), cx).unwrap(); + })?; + anyhow::Ok(()) + }); + tasks.push(task); + } + cx.spawn(async move |_| { + try_join_all(tasks).await?; + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) + }) + } + + fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) { + unimplemented!() + } + } + + #[derive(Clone)] + struct SaboteurAgentConnection; + + impl AgentConnection for SaboteurAgentConnection { + fn new_thread( + self: Rc, + project: Entity, + _cwd: &Path, + cx: &mut gpui::AsyncApp, + ) -> Task>> { + Task::ready(Ok(cx + .new(|cx| { + AcpThread::new( + "SaboteurAgentConnection", + self, + project, + SessionId("test".into()), + cx, + ) + }) + .unwrap())) + } + + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] + } + + fn authenticate( + &self, + _method_id: acp::AuthMethodId, + _cx: &mut App, + ) -> Task> { + unimplemented!() + } + + fn prompt( + &self, + _params: acp::PromptRequest, + _cx: &mut App, + ) -> Task> { + Task::ready(Err(anyhow::anyhow!("Error prompting"))) + } + + fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) { + unimplemented!() + } + } + + fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + AgentSettings::register(cx); + workspace::init_settings(cx); + ThemeSettings::register(cx); + release_channel::init(SemanticVersion::default(), cx); + EditorSettings::register(cx); + }); + } +} diff --git a/crates/agent_ui/src/active_thread.rs b/crates/agent_ui/src/active_thread.rs index 383729017a1635e4301fa50d587f70940543130f..71526c8fe14f0f023cfad260525c5366e66a24e5 100644 --- a/crates/agent_ui/src/active_thread.rs +++ b/crates/agent_ui/src/active_thread.rs @@ -14,6 +14,7 @@ use agent_settings::{AgentSettings, NotifyWhenAgentWaiting}; use anyhow::Context as _; use assistant_tool::ToolUseStatus; use audio::{Audio, Sound}; +use cloud_llm_client::CompletionIntent; use collections::{HashMap, HashSet}; use editor::actions::{MoveUp, Paste}; use editor::scroll::Autoscroll; @@ -52,7 +53,6 @@ use util::ResultExt as _; use util::markdown::MarkdownCodeBlock; use workspace::{CollaboratorId, Workspace}; use zed_actions::assistant::OpenRulesLibrary; -use zed_llm_client::CompletionIntent; const CODEBLOCK_CONTAINER_GROUP: &str = "codeblock_container"; const EDIT_PREVIOUS_MESSAGE_MIN_LINES: usize = 1; @@ -69,8 +69,6 @@ pub struct ActiveThread { messages: Vec, list_state: ListState, scrollbar_state: ScrollbarState, - show_scrollbar: bool, - hide_scrollbar_task: Option>, rendered_messages_by_id: HashMap, rendered_tool_uses: HashMap, editing_message: Option<(MessageId, EditingMessageState)>, @@ -780,13 +778,7 @@ impl ActiveThread { cx.observe_global::(|_, cx| cx.notify()), ]; - let list_state = ListState::new(0, ListAlignment::Bottom, px(2048.), { - let this = cx.entity().downgrade(); - move |ix, window: &mut Window, cx: &mut App| { - this.update(cx, |this, cx| this.render_message(ix, window, cx)) - .unwrap() - } - }); + let list_state = ListState::new(0, ListAlignment::Bottom, px(2048.)); let workspace_subscription = if let Some(workspace) = workspace.upgrade() { Some(cx.observe_release(&workspace, |this, _, cx| { @@ -811,9 +803,7 @@ impl ActiveThread { expanded_thinking_segments: HashMap::default(), expanded_code_blocks: HashMap::default(), list_state: list_state.clone(), - scrollbar_state: ScrollbarState::new(list_state), - show_scrollbar: false, - hide_scrollbar_task: None, + scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()), editing_message: None, last_error: None, copied_code_block_ids: HashSet::default(), @@ -996,30 +986,57 @@ impl ActiveThread { | ThreadEvent::SummaryChanged => { self.save_thread(cx); } - ThreadEvent::Stopped(reason) => match reason { - Ok(StopReason::EndTurn | StopReason::MaxTokens) => { - let used_tools = self.thread.read(cx).used_tools_since_last_user_message(); - self.play_notification_sound(window, cx); - self.show_notification( - if used_tools { - "Finished running tools" - } else { - "New message" - }, - IconName::ZedAssistant, - window, - cx, - ); + ThreadEvent::Stopped(reason) => { + match reason { + Ok(StopReason::EndTurn | StopReason::MaxTokens) => { + let used_tools = self.thread.read(cx).used_tools_since_last_user_message(); + self.notify_with_sound( + if used_tools { + "Finished running tools" + } else { + "New message" + }, + IconName::ZedAssistant, + window, + cx, + ); + } + Ok(StopReason::ToolUse) => { + // Don't notify for intermediate tool use + } + Ok(StopReason::Refusal) => { + self.notify_with_sound( + "Language model refused to respond", + IconName::Warning, + window, + cx, + ); + } + Err(error) => { + self.notify_with_sound( + "Agent stopped due to an error", + IconName::Warning, + window, + cx, + ); + + let error_message = error + .chain() + .map(|err| err.to_string()) + .collect::>() + .join("\n"); + self.last_error = Some(ThreadError::Message { + header: "Error".into(), + message: error_message.into(), + }); + } } - _ => {} - }, + } ThreadEvent::ToolConfirmationNeeded => { - self.play_notification_sound(window, cx); - self.show_notification("Waiting for tool confirmation", IconName::Info, window, cx); + self.notify_with_sound("Waiting for tool confirmation", IconName::Info, window, cx); } ThreadEvent::ToolUseLimitReached => { - self.play_notification_sound(window, cx); - self.show_notification( + self.notify_with_sound( "Consecutive tool use limit reached.", IconName::Warning, window, @@ -1162,9 +1179,6 @@ impl ActiveThread { self.save_thread(cx); cx.notify(); } - ThreadEvent::RetriesFailed { message } => { - self.show_notification(message, ui::IconName::Warning, window, cx); - } } } @@ -1219,6 +1233,17 @@ impl ActiveThread { } } + fn notify_with_sound( + &mut self, + caption: impl Into, + icon: IconName, + window: &mut Window, + cx: &mut Context, + ) { + self.play_notification_sound(window, cx); + self.show_notification(caption, icon, window, cx); + } + fn pop_up( &mut self, icon: IconName, @@ -1811,7 +1836,12 @@ impl ActiveThread { ))) } - fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context) -> AnyElement { + fn render_message( + &mut self, + ix: usize, + window: &mut Window, + cx: &mut Context, + ) -> AnyElement { let message_id = self.messages[ix]; let workspace = self.workspace.clone(); let thread = self.thread.read(cx); @@ -2594,7 +2624,7 @@ impl ActiveThread { h_flex() .gap_1p5() .child( - Icon::new(IconName::ToolBulb) + Icon::new(IconName::ToolThink) .size(IconSize::Small) .color(Color::Muted), ) @@ -3167,7 +3197,10 @@ impl ActiveThread { .border_color(self.tool_card_border_color(cx)) .rounded_b_lg() .child( - LoadingLabel::new("Waiting for Confirmation").size(LabelSize::Small) + div() + .min_w(rems_from_px(145.)) + .child(LoadingLabel::new("Waiting for Confirmation").size(LabelSize::Small) + ) ) .child( h_flex() @@ -3212,7 +3245,6 @@ impl ActiveThread { }, )) }) - .child(ui::Divider::vertical()) .child({ let tool_id = tool_use.id.clone(); Button::new("allow-tool-action", "Allow") @@ -3466,60 +3498,37 @@ impl ActiveThread { } } - fn render_vertical_scrollbar(&self, cx: &mut Context) -> Option> { - if !self.show_scrollbar && !self.scrollbar_state.is_dragging() { - return None; - } - - Some( - div() - .occlude() - .id("active-thread-scrollbar") - .on_mouse_move(cx.listener(|_, _, _, cx| { - cx.notify(); - cx.stop_propagation() - })) - .on_hover(|_, _, cx| { - cx.stop_propagation(); - }) - .on_any_mouse_down(|_, _, cx| { + fn render_vertical_scrollbar(&self, cx: &mut Context) -> Stateful
{ + div() + .occlude() + .id("active-thread-scrollbar") + .on_mouse_move(cx.listener(|_, _, _, cx| { + cx.notify(); + cx.stop_propagation() + })) + .on_hover(|_, _, cx| { + cx.stop_propagation(); + }) + .on_any_mouse_down(|_, _, cx| { + cx.stop_propagation(); + }) + .on_mouse_up( + MouseButton::Left, + cx.listener(|_, _, _, cx| { cx.stop_propagation(); - }) - .on_mouse_up( - MouseButton::Left, - cx.listener(|_, _, _, cx| { - cx.stop_propagation(); - }), - ) - .on_scroll_wheel(cx.listener(|_, _, _, cx| { - cx.notify(); - })) - .h_full() - .absolute() - .right_1() - .top_1() - .bottom_0() - .w(px(12.)) - .cursor_default() - .children(Scrollbar::vertical(self.scrollbar_state.clone())), - ) - } - - fn hide_scrollbar_later(&mut self, cx: &mut Context) { - const SCROLLBAR_SHOW_INTERVAL: Duration = Duration::from_secs(1); - self.hide_scrollbar_task = Some(cx.spawn(async move |thread, cx| { - cx.background_executor() - .timer(SCROLLBAR_SHOW_INTERVAL) - .await; - thread - .update(cx, |thread, cx| { - if !thread.scrollbar_state.is_dragging() { - thread.show_scrollbar = false; - cx.notify(); - } - }) - .log_err(); - })) + }), + ) + .on_scroll_wheel(cx.listener(|_, _, _, cx| { + cx.notify(); + })) + .h_full() + .absolute() + .right_1() + .top_1() + .bottom_0() + .w(px(12.)) + .cursor_default() + .children(Scrollbar::vertical(self.scrollbar_state.clone()).map(|s| s.auto_hide(cx))) } pub fn is_codeblock_expanded(&self, message_id: MessageId, ix: usize) -> bool { @@ -3560,26 +3569,8 @@ impl Render for ActiveThread { .size_full() .relative() .bg(cx.theme().colors().panel_background) - .on_mouse_move(cx.listener(|this, _, _, cx| { - this.show_scrollbar = true; - this.hide_scrollbar_later(cx); - cx.notify(); - })) - .on_scroll_wheel(cx.listener(|this, _, _, cx| { - this.show_scrollbar = true; - this.hide_scrollbar_later(cx); - cx.notify(); - })) - .on_mouse_up( - MouseButton::Left, - cx.listener(|this, _, _, cx| { - this.hide_scrollbar_later(cx); - }), - ) - .child(list(self.list_state.clone()).flex_grow()) - .when_some(self.render_vertical_scrollbar(cx), |this, scrollbar| { - this.child(scrollbar) - }) + .child(list(self.list_state.clone(), cx.processor(Self::render_message)).flex_grow()) + .child(self.render_vertical_scrollbar(cx)) } } @@ -3687,8 +3678,11 @@ pub(crate) fn open_context( AgentContextHandle::Thread(thread_context) => workspace.update(cx, |workspace, cx| { if let Some(panel) = workspace.panel::(cx) { - panel.update(cx, |panel, cx| { - panel.open_thread(thread_context.thread.clone(), window, cx); + let thread = thread_context.thread.clone(); + window.defer(cx, move |window, cx| { + panel.update(cx, |panel, cx| { + panel.open_thread(thread, window, cx); + }); }); } }), @@ -3696,8 +3690,11 @@ pub(crate) fn open_context( AgentContextHandle::TextThread(text_thread_context) => { workspace.update(cx, |workspace, cx| { if let Some(panel) = workspace.panel::(cx) { - panel.update(cx, |panel, cx| { - panel.open_prompt_editor(text_thread_context.context.clone(), window, cx) + let context = text_thread_context.context.clone(); + window.defer(cx, move |window, cx| { + panel.update(cx, |panel, cx| { + panel.open_prompt_editor(context, window, cx) + }); }); } }) @@ -3852,7 +3849,7 @@ mod tests { LanguageModelRegistry::global(cx).update(cx, |registry, cx| { registry.set_default_model( Some(ConfiguredModel { - provider: Arc::new(FakeLanguageModelProvider), + provider: Arc::new(FakeLanguageModelProvider::default()), model, }), cx, @@ -3936,7 +3933,7 @@ mod tests { LanguageModelRegistry::global(cx).update(cx, |registry, cx| { registry.set_default_model( Some(ConfiguredModel { - provider: Arc::new(FakeLanguageModelProvider), + provider: Arc::new(FakeLanguageModelProvider::default()), model: model.clone(), }), cx, diff --git a/crates/agent_ui/src/agent_configuration.rs b/crates/agent_ui/src/agent_configuration.rs index 8bfdd507611112b2930fd07270667050796533e3..02c15b7e4164c298a2c10b854824f9a6f14b521b 100644 --- a/crates/agent_ui/src/agent_configuration.rs +++ b/crates/agent_ui/src/agent_configuration.rs @@ -1,3 +1,4 @@ +mod add_llm_provider_modal; mod configure_context_server_modal; mod manage_profiles_modal; mod tool_picker; @@ -6,6 +7,7 @@ use std::{sync::Arc, time::Duration}; use agent_settings::AgentSettings; use assistant_tool::{ToolSource, ToolWorkingSet}; +use cloud_llm_client::Plan; use collections::HashMap; use context_server::ContextServerId; use extension::ExtensionManifest; @@ -26,8 +28,8 @@ use project::{ }; use settings::{Settings, update_settings_file}; use ui::{ - ContextMenu, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, PopoverMenu, - Scrollbar, ScrollbarState, Switch, SwitchColor, Tooltip, prelude::*, + Chip, ContextMenu, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, PopoverMenu, + Scrollbar, ScrollbarState, Switch, SwitchColor, SwitchField, Tooltip, prelude::*, }; use util::ResultExt as _; use workspace::Workspace; @@ -36,7 +38,10 @@ use zed_actions::ExtensionCategoryFilter; pub(crate) use configure_context_server_modal::ConfigureContextServerModal; pub(crate) use manage_profiles_modal::ManageProfilesModal; -use crate::AddContextServer; +use crate::{ + AddContextServer, + agent_configuration::add_llm_provider_modal::{AddLlmProviderModal, LlmCompatibleProvider}, +}; pub struct AgentConfiguration { fs: Arc, @@ -171,7 +176,24 @@ impl AgentConfiguration { .copied() .unwrap_or(false); + let is_zed_provider = provider.id() == ZED_CLOUD_PROVIDER_ID; + let current_plan = if is_zed_provider { + self.workspace + .upgrade() + .and_then(|workspace| workspace.read(cx).user_store().read(cx).plan()) + } else { + None + }; + + let is_signed_in = self + .workspace + .read_with(cx, |workspace, _| { + workspace.client().status().borrow().is_connected() + }) + .unwrap_or(false); + v_flex() + .w_full() .when(is_expanded, |this| this.mb_2()) .child( div() @@ -202,20 +224,39 @@ impl AgentConfiguration { .hover(|hover| hover.bg(cx.theme().colors().element_hover)) .child( h_flex() + .w_full() .gap_2() .child( Icon::new(provider.icon()) .size(IconSize::Small) .color(Color::Muted), ) - .child(Label::new(provider_name.clone()).size(LabelSize::Large)) - .when( - provider.is_authenticated(cx) && !is_expanded, - |parent| { - parent.child( - Icon::new(IconName::Check).color(Color::Success), + .child( + h_flex() + .w_full() + .gap_1() + .child( + Label::new(provider_name.clone()) + .size(LabelSize::Large), ) - }, + .map(|this| { + if is_zed_provider && is_signed_in { + this.child( + self.render_zed_plan_info(current_plan, cx), + ) + } else { + this.when( + provider.is_authenticated(cx) + && !is_expanded, + |parent| { + parent.child( + Icon::new(IconName::Check) + .color(Color::Success), + ) + }, + ) + } + }), ), ) .child( @@ -276,21 +317,78 @@ impl AgentConfiguration { let providers = LanguageModelRegistry::read_global(cx).providers(); v_flex() + .w_full() .child( - v_flex() + h_flex() .p(DynamicSpacing::Base16.rems(cx)) .pr(DynamicSpacing::Base20.rems(cx)) .pb_0() .mb_2p5() - .gap_0p5() - .child(Headline::new("LLM Providers")) + .items_start() + .justify_between() .child( - Label::new("Add at least one provider to use AI-powered features.") - .color(Color::Muted), + v_flex() + .w_full() + .gap_0p5() + .child( + h_flex() + .w_full() + .gap_2() + .justify_between() + .child(Headline::new("LLM Providers")) + .child( + PopoverMenu::new("add-provider-popover") + .trigger( + Button::new("add-provider", "Add Provider") + .icon_position(IconPosition::Start) + .icon(IconName::Plus) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .label_size(LabelSize::Small), + ) + .anchor(gpui::Corner::TopRight) + .menu({ + let workspace = self.workspace.clone(); + move |window, cx| { + Some(ContextMenu::build( + window, + cx, + |menu, _window, _cx| { + menu.header("Compatible APIs").entry( + "OpenAI", + None, + { + let workspace = + workspace.clone(); + move |window, cx| { + workspace + .update(cx, |workspace, cx| { + AddLlmProviderModal::toggle( + LlmCompatibleProvider::OpenAi, + workspace, + window, + cx, + ); + }) + .log_err(); + } + }, + ) + }, + )) + } + }), + ), + ) + .child( + Label::new("Add at least one provider to use AI-powered features.") + .color(Color::Muted), + ), ), ) .child( div() + .w_full() .pl(DynamicSpacing::Base08.rems(cx)) .pr(DynamicSpacing::Base20.rems(cx)) .children( @@ -303,119 +401,80 @@ impl AgentConfiguration { fn render_command_permission(&mut self, cx: &mut Context) -> impl IntoElement { let always_allow_tool_actions = AgentSettings::get_global(cx).always_allow_tool_actions; + let fs = self.fs.clone(); - h_flex() - .gap_4() - .justify_between() - .flex_wrap() - .child( - v_flex() - .gap_0p5() - .max_w_5_6() - .child(Label::new("Allow running editing tools without asking for confirmation")) - .child( - Label::new( - "The agent can perform potentially destructive actions without asking for your confirmation.", - ) - .color(Color::Muted), - ), - ) - .child( - Switch::new( - "always-allow-tool-actions-switch", - always_allow_tool_actions.into(), - ) - .color(SwitchColor::Accent) - .on_click({ - let fs = self.fs.clone(); - move |state, _window, cx| { - let allow = state == &ToggleState::Selected; - update_settings_file::( - fs.clone(), - cx, - move |settings, _| { - settings.set_always_allow_tool_actions(allow); - }, - ); - } - }), - ) + SwitchField::new( + "always-allow-tool-actions-switch", + "Allow running commands without asking for confirmation", + Some( + "The agent can perform potentially destructive actions without asking for your confirmation.".into(), + ), + always_allow_tool_actions, + move |state, _window, cx| { + let allow = state == &ToggleState::Selected; + update_settings_file::(fs.clone(), cx, move |settings, _| { + settings.set_always_allow_tool_actions(allow); + }); + }, + ) } fn render_single_file_review(&mut self, cx: &mut Context) -> impl IntoElement { let single_file_review = AgentSettings::get_global(cx).single_file_review; + let fs = self.fs.clone(); - h_flex() - .gap_4() - .justify_between() - .flex_wrap() - .child( - v_flex() - .gap_0p5() - .max_w_5_6() - .child(Label::new("Enable single-file agent reviews")) - .child( - Label::new( - "Agent edits are also displayed in single-file editors for review.", - ) - .color(Color::Muted), - ), - ) - .child( - Switch::new("single-file-review-switch", single_file_review.into()) - .color(SwitchColor::Accent) - .on_click({ - let fs = self.fs.clone(); - move |state, _window, cx| { - let allow = state == &ToggleState::Selected; - update_settings_file::( - fs.clone(), - cx, - move |settings, _| { - settings.set_single_file_review(allow); - }, - ); - } - }), - ) + SwitchField::new( + "single-file-review", + "Enable single-file agent reviews", + Some("Agent edits are also displayed in single-file editors for review.".into()), + single_file_review, + move |state, _window, cx| { + let allow = state == &ToggleState::Selected; + update_settings_file::(fs.clone(), cx, move |settings, _| { + settings.set_single_file_review(allow); + }); + }, + ) } fn render_sound_notification(&mut self, cx: &mut Context) -> impl IntoElement { let play_sound_when_agent_done = AgentSettings::get_global(cx).play_sound_when_agent_done; + let fs = self.fs.clone(); - h_flex() - .gap_4() - .justify_between() - .flex_wrap() - .child( - v_flex() - .gap_0p5() - .max_w_5_6() - .child(Label::new("Play sound when finished generating")) - .child( - Label::new( - "Hear a notification sound when the agent is done generating changes or needs your input.", - ) - .color(Color::Muted), - ), - ) - .child( - Switch::new("play-sound-notification-switch", play_sound_when_agent_done.into()) - .color(SwitchColor::Accent) - .on_click({ - let fs = self.fs.clone(); - move |state, _window, cx| { - let allow = state == &ToggleState::Selected; - update_settings_file::( - fs.clone(), - cx, - move |settings, _| { - settings.set_play_sound_when_agent_done(allow); - }, - ); - } - }), - ) + SwitchField::new( + "sound-notification", + "Play sound when finished generating", + Some( + "Hear a notification sound when the agent is done generating changes or needs your input.".into(), + ), + play_sound_when_agent_done, + move |state, _window, cx| { + let allow = state == &ToggleState::Selected; + update_settings_file::(fs.clone(), cx, move |settings, _| { + settings.set_play_sound_when_agent_done(allow); + }); + }, + ) + } + + fn render_modifier_to_send(&mut self, cx: &mut Context) -> impl IntoElement { + let use_modifier_to_send = AgentSettings::get_global(cx).use_modifier_to_send; + let fs = self.fs.clone(); + + SwitchField::new( + "modifier-send", + "Use modifier to submit a message", + Some( + "Make a modifier (cmd-enter on macOS, ctrl-enter on Linux) required to send messages.".into(), + ), + use_modifier_to_send, + move |state, _window, cx| { + let allow = state == &ToggleState::Selected; + update_settings_file::(fs.clone(), cx, move |settings, _| { + settings.set_use_modifier_to_send(allow); + }); + }, + ) } fn render_general_settings_section(&mut self, cx: &mut Context) -> impl IntoElement { @@ -429,6 +488,38 @@ impl AgentConfiguration { .child(self.render_command_permission(cx)) .child(self.render_single_file_review(cx)) .child(self.render_sound_notification(cx)) + .child(self.render_modifier_to_send(cx)) + } + + fn render_zed_plan_info(&self, plan: Option, cx: &mut Context) -> impl IntoElement { + if let Some(plan) = plan { + let free_chip_bg = cx + .theme() + .colors() + .editor_background + .opacity(0.5) + .blend(cx.theme().colors().text_accent.opacity(0.05)); + + let pro_chip_bg = cx + .theme() + .colors() + .editor_background + .opacity(0.5) + .blend(cx.theme().colors().text_accent.opacity(0.2)); + + let (plan_name, label_color, bg_color) = match plan { + Plan::ZedFree => ("Free", Color::Default, free_chip_bg), + Plan::ZedProTrial => ("Pro Trial", Color::Accent, pro_chip_bg), + Plan::ZedPro => ("Pro", Color::Accent, pro_chip_bg), + }; + + Chip::new(plan_name.to_string()) + .bg_color(bg_color) + .label_color(label_color) + .into_any_element() + } else { + div().into_any_element() + } } fn render_context_servers_section( @@ -448,7 +539,7 @@ impl AgentConfiguration { v_flex() .gap_0p5() .child(Headline::new("Model Context Protocol (MCP) Servers")) - .child(Label::new("Connect to context servers via the Model Context Protocol either via Zed extensions or directly.").color(Color::Muted)), + .child(Label::new("Connect to context servers through the Model Context Protocol, either using Zed extensions or directly.").color(Color::Muted)), ) .children( context_server_ids.into_iter().map(|context_server_id| { @@ -491,6 +582,7 @@ impl AgentConfiguration { category_filter: Some( ExtensionCategoryFilter::ContextServers, ), + id: None, } .boxed_clone(), cx, diff --git a/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs b/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs new file mode 100644 index 0000000000000000000000000000000000000000..401a6334886e18ef2e53bbd5b68392597d0db1e9 --- /dev/null +++ b/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs @@ -0,0 +1,635 @@ +use std::sync::Arc; + +use anyhow::Result; +use collections::HashSet; +use fs::Fs; +use gpui::{DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, Task}; +use language_model::LanguageModelRegistry; +use language_models::{ + AllLanguageModelSettings, OpenAiCompatibleSettingsContent, + provider::open_ai_compatible::AvailableModel, +}; +use settings::update_settings_file; +use ui::{Banner, KeyBinding, Modal, ModalFooter, ModalHeader, Section, prelude::*}; +use ui_input::SingleLineInput; +use workspace::{ModalView, Workspace}; + +#[derive(Clone, Copy)] +pub enum LlmCompatibleProvider { + OpenAi, +} + +impl LlmCompatibleProvider { + fn name(&self) -> &'static str { + match self { + LlmCompatibleProvider::OpenAi => "OpenAI", + } + } + + fn api_url(&self) -> &'static str { + match self { + LlmCompatibleProvider::OpenAi => "https://api.openai.com/v1", + } + } +} + +struct AddLlmProviderInput { + provider_name: Entity, + api_url: Entity, + api_key: Entity, + models: Vec, +} + +impl AddLlmProviderInput { + fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut App) -> Self { + let provider_name = single_line_input("Provider Name", provider.name(), None, window, cx); + let api_url = single_line_input("API URL", provider.api_url(), None, window, cx); + let api_key = single_line_input( + "API Key", + "000000000000000000000000000000000000000000000000", + None, + window, + cx, + ); + + Self { + provider_name, + api_url, + api_key, + models: vec![ModelInput::new(window, cx)], + } + } + + fn add_model(&mut self, window: &mut Window, cx: &mut App) { + self.models.push(ModelInput::new(window, cx)); + } + + fn remove_model(&mut self, index: usize) { + self.models.remove(index); + } +} + +struct ModelInput { + name: Entity, + max_completion_tokens: Entity, + max_output_tokens: Entity, + max_tokens: Entity, +} + +impl ModelInput { + fn new(window: &mut Window, cx: &mut App) -> Self { + let model_name = single_line_input( + "Model Name", + "e.g. gpt-4o, claude-opus-4, gemini-2.5-pro", + None, + window, + cx, + ); + let max_completion_tokens = single_line_input( + "Max Completion Tokens", + "200000", + Some("200000"), + window, + cx, + ); + let max_output_tokens = single_line_input( + "Max Output Tokens", + "Max Output Tokens", + Some("32000"), + window, + cx, + ); + let max_tokens = single_line_input("Max Tokens", "Max Tokens", Some("200000"), window, cx); + Self { + name: model_name, + max_completion_tokens, + max_output_tokens, + max_tokens, + } + } + + fn parse(&self, cx: &App) -> Result { + let name = self.name.read(cx).text(cx); + if name.is_empty() { + return Err(SharedString::from("Model Name cannot be empty")); + } + Ok(AvailableModel { + name, + display_name: None, + max_completion_tokens: Some( + self.max_completion_tokens + .read(cx) + .text(cx) + .parse::() + .map_err(|_| SharedString::from("Max Completion Tokens must be a number"))?, + ), + max_output_tokens: Some( + self.max_output_tokens + .read(cx) + .text(cx) + .parse::() + .map_err(|_| SharedString::from("Max Output Tokens must be a number"))?, + ), + max_tokens: self + .max_tokens + .read(cx) + .text(cx) + .parse::() + .map_err(|_| SharedString::from("Max Tokens must be a number"))?, + }) + } +} + +fn single_line_input( + label: impl Into, + placeholder: impl Into, + text: Option<&str>, + window: &mut Window, + cx: &mut App, +) -> Entity { + cx.new(|cx| { + let input = SingleLineInput::new(window, cx, placeholder).label(label); + if let Some(text) = text { + input + .editor() + .update(cx, |editor, cx| editor.set_text(text, window, cx)); + } + input + }) +} + +fn save_provider_to_settings( + input: &AddLlmProviderInput, + cx: &mut App, +) -> Task> { + let provider_name: Arc = input.provider_name.read(cx).text(cx).into(); + if provider_name.is_empty() { + return Task::ready(Err("Provider Name cannot be empty".into())); + } + + if LanguageModelRegistry::read_global(cx) + .providers() + .iter() + .any(|provider| { + provider.id().0.as_ref() == provider_name.as_ref() + || provider.name().0.as_ref() == provider_name.as_ref() + }) + { + return Task::ready(Err( + "Provider Name is already taken by another provider".into() + )); + } + + let api_url = input.api_url.read(cx).text(cx); + if api_url.is_empty() { + return Task::ready(Err("API URL cannot be empty".into())); + } + + let api_key = input.api_key.read(cx).text(cx); + if api_key.is_empty() { + return Task::ready(Err("API Key cannot be empty".into())); + } + + let mut models = Vec::new(); + let mut model_names: HashSet = HashSet::default(); + for model in &input.models { + match model.parse(cx) { + Ok(model) => { + if !model_names.insert(model.name.clone()) { + return Task::ready(Err("Model Names must be unique".into())); + } + models.push(model) + } + Err(err) => return Task::ready(Err(err)), + } + } + + let fs = ::global(cx); + let task = cx.write_credentials(&api_url, "Bearer", api_key.as_bytes()); + cx.spawn(async move |cx| { + task.await + .map_err(|_| "Failed to write API key to keychain")?; + cx.update(|cx| { + update_settings_file::(fs, cx, |settings, _cx| { + settings.openai_compatible.get_or_insert_default().insert( + provider_name, + OpenAiCompatibleSettingsContent { + api_url, + available_models: models, + }, + ); + }); + }) + .ok(); + Ok(()) + }) +} + +pub struct AddLlmProviderModal { + provider: LlmCompatibleProvider, + input: AddLlmProviderInput, + focus_handle: FocusHandle, + last_error: Option, +} + +impl AddLlmProviderModal { + pub fn toggle( + provider: LlmCompatibleProvider, + workspace: &mut Workspace, + window: &mut Window, + cx: &mut Context, + ) { + workspace.toggle_modal(window, cx, |window, cx| Self::new(provider, window, cx)); + } + + fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut Context) -> Self { + Self { + input: AddLlmProviderInput::new(provider, window, cx), + provider, + last_error: None, + focus_handle: cx.focus_handle(), + } + } + + fn confirm(&mut self, _: &menu::Confirm, _: &mut Window, cx: &mut Context) { + let task = save_provider_to_settings(&self.input, cx); + cx.spawn(async move |this, cx| { + let result = task.await; + this.update(cx, |this, cx| match result { + Ok(_) => { + cx.emit(DismissEvent); + } + Err(error) => { + this.last_error = Some(error); + cx.notify(); + } + }) + }) + .detach_and_log_err(cx); + } + + fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context) { + cx.emit(DismissEvent); + } + + fn render_model_section(&self, cx: &mut Context) -> impl IntoElement { + v_flex() + .mt_1() + .gap_2() + .child( + h_flex() + .justify_between() + .child(Label::new("Models").size(LabelSize::Small)) + .child( + Button::new("add-model", "Add Model") + .icon(IconName::Plus) + .icon_position(IconPosition::Start) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .label_size(LabelSize::Small) + .on_click(cx.listener(|this, _, window, cx| { + this.input.add_model(window, cx); + cx.notify(); + })), + ), + ) + .children( + self.input + .models + .iter() + .enumerate() + .map(|(ix, _)| self.render_model(ix, cx)), + ) + } + + fn render_model(&self, ix: usize, cx: &mut Context) -> impl IntoElement + use<> { + let has_more_than_one_model = self.input.models.len() > 1; + let model = &self.input.models[ix]; + + v_flex() + .p_2() + .gap_2() + .rounded_sm() + .border_1() + .border_dashed() + .border_color(cx.theme().colors().border.opacity(0.6)) + .bg(cx.theme().colors().element_active.opacity(0.15)) + .child(model.name.clone()) + .child( + h_flex() + .gap_2() + .child(model.max_completion_tokens.clone()) + .child(model.max_output_tokens.clone()), + ) + .child(model.max_tokens.clone()) + .when(has_more_than_one_model, |this| { + this.child( + Button::new(("remove-model", ix), "Remove Model") + .icon(IconName::Trash) + .icon_position(IconPosition::Start) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .label_size(LabelSize::Small) + .style(ButtonStyle::Outlined) + .full_width() + .on_click(cx.listener(move |this, _, _window, cx| { + this.input.remove_model(ix); + cx.notify(); + })), + ) + }) + } +} + +impl EventEmitter for AddLlmProviderModal {} + +impl Focusable for AddLlmProviderModal { + fn focus_handle(&self, _cx: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl ModalView for AddLlmProviderModal {} + +impl Render for AddLlmProviderModal { + fn render(&mut self, window: &mut ui::Window, cx: &mut ui::Context) -> impl IntoElement { + let focus_handle = self.focus_handle(cx); + + div() + .id("add-llm-provider-modal") + .key_context("AddLlmProviderModal") + .w(rems(34.)) + .elevation_3(cx) + .on_action(cx.listener(Self::cancel)) + .capture_any_mouse_down(cx.listener(|this, _, window, cx| { + this.focus_handle(cx).focus(window); + })) + .child( + Modal::new("configure-context-server", None) + .header(ModalHeader::new().headline("Add LLM Provider").description( + match self.provider { + LlmCompatibleProvider::OpenAi => { + "This provider will use an OpenAI compatible API." + } + }, + )) + .when_some(self.last_error.clone(), |this, error| { + this.section( + Section::new().child( + Banner::new() + .severity(ui::Severity::Warning) + .child(div().text_xs().child(error)), + ), + ) + }) + .child( + v_flex() + .id("modal_content") + .size_full() + .max_h_128() + .overflow_y_scroll() + .px(DynamicSpacing::Base12.rems(cx)) + .gap(DynamicSpacing::Base04.rems(cx)) + .child(self.input.provider_name.clone()) + .child(self.input.api_url.clone()) + .child(self.input.api_key.clone()) + .child(self.render_model_section(cx)), + ) + .footer( + ModalFooter::new().end_slot( + h_flex() + .gap_1() + .child( + Button::new("cancel", "Cancel") + .key_binding( + KeyBinding::for_action_in( + &menu::Cancel, + &focus_handle, + window, + cx, + ) + .map(|kb| kb.size(rems_from_px(12.))), + ) + .on_click(cx.listener(|this, _event, window, cx| { + this.cancel(&menu::Cancel, window, cx) + })), + ) + .child( + Button::new("save-server", "Save Provider") + .key_binding( + KeyBinding::for_action_in( + &menu::Confirm, + &focus_handle, + window, + cx, + ) + .map(|kb| kb.size(rems_from_px(12.))), + ) + .on_click(cx.listener(|this, _event, window, cx| { + this.confirm(&menu::Confirm, window, cx) + })), + ), + ), + ), + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use editor::EditorSettings; + use fs::FakeFs; + use gpui::{TestAppContext, VisualTestContext}; + use language::language_settings; + use language_model::{ + LanguageModelProviderId, LanguageModelProviderName, + fake_provider::FakeLanguageModelProvider, + }; + use project::Project; + use settings::{Settings as _, SettingsStore}; + use util::path; + + #[gpui::test] + async fn test_save_provider_invalid_inputs(cx: &mut TestAppContext) { + let cx = setup_test(cx).await; + + assert_eq!( + save_provider_validation_errors("", "someurl", "somekey", vec![], cx,).await, + Some("Provider Name cannot be empty".into()) + ); + + assert_eq!( + save_provider_validation_errors("someprovider", "", "somekey", vec![], cx,).await, + Some("API URL cannot be empty".into()) + ); + + assert_eq!( + save_provider_validation_errors("someprovider", "someurl", "", vec![], cx,).await, + Some("API Key cannot be empty".into()) + ); + + assert_eq!( + save_provider_validation_errors( + "someprovider", + "someurl", + "somekey", + vec![("", "200000", "200000", "32000")], + cx, + ) + .await, + Some("Model Name cannot be empty".into()) + ); + + assert_eq!( + save_provider_validation_errors( + "someprovider", + "someurl", + "somekey", + vec![("somemodel", "abc", "200000", "32000")], + cx, + ) + .await, + Some("Max Tokens must be a number".into()) + ); + + assert_eq!( + save_provider_validation_errors( + "someprovider", + "someurl", + "somekey", + vec![("somemodel", "200000", "abc", "32000")], + cx, + ) + .await, + Some("Max Completion Tokens must be a number".into()) + ); + + assert_eq!( + save_provider_validation_errors( + "someprovider", + "someurl", + "somekey", + vec![("somemodel", "200000", "200000", "abc")], + cx, + ) + .await, + Some("Max Output Tokens must be a number".into()) + ); + + assert_eq!( + save_provider_validation_errors( + "someprovider", + "someurl", + "somekey", + vec![ + ("somemodel", "200000", "200000", "32000"), + ("somemodel", "200000", "200000", "32000"), + ], + cx, + ) + .await, + Some("Model Names must be unique".into()) + ); + } + + #[gpui::test] + async fn test_save_provider_name_conflict(cx: &mut TestAppContext) { + let cx = setup_test(cx).await; + + cx.update(|_window, cx| { + LanguageModelRegistry::global(cx).update(cx, |registry, cx| { + registry.register_provider( + FakeLanguageModelProvider::new( + LanguageModelProviderId::new("someprovider"), + LanguageModelProviderName::new("Some Provider"), + ), + cx, + ); + }); + }); + + assert_eq!( + save_provider_validation_errors( + "someprovider", + "someurl", + "someapikey", + vec![("somemodel", "200000", "200000", "32000")], + cx, + ) + .await, + Some("Provider Name is already taken by another provider".into()) + ); + } + + async fn setup_test(cx: &mut TestAppContext) -> &mut VisualTestContext { + cx.update(|cx| { + let store = SettingsStore::test(cx); + cx.set_global(store); + workspace::init_settings(cx); + Project::init_settings(cx); + theme::init(theme::LoadThemes::JustBase, cx); + language_settings::init(cx); + EditorSettings::register(cx); + language_model::init_settings(cx); + language_models::init_settings(cx); + }); + + let fs = FakeFs::new(cx.executor()); + cx.update(|cx| ::set_global(fs.clone(), cx)); + let project = Project::test(fs, [path!("/dir").as_ref()], cx).await; + let (_, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + + cx + } + + async fn save_provider_validation_errors( + provider_name: &str, + api_url: &str, + api_key: &str, + models: Vec<(&str, &str, &str, &str)>, + cx: &mut VisualTestContext, + ) -> Option { + fn set_text( + input: &Entity, + text: &str, + window: &mut Window, + cx: &mut App, + ) { + input.update(cx, |input, cx| { + input.editor().update(cx, |editor, cx| { + editor.set_text(text, window, cx); + }); + }); + } + + let task = cx.update(|window, cx| { + let mut input = AddLlmProviderInput::new(LlmCompatibleProvider::OpenAi, window, cx); + set_text(&input.provider_name, provider_name, window, cx); + set_text(&input.api_url, api_url, window, cx); + set_text(&input.api_key, api_key, window, cx); + + for (i, (name, max_tokens, max_completion_tokens, max_output_tokens)) in + models.iter().enumerate() + { + if i >= input.models.len() { + input.models.push(ModelInput::new(window, cx)); + } + let model = &mut input.models[i]; + set_text(&model.name, name, window, cx); + set_text(&model.max_tokens, max_tokens, window, cx); + set_text( + &model.max_completion_tokens, + max_completion_tokens, + window, + cx, + ); + set_text(&model.max_output_tokens, max_output_tokens, window, cx); + } + save_provider_to_settings(&input, cx) + }); + + task.await.err() + } +} diff --git a/crates/agent_ui/src/agent_configuration/configure_context_server_modal.rs b/crates/agent_ui/src/agent_configuration/configure_context_server_modal.rs index 9e5f6e09c82489dd4ccdc89f188e962ceeec596d..06d035d836853068c8ed402ee0e85ff85d9af6b2 100644 --- a/crates/agent_ui/src/agent_configuration/configure_context_server_modal.rs +++ b/crates/agent_ui/src/agent_configuration/configure_context_server_modal.rs @@ -1,4 +1,5 @@ use std::{ + path::PathBuf, sync::{Arc, Mutex}, time::Duration, }; @@ -188,7 +189,7 @@ fn context_server_input(existing: Option<(ContextServerId, ContextServerCommand) } None => ( "some-mcp-server".to_string(), - "".to_string(), + PathBuf::new(), "[]".to_string(), "{}".to_string(), ), @@ -199,13 +200,14 @@ fn context_server_input(existing: Option<(ContextServerId, ContextServerCommand) /// The name of your MCP server "{name}": {{ /// The command which runs the MCP server - "command": "{command}", + "command": "{}", /// The arguments to pass to the MCP server "args": {args}, /// The environment variables to set "env": {env} }} -}}"# +}}"#, + command.display() ) } diff --git a/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs b/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs index 45536ff13b6414f4cfd3f3a00a440770c83bc0fa..5d44bb2d9218cc072cf3d245638b29ff98111861 100644 --- a/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs +++ b/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs @@ -483,7 +483,7 @@ impl ManageProfilesModal { let icon = match mode.profile_id.as_str() { "write" => IconName::Pencil, - "ask" => IconName::MessageBubbles, + "ask" => IconName::Chat, _ => IconName::UserRoundPen, }; diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index 1a0f3ff27d83a98d343985b3f827aab26afd192a..e1ceaf761dbb1e818235610ba232cc9a1d150132 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -1,7 +1,9 @@ use crate::{Keep, KeepAll, OpenAgentDiff, Reject, RejectAll}; -use agent::{Thread, ThreadEvent}; +use acp_thread::{AcpThread, AcpThreadEvent}; +use agent::{Thread, ThreadEvent, ThreadSummary}; use agent_settings::AgentSettings; use anyhow::Result; +use assistant_tool::ActionLog; use buffer_diff::DiffHunkStatus; use collections::{HashMap, HashSet}; use editor::{ @@ -41,16 +43,108 @@ use zed_actions::assistant::ToggleFocus; pub struct AgentDiffPane { multibuffer: Entity, editor: Entity, - thread: Entity, + thread: AgentDiffThread, focus_handle: FocusHandle, workspace: WeakEntity, title: SharedString, _subscriptions: Vec, } +#[derive(PartialEq, Eq, Clone)] +pub enum AgentDiffThread { + Native(Entity), + AcpThread(Entity), +} + +impl AgentDiffThread { + fn project(&self, cx: &App) -> Entity { + match self { + AgentDiffThread::Native(thread) => thread.read(cx).project().clone(), + AgentDiffThread::AcpThread(thread) => thread.read(cx).project().clone(), + } + } + fn action_log(&self, cx: &App) -> Entity { + match self { + AgentDiffThread::Native(thread) => thread.read(cx).action_log().clone(), + AgentDiffThread::AcpThread(thread) => thread.read(cx).action_log().clone(), + } + } + + fn summary(&self, cx: &App) -> ThreadSummary { + match self { + AgentDiffThread::Native(thread) => thread.read(cx).summary().clone(), + AgentDiffThread::AcpThread(thread) => ThreadSummary::Ready(thread.read(cx).title()), + } + } + + fn is_generating(&self, cx: &App) -> bool { + match self { + AgentDiffThread::Native(thread) => thread.read(cx).is_generating(), + AgentDiffThread::AcpThread(thread) => { + thread.read(cx).status() == acp_thread::ThreadStatus::Generating + } + } + } + + fn has_pending_edit_tool_uses(&self, cx: &App) -> bool { + match self { + AgentDiffThread::Native(thread) => thread.read(cx).has_pending_edit_tool_uses(), + AgentDiffThread::AcpThread(thread) => thread.read(cx).has_pending_edit_tool_calls(), + } + } + + fn downgrade(&self) -> WeakAgentDiffThread { + match self { + AgentDiffThread::Native(thread) => WeakAgentDiffThread::Native(thread.downgrade()), + AgentDiffThread::AcpThread(thread) => { + WeakAgentDiffThread::AcpThread(thread.downgrade()) + } + } + } +} + +impl From> for AgentDiffThread { + fn from(entity: Entity) -> Self { + AgentDiffThread::Native(entity) + } +} + +impl From> for AgentDiffThread { + fn from(entity: Entity) -> Self { + AgentDiffThread::AcpThread(entity) + } +} + +#[derive(PartialEq, Eq, Clone)] +pub enum WeakAgentDiffThread { + Native(WeakEntity), + AcpThread(WeakEntity), +} + +impl WeakAgentDiffThread { + pub fn upgrade(&self) -> Option { + match self { + WeakAgentDiffThread::Native(weak) => weak.upgrade().map(AgentDiffThread::Native), + WeakAgentDiffThread::AcpThread(weak) => weak.upgrade().map(AgentDiffThread::AcpThread), + } + } +} + +impl From> for WeakAgentDiffThread { + fn from(entity: WeakEntity) -> Self { + WeakAgentDiffThread::Native(entity) + } +} + +impl From> for WeakAgentDiffThread { + fn from(entity: WeakEntity) -> Self { + WeakAgentDiffThread::AcpThread(entity) + } +} + impl AgentDiffPane { pub fn deploy( - thread: Entity, + thread: impl Into, workspace: WeakEntity, window: &mut Window, cx: &mut App, @@ -61,14 +155,16 @@ impl AgentDiffPane { } pub fn deploy_in_workspace( - thread: Entity, + thread: impl Into, workspace: &mut Workspace, window: &mut Window, cx: &mut Context, ) -> Entity { + let thread = thread.into(); let existing_diff = workspace .items_of_type::(cx) .find(|diff| diff.read(cx).thread == thread); + if let Some(existing_diff) = existing_diff { workspace.activate_item(&existing_diff, true, true, window, cx); existing_diff @@ -81,7 +177,7 @@ impl AgentDiffPane { } pub fn new( - thread: Entity, + thread: AgentDiffThread, workspace: WeakEntity, window: &mut Window, cx: &mut Context, @@ -89,7 +185,7 @@ impl AgentDiffPane { let focus_handle = cx.focus_handle(); let multibuffer = cx.new(|_| MultiBuffer::new(Capability::ReadWrite)); - let project = thread.read(cx).project().clone(); + let project = thread.project(cx).clone(); let editor = cx.new(|cx| { let mut editor = Editor::for_multibuffer(multibuffer.clone(), Some(project.clone()), window, cx); @@ -100,16 +196,27 @@ impl AgentDiffPane { editor }); - let action_log = thread.read(cx).action_log().clone(); + let action_log = thread.action_log(cx).clone(); + let mut this = Self { - _subscriptions: vec![ - cx.observe_in(&action_log, window, |this, _action_log, window, cx| { - this.update_excerpts(window, cx) - }), - cx.subscribe(&thread, |this, _thread, event, cx| { - this.handle_thread_event(event, cx) - }), - ], + _subscriptions: [ + Some( + cx.observe_in(&action_log, window, |this, _action_log, window, cx| { + this.update_excerpts(window, cx) + }), + ), + match &thread { + AgentDiffThread::Native(thread) => { + Some(cx.subscribe(&thread, |this, _thread, event, cx| { + this.handle_thread_event(event, cx) + })) + } + AgentDiffThread::AcpThread(_) => None, + }, + ] + .into_iter() + .flatten() + .collect(), title: SharedString::default(), multibuffer, editor, @@ -123,8 +230,7 @@ impl AgentDiffPane { } fn update_excerpts(&mut self, window: &mut Window, cx: &mut Context) { - let thread = self.thread.read(cx); - let changed_buffers = thread.action_log().read(cx).changed_buffers(cx); + let changed_buffers = self.thread.action_log(cx).read(cx).changed_buffers(cx); let mut paths_to_delete = self.multibuffer.read(cx).paths().collect::>(); for (buffer, diff_handle) in changed_buffers { @@ -211,7 +317,7 @@ impl AgentDiffPane { } fn update_title(&mut self, cx: &mut Context) { - let new_title = self.thread.read(cx).summary().unwrap_or("Agent Changes"); + let new_title = self.thread.summary(cx).unwrap_or("Agent Changes"); if new_title != self.title { self.title = new_title; cx.emit(EditorEvent::TitleChanged); @@ -275,14 +381,15 @@ impl AgentDiffPane { fn keep_all(&mut self, _: &KeepAll, _window: &mut Window, cx: &mut Context) { self.thread - .update(cx, |thread, cx| thread.keep_all_edits(cx)); + .action_log(cx) + .update(cx, |action_log, cx| action_log.keep_all_edits(cx)) } } fn keep_edits_in_selection( editor: &mut Editor, buffer_snapshot: &MultiBufferSnapshot, - thread: &Entity, + thread: &AgentDiffThread, window: &mut Window, cx: &mut Context, ) { @@ -297,7 +404,7 @@ fn keep_edits_in_selection( fn reject_edits_in_selection( editor: &mut Editor, buffer_snapshot: &MultiBufferSnapshot, - thread: &Entity, + thread: &AgentDiffThread, window: &mut Window, cx: &mut Context, ) { @@ -311,7 +418,7 @@ fn reject_edits_in_selection( fn keep_edits_in_ranges( editor: &mut Editor, buffer_snapshot: &MultiBufferSnapshot, - thread: &Entity, + thread: &AgentDiffThread, ranges: Vec>, window: &mut Window, cx: &mut Context, @@ -326,8 +433,8 @@ fn keep_edits_in_ranges( for hunk in &diff_hunks_in_ranges { let buffer = multibuffer.read(cx).buffer(hunk.buffer_id); if let Some(buffer) = buffer { - thread.update(cx, |thread, cx| { - thread.keep_edits_in_range(buffer, hunk.buffer_range.clone(), cx) + thread.action_log(cx).update(cx, |action_log, cx| { + action_log.keep_edits_in_range(buffer, hunk.buffer_range.clone(), cx) }); } } @@ -336,7 +443,7 @@ fn keep_edits_in_ranges( fn reject_edits_in_ranges( editor: &mut Editor, buffer_snapshot: &MultiBufferSnapshot, - thread: &Entity, + thread: &AgentDiffThread, ranges: Vec>, window: &mut Window, cx: &mut Context, @@ -362,8 +469,9 @@ fn reject_edits_in_ranges( for (buffer, ranges) in ranges_by_buffer { thread - .update(cx, |thread, cx| { - thread.reject_edits_in_ranges(buffer, ranges, cx) + .action_log(cx) + .update(cx, |action_log, cx| { + action_log.reject_edits_in_ranges(buffer, ranges, cx) }) .detach_and_log_err(cx); } @@ -461,7 +569,7 @@ impl Item for AgentDiffPane { } fn tab_content(&self, params: TabContentParams, _window: &Window, cx: &App) -> AnyElement { - let summary = self.thread.read(cx).summary().unwrap_or("Agent Changes"); + let summary = self.thread.summary(cx).unwrap_or("Agent Changes"); Label::new(format!("Review: {}", summary)) .color(if params.selected { Color::Default @@ -641,7 +749,7 @@ impl Render for AgentDiffPane { } } -fn diff_hunk_controls(thread: &Entity) -> editor::RenderDiffHunkControlsFn { +fn diff_hunk_controls(thread: &AgentDiffThread) -> editor::RenderDiffHunkControlsFn { let thread = thread.clone(); Arc::new( @@ -676,7 +784,7 @@ fn render_diff_hunk_controls( hunk_range: Range, is_created_file: bool, line_height: Pixels, - thread: &Entity, + thread: &AgentDiffThread, editor: &Entity, window: &mut Window, cx: &mut App, @@ -1112,11 +1220,8 @@ impl Render for AgentDiffToolbar { return Empty.into_any(); }; - let has_pending_edit_tool_use = agent_diff - .read(cx) - .thread - .read(cx) - .has_pending_edit_tool_uses(); + let has_pending_edit_tool_use = + agent_diff.read(cx).thread.has_pending_edit_tool_uses(cx); if has_pending_edit_tool_use { return div().px_2().child(spinner_icon).into_any(); @@ -1187,8 +1292,8 @@ pub enum EditorState { } struct WorkspaceThread { - thread: WeakEntity, - _thread_subscriptions: [Subscription; 2], + thread: WeakAgentDiffThread, + _thread_subscriptions: (Subscription, Subscription), singleton_editors: HashMap, HashMap, Subscription>>, _settings_subscription: Subscription, _workspace_subscription: Option, @@ -1212,23 +1317,23 @@ impl AgentDiff { pub fn set_active_thread( workspace: &WeakEntity, - thread: &Entity, + thread: impl Into, window: &mut Window, cx: &mut App, ) { Self::global(cx).update(cx, |this, cx| { - this.register_active_thread_impl(workspace, thread, window, cx); + this.register_active_thread_impl(workspace, thread.into(), window, cx); }); } fn register_active_thread_impl( &mut self, workspace: &WeakEntity, - thread: &Entity, + thread: AgentDiffThread, window: &mut Window, cx: &mut Context, ) { - let action_log = thread.read(cx).action_log().clone(); + let action_log = thread.action_log(cx).clone(); let action_log_subscription = cx.observe_in(&action_log, window, { let workspace = workspace.clone(); @@ -1237,17 +1342,25 @@ impl AgentDiff { } }); - let thread_subscription = cx.subscribe_in(&thread, window, { - let workspace = workspace.clone(); - move |this, _thread, event, window, cx| { - this.handle_thread_event(&workspace, event, window, cx) - } - }); + let thread_subscription = match &thread { + AgentDiffThread::Native(thread) => cx.subscribe_in(&thread, window, { + let workspace = workspace.clone(); + move |this, _thread, event, window, cx| { + this.handle_native_thread_event(&workspace, event, window, cx) + } + }), + AgentDiffThread::AcpThread(thread) => cx.subscribe_in(&thread, window, { + let workspace = workspace.clone(); + move |this, thread, event, window, cx| { + this.handle_acp_thread_event(&workspace, thread, event, window, cx) + } + }), + }; if let Some(workspace_thread) = self.workspace_threads.get_mut(&workspace) { // replace thread and action log subscription, but keep editors workspace_thread.thread = thread.downgrade(); - workspace_thread._thread_subscriptions = [action_log_subscription, thread_subscription]; + workspace_thread._thread_subscriptions = (action_log_subscription, thread_subscription); self.update_reviewing_editors(&workspace, window, cx); return; } @@ -1272,7 +1385,7 @@ impl AgentDiff { workspace.clone(), WorkspaceThread { thread: thread.downgrade(), - _thread_subscriptions: [action_log_subscription, thread_subscription], + _thread_subscriptions: (action_log_subscription, thread_subscription), singleton_editors: HashMap::default(), _settings_subscription: settings_subscription, _workspace_subscription: workspace_subscription, @@ -1319,7 +1432,7 @@ impl AgentDiff { fn register_review_action( workspace: &mut Workspace, - review: impl Fn(&Entity, &Entity, &mut Window, &mut App) -> PostReviewState + review: impl Fn(&Entity, &AgentDiffThread, &mut Window, &mut App) -> PostReviewState + 'static, this: &Entity, ) { @@ -1338,7 +1451,7 @@ impl AgentDiff { }); } - fn handle_thread_event( + fn handle_native_thread_event( &mut self, workspace: &WeakEntity, event: &ThreadEvent, @@ -1375,11 +1488,46 @@ impl AgentDiff { | ThreadEvent::ToolConfirmationNeeded | ThreadEvent::ToolUseLimitReached | ThreadEvent::CancelEditing - | ThreadEvent::RetriesFailed { .. } | ThreadEvent::ProfileChanged => {} } } + fn handle_acp_thread_event( + &mut self, + workspace: &WeakEntity, + thread: &Entity, + event: &AcpThreadEvent, + window: &mut Window, + cx: &mut Context, + ) { + match event { + AcpThreadEvent::NewEntry => { + if thread + .read(cx) + .entries() + .last() + .map_or(false, |entry| entry.diffs().next().is_some()) + { + self.update_reviewing_editors(workspace, window, cx); + } + } + AcpThreadEvent::EntryUpdated(ix) => { + if thread + .read(cx) + .entries() + .get(*ix) + .map_or(false, |entry| entry.diffs().next().is_some()) + { + self.update_reviewing_editors(workspace, window, cx); + } + } + AcpThreadEvent::Stopped + | AcpThreadEvent::ToolAuthorizationRequired + | AcpThreadEvent::Error + | AcpThreadEvent::ServerExited(_) => {} + } + } + fn handle_workspace_event( &mut self, workspace: &Entity, @@ -1485,7 +1633,7 @@ impl AgentDiff { return; }; - let action_log = thread.read(cx).action_log(); + let action_log = thread.action_log(cx); let changed_buffers = action_log.read(cx).changed_buffers(cx); let mut unaffected = self.reviewing_editors.clone(); @@ -1510,7 +1658,7 @@ impl AgentDiff { multibuffer.add_diff(diff_handle.clone(), cx); }); - let new_state = if thread.read(cx).is_generating() { + let new_state = if thread.is_generating(cx) { EditorState::Generating } else { EditorState::Reviewing @@ -1606,7 +1754,7 @@ impl AgentDiff { fn keep_all( editor: &Entity, - thread: &Entity, + thread: &AgentDiffThread, window: &mut Window, cx: &mut App, ) -> PostReviewState { @@ -1626,7 +1774,7 @@ impl AgentDiff { fn reject_all( editor: &Entity, - thread: &Entity, + thread: &AgentDiffThread, window: &mut Window, cx: &mut App, ) -> PostReviewState { @@ -1646,7 +1794,7 @@ impl AgentDiff { fn keep( editor: &Entity, - thread: &Entity, + thread: &AgentDiffThread, window: &mut Window, cx: &mut App, ) -> PostReviewState { @@ -1659,7 +1807,7 @@ impl AgentDiff { fn reject( editor: &Entity, - thread: &Entity, + thread: &AgentDiffThread, window: &mut Window, cx: &mut App, ) -> PostReviewState { @@ -1682,7 +1830,7 @@ impl AgentDiff { fn review_in_active_editor( &mut self, workspace: &mut Workspace, - review: impl Fn(&Entity, &Entity, &mut Window, &mut App) -> PostReviewState, + review: impl Fn(&Entity, &AgentDiffThread, &mut Window, &mut App) -> PostReviewState, window: &mut Window, cx: &mut Context, ) -> Option>> { @@ -1703,7 +1851,7 @@ impl AgentDiff { if let PostReviewState::AllReviewed = review(&editor, &thread, window, cx) { if let Some(curr_buffer) = editor.read(cx).buffer().read(cx).as_singleton() { - let changed_buffers = thread.read(cx).action_log().read(cx).changed_buffers(cx); + let changed_buffers = thread.action_log(cx).read(cx).changed_buffers(cx); let mut keys = changed_buffers.keys().cycle(); keys.find(|k| *k == &curr_buffer); @@ -1801,8 +1949,9 @@ mod tests { }) .await .unwrap(); - let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); - let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); + let thread = + AgentDiffThread::Native(thread_store.update(cx, |store, cx| store.create_thread(cx))); + let action_log = cx.read(|cx| thread.action_log(cx)); let (workspace, cx) = cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); @@ -1988,8 +2137,9 @@ mod tests { }); // Set the active thread + let thread = AgentDiffThread::Native(thread); cx.update(|window, cx| { - AgentDiff::set_active_thread(&workspace.downgrade(), &thread, window, cx) + AgentDiff::set_active_thread(&workspace.downgrade(), thread.clone(), window, cx) }); let buffer1 = project diff --git a/crates/agent_ui/src/agent_model_selector.rs b/crates/agent_ui/src/agent_model_selector.rs index f7b9157bbb9c07abac6a80dddfc014443165a712..b989e7bf1e9147c7f6beb90b5054120cef7b818f 100644 --- a/crates/agent_ui/src/agent_model_selector.rs +++ b/crates/agent_ui/src/agent_model_selector.rs @@ -1,8 +1,6 @@ use crate::{ ModelUsageContext, - language_model_selector::{ - LanguageModelSelector, ToggleModelSelector, language_model_selector, - }, + language_model_selector::{LanguageModelSelector, language_model_selector}, }; use agent_settings::AgentSettings; use fs::Fs; @@ -12,6 +10,7 @@ use picker::popover_menu::PickerPopoverMenu; use settings::update_settings_file; use std::sync::Arc; use ui::{ButtonLike, PopoverMenuHandle, Tooltip, prelude::*}; +use zed_actions::agent::ToggleModelSelector; pub struct AgentModelSelector { selector: Entity, @@ -96,22 +95,18 @@ impl Render for AgentModelSelector { let model_name = model .as_ref() .map(|model| model.model.name().0) - .unwrap_or_else(|| SharedString::from("No model selected")); - let provider_icon = model - .as_ref() - .map(|model| model.provider.icon()) - .unwrap_or_else(|| IconName::Ai); + .unwrap_or_else(|| SharedString::from("Select a Model")); + + let provider_icon = model.as_ref().map(|model| model.provider.icon()); let focus_handle = self.focus_handle.clone(); PickerPopoverMenu::new( self.selector.clone(), ButtonLike::new("active-model") - .child( - Icon::new(provider_icon) - .color(Color::Muted) - .size(IconSize::XSmall), - ) + .when_some(provider_icon, |this, icon| { + this.child(Icon::new(icon).color(Color::Muted).size(IconSize::XSmall)) + }) .child( Label::new(model_name) .color(Color::Muted) diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 8485c5f09218341081407478bc70650579c44154..6b8e36066baf6de905b045a7ea32ee36263058fb 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -1,14 +1,18 @@ -use std::ops::Range; +use std::cell::RefCell; +use std::ops::{Not, Range}; use std::path::Path; use std::rc::Rc; use std::sync::Arc; use std::time::Duration; +use agent_servers::AgentServer; use db::kvp::{Dismissable, KEY_VALUE_STORE}; use serde::{Deserialize, Serialize}; -use crate::NewAcpThread; -use crate::language_model_selector::ToggleModelSelector; +use crate::NewExternalAgentThread; +use crate::agent_diff::AgentDiffThread; +use crate::message_editor::{MAX_EDITOR_LINES, MIN_EDITOR_LINES}; +use crate::ui::NewThreadButton; use crate::{ AddContextServer, AgentDiffPane, ContinueThread, ContinueWithBurnMode, DeleteRecentlyOpenThread, ExpandMessageEditor, Follow, InlineAssistant, NewTextThread, @@ -25,7 +29,7 @@ use crate::{ render_remaining_tokens, }, thread_history::{HistoryEntryElement, ThreadHistory}, - ui::AgentOnboardingModal, + ui::{AgentOnboardingModal, EndTrialUpsell}, }; use agent::{ Thread, ThreadError, ThreadEvent, ThreadId, ThreadSummary, TokenUsageRatio, @@ -34,27 +38,28 @@ use agent::{ thread_store::{TextThreadStore, ThreadStore}, }; use agent_settings::{AgentDockPosition, AgentSettings, CompletionMode, DefaultView}; +use ai_onboarding::AgentPanelOnboarding; use anyhow::{Result, anyhow}; use assistant_context::{AssistantContext, ContextEvent, ContextSummary}; use assistant_slash_command::SlashCommandWorkingSet; use assistant_tool::ToolWorkingSet; use client::{UserStore, zed_urls}; +use cloud_llm_client::{CompletionIntent, Plan, UsageLimit}; use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer}; use feature_flags::{self, FeatureFlagAppExt}; use fs::Fs; use gpui::{ Action, Animation, AnimationExt as _, AnyElement, App, AsyncWindowContext, ClipboardItem, Corner, DismissEvent, Entity, EventEmitter, ExternalPaths, FocusHandle, Focusable, Hsla, - KeyContext, Pixels, Subscription, Task, UpdateGlobal, WeakEntity, linear_color_stop, - linear_gradient, prelude::*, pulsating_between, + KeyContext, Pixels, Subscription, Task, UpdateGlobal, WeakEntity, prelude::*, + pulsating_between, }; use language::LanguageRegistry; use language_model::{ - ConfigurationError, LanguageModelProviderTosView, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID, + ConfigurationError, ConfiguredModel, LanguageModelProviderTosView, LanguageModelRegistry, }; -use project::{Project, ProjectPath, Worktree}; +use project::{DisableAiSettings, Project, ProjectPath, Worktree}; use prompt_store::{PromptBuilder, PromptStore, UserPromptId}; -use proto::Plan; use rules_library::{RulesLibrary, open_rules_library}; use search::{BufferSearchBar, buffer_search}; use settings::{Settings, update_settings_file}; @@ -62,8 +67,8 @@ use theme::ThemeSettings; use time::UtcOffset; use ui::utils::WithRemSize; use ui::{ - Banner, Callout, CheckboxWithLabel, ContextMenu, ElevationIndex, KeyBinding, PopoverMenu, - PopoverMenuHandle, ProgressBar, Tab, Tooltip, Vector, VectorName, prelude::*, + Banner, Callout, ContextMenu, ContextMenuEntry, ElevationIndex, KeyBinding, PopoverMenu, + PopoverMenuHandle, ProgressBar, Tab, Tooltip, prelude::*, }; use util::ResultExt as _; use workspace::{ @@ -72,10 +77,9 @@ use workspace::{ }; use zed_actions::{ DecreaseBufferFontSize, IncreaseBufferFontSize, ResetBufferFontSize, - agent::{OpenConfiguration, OpenOnboardingModal, ResetOnboarding}, + agent::{OpenOnboardingModal, OpenSettings, ResetOnboarding, ToggleModelSelector}, assistant::{OpenRulesLibrary, ToggleFocus}, }; -use zed_llm_client::{CompletionIntent, UsageLimit}; const AGENT_PANEL_KEY: &str = "agent_panel"; @@ -100,7 +104,7 @@ pub fn init(cx: &mut App) { panel.update(cx, |panel, cx| panel.open_history(window, cx)); } }) - .register_action(|workspace, _: &OpenConfiguration, window, cx| { + .register_action(|workspace, _: &OpenSettings, window, cx| { if let Some(panel) = workspace.panel::(cx) { workspace.focus_panel::(window, cx); panel.update(cx, |panel, cx| panel.open_configuration(window, cx)); @@ -112,10 +116,12 @@ pub fn init(cx: &mut App) { panel.update(cx, |panel, cx| panel.new_prompt_editor(window, cx)); } }) - .register_action(|workspace, _: &NewAcpThread, window, cx| { + .register_action(|workspace, action: &NewExternalAgentThread, window, cx| { if let Some(panel) = workspace.panel::(cx) { workspace.focus_panel::(window, cx); - panel.update(cx, |panel, cx| panel.new_gemini_thread(window, cx)); + panel.update(cx, |panel, cx| { + panel.new_external_thread(action.agent, window, cx) + }); } }) .register_action(|workspace, action: &OpenRulesLibrary, window, cx| { @@ -134,7 +140,7 @@ pub fn init(cx: &mut App) { let thread = thread.read(cx).thread().clone(); AgentDiffPane::deploy_in_workspace(thread, workspace, window, cx); } - ActiveView::AcpThread { .. } + ActiveView::ExternalAgentThread { .. } | ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {} @@ -181,7 +187,7 @@ pub fn init(cx: &mut App) { window.refresh(); }) .register_action(|_workspace, _: &ResetTrialUpsell, _window, cx| { - Upsell::set_dismissed(false, cx); + OnboardingUpsell::set_dismissed(false, cx); }) .register_action(|_workspace, _: &ResetTrialEndUpsell, _window, cx| { TrialEndUpsell::set_dismissed(false, cx); @@ -198,7 +204,7 @@ enum ActiveView { message_editor: Entity, _subscriptions: Vec, }, - AcpThread { + ExternalAgentThread { thread_view: Entity, }, TextThread { @@ -220,9 +226,9 @@ enum WhichFontSize { impl ActiveView { pub fn which_font_size_used(&self) -> WhichFontSize { match self { - ActiveView::Thread { .. } | ActiveView::AcpThread { .. } | ActiveView::History => { - WhichFontSize::AgentFont - } + ActiveView::Thread { .. } + | ActiveView::ExternalAgentThread { .. } + | ActiveView::History => WhichFontSize::AgentFont, ActiveView::TextThread { .. } => WhichFontSize::BufferFont, ActiveView::Configuration => WhichFontSize::None, } @@ -253,7 +259,7 @@ impl ActiveView { thread.scroll_to_bottom(cx); }); } - ActiveView::AcpThread { .. } => {} + ActiveView::ExternalAgentThread { .. } => {} ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {} @@ -432,6 +438,8 @@ pub struct AgentPanel { configuration_subscription: Option, local_timezone: UtcOffset, active_view: ActiveView, + acp_message_history: + Rc>>>, previous_view: Option, history_store: Entity, history: Entity, @@ -444,7 +452,7 @@ pub struct AgentPanel { height: Option, zoomed: bool, pending_serialization: Option>>, - hide_upsell: bool, + onboarding: Entity, } impl AgentPanel { @@ -546,6 +554,7 @@ impl AgentPanel { let user_store = workspace.app_state().user_store.clone(); let project = workspace.project(); let language_registry = project.read(cx).languages().clone(); + let client = workspace.client().clone(); let workspace = workspace.weak_handle(); let weak_self = cx.entity().downgrade(); @@ -554,31 +563,32 @@ impl AgentPanel { let inline_assist_context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), Some(thread_store.downgrade()))); + let thread_id = thread.read(cx).id().clone(); + + let history_store = cx.new(|cx| { + HistoryStore::new( + thread_store.clone(), + context_store.clone(), + [HistoryEntryId::Thread(thread_id)], + cx, + ) + }); + let message_editor = cx.new(|cx| { MessageEditor::new( fs.clone(), workspace.clone(), - user_store.clone(), message_editor_context_store.clone(), prompt_store.clone(), thread_store.downgrade(), context_store.downgrade(), + Some(history_store.downgrade()), thread.clone(), window, cx, ) }); - let thread_id = thread.read(cx).id().clone(); - let history_store = cx.new(|cx| { - HistoryStore::new( - thread_store.clone(), - context_store.clone(), - [HistoryEntryId::Thread(thread_id)], - cx, - ) - }); - cx.observe(&history_store, |_, _, cx| cx.notify()).detach(); let active_thread = cx.new(|cx| { @@ -624,7 +634,7 @@ impl AgentPanel { } }; - AgentDiff::set_active_thread(&workspace, &thread, window, cx); + AgentDiff::set_active_thread(&workspace, thread.clone(), window, cx); let weak_panel = weak_self.clone(); @@ -670,7 +680,7 @@ impl AgentPanel { .clone() .update(cx, |thread, cx| thread.get_or_init_configured_model(cx)); } - ActiveView::AcpThread { .. } + ActiveView::ExternalAgentThread { .. } | ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {} @@ -679,6 +689,17 @@ impl AgentPanel { }, ); + let onboarding = cx.new(|cx| { + AgentPanelOnboarding::new( + user_store.clone(), + client, + |_window, cx| { + OnboardingUpsell::set_dismissed(true, cx); + }, + cx, + ) + }); + Self { active_view, workspace, @@ -698,6 +719,7 @@ impl AgentPanel { .unwrap(), inline_assist_context_store, previous_view: None, + acp_message_history: Default::default(), history_store: history_store.clone(), history: cx.new(|cx| ThreadHistory::new(weak_self, history_store, window, cx)), hovered_recent_history_item: None, @@ -709,7 +731,7 @@ impl AgentPanel { height: None, zoomed: false, pending_serialization: None, - hide_upsell: false, + onboarding, } } @@ -722,6 +744,7 @@ impl AgentPanel { if workspace .panel::(cx) .is_some_and(|panel| panel.read(cx).enabled(cx)) + && !DisableAiSettings::get_global(cx).disable_ai { workspace.toggle_panel_focus::(window, cx); } @@ -752,7 +775,7 @@ impl AgentPanel { ActiveView::Thread { thread, .. } => { thread.update(cx, |thread, cx| thread.cancel_last_completion(window, cx)); } - ActiveView::AcpThread { thread_view, .. } => { + ActiveView::ExternalAgentThread { thread_view, .. } => { thread_view.update(cx, |thread_element, cx| thread_element.cancel(cx)); } ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {} @@ -762,7 +785,7 @@ impl AgentPanel { fn active_message_editor(&self) -> Option<&Entity> { match &self.active_view { ActiveView::Thread { message_editor, .. } => Some(message_editor), - ActiveView::AcpThread { .. } + ActiveView::ExternalAgentThread { .. } | ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => None, @@ -823,11 +846,11 @@ impl AgentPanel { MessageEditor::new( self.fs.clone(), self.workspace.clone(), - self.user_store.clone(), context_store.clone(), self.prompt_store.clone(), self.thread_store.downgrade(), self.context_store.downgrade(), + Some(self.history_store.downgrade()), thread.clone(), window, cx, @@ -845,7 +868,7 @@ impl AgentPanel { let thread_view = ActiveView::thread(active_thread.clone(), message_editor, window, cx); self.set_active_view(thread_view, window, cx); - AgentDiff::set_active_thread(&self.workspace, &thread, window, cx); + AgentDiff::set_active_thread(&self.workspace, thread.clone(), window, cx); } fn new_prompt_editor(&mut self, window: &mut Window, cx: &mut Context) { @@ -884,19 +907,73 @@ impl AgentPanel { context_editor.focus_handle(cx).focus(window); } - fn new_gemini_thread(&mut self, window: &mut Window, cx: &mut Context) { + fn new_external_thread( + &mut self, + agent_choice: Option, + window: &mut Window, + cx: &mut Context, + ) { let workspace = self.workspace.clone(); let project = self.project.clone(); + let message_history = self.acp_message_history.clone(); + + const LAST_USED_EXTERNAL_AGENT_KEY: &str = "agent_panel__last_used_external_agent"; + + #[derive(Default, Serialize, Deserialize)] + struct LastUsedExternalAgent { + agent: crate::ExternalAgent, + } cx.spawn_in(window, async move |this, cx| { - let thread_view = cx.new_window_entity(|window, cx| { - crate::acp::AcpThreadView::new(workspace, project, window, cx) - })?; + let server: Rc = match agent_choice { + Some(agent) => { + cx.background_spawn(async move { + if let Some(serialized) = + serde_json::to_string(&LastUsedExternalAgent { agent }).log_err() + { + KEY_VALUE_STORE + .write_kvp(LAST_USED_EXTERNAL_AGENT_KEY.to_string(), serialized) + .await + .log_err(); + } + }) + .detach(); + + agent.server() + } + None => cx + .background_spawn(async move { + KEY_VALUE_STORE.read_kvp(LAST_USED_EXTERNAL_AGENT_KEY) + }) + .await + .log_err() + .flatten() + .and_then(|value| { + serde_json::from_str::(&value).log_err() + }) + .unwrap_or_default() + .agent + .server(), + }; + this.update_in(cx, |this, window, cx| { - this.set_active_view(ActiveView::AcpThread { thread_view }, window, cx); + let thread_view = cx.new(|cx| { + crate::acp::AcpThreadView::new( + server, + workspace.clone(), + project, + message_history, + MIN_EDITOR_LINES, + Some(MAX_EDITOR_LINES), + window, + cx, + ) + }); + + this.set_active_view(ActiveView::ExternalAgentThread { thread_view }, window, cx); }) }) - .detach(); + .detach_and_log_err(cx); } fn deploy_rules_library( @@ -1036,11 +1113,11 @@ impl AgentPanel { MessageEditor::new( self.fs.clone(), self.workspace.clone(), - self.user_store.clone(), context_store, self.prompt_store.clone(), self.thread_store.downgrade(), self.context_store.downgrade(), + Some(self.history_store.downgrade()), thread.clone(), window, cx, @@ -1050,7 +1127,7 @@ impl AgentPanel { let thread_view = ActiveView::thread(active_thread.clone(), message_editor, window, cx); self.set_active_view(thread_view, window, cx); - AgentDiff::set_active_thread(&self.workspace, &thread, window, cx); + AgentDiff::set_active_thread(&self.workspace, thread.clone(), window, cx); } pub fn go_back(&mut self, _: &workspace::GoBack, window: &mut Window, cx: &mut Context) { @@ -1063,7 +1140,7 @@ impl AgentPanel { ActiveView::Thread { message_editor, .. } => { message_editor.focus_handle(cx).focus(window); } - ActiveView::AcpThread { thread_view } => { + ActiveView::ExternalAgentThread { thread_view } => { thread_view.focus_handle(cx).focus(window); } ActiveView::TextThread { context_editor, .. } => { @@ -1181,11 +1258,16 @@ impl AgentPanel { let thread = thread.read(cx).thread().clone(); self.workspace .update(cx, |workspace, cx| { - AgentDiffPane::deploy_in_workspace(thread, workspace, window, cx) + AgentDiffPane::deploy_in_workspace( + AgentDiffThread::Native(thread), + workspace, + window, + cx, + ) }) .log_err(); } - ActiveView::AcpThread { .. } + ActiveView::ExternalAgentThread { .. } | ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {} @@ -1241,7 +1323,7 @@ impl AgentPanel { ) .detach_and_log_err(cx); } - ActiveView::AcpThread { thread_view } => { + ActiveView::ExternalAgentThread { thread_view } => { thread_view .update(cx, |thread_view, cx| { thread_view.open_thread_as_markdown(workspace, window, cx) @@ -1275,6 +1357,19 @@ impl AgentPanel { } self.new_thread(&NewThread::default(), window, cx); + if let Some((thread, model)) = + self.active_thread(cx).zip(provider.default_model(cx)) + { + thread.update(cx, |thread, cx| { + thread.set_configured_model( + Some(ConfiguredModel { + provider: provider.clone(), + model, + }), + cx, + ); + }); + } } } } @@ -1402,7 +1497,7 @@ impl AgentPanel { } }) } - ActiveView::AcpThread { .. } => {} + ActiveView::ExternalAgentThread { .. } => {} ActiveView::History | ActiveView::Configuration => {} } @@ -1417,6 +1512,8 @@ impl AgentPanel { self.active_view = new_view; } + self.acp_message_history.borrow_mut().reset_position(); + self.focus_handle(cx).focus(window); } @@ -1489,7 +1586,7 @@ impl Focusable for AgentPanel { fn focus_handle(&self, cx: &App) -> FocusHandle { match &self.active_view { ActiveView::Thread { message_editor, .. } => message_editor.focus_handle(cx), - ActiveView::AcpThread { thread_view, .. } => thread_view.focus_handle(cx), + ActiveView::ExternalAgentThread { thread_view, .. } => thread_view.focus_handle(cx), ActiveView::History => self.history.focus_handle(cx), ActiveView::TextThread { context_editor, .. } => context_editor.focus_handle(cx), ActiveView::Configuration => { @@ -1579,7 +1676,7 @@ impl Panel for AgentPanel { } fn enabled(&self, cx: &App) -> bool { - AgentSettings::get_global(cx).enabled + DisableAiSettings::get_global(cx).disable_ai.not() && AgentSettings::get_global(cx).enabled } fn is_zoomed(&self, _window: &Window, _cx: &App) -> bool { @@ -1646,9 +1743,11 @@ impl AgentPanel { .into_any_element(), } } - ActiveView::AcpThread { thread_view } => Label::new(thread_view.read(cx).title(cx)) - .truncate() - .into_any_element(), + ActiveView::ExternalAgentThread { thread_view } => { + Label::new(thread_view.read(cx).title(cx)) + .truncate() + .into_any_element() + } ActiveView::TextThread { title_editor, context_editor, @@ -1775,15 +1874,15 @@ impl AgentPanel { }), ); - let zoom_in_label = if self.is_zoomed(window, cx) { - "Zoom Out" + let full_screen_label = if self.is_zoomed(window, cx) { + "Disable Full Screen" } else { - "Zoom In" + "Enable Full Screen" }; let active_thread = match &self.active_view { ActiveView::Thread { thread, .. } => Some(thread.read(cx).thread().clone()), - ActiveView::AcpThread { .. } + ActiveView::ExternalAgentThread { .. } | ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => None, @@ -1796,35 +1895,112 @@ impl AgentPanel { ) .anchor(Corner::TopRight) .with_handle(self.new_thread_menu_handle.clone()) - .menu(move |window, cx| { - let active_thread = active_thread.clone(); - Some(ContextMenu::build(window, cx, |mut menu, _window, cx| { - menu = menu - .when(cx.has_flag::(), |this| { - this.header("Zed Agent") - }) - .action("New Thread", NewThread::default().boxed_clone()) - .action("New Text Thread", NewTextThread.boxed_clone()) - .when_some(active_thread, |this, active_thread| { - let thread = active_thread.read(cx); - if !thread.is_empty() { - this.action( - "New From Summary", - Box::new(NewThread { - from_thread_id: Some(thread.id().clone()), + .menu({ + let focus_handle = focus_handle.clone(); + move |window, cx| { + let active_thread = active_thread.clone(); + Some(ContextMenu::build(window, cx, |mut menu, _window, cx| { + menu = menu + .context(focus_handle.clone()) + .when(cx.has_flag::(), |this| { + this.header("Zed Agent") + }) + .when_some(active_thread, |this, active_thread| { + let thread = active_thread.read(cx); + + if !thread.is_empty() { + let thread_id = thread.id().clone(); + this.item( + ContextMenuEntry::new("New From Summary") + .icon(IconName::ThreadFromSummary) + .icon_color(Color::Muted) + .handler(move |window, cx| { + window.dispatch_action( + Box::new(NewThread { + from_thread_id: Some(thread_id.clone()), + }), + cx, + ); + }), + ) + } else { + this + } + }) + .item( + ContextMenuEntry::new("New Thread") + .icon(IconName::Thread) + .icon_color(Color::Muted) + .action(NewThread::default().boxed_clone()) + .handler(move |window, cx| { + window.dispatch_action( + NewThread::default().boxed_clone(), + cx, + ); }), - ) - } else { - this - } - }) - .when(cx.has_flag::(), |this| { - this.separator() - .header("External Agents") - .action("New Gemini Thread", NewAcpThread.boxed_clone()) - }); - menu - })) + ) + .item( + ContextMenuEntry::new("New Text Thread") + .icon(IconName::TextThread) + .icon_color(Color::Muted) + .action(NewTextThread.boxed_clone()) + .handler(move |window, cx| { + window.dispatch_action(NewTextThread.boxed_clone(), cx); + }), + ) + .when(cx.has_flag::(), |this| { + this.separator() + .header("External Agents") + .item( + ContextMenuEntry::new("New Gemini Thread") + .icon(IconName::AiGemini) + .icon_color(Color::Muted) + .handler(move |window, cx| { + window.dispatch_action( + NewExternalAgentThread { + agent: Some(crate::ExternalAgent::Gemini), + } + .boxed_clone(), + cx, + ); + }), + ) + .item( + ContextMenuEntry::new("New Claude Code Thread") + .icon(IconName::AiClaude) + .icon_color(Color::Muted) + .handler(move |window, cx| { + window.dispatch_action( + NewExternalAgentThread { + agent: Some( + crate::ExternalAgent::ClaudeCode, + ), + } + .boxed_clone(), + cx, + ); + }), + ) + .item( + ContextMenuEntry::new("New Native Agent Thread") + .icon(IconName::ZedAssistant) + .icon_color(Color::Muted) + .handler(move |window, cx| { + window.dispatch_action( + NewExternalAgentThread { + agent: Some( + crate::ExternalAgent::NativeAgent, + ), + } + .boxed_clone(), + cx, + ); + }), + ) + }); + menu + })) + } }); let agent_panel_menu = PopoverMenu::new("agent-options-menu") @@ -1846,64 +2022,70 @@ impl AgentPanel { ) .anchor(Corner::TopRight) .with_handle(self.agent_panel_menu_handle.clone()) - .menu(move |window, cx| { - Some(ContextMenu::build(window, cx, |mut menu, _window, _| { - if let Some(usage) = usage { - menu = menu - .header_with_link("Prompt Usage", "Manage", account_url.clone()) - .custom_entry( - move |_window, cx| { - let used_percentage = match usage.limit { - UsageLimit::Limited(limit) => { - Some((usage.amount as f32 / limit as f32) * 100.) - } - UsageLimit::Unlimited => None, - }; + .menu({ + let focus_handle = focus_handle.clone(); + move |window, cx| { + Some(ContextMenu::build(window, cx, |mut menu, _window, _| { + menu = menu.context(focus_handle.clone()); + if let Some(usage) = usage { + menu = menu + .header_with_link("Prompt Usage", "Manage", account_url.clone()) + .custom_entry( + move |_window, cx| { + let used_percentage = match usage.limit { + UsageLimit::Limited(limit) => { + Some((usage.amount as f32 / limit as f32) * 100.) + } + UsageLimit::Unlimited => None, + }; + + h_flex() + .flex_1() + .gap_1p5() + .children(used_percentage.map(|percent| { + ProgressBar::new("usage", percent, 100., cx) + })) + .child( + Label::new(match usage.limit { + UsageLimit::Limited(limit) => { + format!("{} / {limit}", usage.amount) + } + UsageLimit::Unlimited => { + format!("{} / ∞", usage.amount) + } + }) + .size(LabelSize::Small) + .color(Color::Muted), + ) + .into_any_element() + }, + move |_, cx| cx.open_url(&zed_urls::account_url(cx)), + ) + .separator() + } - h_flex() - .flex_1() - .gap_1p5() - .children(used_percentage.map(|percent| { - ProgressBar::new("usage", percent, 100., cx) - })) - .child( - Label::new(match usage.limit { - UsageLimit::Limited(limit) => { - format!("{} / {limit}", usage.amount) - } - UsageLimit::Unlimited => { - format!("{} / ∞", usage.amount) - } - }) - .size(LabelSize::Small) - .color(Color::Muted), - ) - .into_any_element() - }, - move |_, cx| cx.open_url(&zed_urls::account_url(cx)), + menu = menu + .header("MCP Servers") + .action( + "View Server Extensions", + Box::new(zed_actions::Extensions { + category_filter: Some( + zed_actions::ExtensionCategoryFilter::ContextServers, + ), + id: None, + }), ) - .separator() - } + .action("Add Custom Server…", Box::new(AddContextServer)) + .separator(); - menu = menu - .header("MCP Servers") - .action( - "View Server Extensions", - Box::new(zed_actions::Extensions { - category_filter: Some( - zed_actions::ExtensionCategoryFilter::ContextServers, - ), - }), - ) - .action("Add Custom Server…", Box::new(AddContextServer)) - .separator(); - - menu = menu - .action("Rules…", Box::new(OpenRulesLibrary::default())) - .action("Settings", Box::new(OpenConfiguration)) - .action(zoom_in_label, Box::new(ToggleZoom)); - menu - })) + menu = menu + .action("Rules…", Box::new(OpenRulesLibrary::default())) + .action("Settings", Box::new(OpenSettings)) + .separator() + .action(full_screen_label, Box::new(ToggleZoom)); + menu + })) + } }); h_flex() @@ -1946,48 +2128,45 @@ impl AgentPanel { } fn render_token_count(&self, cx: &App) -> Option { - let (active_thread, message_editor) = match &self.active_view { + match &self.active_view { ActiveView::Thread { thread, message_editor, .. - } => (thread.read(cx), message_editor.read(cx)), - ActiveView::AcpThread { .. } => { - return None; - } - ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => { - return None; - } - }; + } => { + let active_thread = thread.read(cx); + let message_editor = message_editor.read(cx); - let editor_empty = message_editor.is_editor_fully_empty(cx); + let editor_empty = message_editor.is_editor_fully_empty(cx); - if active_thread.is_empty() && editor_empty { - return None; - } + if active_thread.is_empty() && editor_empty { + return None; + } - let thread = active_thread.thread().read(cx); - let is_generating = thread.is_generating(); - let conversation_token_usage = thread.total_token_usage()?; + let thread = active_thread.thread().read(cx); + let is_generating = thread.is_generating(); + let conversation_token_usage = thread.total_token_usage()?; - let (total_token_usage, is_estimating) = - if let Some((editing_message_id, unsent_tokens)) = active_thread.editing_message_id() { - let combined = thread - .token_usage_up_to_message(editing_message_id) - .add(unsent_tokens); + let (total_token_usage, is_estimating) = + if let Some((editing_message_id, unsent_tokens)) = + active_thread.editing_message_id() + { + let combined = thread + .token_usage_up_to_message(editing_message_id) + .add(unsent_tokens); - (combined, unsent_tokens > 0) - } else { - let unsent_tokens = message_editor.last_estimated_token_count().unwrap_or(0); - let combined = conversation_token_usage.add(unsent_tokens); + (combined, unsent_tokens > 0) + } else { + let unsent_tokens = + message_editor.last_estimated_token_count().unwrap_or(0); + let combined = conversation_token_usage.add(unsent_tokens); - (combined, unsent_tokens > 0) - }; + (combined, unsent_tokens > 0) + }; - let is_waiting_to_update_token_count = message_editor.is_waiting_to_update_token_count(); + let is_waiting_to_update_token_count = + message_editor.is_waiting_to_update_token_count(); - match &self.active_view { - ActiveView::Thread { .. } => { if total_token_usage.total == 0 { return None; } @@ -2064,7 +2243,11 @@ impl AgentPanel { Some(element.into_any_element()) } - _ => None, + ActiveView::ExternalAgentThread { .. } + | ActiveView::History + | ActiveView::Configuration => { + return None; + } } } @@ -2073,191 +2256,101 @@ impl AgentPanel { return false; } - let plan = self.user_store.read(cx).current_plan(); - let has_previous_trial = self.user_store.read(cx).trial_started_at().is_some(); - - matches!(plan, Some(Plan::Free)) && has_previous_trial - } - - fn should_render_upsell(&self, cx: &mut Context) -> bool { match &self.active_view { ActiveView::Thread { thread, .. } => { - let is_using_zed_provider = thread + if thread .read(cx) .thread() .read(cx) .configured_model() - .map_or(false, |model| model.provider.id() == ZED_CLOUD_PROVIDER_ID); - - if !is_using_zed_provider { + .map_or(false, |model| { + model.provider.id() != language_model::ZED_CLOUD_PROVIDER_ID + }) + { return false; } } - ActiveView::AcpThread { .. } => { - return false; - } - ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => { - return false; + ActiveView::TextThread { .. } => { + if LanguageModelRegistry::global(cx) + .read(cx) + .default_model() + .map_or(false, |model| { + model.provider.id() != language_model::ZED_CLOUD_PROVIDER_ID + }) + { + return false; + } } - }; - - if self.hide_upsell || Upsell::dismissed() { - return false; - } - - let plan = self.user_store.read(cx).current_plan(); - if matches!(plan, Some(Plan::ZedPro | Plan::ZedProTrial)) { - return false; + ActiveView::ExternalAgentThread { .. } + | ActiveView::History + | ActiveView::Configuration => return false, } + let plan = self.user_store.read(cx).plan(); let has_previous_trial = self.user_store.read(cx).trial_started_at().is_some(); - if has_previous_trial { + + matches!(plan, Some(Plan::ZedFree)) && has_previous_trial + } + + fn should_render_onboarding(&self, cx: &mut Context) -> bool { + if OnboardingUpsell::dismissed() { return false; } - true + match &self.active_view { + ActiveView::Thread { .. } | ActiveView::TextThread { .. } => { + let history_is_empty = self + .history_store + .update(cx, |store, cx| store.recent_entries(1, cx).is_empty()); + + let has_configured_non_zed_providers = LanguageModelRegistry::read_global(cx) + .providers() + .iter() + .any(|provider| { + provider.is_authenticated(cx) + && provider.id() != language_model::ZED_CLOUD_PROVIDER_ID + }); + + history_is_empty || !has_configured_non_zed_providers + } + ActiveView::ExternalAgentThread { .. } + | ActiveView::History + | ActiveView::Configuration => false, + } } - fn render_upsell( + fn render_onboarding( &self, _window: &mut Window, cx: &mut Context, ) -> Option { - if !self.should_render_upsell(cx) { + if !self.should_render_onboarding(cx) { return None; } - if self.user_store.read(cx).account_too_young() { - Some(self.render_young_account_upsell(cx).into_any_element()) - } else { - Some(self.render_trial_upsell(cx).into_any_element()) - } - } - - fn render_young_account_upsell(&self, cx: &mut Context) -> impl IntoElement { - let checkbox = CheckboxWithLabel::new( - "dont-show-again", - Label::new("Don't show again").color(Color::Muted), - ToggleState::Unselected, - move |toggle_state, _window, cx| { - let toggle_state_bool = toggle_state.selected(); - - Upsell::set_dismissed(toggle_state_bool, cx); - }, - ); - - let contents = div() - .size_full() - .gap_2() - .flex() - .flex_col() - .child(Headline::new("Build better with Zed Pro").size(HeadlineSize::Small)) - .child( - Label::new("Your GitHub account was created less than 30 days ago, so we can't offer you a free trial.") - .size(LabelSize::Small), - ) - .child( - Label::new( - "Use your own API keys, upgrade to Zed Pro or send an email to billing-support@zed.dev.", - ) - .color(Color::Muted), - ) - .child( - h_flex() - .w_full() - .px_neg_1() - .justify_between() - .items_center() - .child(h_flex().items_center().gap_1().child(checkbox)) - .child( - h_flex() - .gap_2() - .child( - Button::new("dismiss-button", "Not Now") - .style(ButtonStyle::Transparent) - .color(Color::Muted) - .on_click({ - let agent_panel = cx.entity(); - move |_, _, cx| { - agent_panel.update(cx, |this, cx| { - this.hide_upsell = true; - cx.notify(); - }); - } - }), - ) - .child( - Button::new("cta-button", "Upgrade to Zed Pro") - .style(ButtonStyle::Transparent) - .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx))), - ), - ), - ); + let thread_view = matches!(&self.active_view, ActiveView::Thread { .. }); + let text_thread_view = matches!(&self.active_view, ActiveView::TextThread { .. }); - self.render_upsell_container(cx, contents) + Some( + div() + .when(thread_view, |this| { + this.size_full().bg(cx.theme().colors().panel_background) + }) + .when(text_thread_view, |this| { + this.bg(cx.theme().colors().editor_background) + }) + .child(self.onboarding.clone()), + ) } - fn render_trial_upsell(&self, cx: &mut Context) -> impl IntoElement { - let checkbox = CheckboxWithLabel::new( - "dont-show-again", - Label::new("Don't show again").color(Color::Muted), - ToggleState::Unselected, - move |toggle_state, _window, cx| { - let toggle_state_bool = toggle_state.selected(); - - Upsell::set_dismissed(toggle_state_bool, cx); - }, - ); - - let contents = div() + fn render_backdrop(&self, cx: &mut Context) -> impl IntoElement { + div() .size_full() - .gap_2() - .flex() - .flex_col() - .child(Headline::new("Build better with Zed Pro").size(HeadlineSize::Small)) - .child( - Label::new("Try Zed Pro for free for 14 days - no credit card required.") - .size(LabelSize::Small), - ) - .child( - Label::new( - "Use your own API keys or enable usage-based billing once you hit the cap.", - ) - .color(Color::Muted), - ) - .child( - h_flex() - .w_full() - .px_neg_1() - .justify_between() - .items_center() - .child(h_flex().items_center().gap_1().child(checkbox)) - .child( - h_flex() - .gap_2() - .child( - Button::new("dismiss-button", "Not Now") - .style(ButtonStyle::Transparent) - .color(Color::Muted) - .on_click({ - let agent_panel = cx.entity(); - move |_, _, cx| { - agent_panel.update(cx, |this, cx| { - this.hide_upsell = true; - cx.notify(); - }); - } - }), - ) - .child( - Button::new("cta-button", "Start Trial") - .style(ButtonStyle::Transparent) - .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx))), - ), - ), - ); - - self.render_upsell_container(cx, contents) + .absolute() + .inset_0() + .bg(cx.theme().colors().panel_background) + .opacity(0.8) + .block_mouse_except_scroll() } fn render_trial_end_upsell( @@ -2270,140 +2363,45 @@ impl AgentPanel { } Some( - self.render_upsell_container( - cx, - div() - .size_full() - .gap_2() - .flex() - .flex_col() - .child( - Headline::new("Your Zed Pro trial has expired.").size(HeadlineSize::Small), - ) - .child( - Label::new("You've been automatically reset to the free plan.") - .size(LabelSize::Small), - ) - .child( - h_flex() - .w_full() - .px_neg_1() - .justify_between() - .items_center() - .child(div()) - .child( - h_flex() - .gap_2() - .child( - Button::new("dismiss-button", "Stay on Free") - .style(ButtonStyle::Transparent) - .color(Color::Muted) - .on_click({ - let agent_panel = cx.entity(); - move |_, _, cx| { - agent_panel.update(cx, |_this, cx| { - TrialEndUpsell::set_dismissed(true, cx); - cx.notify(); - }); - } - }), - ) - .child( - Button::new("cta-button", "Upgrade to Zed Pro") - .style(ButtonStyle::Transparent) - .on_click(|_, _, cx| { - cx.open_url(&zed_urls::account_url(cx)) - }), - ), - ), - ), - ), + v_flex() + .absolute() + .inset_0() + .size_full() + .bg(cx.theme().colors().panel_background) + .opacity(0.85) + .block_mouse_except_scroll() + .child(EndTrialUpsell::new(Arc::new({ + let this = cx.entity(); + move |_, cx| { + this.update(cx, |_this, cx| { + TrialEndUpsell::set_dismissed(true, cx); + cx.notify(); + }); + } + }))), ) } - fn render_upsell_container(&self, cx: &mut Context, content: Div) -> Div { - div().p_2().child( - v_flex() - .w_full() - .elevation_2(cx) - .rounded(px(8.)) - .bg(cx.theme().colors().background.alpha(0.5)) - .p(px(3.)) - .child( - div() - .gap_2() - .flex() - .flex_col() - .size_full() - .border_1() - .rounded(px(5.)) - .border_color(cx.theme().colors().text.alpha(0.1)) - .overflow_hidden() - .relative() - .bg(cx.theme().colors().panel_background) - .px_4() - .py_3() - .child( - div() - .absolute() - .top_0() - .right(px(-1.0)) - .w(px(441.)) - .h(px(167.)) - .child( - Vector::new( - VectorName::Grid, - rems_from_px(441.), - rems_from_px(167.), - ) - .color(ui::Color::Custom(cx.theme().colors().text.alpha(0.1))), - ), - ) - .child( - div() - .absolute() - .top(px(-8.0)) - .right_0() - .w(px(400.)) - .h(px(92.)) - .child( - Vector::new( - VectorName::AiGrid, - rems_from_px(400.), - rems_from_px(92.), - ) - .color(ui::Color::Custom(cx.theme().colors().text.alpha(0.32))), - ), - ) - // .child( - // div() - // .absolute() - // .top_0() - // .right(px(360.)) - // .size(px(401.)) - // .overflow_hidden() - // .bg(cx.theme().colors().panel_background) - // ) - .child( - div() - .absolute() - .top_0() - .right_0() - .w(px(660.)) - .h(px(401.)) - .overflow_hidden() - .bg(linear_gradient( - 75., - linear_color_stop( - cx.theme().colors().panel_background.alpha(0.01), - 1.0, - ), - linear_color_stop(cx.theme().colors().panel_background, 0.45), - )), - ) - .child(content), - ), - ) + fn render_empty_state_section_header( + &self, + label: impl Into, + action_slot: Option, + cx: &mut Context, + ) -> impl IntoElement { + h_flex() + .mt_2() + .pl_1p5() + .pb_1() + .w_full() + .justify_between() + .border_b_1() + .border_color(cx.theme().colors().border_variant) + .child( + Label::new(label.into()) + .size(LabelSize::Small) + .color(Color::Muted), + ) + .children(action_slot) } fn render_thread_empty_state( @@ -2416,8 +2414,10 @@ impl AgentPanel { .update(cx, |this, cx| this.recent_entries(6, cx)); let model_registry = LanguageModelRegistry::read_global(cx); + let configuration_error = model_registry.configuration_error(model_registry.default_model(), cx); + let no_error = configuration_error.is_none(); let focus_handle = self.focus_handle(cx); @@ -2425,11 +2425,9 @@ impl AgentPanel { .size_full() .bg(cx.theme().colors().panel_background) .when(recent_history.is_empty(), |this| { - let configuration_error_ref = &configuration_error; this.child( v_flex() .size_full() - .max_w_80() .mx_auto() .justify_center() .items_center() @@ -2437,156 +2435,100 @@ impl AgentPanel { .child(h_flex().child(Headline::new("Welcome to the Agent Panel"))) .when(no_error, |parent| { parent - .child( - h_flex().child( - Label::new("Ask and build anything.") - .color(Color::Muted) - .mb_2p5(), - ), - ) - .child( - Button::new("new-thread", "Start New Thread") - .icon(IconName::Plus) - .icon_position(IconPosition::Start) - .icon_size(IconSize::Small) - .icon_color(Color::Muted) - .full_width() - .key_binding(KeyBinding::for_action_in( - &NewThread::default(), - &focus_handle, - window, - cx, - )) - .on_click(|_event, window, cx| { - window.dispatch_action( - NewThread::default().boxed_clone(), - cx, - ) - }), - ) - .child( - Button::new("context", "Add Context") - .icon(IconName::FileCode) - .icon_position(IconPosition::Start) - .icon_size(IconSize::Small) - .icon_color(Color::Muted) - .full_width() - .key_binding(KeyBinding::for_action_in( - &ToggleContextPicker, - &focus_handle, - window, - cx, - )) - .on_click(|_event, window, cx| { - window.dispatch_action( - ToggleContextPicker.boxed_clone(), - cx, - ) - }), - ) - .child( - Button::new("mode", "Switch Model") - .icon(IconName::DatabaseZap) - .icon_position(IconPosition::Start) - .icon_size(IconSize::Small) - .icon_color(Color::Muted) - .full_width() - .key_binding(KeyBinding::for_action_in( - &ToggleModelSelector, - &focus_handle, - window, - cx, - )) - .on_click(|_event, window, cx| { - window.dispatch_action( - ToggleModelSelector.boxed_clone(), - cx, - ) - }), - ) - .child( - Button::new("settings", "View Settings") - .icon(IconName::Settings) - .icon_position(IconPosition::Start) - .icon_size(IconSize::Small) - .icon_color(Color::Muted) - .full_width() - .key_binding(KeyBinding::for_action_in( - &OpenConfiguration, - &focus_handle, - window, - cx, - )) - .on_click(|_event, window, cx| { - window.dispatch_action( - OpenConfiguration.boxed_clone(), - cx, - ) - }), - ) - }) - .map(|parent| match configuration_error_ref { - Some( - err @ (ConfigurationError::ModelNotFound - | ConfigurationError::ProviderNotAuthenticated(_) - | ConfigurationError::NoProvider), - ) => parent .child(h_flex().child( - Label::new(err.to_string()).color(Color::Muted).mb_2p5(), + Label::new("Ask and build anything.").color(Color::Muted), )) .child( - Button::new("settings", "Configure a Provider") - .icon(IconName::Settings) - .icon_position(IconPosition::Start) - .icon_size(IconSize::Small) - .icon_color(Color::Muted) - .full_width() - .key_binding(KeyBinding::for_action_in( - &OpenConfiguration, - &focus_handle, - window, - cx, - )) - .on_click(|_event, window, cx| { - window.dispatch_action( - OpenConfiguration.boxed_clone(), - cx, - ) - }), - ), - Some(ConfigurationError::ProviderPendingTermsAcceptance(provider)) => { - parent.children(provider.render_accept_terms( - LanguageModelProviderTosView::ThreadFreshStart, - cx, - )) - } - None => parent, + v_flex() + .mt_2() + .gap_1() + .max_w_48() + .child( + Button::new("context", "Add Context") + .label_size(LabelSize::Small) + .icon(IconName::FileCode) + .icon_position(IconPosition::Start) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .full_width() + .key_binding(KeyBinding::for_action_in( + &ToggleContextPicker, + &focus_handle, + window, + cx, + )) + .on_click(|_event, window, cx| { + window.dispatch_action( + ToggleContextPicker.boxed_clone(), + cx, + ) + }), + ) + .child( + Button::new("mode", "Switch Model") + .label_size(LabelSize::Small) + .icon(IconName::DatabaseZap) + .icon_position(IconPosition::Start) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .full_width() + .key_binding(KeyBinding::for_action_in( + &ToggleModelSelector, + &focus_handle, + window, + cx, + )) + .on_click(|_event, window, cx| { + window.dispatch_action( + ToggleModelSelector.boxed_clone(), + cx, + ) + }), + ) + .child( + Button::new("settings", "View Settings") + .label_size(LabelSize::Small) + .icon(IconName::Settings) + .icon_position(IconPosition::Start) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .full_width() + .key_binding(KeyBinding::for_action_in( + &OpenSettings, + &focus_handle, + window, + cx, + )) + .on_click(|_event, window, cx| { + window.dispatch_action( + OpenSettings.boxed_clone(), + cx, + ) + }), + ), + ) + }) + .when_some(configuration_error.as_ref(), |this, err| { + this.child(self.render_configuration_error( + err, + &focus_handle, + window, + cx, + )) }), ) }) .when(!recent_history.is_empty(), |parent| { let focus_handle = focus_handle.clone(); - let configuration_error_ref = &configuration_error; - parent .overflow_hidden() .p_1p5() .justify_end() .gap_1() .child( - h_flex() - .pl_1p5() - .pb_1() - .w_full() - .justify_between() - .border_b_1() - .border_color(cx.theme().colors().border_variant) - .child( - Label::new("Recent") - .size(LabelSize::Small) - .color(Color::Muted), - ) - .child( + self.render_empty_state_section_header( + "Recent", + Some( Button::new("view-history", "View All") .style(ButtonStyle::Subtle) .label_size(LabelSize::Small) @@ -2601,8 +2543,11 @@ impl AgentPanel { ) .on_click(move |_event, window, cx| { window.dispatch_action(OpenHistory.boxed_clone(), cx); - }), + }) + .into_any_element(), ), + cx, + ), ) .child( v_flex() @@ -2630,49 +2575,182 @@ impl AgentPanel { }, )), ) - .map(|parent| match configuration_error_ref { - Some( - err @ (ConfigurationError::ModelNotFound - | ConfigurationError::ProviderNotAuthenticated(_) - | ConfigurationError::NoProvider), - ) => parent.child( - Banner::new() - .severity(ui::Severity::Warning) - .child(Label::new(err.to_string()).size(LabelSize::Small)) - .action_slot( - Button::new("settings", "Configure Provider") - .style(ButtonStyle::Tinted(ui::TintColor::Warning)) - .label_size(LabelSize::Small) - .key_binding( - KeyBinding::for_action_in( - &OpenConfiguration, - &focus_handle, - window, - cx, + .child(self.render_empty_state_section_header("Start", None, cx)) + .child( + v_flex() + .p_1() + .gap_2() + .child( + h_flex() + .w_full() + .gap_2() + .child( + NewThreadButton::new( + "new-thread-btn", + "New Thread", + IconName::Thread, + ) + .keybinding(KeyBinding::for_action_in( + &NewThread::default(), + &self.focus_handle(cx), + window, + cx, + )) + .on_click( + |window, cx| { + window.dispatch_action( + NewThread::default().boxed_clone(), + cx, + ) + }, + ), + ) + .child( + NewThreadButton::new( + "new-text-thread-btn", + "New Text Thread", + IconName::TextThread, + ) + .keybinding(KeyBinding::for_action_in( + &NewTextThread, + &self.focus_handle(cx), + window, + cx, + )) + .on_click( + |window, cx| { + window.dispatch_action(Box::new(NewTextThread), cx) + }, + ), + ), + ) + .when(cx.has_flag::(), |this| { + this.child( + h_flex() + .w_full() + .gap_2() + .child( + NewThreadButton::new( + "new-gemini-thread-btn", + "New Gemini Thread", + IconName::AiGemini, ) - .map(|kb| kb.size(rems_from_px(12.))), + // .keybinding(KeyBinding::for_action_in( + // &OpenHistory, + // &self.focus_handle(cx), + // window, + // cx, + // )) + .on_click( + |window, cx| { + window.dispatch_action( + Box::new(NewExternalAgentThread { + agent: Some( + crate::ExternalAgent::Gemini, + ), + }), + cx, + ) + }, + ), ) - .on_click(|_event, window, cx| { - window.dispatch_action( - OpenConfiguration.boxed_clone(), - cx, + .child( + NewThreadButton::new( + "new-claude-thread-btn", + "New Claude Code Thread", + IconName::AiClaude, ) - }), - ), - ), - Some(ConfigurationError::ProviderPendingTermsAcceptance(provider)) => { - parent.child(Banner::new().severity(ui::Severity::Warning).child( - h_flex().w_full().children(provider.render_accept_terms( - LanguageModelProviderTosView::ThreadEmptyState, - cx, - )), - )) - } - None => parent, + // .keybinding(KeyBinding::for_action_in( + // &OpenHistory, + // &self.focus_handle(cx), + // window, + // cx, + // )) + .on_click( + |window, cx| { + window.dispatch_action( + Box::new(NewExternalAgentThread { + agent: Some( + crate::ExternalAgent::ClaudeCode, + ), + }), + cx, + ) + }, + ), + ) + .child( + NewThreadButton::new( + "new-native-agent-thread-btn", + "New Native Agent Thread", + IconName::ZedAssistant, + ) + // .keybinding(KeyBinding::for_action_in( + // &OpenHistory, + // &self.focus_handle(cx), + // window, + // cx, + // )) + .on_click( + |window, cx| { + window.dispatch_action( + Box::new(NewExternalAgentThread { + agent: Some( + crate::ExternalAgent::NativeAgent, + ), + }), + cx, + ) + }, + ), + ), + ) + }), + ) + .when_some(configuration_error.as_ref(), |this, err| { + this.child(self.render_configuration_error(err, &focus_handle, window, cx)) }) }) } + fn render_configuration_error( + &self, + configuration_error: &ConfigurationError, + focus_handle: &FocusHandle, + window: &mut Window, + cx: &mut App, + ) -> impl IntoElement { + match configuration_error { + ConfigurationError::ModelNotFound + | ConfigurationError::ProviderNotAuthenticated(_) + | ConfigurationError::NoProvider => Banner::new() + .severity(ui::Severity::Warning) + .child(Label::new(configuration_error.to_string())) + .action_slot( + Button::new("settings", "Configure Provider") + .style(ButtonStyle::Tinted(ui::TintColor::Warning)) + .label_size(LabelSize::Small) + .key_binding( + KeyBinding::for_action_in(&OpenSettings, &focus_handle, window, cx) + .map(|kb| kb.size(rems_from_px(12.))), + ) + .on_click(|_event, window, cx| { + window.dispatch_action(OpenSettings.boxed_clone(), cx) + }), + ), + ConfigurationError::ProviderPendingTermsAcceptance(provider) => { + Banner::new().severity(ui::Severity::Warning).child( + h_flex().w_full().children( + provider.render_accept_terms( + LanguageModelProviderTosView::ThreadEmptyState, + cx, + ), + ), + ) + } + } + } + fn render_tool_use_limit_reached( &self, window: &mut Window, @@ -2680,7 +2758,7 @@ impl AgentPanel { ) -> Option { let active_thread = match &self.active_view { ActiveView::Thread { thread, .. } => thread, - ActiveView::AcpThread { .. } => { + ActiveView::ExternalAgentThread { .. } => { return None; } ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => { @@ -2805,7 +2883,7 @@ impl AgentPanel { this.clear_last_error(); }); - cx.open_url(&zed_urls::account_url(cx)); + cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)); cx.notify(); } })) @@ -2851,7 +2929,7 @@ impl AgentPanel { ) -> AnyElement { let error_message = match plan { Plan::ZedPro => "Upgrade to usage-based billing for more prompts.", - Plan::ZedProTrial | Plan::Free => "Upgrade to Zed Pro for more prompts.", + Plan::ZedProTrial | Plan::ZedFree => "Upgrade to Zed Pro for more prompts.", }; let icon = Icon::new(IconName::XCircle) @@ -2887,6 +2965,23 @@ impl AgentPanel { .size(IconSize::Small) .color(Color::Error); + let retry_button = Button::new("retry", "Retry") + .icon(IconName::RotateCw) + .icon_position(IconPosition::Start) + .icon_size(IconSize::Small) + .label_size(LabelSize::Small) + .on_click({ + let thread = thread.clone(); + move |_, window, cx| { + thread.update(cx, |thread, cx| { + thread.clear_last_error(); + thread.thread().update(cx, |thread, cx| { + thread.retry_last_completion(Some(window.window_handle()), cx); + }); + }); + } + }); + div() .border_t_1() .border_color(cx.theme().colors().border) @@ -2895,13 +2990,76 @@ impl AgentPanel { .icon(icon) .title(header) .description(message.clone()) - .primary_action(self.dismiss_error_button(thread, cx)) - .secondary_action(self.create_copy_button(message_with_header)) + .primary_action(retry_button) + .secondary_action(self.dismiss_error_button(thread, cx)) + .tertiary_action(self.create_copy_button(message_with_header)) .bg_color(self.error_callout_bg(cx)), ) .into_any_element() } + fn render_retryable_error( + &self, + message: SharedString, + can_enable_burn_mode: bool, + thread: &Entity, + cx: &mut Context, + ) -> AnyElement { + let icon = Icon::new(IconName::XCircle) + .size(IconSize::Small) + .color(Color::Error); + + let retry_button = Button::new("retry", "Retry") + .icon(IconName::RotateCw) + .icon_position(IconPosition::Start) + .icon_size(IconSize::Small) + .label_size(LabelSize::Small) + .on_click({ + let thread = thread.clone(); + move |_, window, cx| { + thread.update(cx, |thread, cx| { + thread.clear_last_error(); + thread.thread().update(cx, |thread, cx| { + thread.retry_last_completion(Some(window.window_handle()), cx); + }); + }); + } + }); + + let mut callout = Callout::new() + .icon(icon) + .title("Error") + .description(message.clone()) + .bg_color(self.error_callout_bg(cx)) + .primary_action(retry_button); + + if can_enable_burn_mode { + let burn_mode_button = Button::new("enable_burn_retry", "Enable Burn Mode and Retry") + .icon(IconName::ZedBurnMode) + .icon_position(IconPosition::Start) + .icon_size(IconSize::Small) + .label_size(LabelSize::Small) + .on_click({ + let thread = thread.clone(); + move |_, window, cx| { + thread.update(cx, |thread, cx| { + thread.clear_last_error(); + thread.thread().update(cx, |thread, cx| { + thread.enable_burn_mode_and_retry(Some(window.window_handle()), cx); + }); + }); + } + }); + callout = callout.secondary_action(burn_mode_button); + } + + div() + .border_t_1() + .border_color(cx.theme().colors().border) + .child(callout) + .into_any_element() + } + fn render_prompt_editor( &self, context_editor: &Entity, @@ -3029,7 +3187,7 @@ impl AgentPanel { .detach(); }); } - ActiveView::AcpThread { .. } => { + ActiveView::ExternalAgentThread { .. } => { unimplemented!() } ActiveView::TextThread { context_editor, .. } => { @@ -3051,7 +3209,7 @@ impl AgentPanel { let mut key_context = KeyContext::new_with_defaults(); key_context.add("AgentPanel"); match &self.active_view { - ActiveView::AcpThread { .. } => key_context.add("acp_thread"), + ActiveView::ExternalAgentThread { .. } => key_context.add("external_agent_thread"), ActiveView::TextThread { .. } => key_context.add("prompt_editor"), ActiveView::Thread { .. } | ActiveView::History | ActiveView::Configuration => {} } @@ -3071,9 +3229,10 @@ impl Render for AgentPanel { // - Scrolling in all views works as expected // - Files can be dropped into the panel let content = v_flex() - .key_context(self.key_context()) - .justify_between() + .relative() .size_full() + .justify_between() + .key_context(self.key_context()) .on_action(cx.listener(Self::cancel)) .on_action(cx.listener(|this, action: &NewThread, window, cx| { this.new_thread(action, window, cx); @@ -3081,7 +3240,7 @@ impl Render for AgentPanel { .on_action(cx.listener(|this, _: &OpenHistory, window, cx| { this.open_history(window, cx); })) - .on_action(cx.listener(|this, _: &OpenConfiguration, window, cx| { + .on_action(cx.listener(|this, _: &OpenSettings, window, cx| { this.open_configuration(window, cx); })) .on_action(cx.listener(Self::open_active_thread_as_markdown)) @@ -3107,7 +3266,7 @@ impl Render for AgentPanel { }); this.continue_conversation(window, cx); } - ActiveView::AcpThread { .. } => {} + ActiveView::ExternalAgentThread { .. } => {} ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {} @@ -3115,21 +3274,21 @@ impl Render for AgentPanel { })) .on_action(cx.listener(Self::toggle_burn_mode)) .child(self.render_toolbar(window, cx)) - .children(self.render_upsell(window, cx)) - .children(self.render_trial_end_upsell(window, cx)) + .children(self.render_onboarding(window, cx)) .map(|parent| match &self.active_view { ActiveView::Thread { thread, message_editor, .. } => parent - .relative() - .child(if thread.read(cx).is_empty() { - self.render_thread_empty_state(window, cx) - .into_any_element() - } else { - thread.clone().into_any_element() - }) + .child( + if thread.read(cx).is_empty() && !self.should_render_onboarding(cx) { + self.render_thread_empty_state(window, cx) + .into_any_element() + } else { + thread.clone().into_any_element() + }, + ) .children(self.render_tool_use_limit_reached(window, cx)) .when_some(thread.read(cx).last_error(), |this, last_error| { this.child( @@ -3143,14 +3302,25 @@ impl Render for AgentPanel { ThreadError::Message { header, message } => { self.render_error_message(header, message, thread, cx) } + ThreadError::RetryableError { + message, + can_enable_burn_mode, + } => self.render_retryable_error( + message, + can_enable_burn_mode, + thread, + cx, + ), }) .into_any(), ) }) - .child(h_flex().child(message_editor.clone())) + .child(h_flex().relative().child(message_editor.clone()).when( + !LanguageModelRegistry::read_global(cx).has_authenticated_provider(cx), + |this| this.child(self.render_backdrop(cx)), + )) .child(self.render_drag_target(cx)), - ActiveView::AcpThread { thread_view, .. } => parent - .relative() + ActiveView::ExternalAgentThread { thread_view, .. } => parent .child(thread_view.clone()) .child(self.render_drag_target(cx)), ActiveView::History => parent.child(self.history.clone()), @@ -3158,14 +3328,39 @@ impl Render for AgentPanel { context_editor, buffer_search_bar, .. - } => parent.child(self.render_prompt_editor( - context_editor, - buffer_search_bar, - window, - cx, - )), + } => { + let model_registry = LanguageModelRegistry::read_global(cx); + let configuration_error = + model_registry.configuration_error(model_registry.default_model(), cx); + parent + .map(|this| { + if !self.should_render_onboarding(cx) + && let Some(err) = configuration_error.as_ref() + { + this.child( + div().bg(cx.theme().colors().editor_background).p_2().child( + self.render_configuration_error( + err, + &self.focus_handle(cx), + window, + cx, + ), + ), + ) + } else { + this + } + }) + .child(self.render_prompt_editor( + context_editor, + buffer_search_bar, + window, + cx, + )) + } ActiveView::Configuration => parent.children(self.configuration.clone()), - }); + }) + .children(self.render_trial_end_upsell(window, cx)); match self.active_view.which_font_size_used() { WhichFontSize::AgentFont => { @@ -3332,9 +3527,9 @@ impl AgentPanelDelegate for ConcreteAssistantPanelDelegate { } } -struct Upsell; +struct OnboardingUpsell; -impl Dismissable for Upsell { +impl Dismissable for OnboardingUpsell { const KEY: &'static str = "dismissed-trial-upsell"; } diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index 3170ec4a267d76791968b410e9426079a6ae1f2d..fceb8f4c45157bfc47285e65c69f685bcb156eee 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -25,12 +25,14 @@ mod thread_history; mod tool_compatibility; mod ui; +use std::rc::Rc; use std::sync::Arc; use agent::{Thread, ThreadId}; use agent_settings::{AgentProfileId, AgentSettings, LanguageModelSelection}; use assistant_slash_command::SlashCommandRegistry; use client::Client; +use command_palette_hooks::CommandPaletteFilter; use feature_flags::FeatureFlagAppExt as _; use fs::Fs; use gpui::{Action, App, Entity, actions}; @@ -38,10 +40,12 @@ use language::LanguageRegistry; use language_model::{ ConfiguredModel, LanguageModel, LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, }; +use project::DisableAiSettings; use prompt_store::PromptBuilder; use schemars::JsonSchema; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use settings::{Settings as _, SettingsStore}; +use std::any::TypeId; pub use crate::active_thread::ActiveThread; use crate::agent_configuration::{ConfigureContextServerModal, ManageProfilesModal}; @@ -51,14 +55,13 @@ use crate::slash_command_settings::SlashCommandSettings; pub use agent_diff::{AgentDiffPane, AgentDiffToolbar}; pub use text_thread_editor::{AgentPanelDelegate, TextThreadEditor}; pub use ui::preview::{all_agent_previews, get_agent_preview}; +use zed_actions; actions!( agent, [ /// Creates a new text-based conversation thread. NewTextThread, - /// Creates a new external agent conversation thread. - NewAcpThread, /// Toggles the context picker interface for adding files, symbols, or other context. ToggleContextPicker, /// Toggles the navigation menu for switching between threads and views. @@ -133,6 +136,34 @@ pub struct NewThread { from_thread_id: Option, } +/// Creates a new external agent conversation thread. +#[derive(Default, Clone, PartialEq, Deserialize, JsonSchema, Action)] +#[action(namespace = agent)] +#[serde(deny_unknown_fields)] +pub struct NewExternalAgentThread { + /// Which agent to use for the conversation. + agent: Option, +} + +#[derive(Default, Clone, Copy, PartialEq, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +enum ExternalAgent { + #[default] + Gemini, + ClaudeCode, + NativeAgent, +} + +impl ExternalAgent { + pub fn server(&self) -> Rc { + match self { + ExternalAgent::Gemini => Rc::new(agent_servers::Gemini), + ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode), + ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer), + } + } +} + /// Opens the profile management interface for configuring agent tools and settings. #[derive(PartialEq, Clone, Default, Debug, Deserialize, JsonSchema, Action)] #[action(namespace = agent)] @@ -216,6 +247,69 @@ pub fn init( }) .detach(); cx.observe_new(ManageProfilesModal::register).detach(); + + // Update command palette filter based on AI settings + update_command_palette_filter(cx); + + // Watch for settings changes + cx.observe_global::(|app_cx| { + // When settings change, update the command palette filter + update_command_palette_filter(app_cx); + }) + .detach(); +} + +fn update_command_palette_filter(cx: &mut App) { + let disable_ai = DisableAiSettings::get_global(cx).disable_ai; + CommandPaletteFilter::update_global(cx, |filter, _| { + if disable_ai { + filter.hide_namespace("agent"); + filter.hide_namespace("assistant"); + filter.hide_namespace("copilot"); + filter.hide_namespace("supermaven"); + filter.hide_namespace("zed_predict_onboarding"); + filter.hide_namespace("edit_prediction"); + + use editor::actions::{ + AcceptEditPrediction, AcceptPartialEditPrediction, NextEditPrediction, + PreviousEditPrediction, ShowEditPrediction, ToggleEditPrediction, + }; + let edit_prediction_actions = [ + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + ]; + filter.hide_action_types(&edit_prediction_actions); + filter.hide_action_types(&[TypeId::of::()]); + } else { + filter.show_namespace("agent"); + filter.show_namespace("assistant"); + filter.show_namespace("copilot"); + filter.show_namespace("zed_predict_onboarding"); + + filter.show_namespace("edit_prediction"); + + use editor::actions::{ + AcceptEditPrediction, AcceptPartialEditPrediction, NextEditPrediction, + PreviousEditPrediction, ShowEditPrediction, ToggleEditPrediction, + }; + let edit_prediction_actions = [ + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + ]; + filter.show_action_types(edit_prediction_actions.iter()); + + filter + .show_action_types([TypeId::of::()].iter()); + } + }); } fn init_language_model_settings(cx: &mut App) { diff --git a/crates/agent_ui/src/buffer_codegen.rs b/crates/agent_ui/src/buffer_codegen.rs index 64498e928130d0debfd8a30bdcbcc010c0de48a1..615142b73dfd6eed59f635af780310290e3f6f25 100644 --- a/crates/agent_ui/src/buffer_codegen.rs +++ b/crates/agent_ui/src/buffer_codegen.rs @@ -6,6 +6,7 @@ use agent::{ use agent_settings::AgentSettings; use anyhow::{Context as _, Result}; use client::telemetry::Telemetry; +use cloud_llm_client::CompletionIntent; use collections::HashSet; use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint}; use futures::{ @@ -35,7 +36,6 @@ use std::{ }; use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff}; use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase}; -use zed_llm_client::CompletionIntent; pub struct BufferCodegen { alternatives: Vec>, diff --git a/crates/agent_ui/src/context_picker.rs b/crates/agent_ui/src/context_picker.rs index 5cc56b014e140ac1cba91606f3ceddfa9b477dbd..32f9a096d9bf39e251545e3c82d3d4ff77ccb17e 100644 --- a/crates/agent_ui/src/context_picker.rs +++ b/crates/agent_ui/src/context_picker.rs @@ -148,7 +148,7 @@ impl ContextPickerMode { Self::File => IconName::File, Self::Symbol => IconName::Code, Self::Fetch => IconName::Globe, - Self::Thread => IconName::MessageBubbles, + Self::Thread => IconName::Thread, Self::Rules => RULES_ICON, } } diff --git a/crates/agent_ui/src/context_picker/completion_provider.rs b/crates/agent_ui/src/context_picker/completion_provider.rs index b377e40b193d090a61b88232098fd45645a2ab4f..5ca0913be7b923136298590861a8ca13107952af 100644 --- a/crates/agent_ui/src/context_picker/completion_provider.rs +++ b/crates/agent_ui/src/context_picker/completion_provider.rs @@ -423,7 +423,7 @@ impl ContextPickerCompletionProvider { let icon_for_completion = if recent { IconName::HistoryRerun } else { - IconName::MessageBubbles + IconName::Thread }; let new_text = format!("{} ", MentionLink::for_thread(&thread_entry)); let new_text_len = new_text.len(); @@ -436,7 +436,7 @@ impl ContextPickerCompletionProvider { source: project::CompletionSource::Custom, icon_path: Some(icon_for_completion.path().into()), confirm: Some(confirm_completion_callback( - IconName::MessageBubbles.path().into(), + IconName::Thread.path().into(), thread_entry.title().clone(), excerpt_id, source_range.start, diff --git a/crates/agent_ui/src/context_picker/thread_context_picker.rs b/crates/agent_ui/src/context_picker/thread_context_picker.rs index cb2e97a493b64cde5f05f93e68f03b04e9f755f6..15cc731f8f2b7c82885c566273bc1cda9f3c156a 100644 --- a/crates/agent_ui/src/context_picker/thread_context_picker.rs +++ b/crates/agent_ui/src/context_picker/thread_context_picker.rs @@ -253,7 +253,7 @@ pub fn render_thread_context_entry( .gap_1p5() .max_w_72() .child( - Icon::new(IconName::MessageBubbles) + Icon::new(IconName::Thread) .size(IconSize::XSmall) .color(Color::Muted), ) diff --git a/crates/agent_ui/src/context_strip.rs b/crates/agent_ui/src/context_strip.rs index 080ffd2ea0108400b691c6a614fcdb4f81952856..369964f165dc4d4460fd446c949538ec820fb82e 100644 --- a/crates/agent_ui/src/context_strip.rs +++ b/crates/agent_ui/src/context_strip.rs @@ -504,7 +504,7 @@ impl Render for ContextStrip { ) .on_click({ Rc::new(cx.listener(move |this, event: &ClickEvent, window, cx| { - if event.down.click_count > 1 { + if event.click_count() > 1 { this.open_context(&context, window, cx); } else { this.focused_index = Some(i); diff --git a/crates/agent_ui/src/debug.rs b/crates/agent_ui/src/debug.rs index ff6538dc85a45f0072b805b033952da78255f8b7..bd34659210e933ad99357e7e1ceeedb6b53c5ee0 100644 --- a/crates/agent_ui/src/debug.rs +++ b/crates/agent_ui/src/debug.rs @@ -1,10 +1,10 @@ #![allow(unused, dead_code)] use client::{ModelRequestUsage, RequestUsage}; +use cloud_llm_client::{Plan, UsageLimit}; use gpui::Global; use std::ops::{Deref, DerefMut}; use ui::prelude::*; -use zed_llm_client::{Plan, UsageLimit}; /// Debug only: Used for testing various account states /// diff --git a/crates/agent_ui/src/inline_assistant.rs b/crates/agent_ui/src/inline_assistant.rs index c9c173a68be5191e77690e826378ca52d3db9684..4a4a747899ecc310666685de336bacffc4b271e6 100644 --- a/crates/agent_ui/src/inline_assistant.rs +++ b/crates/agent_ui/src/inline_assistant.rs @@ -39,7 +39,7 @@ use language_model::{ }; use multi_buffer::MultiBufferRow; use parking_lot::Mutex; -use project::{CodeAction, LspAction, Project, ProjectTransaction}; +use project::{CodeAction, DisableAiSettings, LspAction, Project, ProjectTransaction}; use prompt_store::{PromptBuilder, PromptStore}; use settings::{Settings, SettingsStore}; use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase}; @@ -48,7 +48,7 @@ use text::{OffsetRangeExt, ToPoint as _}; use ui::prelude::*; use util::{RangeExt, ResultExt, maybe}; use workspace::{ItemHandle, Toast, Workspace, dock::Panel, notifications::NotificationId}; -use zed_actions::agent::OpenConfiguration; +use zed_actions::agent::OpenSettings; pub fn init( fs: Arc, @@ -57,6 +57,17 @@ pub fn init( cx: &mut App, ) { cx.set_global(InlineAssistant::new(fs, prompt_builder, telemetry)); + + cx.observe_global::(|cx| { + if DisableAiSettings::get_global(cx).disable_ai { + // Hide any active inline assist UI when AI is disabled + InlineAssistant::update_global(cx, |assistant, cx| { + assistant.cancel_all_active_completions(cx); + }); + } + }) + .detach(); + cx.observe_new(|_workspace: &mut Workspace, window, cx| { let Some(window) = window else { return; @@ -141,6 +152,26 @@ impl InlineAssistant { .detach(); } + /// Hides all active inline assists when AI is disabled + pub fn cancel_all_active_completions(&mut self, cx: &mut App) { + // Cancel all active completions in editors + for (editor_handle, _) in self.assists_by_editor.iter() { + if let Some(editor) = editor_handle.upgrade() { + let windows = cx.windows(); + if !windows.is_empty() { + let window = windows[0]; + let _ = window.update(cx, |_, window, cx| { + editor.update(cx, |editor, cx| { + if editor.has_active_edit_prediction() { + editor.cancel(&Default::default(), window, cx); + } + }); + }); + } + } + } + } + fn handle_workspace_event( &mut self, workspace: Entity, @@ -176,7 +207,7 @@ impl InlineAssistant { window: &mut Window, cx: &mut App, ) { - let is_assistant2_enabled = true; + let is_assistant2_enabled = !DisableAiSettings::get_global(cx).disable_ai; if let Some(editor) = item.act_as::(cx) { editor.update(cx, |editor, cx| { @@ -199,6 +230,13 @@ impl InlineAssistant { cx, ); + if DisableAiSettings::get_global(cx).disable_ai { + // Cancel any active edit predictions + if editor.has_active_edit_prediction() { + editor.cancel(&Default::default(), window, cx); + } + } + // Remove the Assistant1 code action provider, as it still might be registered. editor.remove_code_action_provider("assistant".into(), window, cx); } else { @@ -219,7 +257,7 @@ impl InlineAssistant { cx: &mut Context, ) { let settings = AgentSettings::get_global(cx); - if !settings.enabled { + if !settings.enabled || DisableAiSettings::get_global(cx).disable_ai { return; } @@ -307,7 +345,7 @@ impl InlineAssistant { if let Some(answer) = answer { if answer == 0 { cx.update(|window, cx| { - window.dispatch_action(Box::new(OpenConfiguration), cx) + window.dispatch_action(Box::new(OpenSettings), cx) }) .ok(); } @@ -660,7 +698,6 @@ impl InlineAssistant { height: Some(prompt_editor_height), render: build_assist_editor_renderer(prompt_editor), priority: 0, - render_in_minimap: false, }, BlockProperties { style: BlockStyle::Sticky, @@ -675,7 +712,6 @@ impl InlineAssistant { .into_any_element() }), priority: 0, - render_in_minimap: false, }, ]; @@ -1451,7 +1487,6 @@ impl InlineAssistant { .into_any_element() }), priority: 0, - render_in_minimap: false, }); } diff --git a/crates/agent_ui/src/inline_prompt_editor.rs b/crates/agent_ui/src/inline_prompt_editor.rs index 7a61eef7486de92bc181a3f28e032865a4452fe2..a5f90edb57ca9142a3661bdf64629e6f05e96191 100644 --- a/crates/agent_ui/src/inline_prompt_editor.rs +++ b/crates/agent_ui/src/inline_prompt_editor.rs @@ -2,7 +2,6 @@ use crate::agent_model_selector::AgentModelSelector; use crate::buffer_codegen::BufferCodegen; use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider}; use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; -use crate::language_model_selector::ToggleModelSelector; use crate::message_editor::{ContextCreasesAddon, extract_message_creases, insert_message_creases}; use crate::terminal_codegen::TerminalCodegen; use crate::{CycleNextInlineAssist, CyclePreviousInlineAssist, ModelUsageContext}; @@ -38,6 +37,7 @@ use ui::{ CheckboxWithLabel, IconButtonShape, KeyBinding, Popover, PopoverMenuHandle, Tooltip, prelude::*, }; use workspace::Workspace; +use zed_actions::agent::ToggleModelSelector; pub struct PromptEditor { pub editor: Entity, @@ -541,7 +541,7 @@ impl PromptEditor { match &self.mode { PromptEditorMode::Terminal { .. } => vec![ accept, - IconButton::new("confirm", IconName::Play) + IconButton::new("confirm", IconName::PlayOutlined) .icon_color(Color::Info) .shape(IconButtonShape::Square) .tooltip(|window, cx| { diff --git a/crates/agent_ui/src/language_model_selector.rs b/crates/agent_ui/src/language_model_selector.rs index ff18a95f3f8b84eb0876a099cb664aa0908bed8f..7121624c87f6e44ba73f8380bfdf60227cba5b90 100644 --- a/crates/agent_ui/src/language_model_selector.rs +++ b/crates/agent_ui/src/language_model_selector.rs @@ -3,9 +3,7 @@ use std::{cmp::Reverse, sync::Arc}; use collections::{HashSet, IndexMap}; use feature_flags::ZedProFeatureFlag; use fuzzy::{StringMatch, StringMatchCandidate, match_strings}; -use gpui::{ - Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task, actions, -}; +use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task}; use language_model::{ AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId, LanguageModelRegistry, @@ -15,15 +13,6 @@ use picker::{Picker, PickerDelegate}; use proto::Plan; use ui::{ListItem, ListItemSpacing, prelude::*}; -actions!( - agent, - [ - /// Toggles the language model selector dropdown. - #[action(deprecated_aliases = ["assistant::ToggleModelSelector", "assistant2::ToggleModelSelector"])] - ToggleModelSelector - ] -); - const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro"; type OnModelChanged = Arc, &mut App) + 'static>; @@ -587,7 +576,7 @@ impl PickerDelegate for LanguageModelPickerDelegate { .icon_position(IconPosition::Start) .on_click(|_, window, cx| { window.dispatch_action( - zed_actions::agent::OpenConfiguration.boxed_clone(), + zed_actions::agent::OpenSettings.boxed_clone(), cx, ); }), diff --git a/crates/agent_ui/src/message_editor.rs b/crates/agent_ui/src/message_editor.rs index 25c62c5fb32c805d72b4dfd46d7aa6f38579b07c..2185885347478269eb260668aaf3cfaaf2c365c9 100644 --- a/crates/agent_ui/src/message_editor.rs +++ b/crates/agent_ui/src/message_editor.rs @@ -2,20 +2,22 @@ use std::collections::BTreeMap; use std::rc::Rc; use std::sync::Arc; +use crate::agent_diff::AgentDiffThread; use crate::agent_model_selector::AgentModelSelector; -use crate::language_model_selector::ToggleModelSelector; use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip}; use crate::ui::{ MaxModeTooltip, preview::{AgentPreview, UsageCallout}, }; +use agent::history_store::HistoryStore; use agent::{ context::{AgentContextKey, ContextLoadResult, load_context}, context_store::ContextStoreEvent, }; use agent_settings::{AgentSettings, CompletionMode}; +use ai_onboarding::ApiKeysWithProviders; use buffer_diff::BufferDiff; -use client::UserStore; +use cloud_llm_client::CompletionIntent; use collections::{HashMap, HashSet}; use editor::actions::{MoveUp, Paste}; use editor::display_map::CreaseId; @@ -28,17 +30,18 @@ use fs::Fs; use futures::future::Shared; use futures::{FutureExt as _, future}; use gpui::{ - Animation, AnimationExt, App, Entity, EventEmitter, Focusable, Subscription, Task, TextStyle, - WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between, + Animation, AnimationExt, App, Entity, EventEmitter, Focusable, IntoElement, KeyContext, + Subscription, Task, TextStyle, WeakEntity, linear_color_stop, linear_gradient, point, + pulsating_between, }; use language::{Buffer, Language, Point}; use language_model::{ - ConfiguredModel, LanguageModelRequestMessage, MessageContent, ZED_CLOUD_PROVIDER_ID, + ConfiguredModel, LanguageModelRegistry, LanguageModelRequestMessage, MessageContent, + ZED_CLOUD_PROVIDER_ID, }; use multi_buffer; use project::Project; use prompt_store::PromptStore; -use proto::Plan; use settings::Settings; use std::time::Duration; use theme::ThemeSettings; @@ -48,7 +51,7 @@ use ui::{ use util::ResultExt as _; use workspace::{CollaboratorId, Workspace}; use zed_actions::agent::Chat; -use zed_llm_client::CompletionIntent; +use zed_actions::agent::ToggleModelSelector; use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider, crease_for_mention}; use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; @@ -64,6 +67,9 @@ use agent::{ thread_store::{TextThreadStore, ThreadStore}, }; +pub const MIN_EDITOR_LINES: usize = 4; +pub const MAX_EDITOR_LINES: usize = 8; + #[derive(RegisterComponent)] pub struct MessageEditor { thread: Entity, @@ -71,9 +77,9 @@ pub struct MessageEditor { editor: Entity, workspace: WeakEntity, project: Entity, - user_store: Entity, context_store: Entity, prompt_store: Option>, + history_store: Option>, context_strip: Entity, context_picker_menu_handle: PopoverMenuHandle, model_selector: Entity, @@ -87,9 +93,6 @@ pub struct MessageEditor { _subscriptions: Vec, } -const MIN_EDITOR_LINES: usize = 4; -const MAX_EDITOR_LINES: usize = 8; - pub(crate) fn create_editor( workspace: WeakEntity, context_store: WeakEntity, @@ -131,6 +134,7 @@ pub(crate) fn create_editor( placement: Some(ContextMenuPlacement::Above), }); editor.register_addon(ContextCreasesAddon::new()); + editor.register_addon(MessageEditorAddon::new()); editor }); @@ -152,11 +156,11 @@ impl MessageEditor { pub fn new( fs: Arc, workspace: WeakEntity, - user_store: Entity, context_store: Entity, prompt_store: Option>, thread_store: WeakEntity, text_thread_store: WeakEntity, + history_store: Option>, thread: Entity, window: &mut Window, cx: &mut Context, @@ -223,12 +227,12 @@ impl MessageEditor { Self { editor: editor.clone(), project: thread.read(cx).project().clone(), - user_store, thread, incompatible_tools_state: incompatible_tools.clone(), workspace, context_store, prompt_store, + history_store, context_strip, context_picker_menu_handle, load_context_task: None, @@ -475,9 +479,12 @@ impl MessageEditor { window: &mut Window, cx: &mut Context, ) { - if let Ok(diff) = - AgentDiffPane::deploy(self.thread.clone(), self.workspace.clone(), window, cx) - { + if let Ok(diff) = AgentDiffPane::deploy( + AgentDiffThread::Native(self.thread.clone()), + self.workspace.clone(), + window, + cx, + ) { let path_key = multi_buffer::PathKey::for_buffer(&buffer, cx); diff.update(cx, |diff, cx| diff.move_to_path(path_key, window, cx)); } @@ -605,7 +612,11 @@ impl MessageEditor { ) } - fn render_follow_toggle(&self, cx: &mut Context) -> impl IntoElement { + fn render_follow_toggle( + &self, + is_model_selected: bool, + cx: &mut Context, + ) -> impl IntoElement { let following = self .workspace .read_with(cx, |workspace, _| { @@ -614,6 +625,7 @@ impl MessageEditor { .unwrap_or(false); IconButton::new("follow-agent", IconName::Crosshair) + .disabled(!is_model_selected) .icon_size(IconSize::Small) .icon_color(Color::Muted) .toggle_state(following) @@ -701,11 +713,11 @@ impl MessageEditor { cx.listener(|this, _: &RejectAll, window, cx| this.handle_reject_all(window, cx)), ) .capture_action(cx.listener(Self::paste)) - .gap_2() .p_2() - .bg(editor_bg_color) + .gap_2() .border_t_1() .border_color(cx.theme().colors().border) + .bg(editor_bg_color) .child( h_flex() .justify_between() @@ -782,7 +794,7 @@ impl MessageEditor { .justify_between() .child( h_flex() - .child(self.render_follow_toggle(cx)) + .child(self.render_follow_toggle(is_model_selected, cx)) .children(self.render_burn_mode_toggle(cx)), ) .child( @@ -898,6 +910,10 @@ impl MessageEditor { .on_click({ let focus_handle = focus_handle.clone(); move |_event, window, cx| { + telemetry::event!( + "Agent Message Sent", + agent = "zed", + ); focus_handle.dispatch_action( &Chat, window, cx, ); @@ -1266,24 +1282,12 @@ impl MessageEditor { return None; } - let user_store = self.user_store.read(cx); - - let ubb_enable = user_store - .usage_based_billing_enabled() - .map_or(false, |enabled| enabled); - - if ubb_enable { + let user_store = self.project.read(cx).user_store().read(cx); + if user_store.is_usage_based_billing_enabled() { return None; } - let plan = user_store - .current_plan() - .map(|plan| match plan { - Plan::Free => zed_llm_client::Plan::ZedFree, - Plan::ZedPro => zed_llm_client::Plan::ZedPro, - Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial, - }) - .unwrap_or(zed_llm_client::Plan::ZedFree); + let plan = user_store.plan().unwrap_or(cloud_llm_client::Plan::ZedFree); let usage = user_store.model_request_usage()?; @@ -1485,6 +1489,31 @@ pub struct ContextCreasesAddon { _subscription: Option, } +pub struct MessageEditorAddon {} + +impl MessageEditorAddon { + pub fn new() -> Self { + Self {} + } +} + +impl Addon for MessageEditorAddon { + fn to_any(&self) -> &dyn std::any::Any { + self + } + + fn to_any_mut(&mut self) -> Option<&mut dyn std::any::Any> { + Some(self) + } + + fn extend_key_context(&self, key_context: &mut KeyContext, cx: &App) { + let settings = agent_settings::AgentSettings::get_global(cx); + if settings.use_modifier_to_send { + key_context.add("use_modifier_to_send"); + } + } +} + impl Addon for ContextCreasesAddon { fn to_any(&self) -> &dyn std::any::Any { self @@ -1620,9 +1649,38 @@ impl Render for MessageEditor { let line_height = TextSize::Small.rems(cx).to_pixels(window.rem_size()) * 1.5; + let has_configured_providers = LanguageModelRegistry::read_global(cx) + .providers() + .iter() + .filter(|provider| { + provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID + }) + .count() + > 0; + + let is_signed_out = self + .workspace + .read_with(cx, |workspace, _| { + workspace.client().status().borrow().is_signed_out() + }) + .unwrap_or(true); + + let has_history = self + .history_store + .as_ref() + .and_then(|hs| hs.update(cx, |hs, cx| hs.entries(cx).len() > 0).ok()) + .unwrap_or(false) + || self + .thread + .read_with(cx, |thread, _| thread.messages().len() > 0); + v_flex() .size_full() .bg(cx.theme().colors().panel_background) + .when( + !has_history && is_signed_out && has_configured_providers, + |this| this.child(cx.new(ApiKeysWithProviders::new)), + ) .when(changed_buffers.len() > 0, |parent| { parent.child(self.render_edits_bar(&changed_buffers, window, cx)) }) @@ -1694,7 +1752,6 @@ impl AgentPreview for MessageEditor { ) -> Option { if let Some(workspace) = workspace.upgrade() { let fs = workspace.read(cx).app_state().fs.clone(); - let user_store = workspace.read(cx).app_state().user_store.clone(); let project = workspace.read(cx).project().clone(); let weak_project = project.downgrade(); let context_store = cx.new(|_cx| ContextStore::new(weak_project, None)); @@ -1707,11 +1764,11 @@ impl AgentPreview for MessageEditor { MessageEditor::new( fs, workspace.downgrade(), - user_store, context_store, None, thread_store.downgrade(), text_thread_store.downgrade(), + None, thread, window, cx, diff --git a/crates/agent_ui/src/terminal_inline_assistant.rs b/crates/agent_ui/src/terminal_inline_assistant.rs index 91867957cdcd1b3cb2ff9c40d385737b74d969f1..bcbc308c99da7b80e716fce9e60461352dcb814c 100644 --- a/crates/agent_ui/src/terminal_inline_assistant.rs +++ b/crates/agent_ui/src/terminal_inline_assistant.rs @@ -10,6 +10,7 @@ use agent::{ use agent_settings::AgentSettings; use anyhow::{Context as _, Result}; use client::telemetry::Telemetry; +use cloud_llm_client::CompletionIntent; use collections::{HashMap, VecDeque}; use editor::{MultiBuffer, actions::SelectAll}; use fs::Fs; @@ -27,7 +28,6 @@ use terminal_view::TerminalView; use ui::prelude::*; use util::ResultExt; use workspace::{Toast, Workspace, notifications::NotificationId}; -use zed_llm_client::CompletionIntent; pub fn init( fs: Arc, diff --git a/crates/agent_ui/src/text_thread_editor.rs b/crates/agent_ui/src/text_thread_editor.rs index de7606dbfb333e524c81435432050cfba1b71831..4836a95c8efd1dafcd7654802e55c54694280a87 100644 --- a/crates/agent_ui/src/text_thread_editor.rs +++ b/crates/agent_ui/src/text_thread_editor.rs @@ -1,8 +1,6 @@ use crate::{ burn_mode_tooltip::BurnModeTooltip, - language_model_selector::{ - LanguageModelSelector, ToggleModelSelector, language_model_selector, - }, + language_model_selector::{LanguageModelSelector, language_model_selector}, }; use agent_settings::{AgentSettings, CompletionMode}; use anyhow::Result; @@ -14,7 +12,7 @@ use assistant_slash_commands::{ use client::{proto, zed_urls}; use collections::{BTreeSet, HashMap, HashSet, hash_map}; use editor::{ - Anchor, Editor, EditorEvent, MenuInlineCompletionsPolicy, MultiBuffer, MultiBufferSnapshot, + Anchor, Editor, EditorEvent, MenuEditPredictionsPolicy, MultiBuffer, MultiBufferSnapshot, RowExt, ToOffset as _, ToPoint, actions::{MoveToEndOfLine, Newline, ShowCompletions}, display_map::{ @@ -38,8 +36,7 @@ use language::{ language_settings::{SoftWrap, all_language_settings}, }; use language_model::{ - ConfigurationError, LanguageModelExt, LanguageModelImage, LanguageModelProviderTosView, - LanguageModelRegistry, Role, + ConfigurationError, LanguageModelExt, LanguageModelImage, LanguageModelRegistry, Role, }; use multi_buffer::MultiBufferRow; use picker::{Picker, popover_menu::PickerPopoverMenu}; @@ -74,6 +71,7 @@ use workspace::{ pane, searchable::{SearchEvent, SearchableItem}, }; +use zed_actions::agent::ToggleModelSelector; use crate::{slash_command::SlashCommandCompletionProvider, slash_command_picker}; use assistant_context::{ @@ -256,7 +254,7 @@ impl TextThreadEditor { editor.set_show_wrap_guides(false, cx); editor.set_show_indent_guides(false, cx); editor.set_completion_provider(Some(Rc::new(completion_provider))); - editor.set_menu_inline_completions_policy(MenuInlineCompletionsPolicy::Never); + editor.set_menu_edit_predictions_policy(MenuEditPredictionsPolicy::Never); editor.set_collaboration_hub(Box::new(project.clone())); let show_edit_predictions = all_language_settings(None, cx) @@ -1256,7 +1254,6 @@ impl TextThreadEditor { ), priority: usize::MAX, render: render_block(MessageMetadata::from(message)), - render_in_minimap: false, }; let mut new_blocks = vec![]; let mut block_index_to_message = vec![]; @@ -1858,7 +1855,6 @@ impl TextThreadEditor { .into_any_element() }), priority: 0, - render_in_minimap: false, }) }) .collect::>(); @@ -1897,108 +1893,6 @@ impl TextThreadEditor { .update(cx, |context, cx| context.summarize(true, cx)); } - fn render_notice(&self, cx: &mut Context) -> Option { - // This was previously gated behind the `zed-pro` feature flag. Since we - // aren't planning to ship that right now, we're just hard-coding this - // value to not show the nudge. - let nudge = Some(false); - - let model_registry = LanguageModelRegistry::read_global(cx); - - if nudge.map_or(false, |value| value) { - Some( - h_flex() - .p_3() - .border_b_1() - .border_color(cx.theme().colors().border_variant) - .bg(cx.theme().colors().editor_background) - .justify_between() - .child( - h_flex() - .gap_3() - .child(Icon::new(IconName::ZedAssistant).color(Color::Accent)) - .child(Label::new("Zed AI is here! Get started by signing in →")), - ) - .child( - Button::new("sign-in", "Sign in") - .size(ButtonSize::Compact) - .style(ButtonStyle::Filled) - .on_click(cx.listener(|this, _event, _window, cx| { - let client = this - .workspace - .read_with(cx, |workspace, _| workspace.client().clone()) - .log_err(); - - if let Some(client) = client { - cx.spawn(async move |context_editor, cx| { - match client.authenticate_and_connect(true, cx).await { - util::ConnectionResult::Timeout => { - log::error!("Authentication timeout") - } - util::ConnectionResult::ConnectionReset => { - log::error!("Connection reset") - } - util::ConnectionResult::Result(r) => { - if r.log_err().is_some() { - context_editor - .update(cx, |_, cx| cx.notify()) - .ok(); - } - } - } - }) - .detach() - } - })), - ) - .into_any_element(), - ) - } else if let Some(configuration_error) = - model_registry.configuration_error(model_registry.default_model(), cx) - { - Some( - h_flex() - .px_3() - .py_2() - .border_b_1() - .border_color(cx.theme().colors().border_variant) - .bg(cx.theme().colors().editor_background) - .justify_between() - .child( - h_flex() - .gap_3() - .child( - Icon::new(IconName::Warning) - .size(IconSize::Small) - .color(Color::Warning), - ) - .child(Label::new(configuration_error.to_string())), - ) - .child( - Button::new("open-configuration", "Configure Providers") - .size(ButtonSize::Compact) - .icon(Some(IconName::SlidersVertical)) - .icon_size(IconSize::Small) - .icon_position(IconPosition::Start) - .style(ButtonStyle::Filled) - .on_click({ - let focus_handle = self.focus_handle(cx).clone(); - move |_event, window, cx| { - focus_handle.dispatch_action( - &zed_actions::agent::OpenConfiguration, - window, - cx, - ); - } - }), - ) - .into_any_element(), - ) - } else { - None - } - } - fn render_send_button(&self, window: &mut Window, cx: &mut Context) -> impl IntoElement { let focus_handle = self.focus_handle(cx).clone(); @@ -2130,12 +2024,13 @@ impl TextThreadEditor { .map(|default| default.model); let model_name = match active_model { Some(model) => model.name().0, - None => SharedString::from("No model selected"), + None => SharedString::from("Select Model"), }; let active_provider = LanguageModelRegistry::read_global(cx) .default_model() .map(|default| default.provider); + let provider_icon = match active_provider { Some(provider) => provider.icon(), None => IconName::Ai, @@ -2583,20 +2478,7 @@ impl EventEmitter for TextThreadEditor {} impl Render for TextThreadEditor { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { - let provider = LanguageModelRegistry::read_global(cx) - .default_model() - .map(|default| default.provider); - - let accept_terms = if self.show_accept_terms { - provider.as_ref().and_then(|provider| { - provider.render_accept_terms(LanguageModelProviderTosView::PromptEditorPopup, cx) - }) - } else { - None - }; - let language_model_selector = self.language_model_selector_menu_handle.clone(); - let burn_mode_toggle = self.render_burn_mode_toggle(cx); v_flex() .key_context("ContextEditor") @@ -2613,28 +2495,12 @@ impl Render for TextThreadEditor { language_model_selector.toggle(window, cx); }) .size_full() - .children(self.render_notice(cx)) .child( div() .flex_grow() .bg(cx.theme().colors().editor_background) .child(self.editor.clone()), ) - .when_some(accept_terms, |this, element| { - this.child( - div() - .absolute() - .right_3() - .bottom_12() - .max_w_96() - .py_2() - .px_3() - .elevation_2(cx) - .bg(cx.theme().colors().surface_background) - .occlude() - .child(element), - ) - }) .children(self.render_last_error(cx)) .child( h_flex() @@ -2651,7 +2517,7 @@ impl Render for TextThreadEditor { h_flex() .gap_0p5() .child(self.render_inject_context_menu(cx)) - .when_some(burn_mode_toggle, |this, element| this.child(element)), + .children(self.render_burn_mode_toggle(cx)), ) .child( h_flex() diff --git a/crates/agent_ui/src/thread_history.rs b/crates/agent_ui/src/thread_history.rs index a2ee816f7315dd0f99266b30e31f9b9e9eb6534e..b8d1db88d6e3164b32ade0f2137ad7ca37a0650a 100644 --- a/crates/agent_ui/src/thread_history.rs +++ b/crates/agent_ui/src/thread_history.rs @@ -701,7 +701,7 @@ impl RenderOnce for HistoryEntryElement { .on_hover(self.on_hover) .end_slot::(if self.hovered || self.selected { Some( - IconButton::new("delete", IconName::TrashAlt) + IconButton::new("delete", IconName::Trash) .shape(IconButtonShape::Square) .icon_size(IconSize::XSmall) .icon_color(Color::Muted) diff --git a/crates/agent_ui/src/ui.rs b/crates/agent_ui/src/ui.rs index 43cd0f5e8937d860ce0f453d40ece8d230f7d16d..b477a8c385c5f8aee85b54cf5f82cdd49e2e2484 100644 --- a/crates/agent_ui/src/ui.rs +++ b/crates/agent_ui/src/ui.rs @@ -1,11 +1,14 @@ mod agent_notification; mod burn_mode_tooltip; mod context_pill; +mod end_trial_upsell; +mod new_thread_button; mod onboarding_modal; pub mod preview; -mod upsell; pub use agent_notification::*; pub use burn_mode_tooltip::*; pub use context_pill::*; +pub use end_trial_upsell::*; +pub use new_thread_button::*; pub use onboarding_modal::*; diff --git a/crates/agent_ui/src/ui/end_trial_upsell.rs b/crates/agent_ui/src/ui/end_trial_upsell.rs new file mode 100644 index 0000000000000000000000000000000000000000..3a8a119800543ad033efd563d7896ccc80add373 --- /dev/null +++ b/crates/agent_ui/src/ui/end_trial_upsell.rs @@ -0,0 +1,117 @@ +use std::sync::Arc; + +use ai_onboarding::{AgentPanelOnboardingCard, PlanDefinitions}; +use client::zed_urls; +use gpui::{AnyElement, App, IntoElement, RenderOnce, Window}; +use ui::{Divider, Tooltip, prelude::*}; + +#[derive(IntoElement, RegisterComponent)] +pub struct EndTrialUpsell { + dismiss_upsell: Arc, +} + +impl EndTrialUpsell { + pub fn new(dismiss_upsell: Arc) -> Self { + Self { dismiss_upsell } + } +} + +impl RenderOnce for EndTrialUpsell { + fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { + let plan_definitions = PlanDefinitions; + + let pro_section = v_flex() + .gap_1() + .child( + h_flex() + .gap_2() + .child( + Label::new("Pro") + .size(LabelSize::Small) + .color(Color::Accent) + .buffer_font(cx), + ) + .child(Divider::horizontal()), + ) + .child(plan_definitions.pro_plan(false)) + .child( + Button::new("cta-button", "Upgrade to Zed Pro") + .full_width() + .style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .on_click(move |_, _window, cx| { + telemetry::event!("Upgrade To Pro Clicked", state = "end-of-trial"); + cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)) + }), + ); + + let free_section = v_flex() + .mt_1p5() + .gap_1() + .child( + h_flex() + .gap_2() + .child( + Label::new("Free") + .size(LabelSize::Small) + .color(Color::Muted) + .buffer_font(cx), + ) + .child( + Label::new("(Current Plan)") + .size(LabelSize::Small) + .color(Color::Custom(cx.theme().colors().text_muted.opacity(0.6))) + .buffer_font(cx), + ) + .child(Divider::horizontal()), + ) + .child(plan_definitions.free_plan()); + + AgentPanelOnboardingCard::new() + .child(Headline::new("Your Zed Pro Trial has expired")) + .child( + Label::new("You've been automatically reset to the Free plan.") + .color(Color::Muted) + .mb_2(), + ) + .child(pro_section) + .child(free_section) + .child( + h_flex().absolute().top_4().right_4().child( + IconButton::new("dismiss_onboarding", IconName::Close) + .icon_size(IconSize::Small) + .tooltip(Tooltip::text("Dismiss")) + .on_click({ + let callback = self.dismiss_upsell.clone(); + move |_, window, cx| { + telemetry::event!("Banner Dismissed", source = "AI Onboarding"); + callback(window, cx) + } + }), + ), + ) + } +} + +impl Component for EndTrialUpsell { + fn scope() -> ComponentScope { + ComponentScope::Onboarding + } + + fn name() -> &'static str { + "End of Trial Upsell Banner" + } + + fn sort_name() -> &'static str { + "End of Trial Upsell Banner" + } + + fn preview(_window: &mut Window, _cx: &mut App) -> Option { + Some( + v_flex() + .child(EndTrialUpsell { + dismiss_upsell: Arc::new(|_, _| {}), + }) + .into_any_element(), + ) + } +} diff --git a/crates/agent_ui/src/ui/new_thread_button.rs b/crates/agent_ui/src/ui/new_thread_button.rs new file mode 100644 index 0000000000000000000000000000000000000000..7764144150762f9b828ea98f1c917332759bd5ad --- /dev/null +++ b/crates/agent_ui/src/ui/new_thread_button.rs @@ -0,0 +1,75 @@ +use gpui::{ClickEvent, ElementId, IntoElement, ParentElement, Styled}; +use ui::prelude::*; + +#[derive(IntoElement)] +pub struct NewThreadButton { + id: ElementId, + label: SharedString, + icon: IconName, + keybinding: Option, + on_click: Option>, +} + +impl NewThreadButton { + pub fn new(id: impl Into, label: impl Into, icon: IconName) -> Self { + Self { + id: id.into(), + label: label.into(), + icon, + keybinding: None, + on_click: None, + } + } + + pub fn keybinding(mut self, keybinding: Option) -> Self { + self.keybinding = keybinding; + self + } + + pub fn on_click(mut self, handler: F) -> Self + where + F: Fn(&mut Window, &mut App) + 'static, + { + self.on_click = Some(Box::new( + move |_: &ClickEvent, window: &mut Window, cx: &mut App| handler(window, cx), + )); + self + } +} + +impl RenderOnce for NewThreadButton { + fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { + h_flex() + .id(self.id) + .w_full() + .py_1p5() + .px_2() + .gap_1() + .justify_between() + .rounded_md() + .border_1() + .border_color(cx.theme().colors().border.opacity(0.4)) + .bg(cx.theme().colors().element_active.opacity(0.2)) + .hover(|style| { + style + .bg(cx.theme().colors().element_hover) + .border_color(cx.theme().colors().border) + }) + .child( + h_flex() + .gap_1p5() + .child( + Icon::new(self.icon) + .size(IconSize::XSmall) + .color(Color::Muted), + ) + .child(Label::new(self.label).size(LabelSize::Small)), + ) + .when_some(self.keybinding, |this, keybinding| { + this.child(keybinding.size(rems_from_px(10.))) + }) + .when_some(self.on_click, |this, on_click| { + this.on_click(move |event, window, cx| on_click(event, window, cx)) + }) + } +} diff --git a/crates/agent_ui/src/ui/preview/usage_callouts.rs b/crates/agent_ui/src/ui/preview/usage_callouts.rs index 45af41395b52afc8655c7cdd748a3228868b2d0f..64869a6ec71cdbe8e3532983c48784136b3dcb36 100644 --- a/crates/agent_ui/src/ui/preview/usage_callouts.rs +++ b/crates/agent_ui/src/ui/preview/usage_callouts.rs @@ -1,8 +1,8 @@ use client::{ModelRequestUsage, RequestUsage, zed_urls}; +use cloud_llm_client::{Plan, UsageLimit}; use component::{empty_example, example_group_with_title, single_example}; use gpui::{AnyElement, App, IntoElement, RenderOnce, Window}; use ui::{Callout, prelude::*}; -use zed_llm_client::{Plan, UsageLimit}; #[derive(IntoElement, RegisterComponent)] pub struct UsageCallout { diff --git a/crates/agent_ui/src/ui/upsell.rs b/crates/agent_ui/src/ui/upsell.rs deleted file mode 100644 index f311aade22770d534189f25494b3f06588a70d9d..0000000000000000000000000000000000000000 --- a/crates/agent_ui/src/ui/upsell.rs +++ /dev/null @@ -1,163 +0,0 @@ -use component::{Component, ComponentScope, single_example}; -use gpui::{ - AnyElement, App, ClickEvent, IntoElement, ParentElement, RenderOnce, SharedString, Styled, - Window, -}; -use theme::ActiveTheme; -use ui::{ - Button, ButtonCommon, ButtonStyle, Checkbox, Clickable, Color, Label, LabelCommon, - RegisterComponent, ToggleState, h_flex, v_flex, -}; - -/// A component that displays an upsell message with a call-to-action button -/// -/// # Example -/// ``` -/// let upsell = Upsell::new( -/// "Upgrade to Zed Pro", -/// "Get access to advanced AI features and more", -/// "Upgrade Now", -/// Box::new(|_, _window, cx| { -/// cx.open_url("https://zed.dev/pricing"); -/// }), -/// Box::new(|_, _window, cx| { -/// // Handle dismiss -/// }), -/// Box::new(|checked, window, cx| { -/// // Handle don't show again -/// }), -/// ); -/// ``` -#[derive(IntoElement, RegisterComponent)] -pub struct Upsell { - title: SharedString, - message: SharedString, - cta_text: SharedString, - on_click: Box, - on_dismiss: Box, - on_dont_show_again: Box, -} - -impl Upsell { - /// Create a new upsell component - pub fn new( - title: impl Into, - message: impl Into, - cta_text: impl Into, - on_click: Box, - on_dismiss: Box, - on_dont_show_again: Box, - ) -> Self { - Self { - title: title.into(), - message: message.into(), - cta_text: cta_text.into(), - on_click, - on_dismiss, - on_dont_show_again, - } - } -} - -impl RenderOnce for Upsell { - fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { - v_flex() - .w_full() - .p_4() - .gap_3() - .bg(cx.theme().colors().surface_background) - .rounded_md() - .border_1() - .border_color(cx.theme().colors().border) - .child( - v_flex() - .gap_1() - .child( - Label::new(self.title) - .size(ui::LabelSize::Large) - .weight(gpui::FontWeight::BOLD), - ) - .child(Label::new(self.message).color(Color::Muted)), - ) - .child( - h_flex() - .w_full() - .justify_between() - .items_center() - .child( - h_flex() - .items_center() - .gap_1() - .child( - Checkbox::new("dont-show-again", ToggleState::Unselected).on_click( - move |_, window, cx| { - (self.on_dont_show_again)(true, window, cx); - }, - ), - ) - .child( - Label::new("Don't show again") - .color(Color::Muted) - .size(ui::LabelSize::Small), - ), - ) - .child( - h_flex() - .gap_2() - .child( - Button::new("dismiss-button", "No Thanks") - .style(ButtonStyle::Subtle) - .on_click(self.on_dismiss), - ) - .child( - Button::new("cta-button", self.cta_text) - .style(ButtonStyle::Filled) - .on_click(self.on_click), - ), - ), - ) - } -} - -impl Component for Upsell { - fn scope() -> ComponentScope { - ComponentScope::Agent - } - - fn name() -> &'static str { - "Upsell" - } - - fn description() -> Option<&'static str> { - Some("A promotional component that displays a message with a call-to-action.") - } - - fn preview(window: &mut Window, cx: &mut App) -> Option { - let examples = vec![ - single_example( - "Default", - Upsell::new( - "Upgrade to Zed Pro", - "Get unlimited access to AI features and more with Zed Pro. Unlock advanced AI capabilities and other premium features.", - "Upgrade Now", - Box::new(|_, _, _| {}), - Box::new(|_, _, _| {}), - Box::new(|_, _, _| {}), - ).render(window, cx).into_any_element(), - ), - single_example( - "Short Message", - Upsell::new( - "Try Zed Pro for free", - "Start your 7-day trial today.", - "Start Trial", - Box::new(|_, _, _| {}), - Box::new(|_, _, _| {}), - Box::new(|_, _, _| {}), - ).render(window, cx).into_any_element(), - ), - ]; - - Some(v_flex().gap_4().children(examples).into_any_element()) - } -} diff --git a/crates/ai_onboarding/Cargo.toml b/crates/ai_onboarding/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..95a45b1a6fbe103f02532d33c21af707f2f51d45 --- /dev/null +++ b/crates/ai_onboarding/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "ai_onboarding" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/ai_onboarding.rs" + +[features] +default = [] + +[dependencies] +client.workspace = true +cloud_llm_client.workspace = true +component.workspace = true +gpui.workspace = true +language_model.workspace = true +serde.workspace = true +smallvec.workspace = true +telemetry.workspace = true +ui.workspace = true +workspace-hack.workspace = true +zed_actions.workspace = true diff --git a/crates/inline_completion_button/LICENSE-GPL b/crates/ai_onboarding/LICENSE-GPL similarity index 100% rename from crates/inline_completion_button/LICENSE-GPL rename to crates/ai_onboarding/LICENSE-GPL diff --git a/crates/ai_onboarding/src/agent_api_keys_onboarding.rs b/crates/ai_onboarding/src/agent_api_keys_onboarding.rs new file mode 100644 index 0000000000000000000000000000000000000000..b55ad4c89549a8843fe2d8273da60236400cb565 --- /dev/null +++ b/crates/ai_onboarding/src/agent_api_keys_onboarding.rs @@ -0,0 +1,141 @@ +use gpui::{Action, IntoElement, ParentElement, RenderOnce, point}; +use language_model::{LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID}; +use ui::{Divider, List, ListBulletItem, prelude::*}; + +pub struct ApiKeysWithProviders { + configured_providers: Vec<(IconName, SharedString)>, +} + +impl ApiKeysWithProviders { + pub fn new(cx: &mut Context) -> Self { + cx.subscribe( + &LanguageModelRegistry::global(cx), + |this: &mut Self, _registry, event: &language_model::Event, cx| match event { + language_model::Event::ProviderStateChanged + | language_model::Event::AddedProvider(_) + | language_model::Event::RemovedProvider(_) => { + this.configured_providers = Self::compute_configured_providers(cx) + } + _ => {} + }, + ) + .detach(); + + Self { + configured_providers: Self::compute_configured_providers(cx), + } + } + + fn compute_configured_providers(cx: &App) -> Vec<(IconName, SharedString)> { + LanguageModelRegistry::read_global(cx) + .providers() + .iter() + .filter(|provider| { + provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID + }) + .map(|provider| (provider.icon(), provider.name().0.clone())) + .collect() + } +} + +impl Render for ApiKeysWithProviders { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + let configured_providers_list = + self.configured_providers + .iter() + .cloned() + .map(|(icon, name)| { + h_flex() + .gap_1p5() + .child(Icon::new(icon).size(IconSize::XSmall).color(Color::Muted)) + .child(Label::new(name)) + }); + div() + .mx_2p5() + .p_1() + .pb_0() + .gap_2() + .rounded_t_lg() + .border_t_1() + .border_x_1() + .border_color(cx.theme().colors().border.opacity(0.5)) + .bg(cx.theme().colors().background.alpha(0.5)) + .shadow(vec![gpui::BoxShadow { + color: gpui::black().opacity(0.15), + offset: point(px(1.), px(-1.)), + blur_radius: px(3.), + spread_radius: px(0.), + }]) + .child( + h_flex() + .px_2p5() + .py_1p5() + .gap_2() + .flex_wrap() + .rounded_t(px(5.)) + .overflow_hidden() + .border_t_1() + .border_x_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().panel_background) + .child( + h_flex() + .min_w_0() + .gap_2() + .child( + Icon::new(IconName::Info) + .size(IconSize::XSmall) + .color(Color::Muted) + ) + .child( + div() + .w_full() + .child( + Label::new("Start now using API keys from your environment for the following providers:") + .color(Color::Muted) + ) + ) + ) + .children(configured_providers_list) + ) + } +} + +#[derive(IntoElement)] +pub struct ApiKeysWithoutProviders; + +impl ApiKeysWithoutProviders { + pub fn new() -> Self { + Self + } +} + +impl RenderOnce for ApiKeysWithoutProviders { + fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { + v_flex() + .mt_2() + .gap_1() + .child( + h_flex() + .gap_2() + .child( + Label::new("API Keys") + .size(LabelSize::Small) + .color(Color::Muted) + .buffer_font(cx), + ) + .child(Divider::horizontal()), + ) + .child(List::new().child(ListBulletItem::new( + "Add your own keys to use AI without signing in.", + ))) + .child( + Button::new("configure-providers", "Configure Providers") + .full_width() + .style(ButtonStyle::Outlined) + .on_click(move |_, window, cx| { + window.dispatch_action(zed_actions::agent::OpenSettings.boxed_clone(), cx); + }), + ) + } +} diff --git a/crates/ai_onboarding/src/agent_panel_onboarding_card.rs b/crates/ai_onboarding/src/agent_panel_onboarding_card.rs new file mode 100644 index 0000000000000000000000000000000000000000..c63c5926428ab47f80afd2e157f90f8852dbf4ee --- /dev/null +++ b/crates/ai_onboarding/src/agent_panel_onboarding_card.rs @@ -0,0 +1,83 @@ +use gpui::{AnyElement, IntoElement, ParentElement, linear_color_stop, linear_gradient}; +use smallvec::SmallVec; +use ui::{Vector, VectorName, prelude::*}; + +#[derive(IntoElement)] +pub struct AgentPanelOnboardingCard { + children: SmallVec<[AnyElement; 2]>, +} + +impl AgentPanelOnboardingCard { + pub fn new() -> Self { + Self { + children: SmallVec::new(), + } + } +} + +impl ParentElement for AgentPanelOnboardingCard { + fn extend(&mut self, elements: impl IntoIterator) { + self.children.extend(elements) + } +} + +impl RenderOnce for AgentPanelOnboardingCard { + fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { + div() + .m_2p5() + .p(px(3.)) + .elevation_2(cx) + .rounded_lg() + .bg(cx.theme().colors().background.alpha(0.5)) + .child( + v_flex() + .relative() + .size_full() + .px_4() + .py_3() + .gap_2() + .border_1() + .rounded(px(5.)) + .border_color(cx.theme().colors().text.alpha(0.1)) + .overflow_hidden() + .bg(cx.theme().colors().panel_background) + .child( + div() + .opacity(0.5) + .absolute() + .top(px(-8.0)) + .right_0() + .w(px(400.)) + .h(px(92.)) + .rounded_md() + .child( + Vector::new( + VectorName::AiGrid, + rems_from_px(400.), + rems_from_px(92.), + ) + .color(Color::Custom(cx.theme().colors().text.alpha(0.32))), + ), + ) + .child( + div() + .absolute() + .top_0p5() + .right_0p5() + .w(px(660.)) + .h(px(401.)) + .overflow_hidden() + .rounded_md() + .bg(linear_gradient( + 75., + linear_color_stop( + cx.theme().colors().panel_background.alpha(0.01), + 1.0, + ), + linear_color_stop(cx.theme().colors().panel_background, 0.45), + )), + ) + .children(self.children), + ) + } +} diff --git a/crates/ai_onboarding/src/agent_panel_onboarding_content.rs b/crates/ai_onboarding/src/agent_panel_onboarding_content.rs new file mode 100644 index 0000000000000000000000000000000000000000..f1629eeff81ef51bf2ff823eef0db64c1585a669 --- /dev/null +++ b/crates/ai_onboarding/src/agent_panel_onboarding_content.rs @@ -0,0 +1,84 @@ +use std::sync::Arc; + +use client::{Client, UserStore}; +use cloud_llm_client::Plan; +use gpui::{Entity, IntoElement, ParentElement}; +use language_model::{LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID}; +use ui::prelude::*; + +use crate::{AgentPanelOnboardingCard, ApiKeysWithoutProviders, ZedAiOnboarding}; + +pub struct AgentPanelOnboarding { + user_store: Entity, + client: Arc, + configured_providers: Vec<(IconName, SharedString)>, + continue_with_zed_ai: Arc, +} + +impl AgentPanelOnboarding { + pub fn new( + user_store: Entity, + client: Arc, + continue_with_zed_ai: impl Fn(&mut Window, &mut App) + 'static, + cx: &mut Context, + ) -> Self { + cx.subscribe( + &LanguageModelRegistry::global(cx), + |this: &mut Self, _registry, event: &language_model::Event, cx| match event { + language_model::Event::ProviderStateChanged + | language_model::Event::AddedProvider(_) + | language_model::Event::RemovedProvider(_) => { + this.configured_providers = Self::compute_available_providers(cx) + } + _ => {} + }, + ) + .detach(); + + Self { + user_store, + client, + configured_providers: Self::compute_available_providers(cx), + continue_with_zed_ai: Arc::new(continue_with_zed_ai), + } + } + + fn compute_available_providers(cx: &App) -> Vec<(IconName, SharedString)> { + LanguageModelRegistry::read_global(cx) + .providers() + .iter() + .filter(|provider| { + provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID + }) + .map(|provider| (provider.icon(), provider.name().0.clone())) + .collect() + } +} + +impl Render for AgentPanelOnboarding { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + let enrolled_in_trial = self.user_store.read(cx).plan() == Some(Plan::ZedProTrial); + let is_pro_user = self.user_store.read(cx).plan() == Some(Plan::ZedPro); + + AgentPanelOnboardingCard::new() + .child( + ZedAiOnboarding::new( + self.client.clone(), + &self.user_store, + self.continue_with_zed_ai.clone(), + cx, + ) + .with_dismiss({ + let callback = self.continue_with_zed_ai.clone(); + move |window, cx| callback(window, cx) + }), + ) + .map(|this| { + if enrolled_in_trial || is_pro_user || self.configured_providers.len() >= 1 { + this + } else { + this.child(ApiKeysWithoutProviders::new()) + } + }) + } +} diff --git a/crates/ai_onboarding/src/ai_onboarding.rs b/crates/ai_onboarding/src/ai_onboarding.rs new file mode 100644 index 0000000000000000000000000000000000000000..b9a1e49a4acdfb3f4a94b6313d1e6fb3ef969adc --- /dev/null +++ b/crates/ai_onboarding/src/ai_onboarding.rs @@ -0,0 +1,436 @@ +mod agent_api_keys_onboarding; +mod agent_panel_onboarding_card; +mod agent_panel_onboarding_content; +mod ai_upsell_card; +mod edit_prediction_onboarding_content; +mod plan_definitions; +mod young_account_banner; + +pub use agent_api_keys_onboarding::{ApiKeysWithProviders, ApiKeysWithoutProviders}; +pub use agent_panel_onboarding_card::AgentPanelOnboardingCard; +pub use agent_panel_onboarding_content::AgentPanelOnboarding; +pub use ai_upsell_card::AiUpsellCard; +use cloud_llm_client::Plan; +pub use edit_prediction_onboarding_content::EditPredictionOnboarding; +pub use plan_definitions::PlanDefinitions; +pub use young_account_banner::YoungAccountBanner; + +use std::sync::Arc; + +use client::{Client, UserStore, zed_urls}; +use gpui::{AnyElement, Entity, IntoElement, ParentElement}; +use ui::{Divider, RegisterComponent, TintColor, Tooltip, prelude::*}; + +#[derive(PartialEq)] +pub enum SignInStatus { + SignedIn, + SigningIn, + SignedOut, +} + +impl From for SignInStatus { + fn from(status: client::Status) -> Self { + if status.is_signing_in() { + Self::SigningIn + } else if status.is_signed_out() { + Self::SignedOut + } else { + Self::SignedIn + } + } +} + +#[derive(RegisterComponent, IntoElement)] +pub struct ZedAiOnboarding { + pub sign_in_status: SignInStatus, + pub has_accepted_terms_of_service: bool, + pub plan: Option, + pub account_too_young: bool, + pub continue_with_zed_ai: Arc, + pub sign_in: Arc, + pub accept_terms_of_service: Arc, + pub dismiss_onboarding: Option>, +} + +impl ZedAiOnboarding { + pub fn new( + client: Arc, + user_store: &Entity, + continue_with_zed_ai: Arc, + cx: &mut App, + ) -> Self { + let store = user_store.read(cx); + let status = *client.status().borrow(); + + Self { + sign_in_status: status.into(), + has_accepted_terms_of_service: store.has_accepted_terms_of_service(), + plan: store.plan(), + account_too_young: store.account_too_young(), + continue_with_zed_ai, + accept_terms_of_service: Arc::new({ + let store = user_store.clone(); + move |_window, cx| { + let task = store.update(cx, |store, cx| store.accept_terms_of_service(cx)); + task.detach_and_log_err(cx); + } + }), + sign_in: Arc::new(move |_window, cx| { + cx.spawn({ + let client = client.clone(); + async move |cx| client.sign_in_with_optional_connect(true, cx).await + }) + .detach_and_log_err(cx); + }), + dismiss_onboarding: None, + } + } + + pub fn with_dismiss( + mut self, + dismiss_callback: impl Fn(&mut Window, &mut App) + 'static, + ) -> Self { + self.dismiss_onboarding = Some(Arc::new(dismiss_callback)); + self + } + + fn render_accept_terms_of_service(&self) -> AnyElement { + v_flex() + .gap_1() + .w_full() + .child(Headline::new("Accept Terms of Service")) + .child( + Label::new("We don’t sell your data, track you across the web, or compromise your privacy.") + .color(Color::Muted) + .mb_2(), + ) + .child( + Button::new("terms_of_service", "Review Terms of Service") + .full_width() + .style(ButtonStyle::Outlined) + .icon(IconName::ArrowUpRight) + .icon_color(Color::Muted) + .icon_size(IconSize::XSmall) + .on_click(move |_, _window, cx| { + telemetry::event!("Review Terms of Service Clicked"); + cx.open_url(&zed_urls::terms_of_service(cx)) + }), + ) + .child( + Button::new("accept_terms", "Accept") + .full_width() + .style(ButtonStyle::Tinted(TintColor::Accent)) + .on_click({ + let callback = self.accept_terms_of_service.clone(); + move |_, window, cx| { + telemetry::event!("Terms of Service Accepted"); + (callback)(window, cx)} + }), + ) + .into_any_element() + } + + fn render_sign_in_disclaimer(&self, _cx: &mut App) -> AnyElement { + let signing_in = matches!(self.sign_in_status, SignInStatus::SigningIn); + let plan_definitions = PlanDefinitions; + + v_flex() + .gap_1() + .child(Headline::new("Welcome to Zed AI")) + .child( + Label::new("Sign in to try Zed Pro for 14 days, no credit card required.") + .color(Color::Muted) + .mb_2(), + ) + .child(plan_definitions.pro_plan(false)) + .child( + Button::new("sign_in", "Try Zed Pro for Free") + .disabled(signing_in) + .full_width() + .style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .on_click({ + let callback = self.sign_in.clone(); + move |_, window, cx| { + telemetry::event!("Start Trial Clicked", state = "pre-sign-in"); + callback(window, cx) + } + }), + ) + .into_any_element() + } + + fn render_free_plan_state(&self, cx: &mut App) -> AnyElement { + let young_account_banner = YoungAccountBanner; + let plan_definitions = PlanDefinitions; + + if self.account_too_young { + v_flex() + .relative() + .max_w_full() + .gap_1() + .child(Headline::new("Welcome to Zed AI")) + .child(young_account_banner) + .child( + v_flex() + .mt_2() + .gap_1() + .child( + h_flex() + .gap_2() + .child( + Label::new("Pro") + .size(LabelSize::Small) + .color(Color::Accent) + .buffer_font(cx), + ) + .child(Divider::horizontal()), + ) + .child(plan_definitions.pro_plan(true)) + .child( + Button::new("pro", "Get Started") + .full_width() + .style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .on_click(move |_, _window, cx| { + telemetry::event!( + "Upgrade To Pro Clicked", + state = "young-account" + ); + cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)) + }), + ), + ) + .into_any_element() + } else { + v_flex() + .relative() + .gap_1() + .child(Headline::new("Welcome to Zed AI")) + .child( + v_flex() + .mt_2() + .gap_1() + .child( + h_flex() + .gap_2() + .child( + Label::new("Free") + .size(LabelSize::Small) + .color(Color::Muted) + .buffer_font(cx), + ) + .child( + Label::new("(Current Plan)") + .size(LabelSize::Small) + .color(Color::Custom( + cx.theme().colors().text_muted.opacity(0.6), + )) + .buffer_font(cx), + ) + .child(Divider::horizontal()), + ) + .child(plan_definitions.free_plan()), + ) + .when_some( + self.dismiss_onboarding.as_ref(), + |this, dismiss_callback| { + let callback = dismiss_callback.clone(); + + this.child( + h_flex().absolute().top_0().right_0().child( + IconButton::new("dismiss_onboarding", IconName::Close) + .icon_size(IconSize::Small) + .tooltip(Tooltip::text("Dismiss")) + .on_click(move |_, window, cx| { + telemetry::event!( + "Banner Dismissed", + source = "AI Onboarding", + ); + callback(window, cx) + }), + ), + ) + }, + ) + .child( + v_flex() + .mt_2() + .gap_1() + .child( + h_flex() + .gap_2() + .child( + Label::new("Pro Trial") + .size(LabelSize::Small) + .color(Color::Accent) + .buffer_font(cx), + ) + .child(Divider::horizontal()), + ) + .child(plan_definitions.pro_trial(true)) + .child( + Button::new("pro", "Start Free Trial") + .full_width() + .style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .on_click(move |_, _window, cx| { + telemetry::event!( + "Start Trial Clicked", + state = "post-sign-in" + ); + cx.open_url(&zed_urls::start_trial_url(cx)) + }), + ), + ) + .into_any_element() + } + } + + fn render_trial_state(&self, _cx: &mut App) -> AnyElement { + let plan_definitions = PlanDefinitions; + + v_flex() + .relative() + .gap_1() + .child(Headline::new("Welcome to the Zed Pro Trial")) + .child( + Label::new("Here's what you get for the next 14 days:") + .color(Color::Muted) + .mb_2(), + ) + .child(plan_definitions.pro_trial(false)) + .when_some( + self.dismiss_onboarding.as_ref(), + |this, dismiss_callback| { + let callback = dismiss_callback.clone(); + this.child( + h_flex().absolute().top_0().right_0().child( + IconButton::new("dismiss_onboarding", IconName::Close) + .icon_size(IconSize::Small) + .tooltip(Tooltip::text("Dismiss")) + .on_click(move |_, window, cx| { + telemetry::event!( + "Banner Dismissed", + source = "AI Onboarding", + ); + callback(window, cx) + }), + ), + ) + }, + ) + .into_any_element() + } + + fn render_pro_plan_state(&self, _cx: &mut App) -> AnyElement { + let plan_definitions = PlanDefinitions; + + v_flex() + .gap_1() + .child(Headline::new("Welcome to Zed Pro")) + .child( + Label::new("Here's what you get:") + .color(Color::Muted) + .mb_2(), + ) + .child(plan_definitions.pro_plan(false)) + .child( + Button::new("pro", "Continue with Zed Pro") + .full_width() + .style(ButtonStyle::Outlined) + .on_click({ + let callback = self.continue_with_zed_ai.clone(); + move |_, window, cx| { + telemetry::event!("Banner Dismissed", source = "AI Onboarding"); + callback(window, cx) + } + }), + ) + .into_any_element() + } +} + +impl RenderOnce for ZedAiOnboarding { + fn render(self, _window: &mut ui::Window, cx: &mut App) -> impl IntoElement { + if matches!(self.sign_in_status, SignInStatus::SignedIn) { + if self.has_accepted_terms_of_service { + match self.plan { + None | Some(Plan::ZedFree) => self.render_free_plan_state(cx), + Some(Plan::ZedProTrial) => self.render_trial_state(cx), + Some(Plan::ZedPro) => self.render_pro_plan_state(cx), + } + } else { + self.render_accept_terms_of_service() + } + } else { + self.render_sign_in_disclaimer(cx) + } + } +} + +impl Component for ZedAiOnboarding { + fn scope() -> ComponentScope { + ComponentScope::Onboarding + } + + fn name() -> &'static str { + "Agent Panel Banners" + } + + fn sort_name() -> &'static str { + "Agent Panel Banners" + } + + fn preview(_window: &mut Window, _cx: &mut App) -> Option { + fn onboarding( + sign_in_status: SignInStatus, + has_accepted_terms_of_service: bool, + plan: Option, + account_too_young: bool, + ) -> AnyElement { + ZedAiOnboarding { + sign_in_status, + has_accepted_terms_of_service, + plan, + account_too_young, + continue_with_zed_ai: Arc::new(|_, _| {}), + sign_in: Arc::new(|_, _| {}), + accept_terms_of_service: Arc::new(|_, _| {}), + dismiss_onboarding: None, + } + .into_any_element() + } + + Some( + v_flex() + .gap_4() + .items_center() + .max_w_4_5() + .children(vec![ + single_example( + "Not Signed-in", + onboarding(SignInStatus::SignedOut, false, None, false), + ), + single_example( + "Not Accepted ToS", + onboarding(SignInStatus::SignedIn, false, None, false), + ), + single_example( + "Young Account", + onboarding(SignInStatus::SignedIn, true, None, true), + ), + single_example( + "Free Plan", + onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedFree), false), + ), + single_example( + "Pro Trial", + onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedProTrial), false), + ), + single_example( + "Pro Plan", + onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedPro), false), + ), + ]) + .into_any_element(), + ) + } +} diff --git a/crates/ai_onboarding/src/ai_upsell_card.rs b/crates/ai_onboarding/src/ai_upsell_card.rs new file mode 100644 index 0000000000000000000000000000000000000000..e9639ca075d1190ef6ab13f1bb01dd7333010d86 --- /dev/null +++ b/crates/ai_onboarding/src/ai_upsell_card.rs @@ -0,0 +1,366 @@ +use std::{sync::Arc, time::Duration}; + +use client::{Client, UserStore, zed_urls}; +use cloud_llm_client::Plan; +use gpui::{ + Animation, AnimationExt, AnyElement, App, Entity, IntoElement, RenderOnce, Transformation, + Window, percentage, +}; +use ui::{Divider, Vector, VectorName, prelude::*}; + +use crate::{SignInStatus, YoungAccountBanner, plan_definitions::PlanDefinitions}; + +#[derive(IntoElement, RegisterComponent)] +pub struct AiUpsellCard { + pub sign_in_status: SignInStatus, + pub sign_in: Arc, + pub account_too_young: bool, + pub user_plan: Option, + pub tab_index: Option, +} + +impl AiUpsellCard { + pub fn new( + client: Arc, + user_store: &Entity, + user_plan: Option, + cx: &mut App, + ) -> Self { + let status = *client.status().borrow(); + let store = user_store.read(cx); + + Self { + user_plan, + sign_in_status: status.into(), + sign_in: Arc::new(move |_window, cx| { + cx.spawn({ + let client = client.clone(); + async move |cx| client.sign_in_with_optional_connect(true, cx).await + }) + .detach_and_log_err(cx); + }), + account_too_young: store.account_too_young(), + tab_index: None, + } + } +} + +impl RenderOnce for AiUpsellCard { + fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { + let plan_definitions = PlanDefinitions; + let young_account_banner = YoungAccountBanner; + + let pro_section = v_flex() + .flex_grow() + .w_full() + .gap_1() + .child( + h_flex() + .gap_2() + .child( + Label::new("Pro") + .size(LabelSize::Small) + .color(Color::Accent) + .buffer_font(cx), + ) + .child(Divider::horizontal()), + ) + .child(plan_definitions.pro_plan(false)); + + let free_section = v_flex() + .flex_grow() + .w_full() + .gap_1() + .child( + h_flex() + .gap_2() + .child( + Label::new("Free") + .size(LabelSize::Small) + .color(Color::Muted) + .buffer_font(cx), + ) + .child(Divider::horizontal()), + ) + .child(plan_definitions.free_plan()); + + let grid_bg = h_flex().absolute().inset_0().w_full().h(px(240.)).child( + Vector::new(VectorName::Grid, rems_from_px(500.), rems_from_px(240.)) + .color(Color::Custom(cx.theme().colors().border.opacity(0.05))), + ); + + let gradient_bg = div() + .absolute() + .inset_0() + .size_full() + .bg(gpui::linear_gradient( + 180., + gpui::linear_color_stop( + cx.theme().colors().elevated_surface_background.opacity(0.8), + 0., + ), + gpui::linear_color_stop( + cx.theme().colors().elevated_surface_background.opacity(0.), + 0.8, + ), + )); + + let description = PlanDefinitions::AI_DESCRIPTION; + + let card = v_flex() + .relative() + .flex_grow() + .p_4() + .pt_3() + .border_1() + .border_color(cx.theme().colors().border) + .rounded_lg() + .overflow_hidden() + .child(grid_bg) + .child(gradient_bg); + + let plans_section = h_flex() + .w_full() + .mt_1p5() + .mb_2p5() + .items_start() + .gap_6() + .child(free_section) + .child(pro_section); + + let footer_container = v_flex().items_center().gap_1(); + + let certified_user_stamp = div() + .absolute() + .top_2() + .right_2() + .size(rems_from_px(72.)) + .child( + Vector::new( + VectorName::ProUserStamp, + rems_from_px(72.), + rems_from_px(72.), + ) + .color(Color::Custom(cx.theme().colors().text_accent.alpha(0.3))) + .with_animation( + "loading_stamp", + Animation::new(Duration::from_secs(10)).repeat(), + |this, delta| this.transform(Transformation::rotate(percentage(delta))), + ), + ); + + let pro_trial_stamp = div() + .absolute() + .top_2() + .right_2() + .size(rems_from_px(72.)) + .child( + Vector::new( + VectorName::ProTrialStamp, + rems_from_px(72.), + rems_from_px(72.), + ) + .color(Color::Custom(cx.theme().colors().text.alpha(0.2))), + ); + + match self.sign_in_status { + SignInStatus::SignedIn => match self.user_plan { + None | Some(Plan::ZedFree) => card + .child(Label::new("Try Zed AI").size(LabelSize::Large)) + .map(|this| { + if self.account_too_young { + this.child(young_account_banner).child( + v_flex() + .mt_2() + .gap_1() + .child( + h_flex() + .gap_2() + .child( + Label::new("Pro") + .size(LabelSize::Small) + .color(Color::Accent) + .buffer_font(cx), + ) + .child(Divider::horizontal()), + ) + .child(plan_definitions.pro_plan(true)) + .child( + Button::new("pro", "Get Started") + .full_width() + .style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .on_click(move |_, _window, cx| { + telemetry::event!( + "Upgrade To Pro Clicked", + state = "young-account" + ); + cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)) + }), + ), + ) + } else { + this.child( + div() + .max_w_3_4() + .mb_2() + .child(Label::new(description).color(Color::Muted)), + ) + .child(plans_section) + .child( + footer_container + .child( + Button::new("start_trial", "Start 14-day Free Pro Trial") + .full_width() + .style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .when_some(self.tab_index, |this, tab_index| { + this.tab_index(tab_index) + }) + .on_click(move |_, _window, cx| { + telemetry::event!( + "Start Trial Clicked", + state = "post-sign-in" + ); + cx.open_url(&zed_urls::start_trial_url(cx)) + }), + ) + .child( + Label::new("No credit card required") + .size(LabelSize::Small) + .color(Color::Muted), + ), + ) + } + }), + Some(Plan::ZedProTrial) => card + .child(pro_trial_stamp) + .child(Label::new("You're in the Zed Pro Trial").size(LabelSize::Large)) + .child( + Label::new("Here's what you get for the next 14 days:") + .color(Color::Muted) + .mb_2(), + ) + .child(plan_definitions.pro_trial(false)), + Some(Plan::ZedPro) => card + .child(certified_user_stamp) + .child(Label::new("You're in the Zed Pro plan").size(LabelSize::Large)) + .child( + Label::new("Here's what you get:") + .color(Color::Muted) + .mb_2(), + ) + .child(plan_definitions.pro_plan(false)), + }, + // Signed Out State + _ => card + .child(Label::new("Try Zed AI").size(LabelSize::Large)) + .child( + div() + .max_w_3_4() + .mb_2() + .child(Label::new(description).color(Color::Muted)), + ) + .child(plans_section) + .child( + Button::new("sign_in", "Sign In") + .full_width() + .style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .when_some(self.tab_index, |this, tab_index| this.tab_index(tab_index)) + .on_click({ + let callback = self.sign_in.clone(); + move |_, window, cx| { + telemetry::event!("Start Trial Clicked", state = "pre-sign-in"); + callback(window, cx) + } + }), + ), + } + } +} + +impl Component for AiUpsellCard { + fn scope() -> ComponentScope { + ComponentScope::Onboarding + } + + fn name() -> &'static str { + "AI Upsell Card" + } + + fn sort_name() -> &'static str { + "AI Upsell Card" + } + + fn description() -> Option<&'static str> { + Some("A card presenting the Zed AI product during user's first-open onboarding flow.") + } + + fn preview(_window: &mut Window, _cx: &mut App) -> Option { + Some( + v_flex() + .gap_4() + .items_center() + .max_w_4_5() + .child(single_example( + "Signed Out State", + AiUpsellCard { + sign_in_status: SignInStatus::SignedOut, + sign_in: Arc::new(|_, _| {}), + account_too_young: false, + user_plan: None, + tab_index: Some(0), + } + .into_any_element(), + )) + .child(example_group_with_title( + "Signed In States", + vec![ + single_example( + "Free Plan", + AiUpsellCard { + sign_in_status: SignInStatus::SignedIn, + sign_in: Arc::new(|_, _| {}), + account_too_young: false, + user_plan: Some(Plan::ZedFree), + tab_index: Some(1), + } + .into_any_element(), + ), + single_example( + "Free Plan but Young Account", + AiUpsellCard { + sign_in_status: SignInStatus::SignedIn, + sign_in: Arc::new(|_, _| {}), + account_too_young: true, + user_plan: Some(Plan::ZedFree), + tab_index: Some(1), + } + .into_any_element(), + ), + single_example( + "Pro Trial", + AiUpsellCard { + sign_in_status: SignInStatus::SignedIn, + sign_in: Arc::new(|_, _| {}), + account_too_young: false, + user_plan: Some(Plan::ZedProTrial), + tab_index: Some(1), + } + .into_any_element(), + ), + single_example( + "Pro Plan", + AiUpsellCard { + sign_in_status: SignInStatus::SignedIn, + sign_in: Arc::new(|_, _| {}), + account_too_young: false, + user_plan: Some(Plan::ZedPro), + tab_index: Some(1), + } + .into_any_element(), + ), + ], + )) + .into_any_element(), + ) + } +} diff --git a/crates/ai_onboarding/src/edit_prediction_onboarding_content.rs b/crates/ai_onboarding/src/edit_prediction_onboarding_content.rs new file mode 100644 index 0000000000000000000000000000000000000000..e883d8da8ce01bfea3f08676666c308a90f6d650 --- /dev/null +++ b/crates/ai_onboarding/src/edit_prediction_onboarding_content.rs @@ -0,0 +1,73 @@ +use std::sync::Arc; + +use client::{Client, UserStore}; +use gpui::{Entity, IntoElement, ParentElement}; +use ui::prelude::*; + +use crate::ZedAiOnboarding; + +pub struct EditPredictionOnboarding { + user_store: Entity, + client: Arc, + copilot_is_configured: bool, + continue_with_zed_ai: Arc, + continue_with_copilot: Arc, +} + +impl EditPredictionOnboarding { + pub fn new( + user_store: Entity, + client: Arc, + copilot_is_configured: bool, + continue_with_zed_ai: Arc, + continue_with_copilot: Arc, + _cx: &mut Context, + ) -> Self { + Self { + user_store, + copilot_is_configured, + client, + continue_with_zed_ai, + continue_with_copilot, + } + } +} + +impl Render for EditPredictionOnboarding { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + let github_copilot = v_flex() + .gap_1() + .child(Label::new(if self.copilot_is_configured { + "Alternatively, you can continue to use GitHub Copilot as that's already set up." + } else { + "Alternatively, you can use GitHub Copilot as your edit prediction provider." + })) + .child( + Button::new( + "configure-copilot", + if self.copilot_is_configured { + "Use Copilot" + } else { + "Configure Copilot" + }, + ) + .full_width() + .style(ButtonStyle::Outlined) + .on_click({ + let callback = self.continue_with_copilot.clone(); + move |_, window, cx| callback(window, cx) + }), + ); + + v_flex() + .gap_2() + .child(ZedAiOnboarding::new( + self.client.clone(), + &self.user_store, + self.continue_with_zed_ai.clone(), + cx, + )) + .child(ui::Divider::horizontal()) + .child(github_copilot) + } +} diff --git a/crates/ai_onboarding/src/plan_definitions.rs b/crates/ai_onboarding/src/plan_definitions.rs new file mode 100644 index 0000000000000000000000000000000000000000..8d66f6c3563c482b2356e081b5786219f5bf1de3 --- /dev/null +++ b/crates/ai_onboarding/src/plan_definitions.rs @@ -0,0 +1,39 @@ +use gpui::{IntoElement, ParentElement}; +use ui::{List, ListBulletItem, prelude::*}; + +/// Centralized definitions for Zed AI plans +pub struct PlanDefinitions; + +impl PlanDefinitions { + pub const AI_DESCRIPTION: &'static str = "Zed offers a complete agentic experience, with robust editing and reviewing features to collaborate with AI."; + + pub fn free_plan(&self) -> impl IntoElement { + List::new() + .child(ListBulletItem::new("50 prompts with Claude models")) + .child(ListBulletItem::new("2,000 accepted edit predictions")) + } + + pub fn pro_trial(&self, period: bool) -> impl IntoElement { + List::new() + .child(ListBulletItem::new("150 prompts with Claude models")) + .child(ListBulletItem::new( + "Unlimited edit predictions with Zeta, our open-source model", + )) + .when(period, |this| { + this.child(ListBulletItem::new( + "Try it out for 14 days for free, no credit card required", + )) + }) + } + + pub fn pro_plan(&self, price: bool) -> impl IntoElement { + List::new() + .child(ListBulletItem::new("500 prompts with Claude models")) + .child(ListBulletItem::new( + "Unlimited edit predictions with Zeta, our open-source model", + )) + .when(price, |this| { + this.child(ListBulletItem::new("$20 USD per month")) + }) + } +} diff --git a/crates/ai_onboarding/src/young_account_banner.rs b/crates/ai_onboarding/src/young_account_banner.rs new file mode 100644 index 0000000000000000000000000000000000000000..54f563e4aac8ca71fff16199cd6c2e8f81ad5376 --- /dev/null +++ b/crates/ai_onboarding/src/young_account_banner.rs @@ -0,0 +1,22 @@ +use gpui::{IntoElement, ParentElement}; +use ui::{Banner, prelude::*}; + +#[derive(IntoElement)] +pub struct YoungAccountBanner; + +impl RenderOnce for YoungAccountBanner { + fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { + const YOUNG_ACCOUNT_DISCLAIMER: &str = "To prevent abuse of our service, we cannot offer plans to GitHub accounts created fewer than 30 days ago. To request an exception, reach out to billing-support@zed.dev."; + + let label = div() + .w_full() + .text_sm() + .text_color(cx.theme().colors().text_muted) + .child(YOUNG_ACCOUNT_DISCLAIMER); + + div() + .max_w_full() + .my_1() + .child(Banner::new().severity(ui::Severity::Warning).child(label)) + } +} diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index c73f6060458783f22b4d846dbe6d4a619d7e791c..3ff1666755d439cf52a14ea635a06a7c3414d9f6 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -36,11 +36,18 @@ pub enum AnthropicModelMode { pub enum Model { #[serde(rename = "claude-opus-4", alias = "claude-opus-4-latest")] ClaudeOpus4, + #[serde(rename = "claude-opus-4-1", alias = "claude-opus-4-1-latest")] + ClaudeOpus4_1, #[serde( rename = "claude-opus-4-thinking", alias = "claude-opus-4-thinking-latest" )] ClaudeOpus4Thinking, + #[serde( + rename = "claude-opus-4-1-thinking", + alias = "claude-opus-4-1-thinking-latest" + )] + ClaudeOpus4_1Thinking, #[default] #[serde(rename = "claude-sonnet-4", alias = "claude-sonnet-4-latest")] ClaudeSonnet4, @@ -91,10 +98,18 @@ impl Model { } pub fn from_id(id: &str) -> Result { + if id.starts_with("claude-opus-4-1-thinking") { + return Ok(Self::ClaudeOpus4_1Thinking); + } + if id.starts_with("claude-opus-4-thinking") { return Ok(Self::ClaudeOpus4Thinking); } + if id.starts_with("claude-opus-4-1") { + return Ok(Self::ClaudeOpus4_1); + } + if id.starts_with("claude-opus-4") { return Ok(Self::ClaudeOpus4); } @@ -141,7 +156,9 @@ impl Model { pub fn id(&self) -> &str { match self { Self::ClaudeOpus4 => "claude-opus-4-latest", + Self::ClaudeOpus4_1 => "claude-opus-4-1-latest", Self::ClaudeOpus4Thinking => "claude-opus-4-thinking-latest", + Self::ClaudeOpus4_1Thinking => "claude-opus-4-1-thinking-latest", Self::ClaudeSonnet4 => "claude-sonnet-4-latest", Self::ClaudeSonnet4Thinking => "claude-sonnet-4-thinking-latest", Self::Claude3_5Sonnet => "claude-3-5-sonnet-latest", @@ -159,6 +176,7 @@ impl Model { pub fn request_id(&self) -> &str { match self { Self::ClaudeOpus4 | Self::ClaudeOpus4Thinking => "claude-opus-4-20250514", + Self::ClaudeOpus4_1 | Self::ClaudeOpus4_1Thinking => "claude-opus-4-1-20250805", Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking => "claude-sonnet-4-20250514", Self::Claude3_5Sonnet => "claude-3-5-sonnet-latest", Self::Claude3_7Sonnet | Self::Claude3_7SonnetThinking => "claude-3-7-sonnet-latest", @@ -173,7 +191,9 @@ impl Model { pub fn display_name(&self) -> &str { match self { Self::ClaudeOpus4 => "Claude Opus 4", + Self::ClaudeOpus4_1 => "Claude Opus 4.1", Self::ClaudeOpus4Thinking => "Claude Opus 4 Thinking", + Self::ClaudeOpus4_1Thinking => "Claude Opus 4.1 Thinking", Self::ClaudeSonnet4 => "Claude Sonnet 4", Self::ClaudeSonnet4Thinking => "Claude Sonnet 4 Thinking", Self::Claude3_7Sonnet => "Claude 3.7 Sonnet", @@ -192,7 +212,9 @@ impl Model { pub fn cache_configuration(&self) -> Option { match self { Self::ClaudeOpus4 + | Self::ClaudeOpus4_1 | Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1Thinking | Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking | Self::Claude3_5Sonnet @@ -215,7 +237,9 @@ impl Model { pub fn max_token_count(&self) -> u64 { match self { Self::ClaudeOpus4 + | Self::ClaudeOpus4_1 | Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1Thinking | Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking | Self::Claude3_5Sonnet @@ -232,7 +256,9 @@ impl Model { pub fn max_output_tokens(&self) -> u64 { match self { Self::ClaudeOpus4 + | Self::ClaudeOpus4_1 | Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1Thinking | Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking | Self::Claude3_5Sonnet @@ -249,7 +275,9 @@ impl Model { pub fn default_temperature(&self) -> f32 { match self { Self::ClaudeOpus4 + | Self::ClaudeOpus4_1 | Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1Thinking | Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking | Self::Claude3_5Sonnet @@ -269,6 +297,7 @@ impl Model { pub fn mode(&self) -> AnthropicModelMode { match self { Self::ClaudeOpus4 + | Self::ClaudeOpus4_1 | Self::ClaudeSonnet4 | Self::Claude3_5Sonnet | Self::Claude3_7Sonnet @@ -277,6 +306,7 @@ impl Model { | Self::Claude3Sonnet | Self::Claude3Haiku => AnthropicModelMode::Default, Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1Thinking | Self::ClaudeSonnet4Thinking | Self::Claude3_7SonnetThinking => AnthropicModelMode::Thinking { budget_tokens: Some(4_096), diff --git a/crates/assistant_context/Cargo.toml b/crates/assistant_context/Cargo.toml index f35dc43340b98dd8445da9335e27745ec8e35cb8..8f5ff98790f2319c398b7acf04214cd2b3f577f4 100644 --- a/crates/assistant_context/Cargo.toml +++ b/crates/assistant_context/Cargo.toml @@ -19,6 +19,7 @@ assistant_slash_commands.workspace = true chrono.workspace = true client.workspace = true clock.workspace = true +cloud_llm_client.workspace = true collections.workspace = true context_server.workspace = true fs.workspace = true @@ -48,7 +49,6 @@ util.workspace = true uuid.workspace = true workspace-hack.workspace = true workspace.workspace = true -zed_llm_client.workspace = true [dev-dependencies] indoc.workspace = true diff --git a/crates/assistant_context/src/assistant_context.rs b/crates/assistant_context/src/assistant_context.rs index 136468e084593ef6b6475d29d8526d683b1bdc7b..557f9592e4d12e86c4e73d1bc742dfa74535d66c 100644 --- a/crates/assistant_context/src/assistant_context.rs +++ b/crates/assistant_context/src/assistant_context.rs @@ -2,15 +2,16 @@ mod assistant_context_tests; mod context_store; -use agent_settings::AgentSettings; +use agent_settings::{AgentSettings, SUMMARIZE_THREAD_PROMPT}; use anyhow::{Context as _, Result, bail}; use assistant_slash_command::{ SlashCommandContent, SlashCommandEvent, SlashCommandLine, SlashCommandOutputSection, SlashCommandResult, SlashCommandWorkingSet, }; use assistant_slash_commands::FileCommandMetadata; -use client::{self, Client, proto, telemetry::Telemetry}; +use client::{self, Client, ModelRequestUsage, RequestUsage, proto, telemetry::Telemetry}; use clock::ReplicaId; +use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit}; use collections::{HashMap, HashSet}; use fs::{Fs, RenameOptions}; use futures::{FutureExt, StreamExt, future::Shared}; @@ -46,7 +47,6 @@ use text::{BufferSnapshot, ToPoint}; use ui::IconName; use util::{ResultExt, TryFutureExt, post_inc}; use uuid::Uuid; -use zed_llm_client::CompletionIntent; pub use crate::context_store::*; @@ -2080,7 +2080,18 @@ impl AssistantContext { }); match event { - LanguageModelCompletionEvent::StatusUpdate { .. } => {} + LanguageModelCompletionEvent::StatusUpdate(status_update) => { + match status_update { + CompletionRequestStatus::UsageUpdated { amount, limit } => { + this.update_model_request_usage( + amount as u32, + limit, + cx, + ); + } + _ => {} + } + } LanguageModelCompletionEvent::StartMessage { .. } => {} LanguageModelCompletionEvent::Stop(reason) => { stop_reason = reason; @@ -2677,10 +2688,7 @@ impl AssistantContext { let mut request = self.to_completion_request(Some(&model.model), cx); request.messages.push(LanguageModelRequestMessage { role: Role::User, - content: vec![ - "Generate a concise 3-7 word title for this conversation, omitting punctuation. Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`" - .into(), - ], + content: vec![SUMMARIZE_THREAD_PROMPT.into()], cache: false, }); @@ -2956,6 +2964,21 @@ impl AssistantContext { summary.text = custom_summary; cx.emit(ContextEvent::SummaryChanged); } + + fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut App) { + let Some(project) = &self.project else { + return; + }; + project.read(cx).user_store().update(cx, |user_store, cx| { + user_store.update_model_request_usage( + ModelRequestUsage(RequestUsage { + amount: amount as i32, + limit, + }), + cx, + ) + }); + } } #[derive(Debug, Default)] diff --git a/crates/assistant_context/src/assistant_context_tests.rs b/crates/assistant_context/src/assistant_context_tests.rs index dba3bfde61bb997a25d25f29651ad0a7aa2c2708..efcad8ed9654449c747ee4853c7e7aa689c0568b 100644 --- a/crates/assistant_context/src/assistant_context_tests.rs +++ b/crates/assistant_context/src/assistant_context_tests.rs @@ -1210,8 +1210,8 @@ async fn test_summarization(cx: &mut TestAppContext) { }); cx.run_until_parked(); - fake_model.stream_last_completion_response("Brief"); - fake_model.stream_last_completion_response(" Introduction"); + fake_model.send_last_completion_stream_text_chunk("Brief"); + fake_model.send_last_completion_stream_text_chunk(" Introduction"); fake_model.end_last_completion_stream(); cx.run_until_parked(); @@ -1274,7 +1274,7 @@ async fn test_thread_summary_error_retry(cx: &mut TestAppContext) { }); cx.run_until_parked(); - fake_model.stream_last_completion_response("A successful summary"); + fake_model.send_last_completion_stream_text_chunk("A successful summary"); fake_model.end_last_completion_stream(); cx.run_until_parked(); @@ -1323,7 +1323,7 @@ fn setup_context_editor_with_fake_model( ) -> (Entity, Arc) { let registry = Arc::new(LanguageRegistry::test(cx.executor().clone())); - let fake_provider = Arc::new(FakeLanguageModelProvider); + let fake_provider = Arc::new(FakeLanguageModelProvider::default()); let fake_model = Arc::new(fake_provider.test_model()); cx.update(|cx| { @@ -1356,7 +1356,7 @@ fn setup_context_editor_with_fake_model( fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) { cx.run_until_parked(); - fake_model.stream_last_completion_response("Assistant response"); + fake_model.send_last_completion_stream_text_chunk("Assistant response"); fake_model.end_last_completion_stream(); cx.run_until_parked(); } diff --git a/crates/assistant_context/src/context_store.rs b/crates/assistant_context/src/context_store.rs index 3400913eb86ed0717ca29681511fd0d2cb506603..3090a7b23439de9ae8fe1bd5287f439895f17a98 100644 --- a/crates/assistant_context/src/context_store.rs +++ b/crates/assistant_context/src/context_store.rs @@ -767,6 +767,11 @@ impl ContextStore { fn reload(&mut self, cx: &mut Context) -> Task> { let fs = self.fs.clone(); cx.spawn(async move |this, cx| { + pub static ZED_STATELESS: LazyLock = + LazyLock::new(|| std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty())); + if *ZED_STATELESS { + return Ok(()); + } fs.create_dir(contexts_dir()).await?; let mut paths = fs.read_dir(contexts_dir()).await?; diff --git a/crates/assistant_tool/Cargo.toml b/crates/assistant_tool/Cargo.toml index 5a54e86eac15c2846e7e72ee45b47ab014cd69e6..acbe674b02cfe31a08f63e01f7dae1a2448c453e 100644 --- a/crates/assistant_tool/Cargo.toml +++ b/crates/assistant_tool/Cargo.toml @@ -40,6 +40,7 @@ collections = { workspace = true, features = ["test-support"] } clock = { workspace = true, features = ["test-support"] } ctor.workspace = true gpui = { workspace = true, features = ["test-support"] } +indoc.workspace = true language = { workspace = true, features = ["test-support"] } language_model = { workspace = true, features = ["test-support"] } log.workspace = true diff --git a/crates/assistant_tool/src/action_log.rs b/crates/assistant_tool/src/action_log.rs index 2071a1f444b547197fabc252fddb1f9bd165ae67..025aba060d9380390b06478d9ddc0ad9c4f52e5a 100644 --- a/crates/assistant_tool/src/action_log.rs +++ b/crates/assistant_tool/src/action_log.rs @@ -8,7 +8,10 @@ use language::{Anchor, Buffer, BufferEvent, DiskState, Point, ToPoint}; use project::{Project, ProjectItem, lsp_store::OpenLspBufferHandle}; use std::{cmp, ops::Range, sync::Arc}; use text::{Edit, Patch, Rope}; -use util::RangeExt; +use util::{ + RangeExt, ResultExt as _, + paths::{PathStyle, RemotePathBuf}, +}; /// Tracks actions performed by tools in a thread pub struct ActionLog { @@ -18,8 +21,6 @@ pub struct ActionLog { edited_since_project_diagnostics_check: bool, /// The project this action log is associated with project: Entity, - /// Tracks which buffer versions have already been notified as changed externally - notified_versions: BTreeMap, clock::Global>, } impl ActionLog { @@ -29,7 +30,6 @@ impl ActionLog { tracked_buffers: BTreeMap::default(), edited_since_project_diagnostics_check: false, project, - notified_versions: BTreeMap::default(), } } @@ -47,6 +47,65 @@ impl ActionLog { self.edited_since_project_diagnostics_check } + pub fn latest_snapshot(&self, buffer: &Entity) -> Option { + Some(self.tracked_buffers.get(buffer)?.snapshot.clone()) + } + + /// Return a unified diff patch with user edits made since last read or notification + pub fn unnotified_user_edits(&self, cx: &Context) -> Option { + let diffs = self + .tracked_buffers + .values() + .filter_map(|tracked| { + if !tracked.may_have_unnotified_user_edits { + return None; + } + + let text_with_latest_user_edits = tracked.diff_base.to_string(); + let text_with_last_seen_user_edits = tracked.last_seen_base.to_string(); + if text_with_latest_user_edits == text_with_last_seen_user_edits { + return None; + } + let patch = language::unified_diff( + &text_with_last_seen_user_edits, + &text_with_latest_user_edits, + ); + + let buffer = tracked.buffer.clone(); + let file_path = buffer + .read(cx) + .file() + .map(|file| RemotePathBuf::new(file.full_path(cx), PathStyle::Posix).to_proto()) + .unwrap_or_else(|| format!("buffer_{}", buffer.entity_id())); + + let mut result = String::new(); + result.push_str(&format!("--- a/{}\n", file_path)); + result.push_str(&format!("+++ b/{}\n", file_path)); + result.push_str(&patch); + + Some(result) + }) + .collect::>(); + + if diffs.is_empty() { + return None; + } + + let unified_diff = diffs.join("\n\n"); + Some(unified_diff) + } + + /// Return a unified diff patch with user edits made since last read/notification + /// and mark them as notified + pub fn flush_unnotified_user_edits(&mut self, cx: &Context) -> Option { + let patch = self.unnotified_user_edits(cx); + self.tracked_buffers.values_mut().for_each(|tracked| { + tracked.may_have_unnotified_user_edits = false; + tracked.last_seen_base = tracked.diff_base.clone(); + }); + patch + } + fn track_buffer_internal( &mut self, buffer: Entity, @@ -55,7 +114,6 @@ impl ActionLog { ) -> &mut TrackedBuffer { let status = if is_created { if let Some(tracked) = self.tracked_buffers.remove(&buffer) { - self.notified_versions.remove(&buffer); match tracked.status { TrackedBufferStatus::Created { existing_file_content, @@ -97,26 +155,31 @@ impl ActionLog { let diff = cx.new(|cx| BufferDiff::new(&text_snapshot, cx)); let (diff_update_tx, diff_update_rx) = mpsc::unbounded(); let diff_base; + let last_seen_base; let unreviewed_edits; if is_created { diff_base = Rope::default(); + last_seen_base = Rope::default(); unreviewed_edits = Patch::new(vec![Edit { old: 0..1, new: 0..text_snapshot.max_point().row + 1, }]) } else { diff_base = buffer.read(cx).as_rope().clone(); + last_seen_base = diff_base.clone(); unreviewed_edits = Patch::default(); } TrackedBuffer { buffer: buffer.clone(), diff_base, + last_seen_base, unreviewed_edits, snapshot: text_snapshot.clone(), status, version: buffer.read(cx).version(), diff, diff_update: diff_update_tx, + may_have_unnotified_user_edits: false, _open_lsp_handle: open_lsp_handle, _maintain_diff: cx.spawn({ let buffer = buffer.clone(); @@ -170,7 +233,6 @@ impl ActionLog { // If the buffer had been edited by a tool, but it got // deleted externally, we want to stop tracking it. self.tracked_buffers.remove(&buffer); - self.notified_versions.remove(&buffer); } cx.notify(); } @@ -184,7 +246,6 @@ impl ActionLog { // resurrected externally, we want to clear the edits we // were tracking and reset the buffer's state. self.tracked_buffers.remove(&buffer); - self.notified_versions.remove(&buffer); self.track_buffer_internal(buffer, false, cx); } cx.notify(); @@ -258,10 +319,10 @@ impl ActionLog { buffer_snapshot: text::BufferSnapshot, cx: &mut AsyncApp, ) -> Result<()> { - let rebase = this.read_with(cx, |this, cx| { + let rebase = this.update(cx, |this, cx| { let tracked_buffer = this .tracked_buffers - .get(buffer) + .get_mut(buffer) .context("buffer not tracked")?; let rebase = cx.background_spawn({ @@ -269,23 +330,35 @@ impl ActionLog { let old_snapshot = tracked_buffer.snapshot.clone(); let new_snapshot = buffer_snapshot.clone(); let unreviewed_edits = tracked_buffer.unreviewed_edits.clone(); + let edits = diff_snapshots(&old_snapshot, &new_snapshot); + let mut has_user_changes = false; async move { - let edits = diff_snapshots(&old_snapshot, &new_snapshot); if let ChangeAuthor::User = author { - apply_non_conflicting_edits( + has_user_changes = apply_non_conflicting_edits( &unreviewed_edits, edits, &mut base_text, new_snapshot.as_rope(), ); } - (Arc::new(base_text.to_string()), base_text) + + (Arc::new(base_text.to_string()), base_text, has_user_changes) } }); anyhow::Ok(rebase) })??; - let (new_base_text, new_diff_base) = rebase.await; + let (new_base_text, new_diff_base, has_user_changes) = rebase.await; + + this.update(cx, |this, _| { + let tracked_buffer = this + .tracked_buffers + .get_mut(buffer) + .context("buffer not tracked") + .unwrap(); + tracked_buffer.may_have_unnotified_user_edits |= has_user_changes; + })?; + Self::update_diff( this, buffer, @@ -490,7 +563,6 @@ impl ActionLog { match tracked_buffer.status { TrackedBufferStatus::Created { .. } => { self.tracked_buffers.remove(&buffer); - self.notified_versions.remove(&buffer); cx.notify(); } TrackedBufferStatus::Modified => { @@ -516,7 +588,6 @@ impl ActionLog { match tracked_buffer.status { TrackedBufferStatus::Deleted => { self.tracked_buffers.remove(&buffer); - self.notified_versions.remove(&buffer); cx.notify(); } _ => { @@ -559,6 +630,11 @@ impl ActionLog { false } }); + if tracked_buffer.unreviewed_edits.is_empty() { + if let TrackedBufferStatus::Created { .. } = &mut tracked_buffer.status { + tracked_buffer.status = TrackedBufferStatus::Modified; + } + } tracked_buffer.schedule_diff_update(ChangeAuthor::User, cx); } } @@ -625,7 +701,6 @@ impl ActionLog { }; self.tracked_buffers.remove(&buffer); - self.notified_versions.remove(&buffer); cx.notify(); task } @@ -639,7 +714,6 @@ impl ActionLog { // Clear all tracked edits for this buffer and start over as if we just read it. self.tracked_buffers.remove(&buffer); - self.notified_versions.remove(&buffer); self.buffer_read(buffer.clone(), cx); cx.notify(); save @@ -706,6 +780,9 @@ impl ActionLog { .retain(|_buffer, tracked_buffer| match tracked_buffer.status { TrackedBufferStatus::Deleted => false, _ => { + if let TrackedBufferStatus::Created { .. } = &mut tracked_buffer.status { + tracked_buffer.status = TrackedBufferStatus::Modified; + } tracked_buffer.unreviewed_edits.clear(); tracked_buffer.diff_base = tracked_buffer.snapshot.as_rope().clone(); tracked_buffer.schedule_diff_update(ChangeAuthor::User, cx); @@ -715,6 +792,22 @@ impl ActionLog { cx.notify(); } + pub fn reject_all_edits(&mut self, cx: &mut Context) -> Task<()> { + let futures = self.changed_buffers(cx).into_keys().map(|buffer| { + let reject = self.reject_edits_in_ranges(buffer, vec![Anchor::MIN..Anchor::MAX], cx); + + async move { + reject.await.log_err(); + } + }); + + let task = futures::future::join_all(futures); + + cx.spawn(async move |_, _| { + task.await; + }) + } + /// Returns the set of buffers that contain edits that haven't been reviewed by the user. pub fn changed_buffers(&self, cx: &App) -> BTreeMap, Entity> { self.tracked_buffers @@ -724,33 +817,6 @@ impl ActionLog { .collect() } - /// Returns stale buffers that haven't been notified yet - pub fn unnotified_stale_buffers<'a>( - &'a self, - cx: &'a App, - ) -> impl Iterator> { - self.stale_buffers(cx).filter(|buffer| { - let buffer_entity = buffer.read(cx); - self.notified_versions - .get(buffer) - .map_or(true, |notified_version| { - *notified_version != buffer_entity.version - }) - }) - } - - /// Marks the given buffers as notified at their current versions - pub fn mark_buffers_as_notified( - &mut self, - buffers: impl IntoIterator>, - cx: &App, - ) { - for buffer in buffers { - let version = buffer.read(cx).version.clone(); - self.notified_versions.insert(buffer, version); - } - } - /// Iterate over buffers changed since last read or edited by the model pub fn stale_buffers<'a>(&'a self, cx: &'a App) -> impl Iterator> { self.tracked_buffers @@ -772,11 +838,12 @@ fn apply_non_conflicting_edits( edits: Vec>, old_text: &mut Rope, new_text: &Rope, -) { +) -> bool { let mut old_edits = patch.edits().iter().cloned().peekable(); let mut new_edits = edits.into_iter().peekable(); let mut applied_delta = 0i32; let mut rebased_delta = 0i32; + let mut has_made_changes = false; while let Some(mut new_edit) = new_edits.next() { let mut conflict = false; @@ -826,8 +893,10 @@ fn apply_non_conflicting_edits( &new_text.chunks_in_range(new_bytes).collect::(), ); applied_delta += new_edit.new_len() as i32 - new_edit.old_len() as i32; + has_made_changes = true; } } + has_made_changes } fn diff_snapshots( @@ -894,12 +963,14 @@ enum TrackedBufferStatus { struct TrackedBuffer { buffer: Entity, diff_base: Rope, + last_seen_base: Rope, unreviewed_edits: Patch, status: TrackedBufferStatus, version: clock::Global, diff: Entity, snapshot: text::BufferSnapshot, diff_update: mpsc::UnboundedSender<(ChangeAuthor, text::BufferSnapshot)>, + may_have_unnotified_user_edits: bool, _open_lsp_handle: OpenLspBufferHandle, _maintain_diff: Task<()>, _subscription: Subscription, @@ -930,6 +1001,7 @@ mod tests { use super::*; use buffer_diff::DiffHunkStatusKind; use gpui::TestAppContext; + use indoc::indoc; use language::Point; use project::{FakeFs, Fs, Project, RemoveOptions}; use rand::prelude::*; @@ -1212,6 +1284,110 @@ mod tests { assert_eq!(unreviewed_hunks(&action_log, cx), vec![]); } + #[gpui::test(iterations = 10)] + async fn test_user_edits_notifications(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/dir"), + json!({"file": indoc! {" + abc + def + ghi + jkl + mno"}}), + ) + .await; + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let file_path = project + .read_with(cx, |project, cx| project.find_project_path("dir/file", cx)) + .unwrap(); + let buffer = project + .update(cx, |project, cx| project.open_buffer(file_path, cx)) + .await + .unwrap(); + + // Agent edits + cx.update(|cx| { + action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx)); + buffer.update(cx, |buffer, cx| { + buffer + .edit([(Point::new(1, 2)..Point::new(2, 3), "F\nGHI")], None, cx) + .unwrap() + }); + action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); + }); + cx.run_until_parked(); + assert_eq!( + buffer.read_with(cx, |buffer, _| buffer.text()), + indoc! {" + abc + deF + GHI + jkl + mno"} + ); + assert_eq!( + unreviewed_hunks(&action_log, cx), + vec![( + buffer.clone(), + vec![HunkStatus { + range: Point::new(1, 0)..Point::new(3, 0), + diff_status: DiffHunkStatusKind::Modified, + old_text: "def\nghi\n".into(), + }], + )] + ); + + // User edits + buffer.update(cx, |buffer, cx| { + buffer.edit( + [ + (Point::new(0, 2)..Point::new(0, 2), "X"), + (Point::new(3, 0)..Point::new(3, 0), "Y"), + ], + None, + cx, + ) + }); + cx.run_until_parked(); + assert_eq!( + buffer.read_with(cx, |buffer, _| buffer.text()), + indoc! {" + abXc + deF + GHI + Yjkl + mno"} + ); + + // User edits should be stored separately from agent's + let user_edits = action_log.update(cx, |log, cx| log.unnotified_user_edits(cx)); + assert_eq!( + user_edits.expect("should have some user edits"), + indoc! {" + --- a/dir/file + +++ b/dir/file + @@ -1,5 +1,5 @@ + -abc + +abXc + def + ghi + -jkl + +Yjkl + mno + "} + ); + + action_log.update(cx, |log, cx| { + log.keep_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(1, 0), cx) + }); + cx.run_until_parked(); + assert_eq!(unreviewed_hunks(&action_log, cx), vec![]); + } + #[gpui::test(iterations = 10)] async fn test_creating_files(cx: &mut TestAppContext) { init_test(cx); @@ -1907,6 +2083,134 @@ mod tests { assert_eq!(content, "ai content\nuser added this line"); } + #[gpui::test] + async fn test_reject_after_accepting_hunk_on_created_file(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + + let file_path = project + .read_with(cx, |project, cx| { + project.find_project_path("dir/new_file", cx) + }) + .unwrap(); + let buffer = project + .update(cx, |project, cx| project.open_buffer(file_path.clone(), cx)) + .await + .unwrap(); + + // AI creates file with initial content + cx.update(|cx| { + action_log.update(cx, |log, cx| log.buffer_created(buffer.clone(), cx)); + buffer.update(cx, |buffer, cx| buffer.set_text("ai content v1", cx)); + action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); + }); + project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) + .await + .unwrap(); + cx.run_until_parked(); + assert_ne!(unreviewed_hunks(&action_log, cx), vec![]); + + // User accepts the single hunk + action_log.update(cx, |log, cx| { + log.keep_edits_in_range(buffer.clone(), Anchor::MIN..Anchor::MAX, cx) + }); + cx.run_until_parked(); + assert_eq!(unreviewed_hunks(&action_log, cx), vec![]); + assert!(fs.is_file(path!("/dir/new_file").as_ref()).await); + + // AI modifies the file + cx.update(|cx| { + buffer.update(cx, |buffer, cx| buffer.set_text("ai content v2", cx)); + action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); + }); + project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) + .await + .unwrap(); + cx.run_until_parked(); + assert_ne!(unreviewed_hunks(&action_log, cx), vec![]); + + // User rejects the hunk + action_log + .update(cx, |log, cx| { + log.reject_edits_in_ranges(buffer.clone(), vec![Anchor::MIN..Anchor::MAX], cx) + }) + .await + .unwrap(); + cx.run_until_parked(); + assert!(fs.is_file(path!("/dir/new_file").as_ref()).await,); + assert_eq!( + buffer.read_with(cx, |buffer, _| buffer.text()), + "ai content v1" + ); + assert_eq!(unreviewed_hunks(&action_log, cx), vec![]); + } + + #[gpui::test] + async fn test_reject_edits_on_previously_accepted_created_file(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + + let file_path = project + .read_with(cx, |project, cx| { + project.find_project_path("dir/new_file", cx) + }) + .unwrap(); + let buffer = project + .update(cx, |project, cx| project.open_buffer(file_path.clone(), cx)) + .await + .unwrap(); + + // AI creates file with initial content + cx.update(|cx| { + action_log.update(cx, |log, cx| log.buffer_created(buffer.clone(), cx)); + buffer.update(cx, |buffer, cx| buffer.set_text("ai content v1", cx)); + action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); + }); + project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) + .await + .unwrap(); + cx.run_until_parked(); + + // User clicks "Accept All" + action_log.update(cx, |log, cx| log.keep_all_edits(cx)); + cx.run_until_parked(); + assert!(fs.is_file(path!("/dir/new_file").as_ref()).await); + assert_eq!(unreviewed_hunks(&action_log, cx), vec![]); // Hunks are cleared + + // AI modifies file again + cx.update(|cx| { + buffer.update(cx, |buffer, cx| buffer.set_text("ai content v2", cx)); + action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); + }); + project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) + .await + .unwrap(); + cx.run_until_parked(); + assert_ne!(unreviewed_hunks(&action_log, cx), vec![]); + + // User clicks "Reject All" + action_log + .update(cx, |log, cx| log.reject_all_edits(cx)) + .await; + cx.run_until_parked(); + assert!(fs.is_file(path!("/dir/new_file").as_ref()).await); + assert_eq!( + buffer.read_with(cx, |buffer, _| buffer.text()), + "ai content v1" + ); + assert_eq!(unreviewed_hunks(&action_log, cx), vec![]); + } + #[gpui::test(iterations = 100)] async fn test_random_diffs(mut rng: StdRng, cx: &mut TestAppContext) { init_test(cx); @@ -2201,4 +2505,61 @@ mod tests { .collect() }) } + + #[gpui::test] + async fn test_format_patch(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/dir"), + json!({"test.txt": "line 1\nline 2\nline 3\n"}), + ) + .await; + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + + let file_path = project + .read_with(cx, |project, cx| { + project.find_project_path("dir/test.txt", cx) + }) + .unwrap(); + let buffer = project + .update(cx, |project, cx| project.open_buffer(file_path, cx)) + .await + .unwrap(); + + cx.update(|cx| { + // Track the buffer and mark it as read first + action_log.update(cx, |log, cx| { + log.buffer_read(buffer.clone(), cx); + }); + + // Make some edits to create a patch + buffer.update(cx, |buffer, cx| { + buffer + .edit([(Point::new(1, 0)..Point::new(1, 6), "CHANGED")], None, cx) + .unwrap(); // Replace "line2" with "CHANGED" + }); + }); + + cx.run_until_parked(); + + // Get the patch + let patch = action_log.update(cx, |log, cx| log.unnotified_user_edits(cx)); + + // Verify the patch format contains expected unified diff elements + assert_eq!( + patch.unwrap(), + indoc! {" + --- a/dir/test.txt + +++ b/dir/test.txt + @@ -1,3 +1,3 @@ + line 1 + -line 2 + +CHANGED + line 3 + "} + ); + } } diff --git a/crates/assistant_tool/src/assistant_tool.rs b/crates/assistant_tool/src/assistant_tool.rs index 554b3f3f3cf7eb0bc369ee6fed67722755704443..22cbaac3f8b0df95df3c14a6237092cf83ae35ac 100644 --- a/crates/assistant_tool/src/assistant_tool.rs +++ b/crates/assistant_tool/src/assistant_tool.rs @@ -216,7 +216,12 @@ pub trait Tool: 'static + Send + Sync { /// Returns true if the tool needs the users's confirmation /// before having permission to run. - fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool; + fn needs_confirmation( + &self, + input: &serde_json::Value, + project: &Entity, + cx: &App, + ) -> bool; /// Returns true if the tool may perform edits. fn may_perform_edits(&self) -> bool; diff --git a/crates/assistant_tool/src/tool_working_set.rs b/crates/assistant_tool/src/tool_working_set.rs index 9a6ec49914eea3cd22f014ce2a5c014d1dca1220..c0a358917b499908d85fbc157212cf6db5b5e0eb 100644 --- a/crates/assistant_tool/src/tool_working_set.rs +++ b/crates/assistant_tool/src/tool_working_set.rs @@ -375,7 +375,12 @@ mod tests { false } - fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool { + fn needs_confirmation( + &self, + _input: &serde_json::Value, + _project: &Entity, + _cx: &App, + ) -> bool { true } diff --git a/crates/assistant_tools/Cargo.toml b/crates/assistant_tools/Cargo.toml index 2b8958feb1bddc719bcd085058cdb5162fd777b1..d4b8fa3afc3dc3311599a2d9e3e97f2984ebde40 100644 --- a/crates/assistant_tools/Cargo.toml +++ b/crates/assistant_tools/Cargo.toml @@ -20,9 +20,12 @@ anyhow.workspace = true assistant_tool.workspace = true buffer_diff.workspace = true chrono.workspace = true +client.workspace = true +cloud_llm_client.workspace = true collections.workspace = true component.workspace = true derive_more.workspace = true +diffy = "0.4.2" editor.workspace = true feature_flags.workspace = true futures.workspace = true @@ -62,7 +65,6 @@ web_search.workspace = true which.workspace = true workspace-hack.workspace = true workspace.workspace = true -zed_llm_client.workspace = true [dev-dependencies] lsp = { workspace = true, features = ["test-support"] } diff --git a/crates/assistant_tools/src/assistant_tools.rs b/crates/assistant_tools/src/assistant_tools.rs index eef792f526fb684e83752241194d293064a9f4f7..90bb2e9b7c0d937ef4cb0d844f02e90babbe819f 100644 --- a/crates/assistant_tools/src/assistant_tools.rs +++ b/crates/assistant_tools/src/assistant_tools.rs @@ -20,14 +20,13 @@ mod thinking_tool; mod ui; mod web_search_tool; -use std::sync::Arc; - use assistant_tool::ToolRegistry; use copy_path_tool::CopyPathTool; use gpui::{App, Entity}; use http_client::HttpClientWithUrl; use language_model::LanguageModelRegistry; use move_path_tool::MovePathTool; +use std::sync::Arc; use web_search_tool::WebSearchTool; pub(crate) use templates::*; @@ -37,13 +36,12 @@ use crate::delete_path_tool::DeletePathTool; use crate::diagnostics_tool::DiagnosticsTool; use crate::edit_file_tool::EditFileTool; use crate::fetch_tool::FetchTool; -use crate::find_path_tool::FindPathTool; use crate::list_directory_tool::ListDirectoryTool; use crate::now_tool::NowTool; use crate::thinking_tool::ThinkingTool; pub use edit_file_tool::{EditFileMode, EditFileToolInput}; -pub use find_path_tool::FindPathToolInput; +pub use find_path_tool::*; pub use grep_tool::{GrepTool, GrepToolInput}; pub use open_tool::OpenTool; pub use project_notifications_tool::ProjectNotificationsTool; diff --git a/crates/assistant_tools/src/copy_path_tool.rs b/crates/assistant_tools/src/copy_path_tool.rs index 1922b5677a94e0eff8fef2bc12bdab8a0971f395..e34ae9ff9305689593241b45fe986414a211ec3b 100644 --- a/crates/assistant_tools/src/copy_path_tool.rs +++ b/crates/assistant_tools/src/copy_path_tool.rs @@ -44,7 +44,7 @@ impl Tool for CopyPathTool { "copy_path".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/create_directory_tool.rs b/crates/assistant_tools/src/create_directory_tool.rs index 224e8357e5a6de98b088aede62daaa8524f2b6c2..11d969d234228e32a0f4baff2c9cad055a488993 100644 --- a/crates/assistant_tools/src/create_directory_tool.rs +++ b/crates/assistant_tools/src/create_directory_tool.rs @@ -37,7 +37,7 @@ impl Tool for CreateDirectoryTool { include_str!("./create_directory_tool/description.md").into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/delete_path_tool.rs b/crates/assistant_tools/src/delete_path_tool.rs index b13f9863c9f7203ceb5e236c8a06903be4b93b68..9e69c18b65d2f78618ac54b28c4808401c08bd72 100644 --- a/crates/assistant_tools/src/delete_path_tool.rs +++ b/crates/assistant_tools/src/delete_path_tool.rs @@ -33,7 +33,7 @@ impl Tool for DeletePathTool { "delete_path".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/diagnostics_tool.rs b/crates/assistant_tools/src/diagnostics_tool.rs index 84595a37b7069a194694cb70482928148116d465..12ab97f820d89e2d66deba3b58ab388b7f1c886e 100644 --- a/crates/assistant_tools/src/diagnostics_tool.rs +++ b/crates/assistant_tools/src/diagnostics_tool.rs @@ -46,7 +46,7 @@ impl Tool for DiagnosticsTool { "diagnostics".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/edit_agent.rs b/crates/assistant_tools/src/edit_agent.rs index 0184dff36c0a4130ce2880f3e7e84acb013aadfd..715d106a267e4013955c29613262cde9b92d0a8a 100644 --- a/crates/assistant_tools/src/edit_agent.rs +++ b/crates/assistant_tools/src/edit_agent.rs @@ -7,6 +7,7 @@ mod streaming_fuzzy_matcher; use crate::{Template, Templates}; use anyhow::Result; use assistant_tool::ActionLog; +use cloud_llm_client::CompletionIntent; use create_file_parser::{CreateFileParser, CreateFileParserEvent}; pub use edit_parser::EditFormat; use edit_parser::{EditParser, EditParserEvent, EditParserMetrics}; @@ -29,7 +30,6 @@ use std::{cmp, iter, mem, ops::Range, path::PathBuf, pin::Pin, sync::Arc, task:: use streaming_diff::{CharOperation, StreamingDiff}; use streaming_fuzzy_matcher::StreamingFuzzyMatcher; use util::debug_panic; -use zed_llm_client::CompletionIntent; #[derive(Serialize)] struct CreateFilePromptTemplate { @@ -962,7 +962,7 @@ mod tests { ); cx.run_until_parked(); - model.stream_last_completion_response("a"); + model.send_last_completion_stream_text_chunk("a"); cx.run_until_parked(); assert_eq!(drain_events(&mut events), vec![]); assert_eq!( @@ -974,7 +974,7 @@ mod tests { None ); - model.stream_last_completion_response("bc"); + model.send_last_completion_stream_text_chunk("bc"); cx.run_until_parked(); assert_eq!( drain_events(&mut events), @@ -996,7 +996,7 @@ mod tests { }) ); - model.stream_last_completion_response("abX"); + model.send_last_completion_stream_text_chunk("abX"); cx.run_until_parked(); assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]); assert_eq!( @@ -1011,7 +1011,7 @@ mod tests { }) ); - model.stream_last_completion_response("cY"); + model.send_last_completion_stream_text_chunk("cY"); cx.run_until_parked(); assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]); assert_eq!( @@ -1026,8 +1026,8 @@ mod tests { }) ); - model.stream_last_completion_response(""); - model.stream_last_completion_response("hall"); + model.send_last_completion_stream_text_chunk(""); + model.send_last_completion_stream_text_chunk("hall"); cx.run_until_parked(); assert_eq!(drain_events(&mut events), vec![]); assert_eq!( @@ -1042,8 +1042,8 @@ mod tests { }) ); - model.stream_last_completion_response("ucinated old"); - model.stream_last_completion_response(""); + model.send_last_completion_stream_text_chunk("ucinated old"); + model.send_last_completion_stream_text_chunk(""); cx.run_until_parked(); assert_eq!( drain_events(&mut events), @@ -1061,8 +1061,8 @@ mod tests { }) ); - model.stream_last_completion_response("hallucinated new"); + model.send_last_completion_stream_text_chunk("hallucinated new"); cx.run_until_parked(); assert_eq!(drain_events(&mut events), vec![]); assert_eq!( @@ -1077,7 +1077,7 @@ mod tests { }) ); - model.stream_last_completion_response("\nghi\nj"); + model.send_last_completion_stream_text_chunk("\nghi\nj"); cx.run_until_parked(); assert_eq!( drain_events(&mut events), @@ -1099,8 +1099,8 @@ mod tests { }) ); - model.stream_last_completion_response("kl"); - model.stream_last_completion_response(""); + model.send_last_completion_stream_text_chunk("kl"); + model.send_last_completion_stream_text_chunk(""); cx.run_until_parked(); assert_eq!( drain_events(&mut events), @@ -1122,7 +1122,7 @@ mod tests { }) ); - model.stream_last_completion_response("GHI"); + model.send_last_completion_stream_text_chunk("GHI"); cx.run_until_parked(); assert_eq!( drain_events(&mut events), @@ -1367,7 +1367,9 @@ mod tests { cx.background_spawn(async move { for chunk in chunks { executor.simulate_random_delay().await; - model.as_fake().stream_last_completion_response(chunk); + model + .as_fake() + .send_last_completion_stream_text_chunk(chunk); } model.as_fake().end_last_completion_stream(); }) diff --git a/crates/assistant_tools/src/edit_agent/evals.rs b/crates/assistant_tools/src/edit_agent/evals.rs index d2ee03f08f142b024b69eeaea739ba121c35b375..9a8e7624559e9a1284ace7c932f428c7389b6254 100644 --- a/crates/assistant_tools/src/edit_agent/evals.rs +++ b/crates/assistant_tools/src/edit_agent/evals.rs @@ -12,6 +12,7 @@ use collections::HashMap; use fs::FakeFs; use futures::{FutureExt, future::LocalBoxFuture}; use gpui::{AppContext, TestAppContext, Timer}; +use http_client::StatusCode; use indoc::{formatdoc, indoc}; use language_model::{ LanguageModelRegistry, LanguageModelRequestTool, LanguageModelToolResult, @@ -365,17 +366,23 @@ fn eval_disable_cursor_blinking() { // Model | Pass rate // ============================================ // - // claude-3.7-sonnet | 0.99 (2025-06-14) - // claude-sonnet-4 | 0.85 (2025-06-14) - // gemini-2.5-pro-preview-latest | 0.97 (2025-06-16) - // gemini-2.5-flash-preview-04-17 | - // gpt-4.1 | + // claude-3.7-sonnet | 0.59 (2025-07-14) + // claude-sonnet-4 | 0.81 (2025-07-14) + // gemini-2.5-pro | 0.95 (2025-07-14) + // gemini-2.5-flash-preview-04-17 | 0.78 (2025-07-14) + // gpt-4.1 | 0.00 (2025-07-14) (follows edit_description too literally) let input_file_path = "root/editor.rs"; let input_file_content = include_str!("evals/fixtures/disable_cursor_blinking/before.rs"); let edit_description = "Comment out the call to `BlinkManager::enable`"; + let possible_diffs = vec![ + include_str!("evals/fixtures/disable_cursor_blinking/possible-01.diff"), + include_str!("evals/fixtures/disable_cursor_blinking/possible-02.diff"), + include_str!("evals/fixtures/disable_cursor_blinking/possible-03.diff"), + include_str!("evals/fixtures/disable_cursor_blinking/possible-04.diff"), + ]; eval( 100, - 0.95, + 0.51, 0.05, EvalInput::from_conversation( vec![ @@ -433,11 +440,7 @@ fn eval_disable_cursor_blinking() { ), ], Some(input_file_content.into()), - EvalAssertion::judge_diff(indoc! {" - - Calls to BlinkManager in `observe_window_activation` were commented out - - The call to `blink_manager.enable` above the call to show_cursor_names was commented out - - All the edits have valid indentation - "}), + EvalAssertion::assert_diff_any(possible_diffs), ), ); } @@ -1655,28 +1658,61 @@ impl EditAgentTest { } async fn retry_on_rate_limit(mut request: impl AsyncFnMut() -> Result) -> Result { + const MAX_RETRIES: usize = 20; let mut attempt = 0; + loop { attempt += 1; - match request().await { - Ok(result) => return Ok(result), - Err(err) => match err.downcast::() { - Ok(err) => match &err { + let response = request().await; + + if attempt >= MAX_RETRIES { + return response; + } + + let retry_delay = match &response { + Ok(_) => None, + Err(err) => match err.downcast_ref::() { + Some(err) => match &err { LanguageModelCompletionError::RateLimitExceeded { retry_after, .. } | LanguageModelCompletionError::ServerOverloaded { retry_after, .. } => { - let retry_after = retry_after.unwrap_or(Duration::from_secs(5)); - // Wait for the duration supplied, with some jitter to avoid all requests being made at the same time. - let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0)); - eprintln!( - "Attempt #{attempt}: {err}. Retry after {retry_after:?} + jitter of {jitter:?}" - ); - Timer::after(retry_after + jitter).await; - continue; + Some(retry_after.unwrap_or(Duration::from_secs(5))) + } + LanguageModelCompletionError::UpstreamProviderError { + status, + retry_after, + .. + } => { + // Only retry for specific status codes + let should_retry = matches!( + *status, + StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE + ) || status.as_u16() == 529; + + if should_retry { + // Use server-provided retry_after if available, otherwise use default + Some(retry_after.unwrap_or(Duration::from_secs(5))) + } else { + None + } + } + LanguageModelCompletionError::ApiReadResponseError { .. } + | LanguageModelCompletionError::ApiInternalServerError { .. } + | LanguageModelCompletionError::HttpSend { .. } => { + // Exponential backoff for transient I/O and internal server errors + Some(Duration::from_secs(2_u64.pow((attempt - 1) as u32).min(30))) } - _ => return Err(err.into()), + _ => None, }, - Err(err) => return Err(err), + _ => None, }, + }; + + if let Some(retry_after) = retry_delay { + let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0)); + eprintln!("Attempt #{attempt}: Retry after {retry_after:?} + jitter of {jitter:?}"); + Timer::after(retry_after + jitter).await; + } else { + return response; } } } diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-01.diff b/crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-01.diff new file mode 100644 index 0000000000000000000000000000000000000000..1a38a1967f94c974de491c712babb7882020d697 --- /dev/null +++ b/crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-01.diff @@ -0,0 +1,28 @@ +--- before.rs 2025-07-07 11:37:48.434629001 +0300 ++++ expected.rs 2025-07-14 10:33:53.346906775 +0300 +@@ -1780,11 +1780,11 @@ + cx.observe_window_activation(window, |editor, window, cx| { + let active = window.is_window_active(); + editor.blink_manager.update(cx, |blink_manager, cx| { +- if active { +- blink_manager.enable(cx); +- } else { +- blink_manager.disable(cx); +- } ++ // if active { ++ // blink_manager.enable(cx); ++ // } else { ++ // blink_manager.disable(cx); ++ // } + }); + }), + ], +@@ -18463,7 +18463,7 @@ + } + + self.blink_manager.update(cx, |blink_manager, cx| { +- blink_manager.enable(cx); ++ // blink_manager.enable(cx); + }); + self.show_cursor_names(window, cx); + self.buffer.update(cx, |buffer, cx| { diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-02.diff b/crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-02.diff new file mode 100644 index 0000000000000000000000000000000000000000..b484cce48f71b232ddaa947a73940b8bf11846c6 --- /dev/null +++ b/crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-02.diff @@ -0,0 +1,29 @@ +@@ -1778,13 +1778,13 @@ + cx.observe_global_in::(window, Self::settings_changed), + observe_buffer_font_size_adjustment(cx, |_, cx| cx.notify()), + cx.observe_window_activation(window, |editor, window, cx| { +- let active = window.is_window_active(); ++ // let active = window.is_window_active(); + editor.blink_manager.update(cx, |blink_manager, cx| { +- if active { +- blink_manager.enable(cx); +- } else { +- blink_manager.disable(cx); +- } ++ // if active { ++ // blink_manager.enable(cx); ++ // } else { ++ // blink_manager.disable(cx); ++ // } + }); + }), + ], +@@ -18463,7 +18463,7 @@ + } + + self.blink_manager.update(cx, |blink_manager, cx| { +- blink_manager.enable(cx); ++ // blink_manager.enable(cx); + }); + self.show_cursor_names(window, cx); + self.buffer.update(cx, |buffer, cx| { diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-03.diff b/crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-03.diff new file mode 100644 index 0000000000000000000000000000000000000000..431e34e48a250bff80efbd5a2cc20ecc25be1020 --- /dev/null +++ b/crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-03.diff @@ -0,0 +1,34 @@ +@@ -1774,17 +1774,17 @@ + cx.observe(&buffer, Self::on_buffer_changed), + cx.subscribe_in(&buffer, window, Self::on_buffer_event), + cx.observe_in(&display_map, window, Self::on_display_map_changed), +- cx.observe(&blink_manager, |_, _, cx| cx.notify()), ++ // cx.observe(&blink_manager, |_, _, cx| cx.notify()), + cx.observe_global_in::(window, Self::settings_changed), + observe_buffer_font_size_adjustment(cx, |_, cx| cx.notify()), + cx.observe_window_activation(window, |editor, window, cx| { +- let active = window.is_window_active(); ++ // let active = window.is_window_active(); + editor.blink_manager.update(cx, |blink_manager, cx| { +- if active { +- blink_manager.enable(cx); +- } else { +- blink_manager.disable(cx); +- } ++ // if active { ++ // blink_manager.enable(cx); ++ // } else { ++ // blink_manager.disable(cx); ++ // } + }); + }), + ], +@@ -18463,7 +18463,7 @@ + } + + self.blink_manager.update(cx, |blink_manager, cx| { +- blink_manager.enable(cx); ++ // blink_manager.enable(cx); + }); + self.show_cursor_names(window, cx); + self.buffer.update(cx, |buffer, cx| { diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-04.diff b/crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-04.diff new file mode 100644 index 0000000000000000000000000000000000000000..64a6b85dd3751407db65da74656b66ee1beaf58b --- /dev/null +++ b/crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-04.diff @@ -0,0 +1,33 @@ +@@ -1774,17 +1774,17 @@ + cx.observe(&buffer, Self::on_buffer_changed), + cx.subscribe_in(&buffer, window, Self::on_buffer_event), + cx.observe_in(&display_map, window, Self::on_display_map_changed), +- cx.observe(&blink_manager, |_, _, cx| cx.notify()), ++ // cx.observe(&blink_manager, |_, _, cx| cx.notify()), + cx.observe_global_in::(window, Self::settings_changed), + observe_buffer_font_size_adjustment(cx, |_, cx| cx.notify()), + cx.observe_window_activation(window, |editor, window, cx| { + let active = window.is_window_active(); + editor.blink_manager.update(cx, |blink_manager, cx| { +- if active { +- blink_manager.enable(cx); +- } else { +- blink_manager.disable(cx); +- } ++ // if active { ++ // blink_manager.enable(cx); ++ // } else { ++ // blink_manager.disable(cx); ++ // } + }); + }), + ], +@@ -18463,7 +18463,7 @@ + } + + self.blink_manager.update(cx, |blink_manager, cx| { +- blink_manager.enable(cx); ++ // blink_manager.enable(cx); + }); + self.show_cursor_names(window, cx); + self.buffer.update(cx, |buffer, cx| { diff --git a/crates/assistant_tools/src/edit_file_tool.rs b/crates/assistant_tools/src/edit_file_tool.rs index 0423f56145bc484108ff958419353e2378a3779a..dce9f49abdde7e0ea00d976a4d9f029f98f7a067 100644 --- a/crates/assistant_tools/src/edit_file_tool.rs +++ b/crates/assistant_tools/src/edit_file_tool.rs @@ -25,6 +25,7 @@ use language::{ }; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; use markdown::{Markdown, MarkdownElement, MarkdownStyle}; +use paths; use project::{ Project, ProjectPath, lsp_store::{FormatTrigger, LspFormatTarget}, @@ -126,8 +127,47 @@ impl Tool for EditFileTool { "edit_file".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { - false + fn needs_confirmation( + &self, + input: &serde_json::Value, + project: &Entity, + cx: &App, + ) -> bool { + if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions { + return false; + } + + let Ok(input) = serde_json::from_value::(input.clone()) else { + // If it's not valid JSON, it's going to error and confirming won't do anything. + return false; + }; + + // If any path component matches the local settings folder, then this could affect + // the editor in ways beyond the project source, so prompt. + let local_settings_folder = paths::local_settings_folder_relative_path(); + let path = Path::new(&input.path); + if path + .components() + .any(|component| component.as_os_str() == local_settings_folder.as_os_str()) + { + return true; + } + + // It's also possible that the global config dir is configured to be inside the project, + // so check for that edge case too. + if let Ok(canonical_path) = std::fs::canonicalize(&input.path) { + if canonical_path.starts_with(paths::config_dir()) { + return true; + } + } + + // Check if path is inside the global config directory + // First check if it's already inside project - if not, try to canonicalize + let project_path = project.read(cx).find_project_path(&input.path, cx); + + // If the path is inside the project, and it's not one of the above edge cases, + // then no confirmation is necessary. Otherwise, confirmation is necessary. + project_path.is_none() } fn may_perform_edits(&self) -> bool { @@ -148,7 +188,25 @@ impl Tool for EditFileTool { fn ui_text(&self, input: &serde_json::Value) -> String { match serde_json::from_value::(input.clone()) { - Ok(input) => input.display_description, + Ok(input) => { + let path = Path::new(&input.path); + let mut description = input.display_description.clone(); + + // Add context about why confirmation may be needed + let local_settings_folder = paths::local_settings_folder_relative_path(); + if path + .components() + .any(|c| c.as_os_str() == local_settings_folder.as_os_str()) + { + description.push_str(" (local settings)"); + } else if let Ok(canonical_path) = std::fs::canonicalize(&input.path) { + if canonical_path.starts_with(paths::config_dir()) { + description.push_str(" (global settings)"); + } + } + + description + } Err(_) => "Editing file".to_string(), } } @@ -278,6 +336,9 @@ impl Tool for EditFileTool { .unwrap_or(false); if format_on_save_enabled { + action_log.update(cx, |log, cx| { + log.buffer_edited(buffer.clone(), cx); + })?; let format_task = project.update(cx, |project, cx| { project.format( HashSet::from_iter([buffer.clone()]), @@ -1172,19 +1233,20 @@ async fn build_buffer_diff( #[cfg(test)] mod tests { use super::*; + use ::fs::Fs; use client::TelemetrySettings; - use fs::{FakeFs, Fs}; use gpui::{TestAppContext, UpdateGlobal}; use language_model::fake_provider::FakeLanguageModel; use serde_json::json; use settings::SettingsStore; + use std::fs; use util::path; #[gpui::test] async fn test_edit_nonexistent_file(cx: &mut TestAppContext) { init_test(cx); - let fs = FakeFs::new(cx.executor()); + let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/root", json!({})).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); @@ -1274,7 +1336,7 @@ mod tests { ) -> anyhow::Result { init_test(cx); - let fs = FakeFs::new(cx.executor()); + let fs = project::FakeFs::new(cx.executor()); fs.insert_tree( "/root", json!({ @@ -1381,6 +1443,21 @@ mod tests { cx.set_global(settings_store); language::init(cx); TelemetrySettings::register(cx); + agent_settings::AgentSettings::register(cx); + Project::init_settings(cx); + }); + } + + fn init_test_with_config(cx: &mut TestAppContext, data_dir: &Path) { + cx.update(|cx| { + // Set custom data directory (config will be under data_dir/config) + paths::set_custom_data_dir(data_dir.to_str().unwrap()); + + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + TelemetrySettings::register(cx); + agent_settings::AgentSettings::register(cx); Project::init_settings(cx); }); } @@ -1389,7 +1466,7 @@ mod tests { async fn test_format_on_save(cx: &mut TestAppContext) { init_test(cx); - let fs = FakeFs::new(cx.executor()); + let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/root", json!({"src": {}})).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; @@ -1500,7 +1577,7 @@ mod tests { // Stream the unformatted content cx.executor().run_until_parked(); - model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string()); + model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string()); model.end_last_completion_stream(); edit_task.await @@ -1564,7 +1641,7 @@ mod tests { // Stream the unformatted content cx.executor().run_until_parked(); - model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string()); + model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string()); model.end_last_completion_stream(); edit_task.await @@ -1588,7 +1665,7 @@ mod tests { async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) { init_test(cx); - let fs = FakeFs::new(cx.executor()); + let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/root", json!({"src": {}})).await; // Create a simple file with trailing whitespace @@ -1643,7 +1720,9 @@ mod tests { // Stream the content with trailing whitespace cx.executor().run_until_parked(); - model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string()); + model.send_last_completion_stream_text_chunk( + CONTENT_WITH_TRAILING_WHITESPACE.to_string(), + ); model.end_last_completion_stream(); edit_task.await @@ -1700,7 +1779,9 @@ mod tests { // Stream the content with trailing whitespace cx.executor().run_until_parked(); - model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string()); + model.send_last_completion_stream_text_chunk( + CONTENT_WITH_TRAILING_WHITESPACE.to_string(), + ); model.end_last_completion_stream(); edit_task.await @@ -1720,4 +1801,641 @@ mod tests { "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled" ); } + + #[gpui::test] + async fn test_needs_confirmation(cx: &mut TestAppContext) { + init_test(cx); + let tool = Arc::new(EditFileTool); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/root", json!({})).await; + + // Test 1: Path with .zed component should require confirmation + let input_with_zed = json!({ + "display_description": "Edit settings", + "path": ".zed/settings.json", + "mode": "edit" + }); + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + cx.update(|cx| { + assert!( + tool.needs_confirmation(&input_with_zed, &project, cx), + "Path with .zed component should require confirmation" + ); + }); + + // Test 2: Absolute path should require confirmation + let input_absolute = json!({ + "display_description": "Edit file", + "path": "/etc/hosts", + "mode": "edit" + }); + cx.update(|cx| { + assert!( + tool.needs_confirmation(&input_absolute, &project, cx), + "Absolute path should require confirmation" + ); + }); + + // Test 3: Relative path without .zed should not require confirmation + let input_relative = json!({ + "display_description": "Edit file", + "path": "root/src/main.rs", + "mode": "edit" + }); + cx.update(|cx| { + assert!( + !tool.needs_confirmation(&input_relative, &project, cx), + "Relative path without .zed should not require confirmation" + ); + }); + + // Test 4: Path with .zed in the middle should require confirmation + let input_zed_middle = json!({ + "display_description": "Edit settings", + "path": "root/.zed/tasks.json", + "mode": "edit" + }); + cx.update(|cx| { + assert!( + tool.needs_confirmation(&input_zed_middle, &project, cx), + "Path with .zed in any component should require confirmation" + ); + }); + + // Test 5: When always_allow_tool_actions is enabled, no confirmation needed + cx.update(|cx| { + let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); + settings.always_allow_tool_actions = true; + agent_settings::AgentSettings::override_global(settings, cx); + + assert!( + !tool.needs_confirmation(&input_with_zed, &project, cx), + "When always_allow_tool_actions is true, no confirmation should be needed" + ); + assert!( + !tool.needs_confirmation(&input_absolute, &project, cx), + "When always_allow_tool_actions is true, no confirmation should be needed for absolute paths" + ); + }); + } + + #[gpui::test] + async fn test_ui_text_shows_correct_context(cx: &mut TestAppContext) { + // Set up a custom config directory for testing + let temp_dir = tempfile::tempdir().unwrap(); + init_test_with_config(cx, temp_dir.path()); + + let tool = Arc::new(EditFileTool); + + // Test ui_text shows context for various paths + let test_cases = vec![ + ( + json!({ + "display_description": "Update config", + "path": ".zed/settings.json", + "mode": "edit" + }), + "Update config (local settings)", + ".zed path should show local settings context", + ), + ( + json!({ + "display_description": "Fix bug", + "path": "src/.zed/local.json", + "mode": "edit" + }), + "Fix bug (local settings)", + "Nested .zed path should show local settings context", + ), + ( + json!({ + "display_description": "Update readme", + "path": "README.md", + "mode": "edit" + }), + "Update readme", + "Normal path should not show additional context", + ), + ( + json!({ + "display_description": "Edit config", + "path": "config.zed", + "mode": "edit" + }), + "Edit config", + ".zed as extension should not show context", + ), + ]; + + for (input, expected_text, description) in test_cases { + cx.update(|_cx| { + let ui_text = tool.ui_text(&input); + assert_eq!(ui_text, expected_text, "Failed for case: {}", description); + }); + } + } + + #[gpui::test] + async fn test_needs_confirmation_outside_project(cx: &mut TestAppContext) { + init_test(cx); + let tool = Arc::new(EditFileTool); + let fs = project::FakeFs::new(cx.executor()); + + // Create a project in /project directory + fs.insert_tree("/project", json!({})).await; + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + + // Test file outside project requires confirmation + let input_outside = json!({ + "display_description": "Edit file", + "path": "/outside/file.txt", + "mode": "edit" + }); + cx.update(|cx| { + assert!( + tool.needs_confirmation(&input_outside, &project, cx), + "File outside project should require confirmation" + ); + }); + + // Test file inside project doesn't require confirmation + let input_inside = json!({ + "display_description": "Edit file", + "path": "project/file.txt", + "mode": "edit" + }); + cx.update(|cx| { + assert!( + !tool.needs_confirmation(&input_inside, &project, cx), + "File inside project should not require confirmation" + ); + }); + } + + #[gpui::test] + async fn test_needs_confirmation_config_paths(cx: &mut TestAppContext) { + // Set up a custom data directory for testing + let temp_dir = tempfile::tempdir().unwrap(); + init_test_with_config(cx, temp_dir.path()); + + let tool = Arc::new(EditFileTool); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/home/user/myproject", json!({})).await; + let project = Project::test(fs.clone(), [path!("/home/user/myproject").as_ref()], cx).await; + + // Get the actual local settings folder name + let local_settings_folder = paths::local_settings_folder_relative_path(); + + // Test various config path patterns + let test_cases = vec![ + ( + format!("{}/settings.json", local_settings_folder.display()), + true, + "Top-level local settings file".to_string(), + ), + ( + format!( + "myproject/{}/settings.json", + local_settings_folder.display() + ), + true, + "Local settings in project path".to_string(), + ), + ( + format!("src/{}/config.toml", local_settings_folder.display()), + true, + "Local settings in subdirectory".to_string(), + ), + ( + ".zed.backup/file.txt".to_string(), + true, + ".zed.backup is outside project".to_string(), + ), + ( + "my.zed/file.txt".to_string(), + true, + "my.zed is outside project".to_string(), + ), + ( + "myproject/src/file.zed".to_string(), + false, + ".zed as file extension".to_string(), + ), + ( + "myproject/normal/path/file.rs".to_string(), + false, + "Normal file without config paths".to_string(), + ), + ]; + + for (path, should_confirm, description) in test_cases { + let input = json!({ + "display_description": "Edit file", + "path": path, + "mode": "edit" + }); + cx.update(|cx| { + assert_eq!( + tool.needs_confirmation(&input, &project, cx), + should_confirm, + "Failed for case: {} - path: {}", + description, + path + ); + }); + } + } + + #[gpui::test] + async fn test_needs_confirmation_global_config(cx: &mut TestAppContext) { + // Set up a custom data directory for testing + let temp_dir = tempfile::tempdir().unwrap(); + init_test_with_config(cx, temp_dir.path()); + + let tool = Arc::new(EditFileTool); + let fs = project::FakeFs::new(cx.executor()); + + // Create test files in the global config directory + let global_config_dir = paths::config_dir(); + fs::create_dir_all(&global_config_dir).unwrap(); + let global_settings_path = global_config_dir.join("settings.json"); + fs::write(&global_settings_path, "{}").unwrap(); + + fs.insert_tree("/project", json!({})).await; + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + + // Test global config paths + let test_cases = vec![ + ( + global_settings_path.to_str().unwrap().to_string(), + true, + "Global settings file should require confirmation", + ), + ( + global_config_dir + .join("keymap.json") + .to_str() + .unwrap() + .to_string(), + true, + "Global keymap file should require confirmation", + ), + ( + "project/normal_file.rs".to_string(), + false, + "Normal project file should not require confirmation", + ), + ]; + + for (path, should_confirm, description) in test_cases { + let input = json!({ + "display_description": "Edit file", + "path": path, + "mode": "edit" + }); + cx.update(|cx| { + assert_eq!( + tool.needs_confirmation(&input, &project, cx), + should_confirm, + "Failed for case: {}", + description + ); + }); + } + } + + #[gpui::test] + async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) { + init_test(cx); + let tool = Arc::new(EditFileTool); + let fs = project::FakeFs::new(cx.executor()); + + // Create multiple worktree directories + fs.insert_tree( + "/workspace/frontend", + json!({ + "src": { + "main.js": "console.log('frontend');" + } + }), + ) + .await; + fs.insert_tree( + "/workspace/backend", + json!({ + "src": { + "main.rs": "fn main() {}" + } + }), + ) + .await; + fs.insert_tree( + "/workspace/shared", + json!({ + ".zed": { + "settings.json": "{}" + } + }), + ) + .await; + + // Create project with multiple worktrees + let project = Project::test( + fs.clone(), + [ + path!("/workspace/frontend").as_ref(), + path!("/workspace/backend").as_ref(), + path!("/workspace/shared").as_ref(), + ], + cx, + ) + .await; + + // Test files in different worktrees + let test_cases = vec![ + ("frontend/src/main.js", false, "File in first worktree"), + ("backend/src/main.rs", false, "File in second worktree"), + ( + "shared/.zed/settings.json", + true, + ".zed file in third worktree", + ), + ("/etc/hosts", true, "Absolute path outside all worktrees"), + ( + "../outside/file.txt", + true, + "Relative path outside worktrees", + ), + ]; + + for (path, should_confirm, description) in test_cases { + let input = json!({ + "display_description": "Edit file", + "path": path, + "mode": "edit" + }); + cx.update(|cx| { + assert_eq!( + tool.needs_confirmation(&input, &project, cx), + should_confirm, + "Failed for case: {} - path: {}", + description, + path + ); + }); + } + } + + #[gpui::test] + async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) { + init_test(cx); + let tool = Arc::new(EditFileTool); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/project", + json!({ + ".zed": { + "settings.json": "{}" + }, + "src": { + ".zed": { + "local.json": "{}" + } + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + + // Test edge cases + let test_cases = vec![ + // Empty path - find_project_path returns Some for empty paths + ("", false, "Empty path is treated as project root"), + // Root directory + ("/", true, "Root directory should be outside project"), + // Parent directory references - find_project_path resolves these + ( + "project/../other", + false, + "Path with .. is resolved by find_project_path", + ), + ( + "project/./src/file.rs", + false, + "Path with . should work normally", + ), + // Windows-style paths (if on Windows) + #[cfg(target_os = "windows")] + ("C:\\Windows\\System32\\hosts", true, "Windows system path"), + #[cfg(target_os = "windows")] + ("project\\src\\main.rs", false, "Windows-style project path"), + ]; + + for (path, should_confirm, description) in test_cases { + let input = json!({ + "display_description": "Edit file", + "path": path, + "mode": "edit" + }); + cx.update(|cx| { + assert_eq!( + tool.needs_confirmation(&input, &project, cx), + should_confirm, + "Failed for case: {} - path: {}", + description, + path + ); + }); + } + } + + #[gpui::test] + async fn test_ui_text_with_all_path_types(cx: &mut TestAppContext) { + init_test(cx); + let tool = Arc::new(EditFileTool); + + // Test UI text for various scenarios + let test_cases = vec![ + ( + json!({ + "display_description": "Update config", + "path": ".zed/settings.json", + "mode": "edit" + }), + "Update config (local settings)", + ".zed path should show local settings context", + ), + ( + json!({ + "display_description": "Fix bug", + "path": "src/.zed/local.json", + "mode": "edit" + }), + "Fix bug (local settings)", + "Nested .zed path should show local settings context", + ), + ( + json!({ + "display_description": "Update readme", + "path": "README.md", + "mode": "edit" + }), + "Update readme", + "Normal path should not show additional context", + ), + ( + json!({ + "display_description": "Edit config", + "path": "config.zed", + "mode": "edit" + }), + "Edit config", + ".zed as extension should not show context", + ), + ]; + + for (input, expected_text, description) in test_cases { + cx.update(|_cx| { + let ui_text = tool.ui_text(&input); + assert_eq!(ui_text, expected_text, "Failed for case: {}", description); + }); + } + } + + #[gpui::test] + async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) { + init_test(cx); + let tool = Arc::new(EditFileTool); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/project", + json!({ + "existing.txt": "content", + ".zed": { + "settings.json": "{}" + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + + // Test different EditFileMode values + let modes = vec![ + EditFileMode::Edit, + EditFileMode::Create, + EditFileMode::Overwrite, + ]; + + for mode in modes { + // Test .zed path with different modes + let input_zed = json!({ + "display_description": "Edit settings", + "path": "project/.zed/settings.json", + "mode": mode + }); + cx.update(|cx| { + assert!( + tool.needs_confirmation(&input_zed, &project, cx), + ".zed path should require confirmation regardless of mode: {:?}", + mode + ); + }); + + // Test outside path with different modes + let input_outside = json!({ + "display_description": "Edit file", + "path": "/outside/file.txt", + "mode": mode + }); + cx.update(|cx| { + assert!( + tool.needs_confirmation(&input_outside, &project, cx), + "Outside path should require confirmation regardless of mode: {:?}", + mode + ); + }); + + // Test normal path with different modes + let input_normal = json!({ + "display_description": "Edit file", + "path": "project/normal.txt", + "mode": mode + }); + cx.update(|cx| { + assert!( + !tool.needs_confirmation(&input_normal, &project, cx), + "Normal path should not require confirmation regardless of mode: {:?}", + mode + ); + }); + } + } + + #[gpui::test] + async fn test_always_allow_tool_actions_bypasses_all_checks(cx: &mut TestAppContext) { + // Set up with custom directories for deterministic testing + let temp_dir = tempfile::tempdir().unwrap(); + init_test_with_config(cx, temp_dir.path()); + + let tool = Arc::new(EditFileTool); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/project", json!({})).await; + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + + // Enable always_allow_tool_actions + cx.update(|cx| { + let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); + settings.always_allow_tool_actions = true; + agent_settings::AgentSettings::override_global(settings, cx); + }); + + // Test that all paths that normally require confirmation are bypassed + let global_settings_path = paths::config_dir().join("settings.json"); + fs::create_dir_all(paths::config_dir()).unwrap(); + fs::write(&global_settings_path, "{}").unwrap(); + + let test_cases = vec![ + ".zed/settings.json", + "project/.zed/config.toml", + global_settings_path.to_str().unwrap(), + "/etc/hosts", + "/absolute/path/file.txt", + "../outside/project.txt", + ]; + + for path in test_cases { + let input = json!({ + "display_description": "Edit file", + "path": path, + "mode": "edit" + }); + cx.update(|cx| { + assert!( + !tool.needs_confirmation(&input, &project, cx), + "Path {} should not require confirmation when always_allow_tool_actions is true", + path + ); + }); + } + + // Disable always_allow_tool_actions and verify confirmation is required again + cx.update(|cx| { + let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); + settings.always_allow_tool_actions = false; + agent_settings::AgentSettings::override_global(settings, cx); + }); + + // Verify .zed path requires confirmation again + let input = json!({ + "display_description": "Edit file", + "path": ".zed/settings.json", + "mode": "edit" + }); + cx.update(|cx| { + assert!( + tool.needs_confirmation(&input, &project, cx), + ".zed path should require confirmation when always_allow_tool_actions is false" + ); + }); + } } diff --git a/crates/assistant_tools/src/fetch_tool.rs b/crates/assistant_tools/src/fetch_tool.rs index c8fa600e831e60ea22bfbaa8b54a9be3b142c567..a31ec39268d7afcf6072a783c5faa40f2a2e0f78 100644 --- a/crates/assistant_tools/src/fetch_tool.rs +++ b/crates/assistant_tools/src/fetch_tool.rs @@ -69,10 +69,9 @@ impl FetchTool { .to_str() .context("invalid Content-Type header")?; let content_type = match content_type { - "text/html" => ContentType::Html, - "text/plain" => ContentType::Plaintext, + "text/html" | "application/xhtml+xml" => ContentType::Html, "application/json" => ContentType::Json, - _ => ContentType::Html, + _ => ContentType::Plaintext, }; match content_type { @@ -117,7 +116,7 @@ impl Tool for FetchTool { "fetch".to_string() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/find_path_tool.rs b/crates/assistant_tools/src/find_path_tool.rs index fd0e44e42cbe6fad373de21be6e263620c07d3d6..affc01941735e5e97b3c3a509cacbce5fd3c7264 100644 --- a/crates/assistant_tools/src/find_path_tool.rs +++ b/crates/assistant_tools/src/find_path_tool.rs @@ -55,7 +55,7 @@ impl Tool for FindPathTool { "find_path".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/grep_tool.rs b/crates/assistant_tools/src/grep_tool.rs index 053273d71bc01191c19fa1e498290d77e8caac7c..43c3d1d9904e486a9a4309ff24a9e6d4be0dfdca 100644 --- a/crates/assistant_tools/src/grep_tool.rs +++ b/crates/assistant_tools/src/grep_tool.rs @@ -57,7 +57,7 @@ impl Tool for GrepTool { "grep".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/list_directory_tool.rs b/crates/assistant_tools/src/list_directory_tool.rs index 723416e2ce1048d42ceca2af18667817a467d1f2..b1980615d677894264ce2f785068b9e99cb55a61 100644 --- a/crates/assistant_tools/src/list_directory_tool.rs +++ b/crates/assistant_tools/src/list_directory_tool.rs @@ -45,7 +45,7 @@ impl Tool for ListDirectoryTool { "list_directory".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/move_path_tool.rs b/crates/assistant_tools/src/move_path_tool.rs index 27ae10151d4e91f951e198e850e5ff6fc2fb331b..c1cbbf848d53d4e0341bad84fbc7d8cf90f142ac 100644 --- a/crates/assistant_tools/src/move_path_tool.rs +++ b/crates/assistant_tools/src/move_path_tool.rs @@ -42,7 +42,7 @@ impl Tool for MovePathTool { "move_path".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/now_tool.rs b/crates/assistant_tools/src/now_tool.rs index b6b1cf90a43b487684b9c8f0d4f6a69a14af6455..b51b91d3d51b6cc15e54faab55be50287815d96c 100644 --- a/crates/assistant_tools/src/now_tool.rs +++ b/crates/assistant_tools/src/now_tool.rs @@ -33,7 +33,7 @@ impl Tool for NowTool { "now".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/open_tool.rs b/crates/assistant_tools/src/open_tool.rs index 97a4769e19e60758fe509fab56bf7329ac7f30b6..8fddbb0431aee7c8d9d7508c535b598887998225 100644 --- a/crates/assistant_tools/src/open_tool.rs +++ b/crates/assistant_tools/src/open_tool.rs @@ -23,7 +23,7 @@ impl Tool for OpenTool { "open".to_string() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { true } fn may_perform_edits(&self) -> bool { diff --git a/crates/assistant_tools/src/project_notifications_tool.rs b/crates/assistant_tools/src/project_notifications_tool.rs index 168ec61ae98529e1c82dcbe1d4334436457bab44..03487e5419002f0fe08458c49e325f7202612d29 100644 --- a/crates/assistant_tools/src/project_notifications_tool.rs +++ b/crates/assistant_tools/src/project_notifications_tool.rs @@ -6,8 +6,7 @@ use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchem use project::Project; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use std::fmt::Write as _; -use std::sync::Arc; +use std::{fmt::Write, sync::Arc}; use ui::IconName; #[derive(Debug, Serialize, Deserialize, JsonSchema)] @@ -20,7 +19,7 @@ impl Tool for ProjectNotificationsTool { "project_notifications".to_string() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } fn may_perform_edits(&self) -> bool { @@ -52,32 +51,105 @@ impl Tool for ProjectNotificationsTool { _window: Option, cx: &mut App, ) -> ToolResult { - let mut stale_files = String::new(); - let mut notified_buffers = Vec::new(); + let Some(user_edits_diff) = + action_log.update(cx, |log, cx| log.flush_unnotified_user_edits(cx)) + else { + return result("No new notifications"); + }; - for stale_file in action_log.read(cx).unnotified_stale_buffers(cx) { - if let Some(file) = stale_file.read(cx).file() { - writeln!(&mut stale_files, "- {}", file.path().display()).ok(); - notified_buffers.push(stale_file.clone()); + // NOTE: Changes to this prompt require a symmetric update in the LLM Worker + const HEADER: &str = include_str!("./project_notifications_tool/prompt_header.txt"); + const MAX_BYTES: usize = 8000; + let diff = fit_patch_to_size(&user_edits_diff, MAX_BYTES); + result(&format!("{HEADER}\n\n```diff\n{diff}\n```\n").replace("\r\n", "\n")) + } +} + +fn result(response: &str) -> ToolResult { + Task::ready(Ok(response.to_string().into())).into() +} + +/// Make sure that the patch fits into the size limit (in bytes). +/// Compress the patch by omitting some parts if needed. +/// Unified diff format is assumed. +fn fit_patch_to_size(patch: &str, max_size: usize) -> String { + if patch.len() <= max_size { + return patch.to_string(); + } + + // Compression level 1: remove context lines in diff bodies, but + // leave the counts and positions of inserted/deleted lines + let mut current_size = patch.len(); + let mut file_patches = split_patch(&patch); + file_patches.sort_by_key(|patch| patch.len()); + let compressed_patches = file_patches + .iter() + .rev() + .map(|patch| { + if current_size > max_size { + let compressed = compress_patch(patch).unwrap_or_else(|_| patch.to_string()); + current_size -= patch.len() - compressed.len(); + compressed + } else { + patch.to_string() } - } + }) + .collect::>(); - if !notified_buffers.is_empty() { - action_log.update(cx, |log, cx| { - log.mark_buffers_as_notified(notified_buffers, cx); - }); + if current_size <= max_size { + return compressed_patches.join("\n\n"); + } + + // Compression level 2: list paths of the changed files only + let filenames = file_patches + .iter() + .map(|patch| { + let patch = diffy::Patch::from_str(patch).unwrap(); + let path = patch + .modified() + .and_then(|path| path.strip_prefix("b/")) + .unwrap_or_default(); + format!("- {path}\n") + }) + .collect::>(); + + filenames.join("") +} + +/// Split a potentially multi-file patch into multiple single-file patches +fn split_patch(patch: &str) -> Vec { + let mut result = Vec::new(); + let mut current_patch = String::new(); + + for line in patch.lines() { + if line.starts_with("---") && !current_patch.is_empty() { + result.push(current_patch.trim_end_matches('\n').into()); + current_patch = String::new(); } + current_patch.push_str(line); + current_patch.push('\n'); + } - let response = if stale_files.is_empty() { - "No new notifications".to_string() - } else { - // NOTE: Changes to this prompt require a symmetric update in the LLM Worker - const HEADER: &str = include_str!("./project_notifications_tool/prompt_header.txt"); - format!("{HEADER}{stale_files}").replace("\r\n", "\n") - }; + if !current_patch.is_empty() { + result.push(current_patch.trim_end_matches('\n').into()); + } - Task::ready(Ok(response.into())).into() + result +} + +fn compress_patch(patch: &str) -> anyhow::Result { + let patch = diffy::Patch::from_str(patch)?; + let mut out = String::new(); + + writeln!(out, "--- {}", patch.original().unwrap_or("a"))?; + writeln!(out, "+++ {}", patch.modified().unwrap_or("b"))?; + + for hunk in patch.hunks() { + writeln!(out, "@@ -{} +{} @@", hunk.old_range(), hunk.new_range())?; + writeln!(out, "[...skipped...]")?; } + + Ok(out) } #[cfg(test)] @@ -85,6 +157,7 @@ mod tests { use super::*; use assistant_tool::ToolResultContent; use gpui::{AppContext, TestAppContext}; + use indoc::indoc; use language_model::{LanguageModelRequest, fake_provider::FakeLanguageModelProvider}; use project::{FakeFs, Project}; use serde_json::json; @@ -123,10 +196,11 @@ mod tests { action_log.update(cx, |log, cx| { log.buffer_read(buffer.clone(), cx); }); + cx.run_until_parked(); // Run the tool before any changes let tool = Arc::new(ProjectNotificationsTool); - let provider = Arc::new(FakeLanguageModelProvider); + let provider = Arc::new(FakeLanguageModelProvider::default()); let model: Arc = Arc::new(provider.test_model()); let request = Arc::new(LanguageModelRequest::default()); let tool_input = json!({}); @@ -142,6 +216,7 @@ mod tests { cx, ) }); + cx.run_until_parked(); let response = result.output.await.unwrap(); let response_text = match &response.content { @@ -158,6 +233,7 @@ mod tests { buffer.update(cx, |buffer, cx| { buffer.edit([(1..1, "\nChange!\n")], None, cx); }); + cx.run_until_parked(); // Run the tool again let result = cx.update(|cx| { @@ -171,6 +247,7 @@ mod tests { cx, ) }); + cx.run_until_parked(); // This time the buffer is stale, so the tool should return a notification let response = result.output.await.unwrap(); @@ -179,10 +256,12 @@ mod tests { _ => panic!("Expected text response"), }; - let expected_content = "[The following is an auto-generated notification; do not reply]\n\nThese files have changed since the last read:\n- code.rs\n"; - assert_eq!( - response_text.as_str(), - expected_content, + assert!( + response_text.contains("These files have changed"), + "Tool should return the stale buffer notification" + ); + assert!( + response_text.contains("test/code.rs"), "Tool should return the stale buffer notification" ); @@ -198,6 +277,7 @@ mod tests { cx, ) }); + cx.run_until_parked(); let response = result.output.await.unwrap(); let response_text = match &response.content { @@ -212,6 +292,61 @@ mod tests { ); } + #[test] + fn test_patch_compression() { + // Given a patch that doesn't fit into the size budget + let patch = indoc! {" + --- a/dir/test.txt + +++ b/dir/test.txt + @@ -1,3 +1,3 @@ + line 1 + -line 2 + +CHANGED + line 3 + @@ -10,2 +10,2 @@ + line 10 + -line 11 + +line eleven + + + --- a/dir/another.txt + +++ b/dir/another.txt + @@ -100,1 +1,1 @@ + -before + +after + "}; + + // When the size deficit can be compensated by dropping the body, + // then the body should be trimmed for larger files first + let limit = patch.len() - 10; + let compressed = fit_patch_to_size(patch, limit); + let expected = indoc! {" + --- a/dir/test.txt + +++ b/dir/test.txt + @@ -1,3 +1,3 @@ + [...skipped...] + @@ -10,2 +10,2 @@ + [...skipped...] + + + --- a/dir/another.txt + +++ b/dir/another.txt + @@ -100,1 +1,1 @@ + -before + +after"}; + assert_eq!(compressed, expected); + + // When the size deficit is too large, then only file paths + // should be returned + let limit = 10; + let compressed = fit_patch_to_size(patch, limit); + let expected = indoc! {" + - dir/another.txt + - dir/test.txt + "}; + assert_eq!(compressed, expected); + } + fn init_test(cx: &mut TestAppContext) { cx.update(|cx| { let settings_store = SettingsStore::test(cx); diff --git a/crates/assistant_tools/src/read_file_tool.rs b/crates/assistant_tools/src/read_file_tool.rs index 5b9546981734abdba865346896750348c9c9515c..ee38273cc04338180d36eb1b64e78dd0235ccfa0 100644 --- a/crates/assistant_tools/src/read_file_tool.rs +++ b/crates/assistant_tools/src/read_file_tool.rs @@ -18,7 +18,6 @@ use serde::{Deserialize, Serialize}; use settings::Settings; use std::sync::Arc; use ui::IconName; -use util::markdown::MarkdownInlineCode; /// If the model requests to read a file whose size exceeds this, then #[derive(Debug, Serialize, Deserialize, JsonSchema)] @@ -55,7 +54,7 @@ impl Tool for ReadFileTool { "read_file".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } @@ -78,11 +77,21 @@ impl Tool for ReadFileTool { fn ui_text(&self, input: &serde_json::Value) -> String { match serde_json::from_value::(input.clone()) { Ok(input) => { - let path = MarkdownInlineCode(&input.path); + let path = &input.path; match (input.start_line, input.end_line) { - (Some(start), None) => format!("Read file {path} (from line {start})"), - (Some(start), Some(end)) => format!("Read file {path} (lines {start}-{end})"), - _ => format!("Read file {path}"), + (Some(start), Some(end)) => { + format!( + "[Read file `{}` (lines {}-{})](@selection:{}:({}-{}))", + path, start, end, path, start, end + ) + } + (Some(start), None) => { + format!( + "[Read file `{}` (from line {})](@selection:{}:({}-{}))", + path, start, path, start, start + ) + } + _ => format!("[Read file `{}`](@file:{})", path, path), } } Err(_) => "Read file".to_string(), @@ -276,7 +285,10 @@ impl Tool for ReadFileTool { Using the line numbers in this outline, you can call this tool again while specifying the start_line and end_line fields to see the - implementations of symbols in the outline." + implementations of symbols in the outline. + + Alternatively, you can fall back to the `grep` tool (if available) + to search the file for specific content." } .into()) } diff --git a/crates/assistant_tools/src/terminal_tool.rs b/crates/assistant_tools/src/terminal_tool.rs index 03e76f6a5b657a706c2337087984757b62d0ab84..58833c520848aa3bb345db0af7cccf70429784ea 100644 --- a/crates/assistant_tools/src/terminal_tool.rs +++ b/crates/assistant_tools/src/terminal_tool.rs @@ -77,7 +77,7 @@ impl Tool for TerminalTool { Self::NAME.to_string() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { true } diff --git a/crates/assistant_tools/src/thinking_tool.rs b/crates/assistant_tools/src/thinking_tool.rs index 422204f97d46a487032534a846fce455c5bdc0b3..76c6e6c0bad745609386421ea7d720f9d6fefa07 100644 --- a/crates/assistant_tools/src/thinking_tool.rs +++ b/crates/assistant_tools/src/thinking_tool.rs @@ -24,7 +24,7 @@ impl Tool for ThinkingTool { "thinking".to_string() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } @@ -37,7 +37,7 @@ impl Tool for ThinkingTool { } fn icon(&self) -> IconName { - IconName::ToolBulb + IconName::ToolThink } fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { diff --git a/crates/assistant_tools/src/web_search_tool.rs b/crates/assistant_tools/src/web_search_tool.rs index 24bc8e9cba36d09a301a5a398e268ff530bdd072..d4a12f22c56796c5dfbb2f1935e2ad7646ce7c5a 100644 --- a/crates/assistant_tools/src/web_search_tool.rs +++ b/crates/assistant_tools/src/web_search_tool.rs @@ -6,6 +6,7 @@ use anyhow::{Context as _, Result, anyhow}; use assistant_tool::{ ActionLog, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus, }; +use cloud_llm_client::{WebSearchResponse, WebSearchResult}; use futures::{Future, FutureExt, TryFutureExt}; use gpui::{ AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window, @@ -17,7 +18,6 @@ use serde::{Deserialize, Serialize}; use ui::{IconName, Tooltip, prelude::*}; use web_search::WebSearchRegistry; use workspace::Workspace; -use zed_llm_client::{WebSearchResponse, WebSearchResult}; #[derive(Debug, Serialize, Deserialize, JsonSchema)] pub struct WebSearchToolInput { @@ -32,7 +32,7 @@ impl Tool for WebSearchTool { "web_search".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/audio/Cargo.toml b/crates/audio/Cargo.toml index 960aaf8e08d864f7bf3b1883951d0f7d22ad56ed..d857a3eb2f6c112b9d6a9851715718047f72ccbf 100644 --- a/crates/audio/Cargo.toml +++ b/crates/audio/Cargo.toml @@ -18,6 +18,6 @@ collections.workspace = true derive_more.workspace = true gpui.workspace = true parking_lot.workspace = true -rodio = { version = "0.20.0", default-features = false, features = ["wav"] } +rodio = { version = "0.21.1", default-features = false, features = ["wav", "playback", "tracing"] } util.workspace = true workspace-hack.workspace = true diff --git a/crates/audio/src/assets.rs b/crates/audio/src/assets.rs index 02da79dc24f067795b6636fc7fa031bce95cf935..fd5c935d875960f4fd9bf30494301f4811b22448 100644 --- a/crates/audio/src/assets.rs +++ b/crates/audio/src/assets.rs @@ -3,12 +3,9 @@ use std::{io::Cursor, sync::Arc}; use anyhow::{Context as _, Result}; use collections::HashMap; use gpui::{App, AssetSource, Global}; -use rodio::{ - Decoder, Source, - source::{Buffered, SamplesConverter}, -}; +use rodio::{Decoder, Source, source::Buffered}; -type Sound = Buffered>>, f32>>; +type Sound = Buffered>>>; pub struct SoundRegistry { cache: Arc>>, @@ -48,7 +45,7 @@ impl SoundRegistry { .with_context(|| format!("No asset available for path {path}"))?? .into_owned(); let cursor = Cursor::new(bytes); - let source = Decoder::new(cursor)?.convert_samples::().buffered(); + let source = Decoder::new(cursor)?.buffered(); self.cache.lock().insert(name.to_string(), source.clone()); diff --git a/crates/audio/src/audio.rs b/crates/audio/src/audio.rs index e7b9a59e8f281e9fb19481b118990b07c439448f..44baa16aa20a3e4b7651744974cfc085dcde7fb1 100644 --- a/crates/audio/src/audio.rs +++ b/crates/audio/src/audio.rs @@ -1,7 +1,7 @@ use assets::SoundRegistry; use derive_more::{Deref, DerefMut}; use gpui::{App, AssetSource, BorrowAppContext, Global}; -use rodio::{OutputStream, OutputStreamHandle}; +use rodio::{OutputStream, OutputStreamBuilder}; use util::ResultExt; mod assets; @@ -37,8 +37,7 @@ impl Sound { #[derive(Default)] pub struct Audio { - _output_stream: Option, - output_handle: Option, + output_handle: Option, } #[derive(Deref, DerefMut)] @@ -51,11 +50,9 @@ impl Audio { Self::default() } - fn ensure_output_exists(&mut self) -> Option<&OutputStreamHandle> { + fn ensure_output_exists(&mut self) -> Option<&OutputStream> { if self.output_handle.is_none() { - let (_output_stream, output_handle) = OutputStream::try_default().log_err().unzip(); - self.output_handle = output_handle; - self._output_stream = _output_stream; + self.output_handle = OutputStreamBuilder::open_default_stream().log_err(); } self.output_handle.as_ref() @@ -69,7 +66,7 @@ impl Audio { cx.update_global::(|this, cx| { let output_handle = this.ensure_output_exists()?; let source = SoundRegistry::global(cx).get(sound.file()).log_err()?; - output_handle.play_raw(source).log_err()?; + output_handle.mixer().add(source); Some(()) }); } @@ -80,7 +77,6 @@ impl Audio { } cx.update_global::(|this, _| { - this._output_stream.take(); this.output_handle.take(); }); } diff --git a/crates/auto_update/src/auto_update.rs b/crates/auto_update/src/auto_update.rs index d62a9cdbe330964759fa5362689349c28cd2b713..074aaa6fea7033134bd0cc35d87bc0951e25b663 100644 --- a/crates/auto_update/src/auto_update.rs +++ b/crates/auto_update/src/auto_update.rs @@ -134,10 +134,15 @@ impl Settings for AutoUpdateSetting { type FileContent = Option; fn load(sources: SettingsSources, _: &mut App) -> Result { - let auto_update = [sources.server, sources.release_channel, sources.user] - .into_iter() - .find_map(|value| value.copied().flatten()) - .unwrap_or(sources.default.ok_or_else(Self::missing_default)?); + let auto_update = [ + sources.server, + sources.release_channel, + sources.operating_system, + sources.user, + ] + .into_iter() + .find_map(|value| value.copied().flatten()) + .unwrap_or(sources.default.ok_or_else(Self::missing_default)?); Ok(Self(auto_update.0)) } diff --git a/crates/aws_http_client/Cargo.toml b/crates/aws_http_client/Cargo.toml index 3760f70fe02973fb9f83ea401f4f9807309cf2d8..2749286d4c1361d9dbdb50d6566e3b4043f97b2e 100644 --- a/crates/aws_http_client/Cargo.toml +++ b/crates/aws_http_client/Cargo.toml @@ -17,7 +17,5 @@ default = [] [dependencies] aws-smithy-runtime-api.workspace = true aws-smithy-types.workspace = true -futures.workspace = true http_client.workspace = true -tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } workspace-hack.workspace = true diff --git a/crates/aws_http_client/src/aws_http_client.rs b/crates/aws_http_client/src/aws_http_client.rs index 6adb995747317c33a3e3a177d0240b45ce03b8f9..d08c8e64a792a06126ecdd8e3833a87bb0ace7ab 100644 --- a/crates/aws_http_client/src/aws_http_client.rs +++ b/crates/aws_http_client/src/aws_http_client.rs @@ -11,14 +11,11 @@ use aws_smithy_runtime_api::client::result::ConnectorError; use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents; use aws_smithy_runtime_api::http::{Headers, StatusCode}; use aws_smithy_types::body::SdkBody; -use futures::AsyncReadExt; -use http_client::{AsyncBody, Inner}; +use http_client::AsyncBody; use http_client::{HttpClient, Request}; -use tokio::runtime::Handle; struct AwsHttpConnector { client: Arc, - handle: Handle, } impl std::fmt::Debug for AwsHttpConnector { @@ -42,18 +39,17 @@ impl AwsConnector for AwsHttpConnector { .client .send(Request::from_parts(parts, convert_to_async_body(body))); - let handle = self.handle.clone(); - HttpConnectorFuture::new(async move { let response = match response.await { Ok(response) => response, Err(err) => return Err(ConnectorError::other(err.into(), None)), }; let (parts, body) = response.into_parts(); - let body = convert_to_sdk_body(body, handle).await; - let mut response = - HttpResponse::new(StatusCode::try_from(parts.status.as_u16()).unwrap(), body); + let mut response = HttpResponse::new( + StatusCode::try_from(parts.status.as_u16()).unwrap(), + convert_to_sdk_body(body), + ); let headers = match Headers::try_from(parts.headers) { Ok(headers) => headers, @@ -70,7 +66,6 @@ impl AwsConnector for AwsHttpConnector { #[derive(Clone)] pub struct AwsHttpClient { client: Arc, - handler: Handle, } impl std::fmt::Debug for AwsHttpClient { @@ -80,11 +75,8 @@ impl std::fmt::Debug for AwsHttpClient { } impl AwsHttpClient { - pub fn new(client: Arc, handle: Handle) -> Self { - Self { - client, - handler: handle, - } + pub fn new(client: Arc) -> Self { + Self { client } } } @@ -96,25 +88,12 @@ impl AwsClient for AwsHttpClient { ) -> SharedHttpConnector { SharedHttpConnector::new(AwsHttpConnector { client: self.client.clone(), - handle: self.handler.clone(), }) } } -pub async fn convert_to_sdk_body(body: AsyncBody, handle: Handle) -> SdkBody { - match body.0 { - Inner::Empty => SdkBody::empty(), - Inner::Bytes(bytes) => SdkBody::from(bytes.into_inner()), - Inner::AsyncReader(mut reader) => { - let buffer = handle.spawn(async move { - let mut buffer = Vec::new(); - let _ = reader.read_to_end(&mut buffer).await; - buffer - }); - - SdkBody::from(buffer.await.unwrap_or_default()) - } - } +pub fn convert_to_sdk_body(body: AsyncBody) -> SdkBody { + SdkBody::from_body_1_x(body) } pub fn convert_to_async_body(body: SdkBody) -> AsyncBody { diff --git a/crates/bedrock/src/models.rs b/crates/bedrock/src/models.rs index b6eeafa2d6b273a8cc3f0c6cc7a18ea0589c4ba2..69d2ffb84569ef848f88de47f5394a6b25b18e02 100644 --- a/crates/bedrock/src/models.rs +++ b/crates/bedrock/src/models.rs @@ -32,11 +32,18 @@ pub enum Model { ClaudeSonnet4Thinking, #[serde(rename = "claude-opus-4", alias = "claude-opus-4-latest")] ClaudeOpus4, + #[serde(rename = "claude-opus-4-1", alias = "claude-opus-4-1-latest")] + ClaudeOpus4_1, #[serde( rename = "claude-opus-4-thinking", alias = "claude-opus-4-thinking-latest" )] ClaudeOpus4Thinking, + #[serde( + rename = "claude-opus-4-1-thinking", + alias = "claude-opus-4-1-thinking-latest" + )] + ClaudeOpus4_1Thinking, #[serde(rename = "claude-3-5-sonnet-v2", alias = "claude-3-5-sonnet-latest")] Claude3_5SonnetV2, #[serde(rename = "claude-3-7-sonnet", alias = "claude-3-7-sonnet-latest")] @@ -147,7 +154,9 @@ impl Model { Model::ClaudeSonnet4 => "claude-4-sonnet", Model::ClaudeSonnet4Thinking => "claude-4-sonnet-thinking", Model::ClaudeOpus4 => "claude-4-opus", + Model::ClaudeOpus4_1 => "claude-4-opus-1", Model::ClaudeOpus4Thinking => "claude-4-opus-thinking", + Model::ClaudeOpus4_1Thinking => "claude-4-opus-1-thinking", Model::Claude3_5SonnetV2 => "claude-3-5-sonnet-v2", Model::Claude3_5Sonnet => "claude-3-5-sonnet", Model::Claude3Opus => "claude-3-opus", @@ -208,6 +217,9 @@ impl Model { Model::ClaudeOpus4 | Model::ClaudeOpus4Thinking => { "anthropic.claude-opus-4-20250514-v1:0" } + Model::ClaudeOpus4_1 | Model::ClaudeOpus4_1Thinking => { + "anthropic.claude-opus-4-1-20250805-v1:0" + } Model::Claude3_5SonnetV2 => "anthropic.claude-3-5-sonnet-20241022-v2:0", Model::Claude3_5Sonnet => "anthropic.claude-3-5-sonnet-20240620-v1:0", Model::Claude3Opus => "anthropic.claude-3-opus-20240229-v1:0", @@ -266,7 +278,9 @@ impl Model { Self::ClaudeSonnet4 => "Claude Sonnet 4", Self::ClaudeSonnet4Thinking => "Claude Sonnet 4 Thinking", Self::ClaudeOpus4 => "Claude Opus 4", + Self::ClaudeOpus4_1 => "Claude Opus 4.1", Self::ClaudeOpus4Thinking => "Claude Opus 4 Thinking", + Self::ClaudeOpus4_1Thinking => "Claude Opus 4.1 Thinking", Self::Claude3_5SonnetV2 => "Claude 3.5 Sonnet v2", Self::Claude3_5Sonnet => "Claude 3.5 Sonnet", Self::Claude3Opus => "Claude 3 Opus", @@ -330,8 +344,10 @@ impl Model { | Self::Claude3_7Sonnet | Self::ClaudeSonnet4 | Self::ClaudeOpus4 + | Self::ClaudeOpus4_1 | Self::ClaudeSonnet4Thinking - | Self::ClaudeOpus4Thinking => 200_000, + | Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1Thinking => 200_000, Self::AmazonNovaPremier => 1_000_000, Self::PalmyraWriterX5 => 1_000_000, Self::PalmyraWriterX4 => 128_000, @@ -348,7 +364,9 @@ impl Model { | Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking | Self::ClaudeOpus4 - | Model::ClaudeOpus4Thinking => 128_000, + | Model::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1 + | Model::ClaudeOpus4_1Thinking => 128_000, Self::Claude3_5SonnetV2 | Self::PalmyraWriterX4 | Self::PalmyraWriterX5 => 8_192, Self::Custom { max_output_tokens, .. @@ -366,6 +384,8 @@ impl Model { | Self::Claude3_7Sonnet | Self::ClaudeOpus4 | Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1 + | Self::ClaudeOpus4_1Thinking | Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking => 1.0, Self::Custom { @@ -387,6 +407,8 @@ impl Model { | Self::Claude3_7SonnetThinking | Self::ClaudeOpus4 | Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1 + | Self::ClaudeOpus4_1Thinking | Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking | Self::Claude3_5Haiku => true, @@ -420,7 +442,9 @@ impl Model { | Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking | Self::ClaudeOpus4 - | Self::ClaudeOpus4Thinking => true, + | Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1 + | Self::ClaudeOpus4_1Thinking => true, // Custom models - check if they have cache configuration Self::Custom { @@ -440,7 +464,9 @@ impl Model { | Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking | Self::ClaudeOpus4 - | Self::ClaudeOpus4Thinking => Some(BedrockModelCacheConfiguration { + | Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1 + | Self::ClaudeOpus4_1Thinking => Some(BedrockModelCacheConfiguration { max_cache_anchors: 4, min_total_token: 1024, }), @@ -467,9 +493,11 @@ impl Model { Model::ClaudeSonnet4Thinking => BedrockModelMode::Thinking { budget_tokens: Some(4096), }, - Model::ClaudeOpus4Thinking => BedrockModelMode::Thinking { - budget_tokens: Some(4096), - }, + Model::ClaudeOpus4Thinking | Model::ClaudeOpus4_1Thinking => { + BedrockModelMode::Thinking { + budget_tokens: Some(4096), + } + } _ => BedrockModelMode::Default, } } @@ -518,6 +546,8 @@ impl Model { | Model::ClaudeSonnet4Thinking | Model::ClaudeOpus4 | Model::ClaudeOpus4Thinking + | Model::ClaudeOpus4_1 + | Model::ClaudeOpus4_1Thinking | Model::Claude3Haiku | Model::Claude3Opus | Model::Claude3Sonnet diff --git a/crates/buffer_diff/src/buffer_diff.rs b/crates/buffer_diff/src/buffer_diff.rs index ee09fda46e008c903120eb0430ff18fae57dc3da..97f529fe377c0eaa3d74dca1600e1b3f0c3499db 100644 --- a/crates/buffer_diff/src/buffer_diff.rs +++ b/crates/buffer_diff/src/buffer_diff.rs @@ -343,8 +343,7 @@ impl BufferDiffInner { .. } in hunks.iter().cloned() { - let preceding_pending_hunks = - old_pending_hunks.slice(&buffer_range.start, Bias::Left, buffer); + let preceding_pending_hunks = old_pending_hunks.slice(&buffer_range.start, Bias::Left); pending_hunks.append(preceding_pending_hunks, buffer); // Skip all overlapping or adjacent old pending hunks @@ -355,7 +354,7 @@ impl BufferDiffInner { .cmp(&buffer_range.end, buffer) .is_le() }) { - old_pending_hunks.next(buffer); + old_pending_hunks.next(); } if (stage && secondary_status == DiffHunkSecondaryStatus::NoSecondaryHunk) @@ -379,10 +378,10 @@ impl BufferDiffInner { ); } // append the remainder - pending_hunks.append(old_pending_hunks.suffix(buffer), buffer); + pending_hunks.append(old_pending_hunks.suffix(), buffer); let mut unstaged_hunk_cursor = unstaged_diff.hunks.cursor::(buffer); - unstaged_hunk_cursor.next(buffer); + unstaged_hunk_cursor.next(); // then, iterate over all pending hunks (both new ones and the existing ones) and compute the edits let mut prev_unstaged_hunk_buffer_end = 0; @@ -397,8 +396,7 @@ impl BufferDiffInner { }) = pending_hunks_iter.next() { // Advance unstaged_hunk_cursor to skip unstaged hunks before current hunk - let skipped_unstaged = - unstaged_hunk_cursor.slice(&buffer_range.start, Bias::Left, buffer); + let skipped_unstaged = unstaged_hunk_cursor.slice(&buffer_range.start, Bias::Left); if let Some(unstaged_hunk) = skipped_unstaged.last() { prev_unstaged_hunk_base_text_end = unstaged_hunk.diff_base_byte_range.end; @@ -425,7 +423,7 @@ impl BufferDiffInner { buffer_offset_range.end = buffer_offset_range.end.max(unstaged_hunk_offset_range.end); - unstaged_hunk_cursor.next(buffer); + unstaged_hunk_cursor.next(); continue; } } @@ -514,7 +512,7 @@ impl BufferDiffInner { }); let anchor_iter = iter::from_fn(move || { - cursor.next(buffer); + cursor.next(); cursor.item() }) .flat_map(move |hunk| { @@ -531,12 +529,12 @@ impl BufferDiffInner { }); let mut pending_hunks_cursor = self.pending_hunks.cursor::(buffer); - pending_hunks_cursor.next(buffer); + pending_hunks_cursor.next(); let mut secondary_cursor = None; if let Some(secondary) = secondary.as_ref() { let mut cursor = secondary.hunks.cursor::(buffer); - cursor.next(buffer); + cursor.next(); secondary_cursor = Some(cursor); } @@ -564,7 +562,7 @@ impl BufferDiffInner { .cmp(&pending_hunks_cursor.start().buffer_range.start, buffer) .is_gt() { - pending_hunks_cursor.seek_forward(&start_anchor, Bias::Left, buffer); + pending_hunks_cursor.seek_forward(&start_anchor, Bias::Left); } if let Some(pending_hunk) = pending_hunks_cursor.item() { @@ -590,7 +588,7 @@ impl BufferDiffInner { .cmp(&secondary_cursor.start().buffer_range.start, buffer) .is_gt() { - secondary_cursor.seek_forward(&start_anchor, Bias::Left, buffer); + secondary_cursor.seek_forward(&start_anchor, Bias::Left); } if let Some(secondary_hunk) = secondary_cursor.item() { @@ -635,7 +633,7 @@ impl BufferDiffInner { }); iter::from_fn(move || { - cursor.prev(buffer); + cursor.prev(); let hunk = cursor.item()?; let range = hunk.buffer_range.to_point(buffer); @@ -653,8 +651,8 @@ impl BufferDiffInner { fn compare(&self, old: &Self, new_snapshot: &text::BufferSnapshot) -> Option> { let mut new_cursor = self.hunks.cursor::<()>(new_snapshot); let mut old_cursor = old.hunks.cursor::<()>(new_snapshot); - old_cursor.next(new_snapshot); - new_cursor.next(new_snapshot); + old_cursor.next(); + new_cursor.next(); let mut start = None; let mut end = None; @@ -669,7 +667,7 @@ impl BufferDiffInner { Ordering::Less => { start.get_or_insert(new_hunk.buffer_range.start); end.replace(new_hunk.buffer_range.end); - new_cursor.next(new_snapshot); + new_cursor.next(); } Ordering::Equal => { if new_hunk != old_hunk { @@ -686,25 +684,25 @@ impl BufferDiffInner { } } - new_cursor.next(new_snapshot); - old_cursor.next(new_snapshot); + new_cursor.next(); + old_cursor.next(); } Ordering::Greater => { start.get_or_insert(old_hunk.buffer_range.start); end.replace(old_hunk.buffer_range.end); - old_cursor.next(new_snapshot); + old_cursor.next(); } } } (Some(new_hunk), None) => { start.get_or_insert(new_hunk.buffer_range.start); end.replace(new_hunk.buffer_range.end); - new_cursor.next(new_snapshot); + new_cursor.next(); } (None, Some(old_hunk)) => { start.get_or_insert(old_hunk.buffer_range.start); end.replace(old_hunk.buffer_range.end); - old_cursor.next(new_snapshot); + old_cursor.next(); } (None, None) => break, } diff --git a/crates/call/src/call_impl/room.rs b/crates/call/src/call_impl/room.rs index 31ca144cf8a61946318dc518e7ffee29b4c06d6f..afeee4c924feb2990668f953d5b2f7dfcff26f34 100644 --- a/crates/call/src/call_impl/room.rs +++ b/crates/call/src/call_impl/room.rs @@ -11,15 +11,18 @@ use client::{ use collections::{BTreeMap, HashMap, HashSet}; use fs::Fs; use futures::{FutureExt, StreamExt}; -use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity}; +use gpui::{ + App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, ScreenCaptureSource, + ScreenCaptureStream, Task, WeakEntity, +}; use gpui_tokio::Tokio; use language::LanguageRegistry; use livekit::{LocalTrackPublication, ParticipantIdentity, RoomEvent}; -use livekit_client::{self as livekit, TrackSid}; +use livekit_client::{self as livekit, AudioStream, TrackSid}; use postage::{sink::Sink, stream::Stream, watch}; use project::Project; use settings::Settings as _; -use std::{any::Any, future::Future, mem, rc::Rc, sync::Arc, time::Duration}; +use std::{future::Future, mem, rc::Rc, sync::Arc, time::Duration}; use util::{ResultExt, TryFutureExt, post_inc}; pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30); @@ -1251,12 +1254,21 @@ impl Room { }) } - pub fn is_screen_sharing(&self) -> bool { + pub fn is_sharing_screen(&self) -> bool { self.live_kit.as_ref().map_or(false, |live_kit| { !matches!(live_kit.screen_track, LocalTrack::None) }) } + pub fn shared_screen_id(&self) -> Option { + self.live_kit.as_ref().and_then(|lk| match lk.screen_track { + LocalTrack::Published { ref _stream, .. } => { + _stream.metadata().ok().map(|meta| meta.id) + } + _ => None, + }) + } + pub fn is_sharing_mic(&self) -> bool { self.live_kit.as_ref().map_or(false, |live_kit| { !matches!(live_kit.microphone_track, LocalTrack::None) @@ -1369,11 +1381,15 @@ impl Room { }) } - pub fn share_screen(&mut self, cx: &mut Context) -> Task> { + pub fn share_screen( + &mut self, + source: Rc, + cx: &mut Context, + ) -> Task> { if self.status.is_offline() { return Task::ready(Err(anyhow!("room is offline"))); } - if self.is_screen_sharing() { + if self.is_sharing_screen() { return Task::ready(Err(anyhow!("screen was already shared"))); } @@ -1386,13 +1402,8 @@ impl Room { return Task::ready(Err(anyhow!("live-kit was not initialized"))); }; - let sources = cx.screen_capture_sources(); - cx.spawn(async move |this, cx| { - let sources = sources.await??; - let source = sources.first().context("no display found")?; - - let publication = participant.publish_screenshare_track(&**source, cx).await; + let publication = participant.publish_screenshare_track(&*source, cx).await; this.update(cx, |this, cx| { let live_kit = this @@ -1419,7 +1430,7 @@ impl Room { } else { live_kit.screen_track = LocalTrack::Published { track_publication: publication, - _stream: Box::new(stream), + _stream: stream, }; cx.notify(); } @@ -1485,7 +1496,7 @@ impl Room { } } - pub fn unshare_screen(&mut self, cx: &mut Context) -> Result<()> { + pub fn unshare_screen(&mut self, play_sound: bool, cx: &mut Context) -> Result<()> { anyhow::ensure!(!self.status.is_offline(), "room is offline"); let live_kit = self @@ -1509,7 +1520,10 @@ impl Room { cx.notify(); } - Audio::play_sound(Sound::StopScreenshare, cx); + if play_sound { + Audio::play_sound(Sound::StopScreenshare, cx); + } + Ok(()) } } @@ -1617,8 +1631,8 @@ fn spawn_room_connection( struct LiveKitRoom { room: Rc, - screen_track: LocalTrack, - microphone_track: LocalTrack, + screen_track: LocalTrack, + microphone_track: LocalTrack, /// Tracks whether we're currently in a muted state due to auto-mute from deafening or manual mute performed by user. muted_by_user: bool, deafened: bool, @@ -1656,18 +1670,18 @@ impl LiveKitRoom { } } -enum LocalTrack { +enum LocalTrack { None, Pending { publish_id: usize, }, Published { track_publication: LocalTrackPublication, - _stream: Box, + _stream: Box, }, } -impl Default for LocalTrack { +impl Default for LocalTrack { fn default() -> Self { Self::None } diff --git a/crates/channel/src/channel_chat.rs b/crates/channel/src/channel_chat.rs index 8394972d43754e07d0f197a315a4e17879aa17fe..4ac37ffd14ca2602756afecc788aae9f6065cad9 100644 --- a/crates/channel/src/channel_chat.rs +++ b/crates/channel/src/channel_chat.rs @@ -13,7 +13,7 @@ use std::{ ops::{ControlFlow, Range}, sync::Arc, }; -use sum_tree::{Bias, SumTree}; +use sum_tree::{Bias, Dimensions, SumTree}; use time::OffsetDateTime; use util::{ResultExt as _, TryFutureExt, post_inc}; @@ -331,9 +331,11 @@ impl ChannelChat { .update(&mut cx, |chat, cx| { if let Some(first_id) = chat.first_loaded_message_id() { if first_id <= message_id { - let mut cursor = chat.messages.cursor::<(ChannelMessageId, Count)>(&()); + let mut cursor = chat + .messages + .cursor::>(&()); let message_id = ChannelMessageId::Saved(message_id); - cursor.seek(&message_id, Bias::Left, &()); + cursor.seek(&message_id, Bias::Left); return ControlFlow::Break( if cursor .item() @@ -499,7 +501,7 @@ impl ChannelChat { pub fn message(&self, ix: usize) -> &ChannelMessage { let mut cursor = self.messages.cursor::(&()); - cursor.seek(&Count(ix), Bias::Right, &()); + cursor.seek(&Count(ix), Bias::Right); cursor.item().unwrap() } @@ -516,13 +518,13 @@ impl ChannelChat { pub fn messages_in_range(&self, range: Range) -> impl Iterator { let mut cursor = self.messages.cursor::(&()); - cursor.seek(&Count(range.start), Bias::Right, &()); + cursor.seek(&Count(range.start), Bias::Right); cursor.take(range.len()) } pub fn pending_messages(&self) -> impl Iterator { let mut cursor = self.messages.cursor::(&()); - cursor.seek(&ChannelMessageId::Pending(0), Bias::Left, &()); + cursor.seek(&ChannelMessageId::Pending(0), Bias::Left); cursor } @@ -587,10 +589,12 @@ impl ChannelChat { .map(|m| m.nonce) .collect::>(); - let mut old_cursor = self.messages.cursor::<(ChannelMessageId, Count)>(&()); - let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &()); + let mut old_cursor = self + .messages + .cursor::>(&()); + let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left); let start_ix = old_cursor.start().1.0; - let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &()); + let removed_messages = old_cursor.slice(&last_message.id, Bias::Right); let removed_count = removed_messages.summary().count; let new_count = messages.summary().count; let end_ix = start_ix + removed_count; @@ -599,10 +603,10 @@ impl ChannelChat { let mut ranges = Vec::>::new(); if new_messages.last().unwrap().is_pending() { - new_messages.append(old_cursor.suffix(&()), &()); + new_messages.append(old_cursor.suffix(), &()); } else { new_messages.append( - old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left, &()), + old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left), &(), ); @@ -617,7 +621,7 @@ impl ChannelChat { } else { new_messages.push(message.clone(), &()); } - old_cursor.next(&()); + old_cursor.next(); } } @@ -641,12 +645,12 @@ impl ChannelChat { fn message_removed(&mut self, id: u64, cx: &mut Context) { let mut cursor = self.messages.cursor::(&()); - let mut messages = cursor.slice(&ChannelMessageId::Saved(id), Bias::Left, &()); + let mut messages = cursor.slice(&ChannelMessageId::Saved(id), Bias::Left); if let Some(item) = cursor.item() { if item.id == ChannelMessageId::Saved(id) { let deleted_message_ix = messages.summary().count; - cursor.next(&()); - messages.append(cursor.suffix(&()), &()); + cursor.next(); + messages.append(cursor.suffix(), &()); drop(cursor); self.messages = messages; @@ -680,7 +684,7 @@ impl ChannelChat { cx: &mut Context, ) { let mut cursor = self.messages.cursor::(&()); - let mut messages = cursor.slice(&id, Bias::Left, &()); + let mut messages = cursor.slice(&id, Bias::Left); let ix = messages.summary().count; if let Some(mut message_to_update) = cursor.item().cloned() { @@ -688,10 +692,10 @@ impl ChannelChat { message_to_update.mentions = mentions; message_to_update.edited_at = edited_at; messages.push(message_to_update, &()); - cursor.next(&()); + cursor.next(); } - messages.append(cursor.suffix(&()), &()); + messages.append(cursor.suffix(), &()); drop(cursor); self.messages = messages; diff --git a/crates/channel/src/channel_store.rs b/crates/channel/src/channel_store.rs index b7ba811421d63be6d288606eb2e7d2fa1199f983..4ad156b9fb08e8af95e5ea49132c4c4786e348a1 100644 --- a/crates/channel/src/channel_store.rs +++ b/crates/channel/src/channel_store.rs @@ -126,7 +126,7 @@ impl ChannelMembership { proto::channel_member::Kind::Member => 0, proto::channel_member::Kind::Invitee => 1, }, - username_order: self.user.github_login.as_str(), + username_order: &self.user.github_login, } } } diff --git a/crates/channel/src/channel_store_tests.rs b/crates/channel/src/channel_store_tests.rs index f8f5de3c39d2414be59b42aa72b0817bd7878584..c92226eeebd131170b0a5b04e4ed7f42c19a64fc 100644 --- a/crates/channel/src/channel_store_tests.rs +++ b/crates/channel/src/channel_store_tests.rs @@ -259,20 +259,6 @@ async fn test_channel_messages(cx: &mut TestAppContext) { assert_channels(&channel_store, &[(0, "the-channel".to_string())], cx); }); - let get_users = server.receive::().await.unwrap(); - assert_eq!(get_users.payload.user_ids, vec![5]); - server.respond( - get_users.receipt(), - proto::UsersResponse { - users: vec![proto::User { - id: 5, - github_login: "nathansobo".into(), - avatar_url: "http://avatar.com/nathansobo".into(), - name: None, - }], - }, - ); - // Join a channel and populate its existing messages. let channel = channel_store.update(cx, |store, cx| { let channel_id = store.ordered_channels().next().unwrap().1.id; @@ -334,7 +320,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { .map(|message| (message.sender.github_login.clone(), message.body.clone())) .collect::>(), &[ - ("nathansobo".into(), "a".into()), + ("user-5".into(), "a".into()), ("maxbrunsfeld".into(), "b".into()) ] ); @@ -437,7 +423,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { .map(|message| (message.sender.github_login.clone(), message.body.clone())) .collect::>(), &[ - ("nathansobo".into(), "y".into()), + ("user-5".into(), "y".into()), ("maxbrunsfeld".into(), "z".into()) ] ); diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index d6ddf79ea6cafd53efcad1d9979afbc47f3d6083..287c62b753f1ce875ca38a9f2caa62b906e6ee27 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -315,19 +315,19 @@ fn main() -> Result<()> { }); let stdin_pipe_handle: Option>> = - stdin_tmp_file.map(|tmp_file| { + stdin_tmp_file.map(|mut tmp_file| { thread::spawn(move || { - let stdin = std::io::stdin().lock(); - if io::IsTerminal::is_terminal(&stdin) { - return Ok(()); + let mut stdin = std::io::stdin().lock(); + if !io::IsTerminal::is_terminal(&stdin) { + io::copy(&mut stdin, &mut tmp_file)?; } - return pipe_to_tmp(stdin, tmp_file); + Ok(()) }) }); - let anonymous_fd_pipe_handles: Vec>> = anonymous_fd_tmp_files + let anonymous_fd_pipe_handles: Vec<_> = anonymous_fd_tmp_files .into_iter() - .map(|(file, tmp_file)| thread::spawn(move || pipe_to_tmp(file, tmp_file))) + .map(|(mut file, mut tmp_file)| thread::spawn(move || io::copy(&mut file, &mut tmp_file))) .collect(); if args.foreground { @@ -349,22 +349,6 @@ fn main() -> Result<()> { Ok(()) } -fn pipe_to_tmp(mut src: impl io::Read, mut dest: fs::File) -> Result<()> { - let mut buffer = [0; 8 * 1024]; - loop { - let bytes_read = match src.read(&mut buffer) { - Err(err) if err.kind() == io::ErrorKind::Interrupted => continue, - res => res?, - }; - if bytes_read == 0 { - break; - } - io::Write::write_all(&mut dest, &buffer[..bytes_read])?; - } - io::Write::flush(&mut dest)?; - Ok(()) -} - fn anonymous_fd(path: &str) -> Option { #[cfg(target_os = "linux")] { diff --git a/crates/client/Cargo.toml b/crates/client/Cargo.toml index b741f515fd1681a721048d831d01db8ec0f889e6..365625b44535e474baecf058c98f54aaf05b5e49 100644 --- a/crates/client/Cargo.toml +++ b/crates/client/Cargo.toml @@ -17,11 +17,12 @@ test-support = ["clock/test-support", "collections/test-support", "gpui/test-sup [dependencies] anyhow.workspace = true -async-recursion = "0.3" async-tungstenite = { workspace = true, features = ["tokio", "tokio-rustls-manual-roots"] } base64.workspace = true chrono = { workspace = true, features = ["serde"] } clock.workspace = true +cloud_api_client.workspace = true +cloud_llm_client.workspace = true collections.workspace = true credentials_provider.workspace = true derive_more.workspace = true @@ -33,8 +34,8 @@ http_client.workspace = true http_client_tls.workspace = true httparse = "1.10" log.workspace = true -paths.workspace = true parking_lot.workspace = true +paths.workspace = true postage.workspace = true rand.workspace = true regex.workspace = true @@ -46,19 +47,18 @@ serde_json.workspace = true settings.workspace = true sha2.workspace = true smol.workspace = true +telemetry.workspace = true telemetry_events.workspace = true text.workspace = true thiserror.workspace = true time.workspace = true tiny_http.workspace = true tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io"] } +tokio.workspace = true url.workspace = true util.workspace = true -worktree.workspace = true -telemetry.workspace = true -tokio.workspace = true workspace-hack.workspace = true -zed_llm_client.workspace = true +worktree.workspace = true [dev-dependencies] clock = { workspace = true, features = ["test-support"] } diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index c4211f72c819cfed5c0ee2f555356aa970968bc5..f09c012a858e3cf97166dae9dbdbeb3da51b96b6 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -6,22 +6,23 @@ pub mod telemetry; pub mod user; pub mod zed_urls; -use anyhow::{Context as _, Result, anyhow, bail}; -use async_recursion::async_recursion; +use anyhow::{Context as _, Result, anyhow}; use async_tungstenite::tungstenite::{ client::IntoClientRequest, error::Error as WebsocketError, http::{HeaderValue, Request, StatusCode}, }; -use chrono::{DateTime, Utc}; use clock::SystemClock; +use cloud_api_client::CloudApiClient; +use cloud_api_client::websocket_protocol::MessageToClient; use credentials_provider::CredentialsProvider; +use feature_flags::FeatureFlagAppExt as _; use futures::{ AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt, channel::oneshot, future::BoxFuture, }; use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions}; -use http_client::{AsyncBody, HttpClient, HttpClientWithUrl}; +use http_client::{HttpClient, HttpClientWithUrl, http}; use parking_lot::RwLock; use postage::watch; use proxy::connect_proxy_stream; @@ -31,7 +32,6 @@ use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsSources}; -use std::pin::Pin; use std::{ any::TypeId, convert::TryFrom, @@ -45,6 +45,7 @@ use std::{ }, time::{Duration, Instant}, }; +use std::{cmp, pin::Pin}; use telemetry::Telemetry; use thiserror::Error; use tokio::net::TcpStream; @@ -78,7 +79,7 @@ pub static ZED_ALWAYS_ACTIVE: LazyLock = LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| !e.is_empty())); pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500); -pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(10); +pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(30); pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20); actions!( @@ -161,20 +162,8 @@ pub fn init(client: &Arc, cx: &mut App) { let client = client.clone(); move |_: &SignIn, cx| { if let Some(client) = client.upgrade() { - cx.spawn( - async move |cx| match client.authenticate_and_connect(true, &cx).await { - ConnectionResult::Timeout => { - log::error!("Initial authentication timed out"); - } - ConnectionResult::ConnectionReset => { - log::error!("Initial authentication connection reset"); - } - ConnectionResult::Result(r) => { - r.log_err(); - } - }, - ) - .detach(); + cx.spawn(async move |cx| client.sign_in_with_optional_connect(true, &cx).await) + .detach_and_log_err(cx); } } }); @@ -204,6 +193,8 @@ pub fn init(client: &Arc, cx: &mut App) { }); } +pub type MessageToClientHandler = Box; + struct GlobalClient(Arc); impl Global for GlobalClient {} @@ -212,10 +203,12 @@ pub struct Client { id: AtomicU64, peer: Arc, http: Arc, + cloud_client: Arc, telemetry: Arc, credentials_provider: ClientCredentialsProvider, state: RwLock, handler_set: parking_lot::Mutex, + message_to_client_handlers: parking_lot::Mutex>, #[allow(clippy::type_complexity)] #[cfg(any(test, feature = "test-support"))] @@ -282,6 +275,8 @@ pub enum Status { SignedOut, UpgradeRequired, Authenticating, + Authenticated, + AuthenticationError, Connecting, ConnectionError, Connected { @@ -301,6 +296,13 @@ impl Status { matches!(self, Self::Connected { .. }) } + pub fn is_signing_in(&self) -> bool { + matches!( + self, + Self::Authenticating | Self::Reauthenticating | Self::Connecting | Self::Reconnecting + ) + } + pub fn is_signed_out(&self) -> bool { matches!(self, Self::SignedOut | Self::UpgradeRequired) } @@ -551,10 +553,12 @@ impl Client { id: AtomicU64::new(0), peer: Peer::new(0), telemetry: Telemetry::new(clock, http.clone(), cx), + cloud_client: Arc::new(CloudApiClient::new(http.clone())), http, credentials_provider: ClientCredentialsProvider::new(cx), state: Default::default(), handler_set: Default::default(), + message_to_client_handlers: parking_lot::Mutex::new(Vec::new()), #[cfg(any(test, feature = "test-support"))] authenticate: Default::default(), @@ -583,6 +587,10 @@ impl Client { self.http.clone() } + pub fn cloud_client(&self) -> Arc { + self.cloud_client.clone() + } + pub fn set_id(&self, id: u64) -> &Self { self.id.store(id, Ordering::SeqCst); self @@ -669,7 +677,7 @@ impl Client { let mut delay = INITIAL_RECONNECTION_DELAY; loop { - match client.authenticate_and_connect(true, &cx).await { + match client.connect(true, &cx).await { ConnectionResult::Timeout => { log::error!("client connect attempt timed out") } @@ -685,18 +693,20 @@ impl Client { } } - if matches!(*client.status().borrow(), Status::ConnectionError) { + if matches!( + *client.status().borrow(), + Status::AuthenticationError | Status::ConnectionError + ) { client.set_status( Status::ReconnectionError { next_reconnection: Instant::now() + delay, }, &cx, ); - cx.background_executor().timer(delay).await; - delay = delay - .mul_f32(rng.gen_range(0.5..=2.5)) - .max(INITIAL_RECONNECTION_DELAY) - .min(MAX_RECONNECTION_DELAY); + let jitter = + Duration::from_millis(rng.gen_range(0..delay.as_millis() as u64)); + cx.background_executor().timer(delay + jitter).await; + delay = cmp::min(delay * 2, MAX_RECONNECTION_DELAY); } else { break; } @@ -840,40 +850,37 @@ impl Client { .is_some() } - #[async_recursion(?Send)] - pub async fn authenticate_and_connect( + pub async fn sign_in( self: &Arc, try_provider: bool, cx: &AsyncApp, - ) -> ConnectionResult<()> { - let was_disconnected = match *self.status().borrow() { - Status::SignedOut => true, - Status::ConnectionError - | Status::ConnectionLost - | Status::Authenticating { .. } - | Status::Reauthenticating { .. } - | Status::ReconnectionError { .. } => false, - Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => { - return ConnectionResult::Result(Ok(())); - } - Status::UpgradeRequired => { - return ConnectionResult::Result( - Err(EstablishConnectionError::UpgradeRequired) - .context("client auth and connect"), - ); - } - }; - if was_disconnected { + ) -> Result { + if self.status().borrow().is_signed_out() { self.set_status(Status::Authenticating, cx); } else { - self.set_status(Status::Reauthenticating, cx) + self.set_status(Status::Reauthenticating, cx); + } + + let mut credentials = None; + + let old_credentials = self.state.read().credentials.clone(); + if let Some(old_credentials) = old_credentials { + if self.validate_credentials(&old_credentials, cx).await? { + credentials = Some(old_credentials); + } } - let mut read_from_provider = false; - let mut credentials = self.state.read().credentials.clone(); if credentials.is_none() && try_provider { - credentials = self.credentials_provider.read_credentials(cx).await; - read_from_provider = credentials.is_some(); + if let Some(stored_credentials) = self.credentials_provider.read_credentials(cx).await { + if self.validate_credentials(&stored_credentials, cx).await? { + credentials = Some(stored_credentials); + } else { + self.credentials_provider + .delete_credentials(cx) + .await + .log_err(); + } + } } if credentials.is_none() { @@ -882,20 +889,158 @@ impl Client { futures::select_biased! { authenticate = self.authenticate(cx).fuse() => { match authenticate { - Ok(creds) => credentials = Some(creds), + Ok(creds) => { + if IMPERSONATE_LOGIN.is_none() { + self.credentials_provider + .write_credentials(creds.user_id, creds.access_token.clone(), cx) + .await + .log_err(); + } + + credentials = Some(creds); + }, Err(err) => { - self.set_status(Status::ConnectionError, cx); - return ConnectionResult::Result(Err(err)); + self.set_status(Status::AuthenticationError, cx); + return Err(err); } } } _ = status_rx.next().fuse() => { - return ConnectionResult::Result(Err(anyhow!("authentication canceled"))); + return Err(anyhow!("authentication canceled")); } } } + let credentials = credentials.unwrap(); self.set_id(credentials.user_id); + self.cloud_client + .set_credentials(credentials.user_id as u32, credentials.access_token.clone()); + self.state.write().credentials = Some(credentials.clone()); + self.set_status(Status::Authenticated, cx); + + Ok(credentials) + } + + async fn validate_credentials( + self: &Arc, + credentials: &Credentials, + cx: &AsyncApp, + ) -> Result { + match self + .cloud_client + .validate_credentials(credentials.user_id as u32, &credentials.access_token) + .await + { + Ok(valid) => Ok(valid), + Err(err) => { + self.set_status(Status::AuthenticationError, cx); + Err(anyhow!("failed to validate credentials: {}", err)) + } + } + } + + /// Establishes a WebSocket connection with Cloud for receiving updates from the server. + async fn connect_to_cloud(self: &Arc, cx: &AsyncApp) -> Result<()> { + let connect_task = cx.update({ + let cloud_client = self.cloud_client.clone(); + move |cx| cloud_client.connect(cx) + })??; + let connection = connect_task.await?; + + let (mut messages, task) = cx.update(|cx| connection.spawn(cx))?; + task.detach(); + + cx.spawn({ + let this = self.clone(); + async move |cx| { + while let Some(message) = messages.next().await { + if let Some(message) = message.log_err() { + this.handle_message_to_client(message, cx); + } + } + } + }) + .detach(); + + Ok(()) + } + + /// Performs a sign-in and also (optionally) connects to Collab. + /// + /// Only Zed staff automatically connect to Collab. + pub async fn sign_in_with_optional_connect( + self: &Arc, + try_provider: bool, + cx: &AsyncApp, + ) -> Result<()> { + let (is_staff_tx, is_staff_rx) = oneshot::channel::(); + let mut is_staff_tx = Some(is_staff_tx); + cx.update(|cx| { + cx.on_flags_ready(move |state, _cx| { + if let Some(is_staff_tx) = is_staff_tx.take() { + is_staff_tx.send(state.is_staff).log_err(); + } + }) + .detach(); + }) + .log_err(); + + let credentials = self.sign_in(try_provider, cx).await?; + + self.connect_to_cloud(cx).await.log_err(); + + cx.update(move |cx| { + cx.spawn({ + let client = self.clone(); + async move |cx| { + let is_staff = is_staff_rx.await?; + if is_staff { + match client.connect_with_credentials(credentials, cx).await { + ConnectionResult::Timeout => Err(anyhow!("connection timed out")), + ConnectionResult::ConnectionReset => Err(anyhow!("connection reset")), + ConnectionResult::Result(result) => { + result.context("client auth and connect") + } + } + } else { + Ok(()) + } + } + }) + .detach_and_log_err(cx); + }) + .log_err(); + + Ok(()) + } + + pub async fn connect( + self: &Arc, + try_provider: bool, + cx: &AsyncApp, + ) -> ConnectionResult<()> { + let was_disconnected = match *self.status().borrow() { + Status::SignedOut | Status::Authenticated => true, + Status::ConnectionError + | Status::ConnectionLost + | Status::Authenticating { .. } + | Status::AuthenticationError + | Status::Reauthenticating { .. } + | Status::ReconnectionError { .. } => false, + Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => { + return ConnectionResult::Result(Ok(())); + } + Status::UpgradeRequired => { + return ConnectionResult::Result( + Err(EstablishConnectionError::UpgradeRequired) + .context("client auth and connect"), + ); + } + }; + let credentials = match self.sign_in(try_provider, cx).await { + Ok(credentials) => credentials, + Err(err) => return ConnectionResult::Result(Err(err)), + }; if was_disconnected { self.set_status(Status::Connecting, cx); @@ -903,17 +1048,20 @@ impl Client { self.set_status(Status::Reconnecting, cx); } + self.connect_with_credentials(credentials, cx).await + } + + async fn connect_with_credentials( + self: &Arc, + credentials: Credentials, + cx: &AsyncApp, + ) -> ConnectionResult<()> { let mut timeout = futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT)); futures::select_biased! { connection = self.establish_connection(&credentials, cx).fuse() => { match connection { Ok(conn) => { - self.state.write().credentials = Some(credentials.clone()); - if !read_from_provider && IMPERSONATE_LOGIN.is_none() { - self.credentials_provider.write_credentials(credentials.user_id, credentials.access_token, cx).await.log_err(); - } - futures::select_biased! { result = self.set_connection(conn, cx).fuse() => { match result.context("client auth and connect") { @@ -931,15 +1079,8 @@ impl Client { } } Err(EstablishConnectionError::Unauthorized) => { - self.state.write().credentials.take(); - if read_from_provider { - self.credentials_provider.delete_credentials(cx).await.log_err(); - self.set_status(Status::SignedOut, cx); - self.authenticate_and_connect(false, cx).await - } else { - self.set_status(Status::ConnectionError, cx); - ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect")) - } + self.set_status(Status::ConnectionError, cx); + ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect")) } Err(EstablishConnectionError::UpgradeRequired) => { self.set_status(Status::UpgradeRequired, cx); @@ -1103,7 +1244,7 @@ impl Client { .to_str() .map_err(EstablishConnectionError::other)? .to_string(); - Url::parse(&collab_url).with_context(|| format!("parsing colab rpc url {collab_url}")) + Url::parse(&collab_url).with_context(|| format!("parsing collab rpc url {collab_url}")) } } @@ -1123,6 +1264,7 @@ impl Client { let http = self.http.clone(); let proxy = http.proxy().cloned(); + let user_agent = http.user_agent().cloned(); let credentials = credentials.clone(); let rpc_url = self.rpc_url(http, release_channel); let system_id = self.telemetry.system_id(); @@ -1174,7 +1316,7 @@ impl Client { // We then modify the request to add our desired headers. let request_headers = request.headers_mut(); request_headers.insert( - "Authorization", + http::header::AUTHORIZATION, HeaderValue::from_str(&credentials.authorization_header())?, ); request_headers.insert( @@ -1186,6 +1328,9 @@ impl Client { "x-zed-release-channel", HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?, ); + if let Some(user_agent) = user_agent { + request_headers.insert(http::header::USER_AGENT, user_agent); + } if let Some(system_id) = system_id { request_headers.insert("x-zed-system-id", HeaderValue::from_str(&system_id)?); } @@ -1330,96 +1475,31 @@ impl Client { self: &Arc, http: Arc, login: String, - mut api_token: String, + api_token: String, ) -> Result { - #[derive(Deserialize)] - struct AuthenticatedUserResponse { - user: User, + #[derive(Serialize)] + struct ImpersonateUserBody { + github_login: String, } #[derive(Deserialize)] - struct User { - id: u64, + struct ImpersonateUserResponse { + user_id: u64, + access_token: String, } - let github_user = { - #[derive(Deserialize)] - struct GithubUser { - id: i32, - login: String, - created_at: DateTime, - } - - let request = { - let mut request_builder = - Request::get(&format!("https://api.github.com/users/{login}")); - if let Ok(github_token) = std::env::var("GITHUB_TOKEN") { - request_builder = - request_builder.header("Authorization", format!("Bearer {}", github_token)); - } - - request_builder.body(AsyncBody::empty())? - }; - - let mut response = http - .send(request) - .await - .context("error fetching GitHub user")?; - - let mut body = Vec::new(); - response - .body_mut() - .read_to_end(&mut body) - .await - .context("error reading GitHub user")?; - - if !response.status().is_success() { - let text = String::from_utf8_lossy(body.as_slice()); - bail!( - "status error {}, response: {text:?}", - response.status().as_u16() - ); - } - - serde_json::from_slice::(body.as_slice()).map_err(|err| { - log::error!("Error deserializing: {:?}", err); - log::error!( - "GitHub API response text: {:?}", - String::from_utf8_lossy(body.as_slice()) - ); - anyhow!("error deserializing GitHub user") - })? - }; - - let query_params = [ - ("github_login", &github_user.login), - ("github_user_id", &github_user.id.to_string()), - ( - "github_user_created_at", - &github_user.created_at.to_rfc3339(), - ), - ]; - - // Use the collab server's admin API to retrieve the ID - // of the impersonated user. - let mut url = self.rpc_url(http.clone(), None).await?; - url.set_path("/user"); - url.set_query(Some( - &query_params - .iter() - .map(|(key, value)| { - format!( - "{}={}", - key, - url::form_urlencoded::byte_serialize(value.as_bytes()).collect::() - ) - }) - .collect::>() - .join("&"), - )); - let request: http_client::Request = Request::get(url.as_str()) - .header("Authorization", format!("token {api_token}")) - .body("".into())?; + let url = self + .http + .build_zed_cloud_url("/internal/users/impersonate", &[])?; + let request = Request::post(url.as_str()) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {api_token}")) + .body( + serde_json::to_string(&ImpersonateUserBody { + github_login: login, + })? + .into(), + )?; let mut response = http.send(request).await?; let mut body = String::new(); @@ -1430,18 +1510,17 @@ impl Client { response.status().as_u16(), body, ); - let response: AuthenticatedUserResponse = serde_json::from_str(&body)?; + let response: ImpersonateUserResponse = serde_json::from_str(&body)?; - // Use the admin API token to authenticate as the impersonated user. - api_token.insert_str(0, "ADMIN_TOKEN:"); Ok(Credentials { - user_id: response.user.id, - access_token: api_token, + user_id: response.user_id, + access_token: response.access_token, }) } pub async fn sign_out(self: &Arc, cx: &AsyncApp) { self.state.write().credentials = None; + self.cloud_client.clear_credentials(); self.disconnect(cx); if self.has_credentials(cx).await { @@ -1603,6 +1682,24 @@ impl Client { } } + pub fn add_message_to_client_handler( + self: &Arc, + handler: impl Fn(&MessageToClient, &mut App) + Send + Sync + 'static, + ) { + self.message_to_client_handlers + .lock() + .push(Box::new(handler)); + } + + fn handle_message_to_client(self: &Arc, message: MessageToClient, cx: &AsyncApp) { + cx.update(|cx| { + for handler in self.message_to_client_handlers.lock().iter() { + handler(&message, cx); + } + }) + .ok(); + } + pub fn telemetry(&self) -> &Arc { &self.telemetry } @@ -1670,7 +1767,7 @@ pub fn parse_zed_link<'a>(link: &'a str, cx: &App) -> Option<&'a str> { #[cfg(test)] mod tests { use super::*; - use crate::test::FakeServer; + use crate::test::{FakeServer, parse_authorization_header}; use clock::FakeSystemClock; use gpui::{AppContext as _, BackgroundExecutor, TestAppContext}; @@ -1721,6 +1818,46 @@ mod tests { assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token } + #[gpui::test(iterations = 10)] + async fn test_auth_failure_during_reconnection(cx: &mut TestAppContext) { + init_test(cx); + let http_client = FakeHttpClient::with_200_response(); + let client = + cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx)); + let server = FakeServer::for_client(42, &client, cx).await; + let mut status = client.status(); + assert!(matches!( + status.next().await, + Some(Status::Connected { .. }) + )); + assert_eq!(server.auth_count(), 1); + + // Simulate an auth failure during reconnection. + http_client + .as_fake() + .replace_handler(|_, _request| async move { + Ok(http_client::Response::builder() + .status(503) + .body("".into()) + .unwrap()) + }); + server.disconnect(); + while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {} + + // Restore the ability to authenticate. + http_client + .as_fake() + .replace_handler(|_, _request| async move { + Ok(http_client::Response::builder() + .status(200) + .body("".into()) + .unwrap()) + }); + cx.executor().advance_clock(Duration::from_secs(10)); + while !matches!(status.next().await, Some(Status::Connected { .. })) {} + assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting + } + #[gpui::test(iterations = 10)] async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) { init_test(cx); @@ -1751,7 +1888,7 @@ mod tests { }); let auth_and_connect = cx.spawn({ let client = client.clone(); - |cx| async move { client.authenticate_and_connect(false, &cx).await } + |cx| async move { client.connect(false, &cx).await } }); executor.run_until_parked(); assert!(matches!(status.next().await, Some(Status::Connecting))); @@ -1796,6 +1933,75 @@ mod tests { )); } + #[gpui::test(iterations = 10)] + async fn test_reauthenticate_only_if_unauthorized(cx: &mut TestAppContext) { + init_test(cx); + let auth_count = Arc::new(Mutex::new(0)); + let http_client = FakeHttpClient::create(|_request| async move { + Ok(http_client::Response::builder() + .status(200) + .body("".into()) + .unwrap()) + }); + let client = + cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx)); + client.override_authenticate({ + let auth_count = auth_count.clone(); + move |cx| { + let auth_count = auth_count.clone(); + cx.background_spawn(async move { + *auth_count.lock() += 1; + Ok(Credentials { + user_id: 1, + access_token: auth_count.lock().to_string(), + }) + }) + } + }); + + let credentials = client.sign_in(false, &cx.to_async()).await.unwrap(); + assert_eq!(*auth_count.lock(), 1); + assert_eq!(credentials.access_token, "1"); + + // If credentials are still valid, signing in doesn't trigger authentication. + let credentials = client.sign_in(false, &cx.to_async()).await.unwrap(); + assert_eq!(*auth_count.lock(), 1); + assert_eq!(credentials.access_token, "1"); + + // If the server is unavailable, signing in doesn't trigger authentication. + http_client + .as_fake() + .replace_handler(|_, _request| async move { + Ok(http_client::Response::builder() + .status(503) + .body("".into()) + .unwrap()) + }); + client.sign_in(false, &cx.to_async()).await.unwrap_err(); + assert_eq!(*auth_count.lock(), 1); + + // If credentials became invalid, signing in triggers authentication. + http_client + .as_fake() + .replace_handler(|_, request| async move { + let credentials = parse_authorization_header(&request).unwrap(); + if credentials.access_token == "2" { + Ok(http_client::Response::builder() + .status(200) + .body("".into()) + .unwrap()) + } else { + Ok(http_client::Response::builder() + .status(401) + .body("".into()) + .unwrap()) + } + }); + let credentials = client.sign_in(false, &cx.to_async()).await.unwrap(); + assert_eq!(*auth_count.lock(), 2); + assert_eq!(credentials.access_token, "2"); + } + #[gpui::test(iterations = 10)] async fn test_authenticating_more_than_once( cx: &mut TestAppContext, @@ -1828,7 +2034,7 @@ mod tests { let _authenticate = cx.spawn({ let client = client.clone(); - move |cx| async move { client.authenticate_and_connect(false, &cx).await } + move |cx| async move { client.connect(false, &cx).await } }); executor.run_until_parked(); assert_eq!(*auth_count.lock(), 1); @@ -1836,7 +2042,7 @@ mod tests { let _authenticate = cx.spawn({ let client = client.clone(); - |cx| async move { client.authenticate_and_connect(false, &cx).await } + |cx| async move { client.connect(false, &cx).await } }); executor.run_until_parked(); assert_eq!(*auth_count.lock(), 2); diff --git a/crates/client/src/telemetry.rs b/crates/client/src/telemetry.rs index 4983fda5efa034c73326c627f555180afe753dfa..43a1a0b7a4f85fbee43c05292e354bc257e5f941 100644 --- a/crates/client/src/telemetry.rs +++ b/crates/client/src/telemetry.rs @@ -74,6 +74,12 @@ static ZED_CLIENT_CHECKSUM_SEED: LazyLock>> = LazyLock::new(|| { }) }); +pub static MINIDUMP_ENDPOINT: LazyLock> = LazyLock::new(|| { + option_env!("ZED_MINIDUMP_ENDPOINT") + .map(|s| s.to_owned()) + .or_else(|| env::var("ZED_MINIDUMP_ENDPOINT").ok()) +}); + static DOTNET_PROJECT_FILES_REGEX: LazyLock = LazyLock::new(|| { Regex::new(r"^(global\.json|Directory\.Build\.props|.*\.(csproj|fsproj|vbproj|sln))$").unwrap() }); @@ -358,13 +364,13 @@ impl Telemetry { worktree_id: WorktreeId, updated_entries_set: &UpdatedEntriesSet, ) { - let Some(project_type_names) = self.detect_project_types(worktree_id, updated_entries_set) + let Some(project_types) = self.detect_project_types(worktree_id, updated_entries_set) else { return; }; - for project_type_name in project_type_names { - telemetry::event!("Project Opened", project_type = project_type_name); + for project_type in project_types { + telemetry::event!("Project Opened", project_type = project_type); } } diff --git a/crates/client/src/test.rs b/crates/client/src/test.rs index 6ce79fa9c53494f1da97d861bcdca78a2a9dbf1f..439fb100d2244499fa59a81495e282673305e00b 100644 --- a/crates/client/src/test.rs +++ b/crates/client/src/test.rs @@ -1,8 +1,11 @@ use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore}; use anyhow::{Context as _, Result, anyhow}; use chrono::Duration; +use cloud_api_client::{AuthenticatedUser, GetAuthenticatedUserResponse, PlanInfo}; +use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit}; use futures::{StreamExt, stream::BoxStream}; use gpui::{AppContext as _, BackgroundExecutor, Entity, TestAppContext}; +use http_client::{AsyncBody, Method, Request, http}; use parking_lot::Mutex; use rpc::{ ConnectionId, Peer, Receipt, TypedEnvelope, @@ -39,6 +42,44 @@ impl FakeServer { executor: cx.executor(), }; + client.http_client().as_fake().replace_handler({ + let state = server.state.clone(); + move |old_handler, req| { + let state = state.clone(); + let old_handler = old_handler.clone(); + async move { + match (req.method(), req.uri().path()) { + (&Method::GET, "/client/users/me") => { + let credentials = parse_authorization_header(&req); + if credentials + != Some(Credentials { + user_id: client_user_id, + access_token: state.lock().access_token.to_string(), + }) + { + return Ok(http_client::Response::builder() + .status(401) + .body("Unauthorized".into()) + .unwrap()); + } + + Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&make_get_authenticated_user_response( + client_user_id as i32, + format!("user-{client_user_id}"), + )) + .unwrap() + .into(), + ) + .unwrap()) + } + _ => old_handler(req).await, + } + } + } + }); client .override_authenticate({ let state = Arc::downgrade(&server.state); @@ -105,7 +146,7 @@ impl FakeServer { }); client - .authenticate_and_connect(false, &cx.to_async()) + .connect(false, &cx.to_async()) .await .into_response() .unwrap(); @@ -223,3 +264,54 @@ impl Drop for FakeServer { self.disconnect(); } } + +pub fn parse_authorization_header(req: &Request) -> Option { + let mut auth_header = req + .headers() + .get(http::header::AUTHORIZATION)? + .to_str() + .ok()? + .split_whitespace(); + let user_id = auth_header.next()?.parse().ok()?; + let access_token = auth_header.next()?; + Some(Credentials { + user_id, + access_token: access_token.to_string(), + }) +} + +pub fn make_get_authenticated_user_response( + user_id: i32, + github_login: String, +) -> GetAuthenticatedUserResponse { + GetAuthenticatedUserResponse { + user: AuthenticatedUser { + id: user_id, + metrics_id: format!("metrics-id-{user_id}"), + avatar_url: "".to_string(), + github_login, + name: None, + is_staff: false, + accepted_tos_at: None, + }, + feature_flags: vec![], + plan: PlanInfo { + plan: Plan::ZedPro, + subscription_period: None, + usage: CurrentUsage { + model_requests: UsageData { + used: 0, + limit: UsageLimit::Limited(500), + }, + edit_predictions: UsageData { + used: 250, + limit: UsageLimit::Unlimited, + }, + }, + trial_started_at: None, + is_usage_based_billing_enabled: false, + is_account_too_young: false, + has_overdue_invoices: false, + }, + } +} diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index 61e3064eb496b59910ce8ab25797b9b4b4848201..9f76dd7ad08034e814d4be0ff825f37415a3bccd 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -1,6 +1,12 @@ use super::{Client, Status, TypedEnvelope, proto}; use anyhow::{Context as _, Result, anyhow}; use chrono::{DateTime, Utc}; +use cloud_api_client::websocket_protocol::MessageToClient; +use cloud_api_client::{GetAuthenticatedUserResponse, PlanInfo}; +use cloud_llm_client::{ + EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME, + MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit, +}; use collections::{HashMap, HashSet, hash_map::Entry}; use derive_more::Deref; use feature_flags::FeatureFlagAppExt; @@ -16,11 +22,7 @@ use std::{ sync::{Arc, Weak}, }; use text::ReplicaId; -use util::{TryFutureExt as _, maybe}; -use zed_llm_client::{ - EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME, - MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit, -}; +use util::{ResultExt, TryFutureExt as _}; pub type UserId = u64; @@ -55,7 +57,7 @@ pub struct ParticipantIndex(pub u32); #[derive(Default, Debug)] pub struct User { pub id: UserId, - pub github_login: String, + pub github_login: SharedString, pub avatar_uri: SharedUri, pub name: Option, } @@ -107,19 +109,14 @@ pub enum ContactRequestStatus { pub struct UserStore { users: HashMap>, - by_github_login: HashMap, + by_github_login: HashMap, participant_indices: HashMap, update_contacts_tx: mpsc::UnboundedSender, - current_plan: Option, - subscription_period: Option<(DateTime, DateTime)>, - trial_started_at: Option>, model_request_usage: Option, edit_prediction_usage: Option, - is_usage_based_billing_enabled: Option, - account_too_young: Option, - has_overdue_invoices: Option, + plan_info: Option, current_user: watch::Receiver>>, - accepted_tos_at: Option>>, + accepted_tos_at: Option>, contacts: Vec>, incoming_contact_requests: Vec>, outgoing_contact_requests: Vec>, @@ -145,6 +142,7 @@ pub enum Event { ShowContacts, ParticipantIndicesChanged, PrivateUserInfoUpdated, + PlanUpdated, } #[derive(Clone, Copy)] @@ -184,18 +182,19 @@ impl UserStore { client.add_message_handler(cx.weak_entity(), Self::handle_update_invite_info), client.add_message_handler(cx.weak_entity(), Self::handle_show_contacts), ]; + + client.add_message_to_client_handler({ + let this = cx.weak_entity(); + move |message, cx| Self::handle_message_to_client(this.clone(), message, cx) + }); + Self { users: Default::default(), by_github_login: Default::default(), current_user: current_user_rx, - current_plan: None, - subscription_period: None, - trial_started_at: None, + plan_info: None, model_request_usage: None, edit_prediction_usage: None, - is_usage_based_billing_enabled: None, - account_too_young: None, - has_overdue_invoices: None, accepted_tos_at: None, contacts: Default::default(), incoming_contact_requests: Default::default(), @@ -225,53 +224,30 @@ impl UserStore { return Ok(()); }; match status { - Status::Connected { .. } => { + Status::Authenticated | Status::Connected { .. } => { if let Some(user_id) = client.user_id() { - let fetch_user = if let Ok(fetch_user) = - this.update(cx, |this, cx| this.get_user(user_id, cx).log_err()) - { - fetch_user - } else { - break; - }; - let fetch_private_user_info = - client.request(proto::GetPrivateUserInfo {}).log_err(); - let (user, info) = - futures::join!(fetch_user, fetch_private_user_info); - + let response = client.cloud_client().get_authenticated_user().await; + let mut current_user = None; cx.update(|cx| { - if let Some(info) = info { - let staff = - info.staff && !*feature_flags::ZED_DISABLE_STAFF; - cx.update_flags(staff, info.flags); - client.telemetry.set_authenticated_user_info( - Some(info.metrics_id.clone()), - staff, - ); - + if let Some(response) = response.log_err() { + let user = Arc::new(User { + id: user_id, + github_login: response.user.github_login.clone().into(), + avatar_uri: response.user.avatar_url.clone().into(), + name: response.user.name.clone(), + }); + current_user = Some(user.clone()); this.update(cx, |this, cx| { - let accepted_tos_at = { - #[cfg(debug_assertions)] - if std::env::var("ZED_IGNORE_ACCEPTED_TOS").is_ok() - { - None - } else { - info.accepted_tos_at - } - - #[cfg(not(debug_assertions))] - info.accepted_tos_at - }; - - this.set_current_user_accepted_tos_at(accepted_tos_at); - cx.emit(Event::PrivateUserInfoUpdated); + this.by_github_login + .insert(user.github_login.clone(), user_id); + this.users.insert(user_id, user); + this.update_authenticated_user(response, cx) }) } else { anyhow::Ok(()) } })??; - - current_user_tx.send(user).await.ok(); + current_user_tx.send(current_user).await.ok(); this.update(cx, |_, cx| cx.notify())?; } @@ -352,59 +328,22 @@ impl UserStore { async fn handle_update_plan( this: Entity, - message: TypedEnvelope, + _message: TypedEnvelope, mut cx: AsyncApp, ) -> Result<()> { - this.update(&mut cx, |this, cx| { - this.current_plan = Some(message.payload.plan()); - this.subscription_period = maybe!({ - let period = message.payload.subscription_period?; - let started_at = DateTime::from_timestamp(period.started_at as i64, 0)?; - let ended_at = DateTime::from_timestamp(period.ended_at as i64, 0)?; + let client = this + .read_with(&cx, |this, _| this.client.upgrade())? + .context("client was dropped")?; - Some((started_at, ended_at)) - }); - this.trial_started_at = message - .payload - .trial_started_at - .and_then(|trial_started_at| DateTime::from_timestamp(trial_started_at as i64, 0)); - this.is_usage_based_billing_enabled = message.payload.is_usage_based_billing_enabled; - this.account_too_young = message.payload.account_too_young; - this.has_overdue_invoices = message.payload.has_overdue_invoices; - - if let Some(usage) = message.payload.usage { - // limits are always present even though they are wrapped in Option - this.model_request_usage = usage - .model_requests_usage_limit - .and_then(|limit| { - RequestUsage::from_proto(usage.model_requests_usage_amount, limit) - }) - .map(ModelRequestUsage); - this.edit_prediction_usage = usage - .edit_predictions_usage_limit - .and_then(|limit| { - RequestUsage::from_proto(usage.model_requests_usage_amount, limit) - }) - .map(EditPredictionUsage); - } + let response = client + .cloud_client() + .get_authenticated_user() + .await + .context("failed to fetch authenticated user")?; - cx.notify(); - })?; - Ok(()) - } - - pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context) { - self.model_request_usage = Some(usage); - cx.notify(); - } - - pub fn update_edit_prediction_usage( - &mut self, - usage: EditPredictionUsage, - cx: &mut Context, - ) { - self.edit_prediction_usage = Some(usage); - cx.notify(); + this.update(&mut cx, |this, cx| { + this.update_authenticated_user(response, cx); + }) } fn update_contacts(&mut self, message: UpdateContacts, cx: &Context) -> Task> { @@ -763,47 +702,157 @@ impl UserStore { self.current_user.borrow().clone() } - pub fn current_plan(&self) -> Option { - self.current_plan + pub fn plan(&self) -> Option { + #[cfg(debug_assertions)] + if let Ok(plan) = std::env::var("ZED_SIMULATE_PLAN").as_ref() { + return match plan.as_str() { + "free" => Some(cloud_llm_client::Plan::ZedFree), + "trial" => Some(cloud_llm_client::Plan::ZedProTrial), + "pro" => Some(cloud_llm_client::Plan::ZedPro), + _ => { + panic!("ZED_SIMULATE_PLAN must be one of 'free', 'trial', or 'pro'"); + } + }; + } + + self.plan_info.as_ref().map(|info| info.plan) } pub fn subscription_period(&self) -> Option<(DateTime, DateTime)> { - self.subscription_period + self.plan_info + .as_ref() + .and_then(|plan| plan.subscription_period) + .map(|subscription_period| { + ( + subscription_period.started_at.0, + subscription_period.ended_at.0, + ) + }) } pub fn trial_started_at(&self) -> Option> { - self.trial_started_at + self.plan_info + .as_ref() + .and_then(|plan| plan.trial_started_at) + .map(|trial_started_at| trial_started_at.0) + } + + /// Returns whether the user's account is too new to use the service. + pub fn account_too_young(&self) -> bool { + self.plan_info + .as_ref() + .map(|plan| plan.is_account_too_young) + .unwrap_or_default() } - pub fn usage_based_billing_enabled(&self) -> Option { - self.is_usage_based_billing_enabled + /// Returns whether the current user has overdue invoices and usage should be blocked. + pub fn has_overdue_invoices(&self) -> bool { + self.plan_info + .as_ref() + .map(|plan| plan.has_overdue_invoices) + .unwrap_or_default() + } + + pub fn is_usage_based_billing_enabled(&self) -> bool { + self.plan_info + .as_ref() + .map(|plan| plan.is_usage_based_billing_enabled) + .unwrap_or_default() } pub fn model_request_usage(&self) -> Option { self.model_request_usage } + pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context) { + self.model_request_usage = Some(usage); + cx.notify(); + } + pub fn edit_prediction_usage(&self) -> Option { self.edit_prediction_usage } - pub fn watch_current_user(&self) -> watch::Receiver>> { - self.current_user.clone() + pub fn update_edit_prediction_usage( + &mut self, + usage: EditPredictionUsage, + cx: &mut Context, + ) { + self.edit_prediction_usage = Some(usage); + cx.notify(); } - /// Returns whether the user's account is too new to use the service. - pub fn account_too_young(&self) -> bool { - self.account_too_young.unwrap_or(false) + fn update_authenticated_user( + &mut self, + response: GetAuthenticatedUserResponse, + cx: &mut Context, + ) { + let staff = response.user.is_staff && !*feature_flags::ZED_DISABLE_STAFF; + cx.update_flags(staff, response.feature_flags); + if let Some(client) = self.client.upgrade() { + client + .telemetry + .set_authenticated_user_info(Some(response.user.metrics_id.clone()), staff); + } + + let accepted_tos_at = { + #[cfg(debug_assertions)] + if std::env::var("ZED_IGNORE_ACCEPTED_TOS").is_ok() { + None + } else { + response.user.accepted_tos_at + } + + #[cfg(not(debug_assertions))] + response.user.accepted_tos_at + }; + + self.accepted_tos_at = Some(accepted_tos_at); + self.model_request_usage = Some(ModelRequestUsage(RequestUsage { + limit: response.plan.usage.model_requests.limit, + amount: response.plan.usage.model_requests.used as i32, + })); + self.edit_prediction_usage = Some(EditPredictionUsage(RequestUsage { + limit: response.plan.usage.edit_predictions.limit, + amount: response.plan.usage.edit_predictions.used as i32, + })); + self.plan_info = Some(response.plan); + cx.emit(Event::PrivateUserInfoUpdated); } - /// Returns whether the current user has overdue invoices and usage should be blocked. - pub fn has_overdue_invoices(&self) -> bool { - self.has_overdue_invoices.unwrap_or(false) + fn handle_message_to_client(this: WeakEntity, message: &MessageToClient, cx: &App) { + cx.spawn(async move |cx| { + match message { + MessageToClient::UserUpdated => { + let cloud_client = cx + .update(|cx| { + this.read_with(cx, |this, _cx| { + this.client.upgrade().map(|client| client.cloud_client()) + }) + })?? + .ok_or(anyhow::anyhow!("Failed to get Cloud client"))?; + + let response = cloud_client.get_authenticated_user().await?; + cx.update(|cx| { + this.update(cx, |this, cx| { + this.update_authenticated_user(response, cx); + }) + })??; + } + } + + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + } + + pub fn watch_current_user(&self) -> watch::Receiver>> { + self.current_user.clone() } - pub fn current_user_has_accepted_terms(&self) -> Option { + pub fn has_accepted_terms_of_service(&self) -> bool { self.accepted_tos_at - .map(|accepted_tos_at| accepted_tos_at.is_some()) + .map_or(false, |accepted_tos_at| accepted_tos_at.is_some()) } pub fn accept_terms_of_service(&self, cx: &Context) -> Task> { @@ -815,23 +864,18 @@ impl UserStore { cx.spawn(async move |this, cx| -> anyhow::Result<()> { let client = client.upgrade().context("client not found")?; let response = client - .request(proto::AcceptTermsOfService {}) + .cloud_client() + .accept_terms_of_service() .await .context("error accepting tos")?; this.update(cx, |this, cx| { - this.set_current_user_accepted_tos_at(Some(response.accepted_tos_at)); + this.accepted_tos_at = Some(response.user.accepted_tos_at); cx.emit(Event::PrivateUserInfoUpdated); })?; Ok(()) }) } - fn set_current_user_accepted_tos_at(&mut self, accepted_tos_at: Option) { - self.accepted_tos_at = Some( - accepted_tos_at.and_then(|timestamp| DateTime::from_timestamp(timestamp as i64, 0)), - ); - } - fn load_users( &self, request: impl RequestMessage, @@ -890,7 +934,7 @@ impl UserStore { let mut missing_user_ids = Vec::new(); for id in user_ids { if let Some(github_login) = self.get_cached_user(id).map(|u| u.github_login.clone()) { - ret.insert(id, github_login.into()); + ret.insert(id, github_login); } else { missing_user_ids.push(id) } @@ -911,7 +955,7 @@ impl User { fn new(message: proto::User) -> Arc { Arc::new(User { id: message.id, - github_login: message.github_login, + github_login: message.github_login.into(), avatar_uri: message.avatar_url.into(), name: message.name, }) diff --git a/crates/client/src/zed_urls.rs b/crates/client/src/zed_urls.rs index bfdae468fbb6cc9d829d820a7d9cb0828a8763dd..693c7bf836330fc8c6cd36ca72ee862a9e2b865b 100644 --- a/crates/client/src/zed_urls.rs +++ b/crates/client/src/zed_urls.rs @@ -17,3 +17,21 @@ fn server_url(cx: &App) -> &str { pub fn account_url(cx: &App) -> String { format!("{server_url}/account", server_url = server_url(cx)) } + +/// Returns the URL to the start trial page on zed.dev. +pub fn start_trial_url(cx: &App) -> String { + format!( + "{server_url}/account/start-trial", + server_url = server_url(cx) + ) +} + +/// Returns the URL to the upgrade page on zed.dev. +pub fn upgrade_to_zed_pro_url(cx: &App) -> String { + format!("{server_url}/account/upgrade", server_url = server_url(cx)) +} + +/// Returns the URL to Zed's terms of service. +pub fn terms_of_service(cx: &App) -> String { + format!("{server_url}/terms-of-service", server_url = server_url(cx)) +} diff --git a/crates/cloud_api_client/Cargo.toml b/crates/cloud_api_client/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..8e50ccb191373fe2cfadce2e4fd12cc3e397357f --- /dev/null +++ b/crates/cloud_api_client/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "cloud_api_client" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "Apache-2.0" + +[lints] +workspace = true + +[lib] +path = "src/cloud_api_client.rs" + +[dependencies] +anyhow.workspace = true +cloud_api_types.workspace = true +futures.workspace = true +gpui.workspace = true +gpui_tokio.workspace = true +http_client.workspace = true +parking_lot.workspace = true +serde_json.workspace = true +workspace-hack.workspace = true +yawc.workspace = true diff --git a/crates/cloud_api_client/LICENSE-APACHE b/crates/cloud_api_client/LICENSE-APACHE new file mode 120000 index 0000000000000000000000000000000000000000..1cd601d0a3affae83854be02a0afdec3b7a9ec4d --- /dev/null +++ b/crates/cloud_api_client/LICENSE-APACHE @@ -0,0 +1 @@ +../../LICENSE-APACHE \ No newline at end of file diff --git a/crates/cloud_api_client/src/cloud_api_client.rs b/crates/cloud_api_client/src/cloud_api_client.rs new file mode 100644 index 0000000000000000000000000000000000000000..ef9a1a9a553596baf737c4e1ee60d9b3344f4ecf --- /dev/null +++ b/crates/cloud_api_client/src/cloud_api_client.rs @@ -0,0 +1,231 @@ +mod websocket; + +use std::sync::Arc; + +use anyhow::{Context, Result, anyhow}; +use cloud_api_types::websocket_protocol::{PROTOCOL_VERSION, PROTOCOL_VERSION_HEADER_NAME}; +pub use cloud_api_types::*; +use futures::AsyncReadExt as _; +use gpui::{App, Task}; +use gpui_tokio::Tokio; +use http_client::http::request; +use http_client::{AsyncBody, HttpClientWithUrl, Method, Request, StatusCode}; +use parking_lot::RwLock; +use yawc::WebSocket; + +use crate::websocket::Connection; + +struct Credentials { + user_id: u32, + access_token: String, +} + +pub struct CloudApiClient { + credentials: RwLock>, + http_client: Arc, +} + +impl CloudApiClient { + pub fn new(http_client: Arc) -> Self { + Self { + credentials: RwLock::new(None), + http_client, + } + } + + pub fn has_credentials(&self) -> bool { + self.credentials.read().is_some() + } + + pub fn set_credentials(&self, user_id: u32, access_token: String) { + *self.credentials.write() = Some(Credentials { + user_id, + access_token, + }); + } + + pub fn clear_credentials(&self) { + *self.credentials.write() = None; + } + + fn build_request( + &self, + req: request::Builder, + body: impl Into, + ) -> Result> { + let credentials = self.credentials.read(); + let credentials = credentials.as_ref().context("no credentials provided")?; + build_request(req, body, credentials) + } + + pub async fn get_authenticated_user(&self) -> Result { + let request = self.build_request( + Request::builder().method(Method::GET).uri( + self.http_client + .build_zed_cloud_url("/client/users/me", &[])? + .as_ref(), + ), + AsyncBody::default(), + )?; + + let mut response = self.http_client.send(request).await?; + + if !response.status().is_success() { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + anyhow::bail!( + "Failed to get authenticated user.\nStatus: {:?}\nBody: {body}", + response.status() + ) + } + + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + Ok(serde_json::from_str(&body)?) + } + + pub fn connect(&self, cx: &App) -> Result>> { + let mut connect_url = self + .http_client + .build_zed_cloud_url("/client/users/connect", &[])?; + connect_url + .set_scheme(match connect_url.scheme() { + "https" => "wss", + "http" => "ws", + scheme => Err(anyhow!("invalid URL scheme: {scheme}"))?, + }) + .map_err(|_| anyhow!("failed to set URL scheme"))?; + + let credentials = self.credentials.read(); + let credentials = credentials.as_ref().context("no credentials provided")?; + let authorization_header = format!("{} {}", credentials.user_id, credentials.access_token); + + Ok(cx.spawn(async move |cx| { + let handle = cx + .update(|cx| Tokio::handle(cx)) + .ok() + .context("failed to get Tokio handle")?; + let _guard = handle.enter(); + + let ws = WebSocket::connect(connect_url) + .with_request( + request::Builder::new() + .header("Authorization", authorization_header) + .header(PROTOCOL_VERSION_HEADER_NAME, PROTOCOL_VERSION.to_string()), + ) + .await?; + + Ok(Connection::new(ws)) + })) + } + + pub async fn accept_terms_of_service(&self) -> Result { + let request = self.build_request( + Request::builder().method(Method::POST).uri( + self.http_client + .build_zed_cloud_url("/client/terms_of_service/accept", &[])? + .as_ref(), + ), + AsyncBody::default(), + )?; + + let mut response = self.http_client.send(request).await?; + + if !response.status().is_success() { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + anyhow::bail!( + "Failed to accept terms of service.\nStatus: {:?}\nBody: {body}", + response.status() + ) + } + + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + Ok(serde_json::from_str(&body)?) + } + + pub async fn create_llm_token( + &self, + system_id: Option, + ) -> Result { + let mut request_builder = Request::builder().method(Method::POST).uri( + self.http_client + .build_zed_cloud_url("/client/llm_tokens", &[])? + .as_ref(), + ); + + if let Some(system_id) = system_id { + request_builder = request_builder.header(ZED_SYSTEM_ID_HEADER_NAME, system_id); + } + + let request = self.build_request(request_builder, AsyncBody::default())?; + + let mut response = self.http_client.send(request).await?; + + if !response.status().is_success() { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + anyhow::bail!( + "Failed to create LLM token.\nStatus: {:?}\nBody: {body}", + response.status() + ) + } + + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + Ok(serde_json::from_str(&body)?) + } + + pub async fn validate_credentials(&self, user_id: u32, access_token: &str) -> Result { + let request = build_request( + Request::builder().method(Method::GET).uri( + self.http_client + .build_zed_cloud_url("/client/users/me", &[])? + .as_ref(), + ), + AsyncBody::default(), + &Credentials { + user_id, + access_token: access_token.into(), + }, + )?; + + let mut response = self.http_client.send(request).await?; + + if response.status().is_success() { + Ok(true) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + if response.status() == StatusCode::UNAUTHORIZED { + return Ok(false); + } else { + return Err(anyhow!( + "Failed to get authenticated user.\nStatus: {:?}\nBody: {body}", + response.status() + )); + } + } + } +} + +fn build_request( + req: request::Builder, + body: impl Into, + credentials: &Credentials, +) -> Result> { + Ok(req + .header("Content-Type", "application/json") + .header( + "Authorization", + format!("{} {}", credentials.user_id, credentials.access_token), + ) + .body(body.into())?) +} diff --git a/crates/cloud_api_client/src/websocket.rs b/crates/cloud_api_client/src/websocket.rs new file mode 100644 index 0000000000000000000000000000000000000000..48a628db78b2fedc12bdce4dbe806c0e7bb37e63 --- /dev/null +++ b/crates/cloud_api_client/src/websocket.rs @@ -0,0 +1,73 @@ +use std::pin::Pin; +use std::time::Duration; + +use anyhow::Result; +use cloud_api_types::websocket_protocol::MessageToClient; +use futures::channel::mpsc::unbounded; +use futures::stream::{SplitSink, SplitStream}; +use futures::{FutureExt as _, SinkExt as _, Stream, StreamExt as _, TryStreamExt as _, pin_mut}; +use gpui::{App, BackgroundExecutor, Task}; +use yawc::WebSocket; +use yawc::frame::{FrameView, OpCode}; + +const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1); + +pub type MessageStream = Pin>>>; + +pub struct Connection { + tx: SplitSink, + rx: SplitStream, +} + +impl Connection { + pub fn new(ws: WebSocket) -> Self { + let (tx, rx) = ws.split(); + + Self { tx, rx } + } + + pub fn spawn(self, cx: &App) -> (MessageStream, Task<()>) { + let (mut tx, rx) = (self.tx, self.rx); + + let (message_tx, message_rx) = unbounded(); + + let handle_io = |executor: BackgroundExecutor| async move { + // Send messages on this frequency so the connection isn't closed. + let keepalive_timer = executor.timer(KEEPALIVE_INTERVAL).fuse(); + futures::pin_mut!(keepalive_timer); + + let rx = rx.fuse(); + pin_mut!(rx); + + loop { + futures::select_biased! { + _ = keepalive_timer => { + let _ = tx.send(FrameView::ping(Vec::new())).await; + + keepalive_timer.set(executor.timer(KEEPALIVE_INTERVAL).fuse()); + } + frame = rx.next() => { + let Some(frame) = frame else { + break; + }; + + match frame.opcode { + OpCode::Binary => { + let message_result = MessageToClient::deserialize(&frame.payload); + message_tx.unbounded_send(message_result).ok(); + } + OpCode::Close => { + break; + } + _ => {} + } + } + } + } + }; + + let task = cx.spawn(async move |cx| handle_io(cx.background_executor().clone()).await); + + (message_rx.into_stream().boxed(), task) + } +} diff --git a/crates/cloud_api_types/Cargo.toml b/crates/cloud_api_types/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..28e0a36a44f023e883bea98e4facacd9085e0efb --- /dev/null +++ b/crates/cloud_api_types/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "cloud_api_types" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "Apache-2.0" + +[lints] +workspace = true + +[lib] +path = "src/cloud_api_types.rs" + +[dependencies] +anyhow.workspace = true +chrono.workspace = true +ciborium.workspace = true +cloud_llm_client.workspace = true +serde.workspace = true +workspace-hack.workspace = true + +[dev-dependencies] +pretty_assertions.workspace = true +serde_json.workspace = true diff --git a/crates/cloud_api_types/LICENSE-APACHE b/crates/cloud_api_types/LICENSE-APACHE new file mode 120000 index 0000000000000000000000000000000000000000..1cd601d0a3affae83854be02a0afdec3b7a9ec4d --- /dev/null +++ b/crates/cloud_api_types/LICENSE-APACHE @@ -0,0 +1 @@ +../../LICENSE-APACHE \ No newline at end of file diff --git a/crates/cloud_api_types/src/cloud_api_types.rs b/crates/cloud_api_types/src/cloud_api_types.rs new file mode 100644 index 0000000000000000000000000000000000000000..fa189cd3b5ed7e87e2f3f2a6803d9095c6305105 --- /dev/null +++ b/crates/cloud_api_types/src/cloud_api_types.rs @@ -0,0 +1,56 @@ +mod timestamp; +pub mod websocket_protocol; + +use serde::{Deserialize, Serialize}; + +pub use crate::timestamp::Timestamp; + +pub const ZED_SYSTEM_ID_HEADER_NAME: &str = "x-zed-system-id"; + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct GetAuthenticatedUserResponse { + pub user: AuthenticatedUser, + pub feature_flags: Vec, + pub plan: PlanInfo, +} + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct AuthenticatedUser { + pub id: i32, + pub metrics_id: String, + pub avatar_url: String, + pub github_login: String, + pub name: Option, + pub is_staff: bool, + pub accepted_tos_at: Option, +} + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct PlanInfo { + pub plan: cloud_llm_client::Plan, + pub subscription_period: Option, + pub usage: cloud_llm_client::CurrentUsage, + pub trial_started_at: Option, + pub is_usage_based_billing_enabled: bool, + pub is_account_too_young: bool, + pub has_overdue_invoices: bool, +} + +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] +pub struct SubscriptionPeriod { + pub started_at: Timestamp, + pub ended_at: Timestamp, +} + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct AcceptTermsOfServiceResponse { + pub user: AuthenticatedUser, +} + +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +pub struct LlmToken(pub String); + +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +pub struct CreateLlmTokenResponse { + pub token: LlmToken, +} diff --git a/crates/cloud_api_types/src/timestamp.rs b/crates/cloud_api_types/src/timestamp.rs new file mode 100644 index 0000000000000000000000000000000000000000..1f055d58ef42a856ee9cbf5c2effd2edeaa7c850 --- /dev/null +++ b/crates/cloud_api_types/src/timestamp.rs @@ -0,0 +1,166 @@ +use chrono::{DateTime, NaiveDateTime, SecondsFormat, Utc}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +/// A timestamp with a serialized representation in RFC 3339 format. +#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] +pub struct Timestamp(pub DateTime); + +impl Timestamp { + pub fn new(datetime: DateTime) -> Self { + Self(datetime) + } +} + +impl From> for Timestamp { + fn from(value: DateTime) -> Self { + Self(value) + } +} + +impl From for Timestamp { + fn from(value: NaiveDateTime) -> Self { + Self(value.and_utc()) + } +} + +impl Serialize for Timestamp { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let rfc3339_string = self.0.to_rfc3339_opts(SecondsFormat::Millis, true); + serializer.serialize_str(&rfc3339_string) + } +} + +impl<'de> Deserialize<'de> for Timestamp { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let value = String::deserialize(deserializer)?; + let datetime = DateTime::parse_from_rfc3339(&value) + .map_err(serde::de::Error::custom)? + .to_utc(); + Ok(Self(datetime)) + } +} + +#[cfg(test)] +mod tests { + use chrono::NaiveDate; + use pretty_assertions::assert_eq; + + use super::*; + + #[test] + fn test_timestamp_serialization() { + let datetime = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z") + .unwrap() + .to_utc(); + let timestamp = Timestamp::new(datetime); + + let json = serde_json::to_string(×tamp).unwrap(); + assert_eq!(json, "\"2023-12-25T14:30:45.123Z\""); + } + + #[test] + fn test_timestamp_deserialization() { + let json = "\"2023-12-25T14:30:45.123Z\""; + let timestamp: Timestamp = serde_json::from_str(json).unwrap(); + + let expected = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z") + .unwrap() + .to_utc(); + + assert_eq!(timestamp.0, expected); + } + + #[test] + fn test_timestamp_roundtrip() { + let original = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z") + .unwrap() + .to_utc(); + + let timestamp = Timestamp::new(original); + let json = serde_json::to_string(×tamp).unwrap(); + let deserialized: Timestamp = serde_json::from_str(&json).unwrap(); + + assert_eq!(deserialized.0, original); + } + + #[test] + fn test_timestamp_from_datetime_utc() { + let datetime = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z") + .unwrap() + .to_utc(); + + let timestamp = Timestamp::from(datetime); + assert_eq!(timestamp.0, datetime); + } + + #[test] + fn test_timestamp_from_naive_datetime() { + let naive_dt = NaiveDate::from_ymd_opt(2023, 12, 25) + .unwrap() + .and_hms_milli_opt(14, 30, 45, 123) + .unwrap(); + + let timestamp = Timestamp::from(naive_dt); + let expected = naive_dt.and_utc(); + + assert_eq!(timestamp.0, expected); + } + + #[test] + fn test_timestamp_serialization_with_microseconds() { + // Test that microseconds are truncated to milliseconds + let datetime = NaiveDate::from_ymd_opt(2023, 12, 25) + .unwrap() + .and_hms_micro_opt(14, 30, 45, 123456) + .unwrap() + .and_utc(); + + let timestamp = Timestamp::new(datetime); + let json = serde_json::to_string(×tamp).unwrap(); + + // Should be truncated to milliseconds + assert_eq!(json, "\"2023-12-25T14:30:45.123Z\""); + } + + #[test] + fn test_timestamp_deserialization_without_milliseconds() { + let json = "\"2023-12-25T14:30:45Z\""; + let timestamp: Timestamp = serde_json::from_str(json).unwrap(); + + let expected = NaiveDate::from_ymd_opt(2023, 12, 25) + .unwrap() + .and_hms_opt(14, 30, 45) + .unwrap() + .and_utc(); + + assert_eq!(timestamp.0, expected); + } + + #[test] + fn test_timestamp_deserialization_with_timezone() { + let json = "\"2023-12-25T14:30:45.123+05:30\""; + let timestamp: Timestamp = serde_json::from_str(json).unwrap(); + + // Should be converted to UTC + let expected = NaiveDate::from_ymd_opt(2023, 12, 25) + .unwrap() + .and_hms_milli_opt(9, 0, 45, 123) // 14:30:45 + 5:30 = 20:00:45, but we want UTC so subtract 5:30 + .unwrap() + .and_utc(); + + assert_eq!(timestamp.0, expected); + } + + #[test] + fn test_timestamp_deserialization_with_invalid_format() { + let json = "\"invalid-date\""; + let result: Result = serde_json::from_str(json); + assert!(result.is_err()); + } +} diff --git a/crates/cloud_api_types/src/websocket_protocol.rs b/crates/cloud_api_types/src/websocket_protocol.rs new file mode 100644 index 0000000000000000000000000000000000000000..75f6a73b43d16226bc596ece038e721ce709227e --- /dev/null +++ b/crates/cloud_api_types/src/websocket_protocol.rs @@ -0,0 +1,28 @@ +use anyhow::{Context as _, Result}; +use serde::{Deserialize, Serialize}; + +/// The version of the Cloud WebSocket protocol. +pub const PROTOCOL_VERSION: u32 = 0; + +/// The name of the header used to indicate the protocol version in use. +pub const PROTOCOL_VERSION_HEADER_NAME: &str = "x-zed-protocol-version"; + +/// A message from Cloud to the Zed client. +#[derive(Debug, Serialize, Deserialize)] +pub enum MessageToClient { + /// The user was updated and should be refreshed. + UserUpdated, +} + +impl MessageToClient { + pub fn serialize(&self) -> Result> { + let mut buffer = Vec::new(); + ciborium::into_writer(self, &mut buffer).context("failed to serialize message")?; + + Ok(buffer) + } + + pub fn deserialize(data: &[u8]) -> Result { + ciborium::from_reader(data).context("failed to deserialize message") + } +} diff --git a/crates/cloud_llm_client/Cargo.toml b/crates/cloud_llm_client/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..6f090d3c6ea67d8bb189212fb9704b618554f671 --- /dev/null +++ b/crates/cloud_llm_client/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "cloud_llm_client" +version = "0.1.0" +publish.workspace = true +edition.workspace = true +license = "Apache-2.0" + +[lints] +workspace = true + +[lib] +path = "src/cloud_llm_client.rs" + +[dependencies] +anyhow.workspace = true +serde = { workspace = true, features = ["derive", "rc"] } +serde_json.workspace = true +strum = { workspace = true, features = ["derive"] } +uuid = { workspace = true, features = ["serde"] } +workspace-hack.workspace = true + +[dev-dependencies] +pretty_assertions.workspace = true diff --git a/crates/cloud_llm_client/LICENSE-APACHE b/crates/cloud_llm_client/LICENSE-APACHE new file mode 120000 index 0000000000000000000000000000000000000000..1cd601d0a3affae83854be02a0afdec3b7a9ec4d --- /dev/null +++ b/crates/cloud_llm_client/LICENSE-APACHE @@ -0,0 +1 @@ +../../LICENSE-APACHE \ No newline at end of file diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs new file mode 100644 index 0000000000000000000000000000000000000000..e78957ec4905b05bae4752613707f84316edcfba --- /dev/null +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -0,0 +1,386 @@ +use std::str::FromStr; +use std::sync::Arc; + +use anyhow::Context as _; +use serde::{Deserialize, Serialize}; +use strum::{Display, EnumIter, EnumString}; +use uuid::Uuid; + +/// The name of the header used to indicate which version of Zed the client is running. +pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version"; + +/// The name of the header used to indicate when a request failed due to an +/// expired LLM token. +/// +/// The client may use this as a signal to refresh the token. +pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token"; + +/// The name of the header used to indicate what plan the user is currently on. +pub const CURRENT_PLAN_HEADER_NAME: &str = "x-zed-plan"; + +/// The name of the header used to indicate the usage limit for model requests. +pub const MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-model-requests-usage-limit"; + +/// The name of the header used to indicate the usage amount for model requests. +pub const MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-model-requests-usage-amount"; + +/// The name of the header used to indicate the usage limit for edit predictions. +pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit"; + +/// The name of the header used to indicate the usage amount for edit predictions. +pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount"; + +/// The name of the header used to indicate the resource for which the subscription limit has been reached. +pub const SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME: &str = "x-zed-subscription-limit-resource"; + +pub const MODEL_REQUESTS_RESOURCE_HEADER_VALUE: &str = "model_requests"; +pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions"; + +/// The name of the header used to indicate that the maximum number of consecutive tool uses has been reached. +pub const TOOL_USE_LIMIT_REACHED_HEADER_NAME: &str = "x-zed-tool-use-limit-reached"; + +/// The name of the header used to indicate the the minimum required Zed version. +/// +/// This can be used to force a Zed upgrade in order to continue communicating +/// with the LLM service. +pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version"; + +/// The name of the header used by the client to indicate to the server that it supports receiving status messages. +pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str = + "x-zed-client-supports-status-messages"; + +/// The name of the header used by the server to indicate to the client that it supports sending status messages. +pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str = + "x-zed-server-supports-status-messages"; + +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum UsageLimit { + Limited(i32), + Unlimited, +} + +impl FromStr for UsageLimit { + type Err = anyhow::Error; + + fn from_str(value: &str) -> Result { + match value { + "unlimited" => Ok(Self::Unlimited), + limit => limit + .parse::() + .map(Self::Limited) + .context("failed to parse limit"), + } + } +} + +#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum Plan { + #[default] + #[serde(alias = "Free")] + ZedFree, + #[serde(alias = "ZedPro")] + ZedPro, + #[serde(alias = "ZedProTrial")] + ZedProTrial, +} + +impl Plan { + pub fn as_str(&self) -> &'static str { + match self { + Plan::ZedFree => "zed_free", + Plan::ZedPro => "zed_pro", + Plan::ZedProTrial => "zed_pro_trial", + } + } + + pub fn model_requests_limit(&self) -> UsageLimit { + match self { + Plan::ZedPro => UsageLimit::Limited(500), + Plan::ZedProTrial => UsageLimit::Limited(150), + Plan::ZedFree => UsageLimit::Limited(50), + } + } + + pub fn edit_predictions_limit(&self) -> UsageLimit { + match self { + Plan::ZedPro => UsageLimit::Unlimited, + Plan::ZedProTrial => UsageLimit::Unlimited, + Plan::ZedFree => UsageLimit::Limited(2_000), + } + } +} + +impl FromStr for Plan { + type Err = anyhow::Error; + + fn from_str(value: &str) -> Result { + match value { + "zed_free" => Ok(Plan::ZedFree), + "zed_pro" => Ok(Plan::ZedPro), + "zed_pro_trial" => Ok(Plan::ZedProTrial), + plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")), + } + } +} + +#[derive( + Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display, +)] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum LanguageModelProvider { + Anthropic, + OpenAi, + Google, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PredictEditsBody { + #[serde(skip_serializing_if = "Option::is_none", default)] + pub outline: Option, + pub input_events: String, + pub input_excerpt: String, + #[serde(skip_serializing_if = "Option::is_none", default)] + pub speculated_output: Option, + /// Whether the user provided consent for sampling this interaction. + #[serde(default, alias = "data_collection_permission")] + pub can_collect_data: bool, + #[serde(skip_serializing_if = "Option::is_none", default)] + pub diagnostic_groups: Option>, + /// Info about the git repository state, only present when can_collect_data is true. + #[serde(skip_serializing_if = "Option::is_none", default)] + pub git_info: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PredictEditsGitInfo { + /// SHA of git HEAD commit at time of prediction. + #[serde(skip_serializing_if = "Option::is_none", default)] + pub head_sha: Option, + /// URL of the remote called `origin`. + #[serde(skip_serializing_if = "Option::is_none", default)] + pub remote_origin_url: Option, + /// URL of the remote called `upstream`. + #[serde(skip_serializing_if = "Option::is_none", default)] + pub remote_upstream_url: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PredictEditsResponse { + pub request_id: Uuid, + pub output_excerpt: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AcceptEditPredictionBody { + pub request_id: Uuid, +} + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum CompletionMode { + Normal, + Max, +} + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum CompletionIntent { + UserPrompt, + ToolResults, + ThreadSummarization, + ThreadContextSummarization, + CreateFile, + EditFile, + InlineAssist, + TerminalInlineAssist, + GenerateGitCommitMessage, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CompletionBody { + #[serde(skip_serializing_if = "Option::is_none", default)] + pub thread_id: Option, + #[serde(skip_serializing_if = "Option::is_none", default)] + pub prompt_id: Option, + #[serde(skip_serializing_if = "Option::is_none", default)] + pub intent: Option, + #[serde(skip_serializing_if = "Option::is_none", default)] + pub mode: Option, + pub provider: LanguageModelProvider, + pub model: String, + pub provider_request: serde_json::Value, +} + +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum CompletionRequestStatus { + Queued { + position: usize, + }, + Started, + Failed { + code: String, + message: String, + request_id: Uuid, + /// Retry duration in seconds. + retry_after: Option, + }, + UsageUpdated { + amount: usize, + limit: UsageLimit, + }, + ToolUseLimitReached, +} + +#[derive(Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum CompletionEvent { + Status(CompletionRequestStatus), + Event(T), +} + +impl CompletionEvent { + pub fn into_status(self) -> Option { + match self { + Self::Status(status) => Some(status), + Self::Event(_) => None, + } + } + + pub fn into_event(self) -> Option { + match self { + Self::Event(event) => Some(event), + Self::Status(_) => None, + } + } +} + +#[derive(Serialize, Deserialize)] +pub struct WebSearchBody { + pub query: String, +} + +#[derive(Serialize, Deserialize, Clone)] +pub struct WebSearchResponse { + pub results: Vec, +} + +#[derive(Serialize, Deserialize, Clone)] +pub struct WebSearchResult { + pub title: String, + pub url: String, + pub text: String, +} + +#[derive(Serialize, Deserialize)] +pub struct CountTokensBody { + pub provider: LanguageModelProvider, + pub model: String, + pub provider_request: serde_json::Value, +} + +#[derive(Serialize, Deserialize)] +pub struct CountTokensResponse { + pub tokens: usize, +} + +#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] +pub struct LanguageModelId(pub Arc); + +impl std::fmt::Display for LanguageModelId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct LanguageModel { + pub provider: LanguageModelProvider, + pub id: LanguageModelId, + pub display_name: String, + pub max_token_count: usize, + pub max_token_count_in_max_mode: Option, + pub max_output_tokens: usize, + pub supports_tools: bool, + pub supports_images: bool, + pub supports_thinking: bool, + pub supports_max_mode: bool, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ListModelsResponse { + pub models: Vec, + pub default_model: LanguageModelId, + pub default_fast_model: LanguageModelId, + pub recommended_models: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct GetSubscriptionResponse { + pub plan: Plan, + pub usage: Option, +} + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct CurrentUsage { + pub model_requests: UsageData, + pub edit_predictions: UsageData, +} + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct UsageData { + pub used: u32, + pub limit: UsageLimit, +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + use serde_json::json; + + use super::*; + + #[test] + fn test_plan_deserialize_snake_case() { + let plan = serde_json::from_value::(json!("zed_free")).unwrap(); + assert_eq!(plan, Plan::ZedFree); + + let plan = serde_json::from_value::(json!("zed_pro")).unwrap(); + assert_eq!(plan, Plan::ZedPro); + + let plan = serde_json::from_value::(json!("zed_pro_trial")).unwrap(); + assert_eq!(plan, Plan::ZedProTrial); + } + + #[test] + fn test_plan_deserialize_aliases() { + let plan = serde_json::from_value::(json!("Free")).unwrap(); + assert_eq!(plan, Plan::ZedFree); + + let plan = serde_json::from_value::(json!("ZedPro")).unwrap(); + assert_eq!(plan, Plan::ZedPro); + + let plan = serde_json::from_value::(json!("ZedProTrial")).unwrap(); + assert_eq!(plan, Plan::ZedProTrial); + } + + #[test] + fn test_usage_limit_from_str() { + let limit = UsageLimit::from_str("unlimited").unwrap(); + assert!(matches!(limit, UsageLimit::Unlimited)); + + let limit = UsageLimit::from_str(&0.to_string()).unwrap(); + assert!(matches!(limit, UsageLimit::Limited(0))); + + let limit = UsageLimit::from_str(&50.to_string()).unwrap(); + assert!(matches!(limit, UsageLimit::Limited(50))); + + for value in ["not_a_number", "50xyz"] { + let limit = UsageLimit::from_str(value); + assert!(limit.is_err()); + } + } +} diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 55c15cac5ac84c9d166c54a127dd18b2237b9bd9..9af95317e60db78fc93b9a1fa01eaee687fac4fc 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -23,19 +23,20 @@ async-stripe.workspace = true async-trait.workspace = true async-tungstenite.workspace = true aws-config = { version = "1.1.5" } -aws-sdk-s3 = { version = "1.15.0" } aws-sdk-kinesis = "1.51.0" +aws-sdk-s3 = { version = "1.15.0" } axum = { version = "0.6", features = ["json", "headers", "ws"] } axum-extra = { version = "0.4", features = ["erased-json"] } base64.workspace = true chrono.workspace = true clock.workspace = true +cloud_llm_client.workspace = true collections.workspace = true dashmap.workspace = true derive_more.workspace = true envy = "0.4.2" futures.workspace = true -gpui = { workspace = true, features = ["screen-capture"] } +gpui.workspace = true hex.workspace = true http_client.workspace = true jsonwebtoken.workspace = true @@ -75,7 +76,6 @@ tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json", "re util.workspace = true uuid.workspace = true workspace-hack.workspace = true -zed_llm_client.workspace = true [dev-dependencies] agent_settings.workspace = true @@ -94,6 +94,7 @@ context_server.workspace = true ctor.workspace = true dap = { workspace = true, features = ["test-support"] } dap_adapters = { workspace = true, features = ["test-support"] } +dap-types.workspace = true debugger_ui = { workspace = true, features = ["test-support"] } editor = { workspace = true, features = ["test-support"] } extension.workspace = true @@ -126,6 +127,7 @@ sea-orm = { version = "1.1.0-rc.1", features = ["sqlx-sqlite"] } serde_json.workspace = true session = { workspace = true, features = ["test-support"] } settings = { workspace = true, features = ["test-support"] } +smol.workspace = true sqlx = { version = "0.8", features = ["sqlite"] } task.workspace = true theme.workspace = true diff --git a/crates/collab/k8s/environments/production.sh b/crates/collab/k8s/environments/production.sh index e9e68849b88a5cb7afe4301dbf2805b22ea8a14d..2861f378962fb7c4a5f6fd22a7b9f6ef906af301 100644 --- a/crates/collab/k8s/environments/production.sh +++ b/crates/collab/k8s/environments/production.sh @@ -2,5 +2,6 @@ ZED_ENVIRONMENT=production RUST_LOG=info INVITE_LINK_PREFIX=https://zed.dev/invites/ AUTO_JOIN_CHANNEL_ID=283 -DATABASE_MAX_CONNECTIONS=250 +# Set DATABASE_MAX_CONNECTIONS max connections in the `deploy_collab.yml`: +# https://github.com/zed-industries/zed/blob/main/.github/workflows/deploy_collab.yml LLM_DATABASE_MAX_CONNECTIONS=25 diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index ca840493ad5772b2eaa054f86c8c927fea5d13b9..73d473ab767e633ae2cefc309d87074523811851 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -173,6 +173,7 @@ CREATE TABLE "language_servers" ( "id" INTEGER NOT NULL, "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, "name" VARCHAR NOT NULL, + "capabilities" TEXT NOT NULL, PRIMARY KEY (project_id, id) ); diff --git a/crates/collab/migrations/20250804080620_language_server_capabilities.sql b/crates/collab/migrations/20250804080620_language_server_capabilities.sql new file mode 100644 index 0000000000000000000000000000000000000000..f74f094ed25d488720f2f85f30b6762f83647b02 --- /dev/null +++ b/crates/collab/migrations/20250804080620_language_server_capabilities.sql @@ -0,0 +1,5 @@ +ALTER TABLE language_servers + ADD COLUMN capabilities TEXT NOT NULL DEFAULT '{}'; + +ALTER TABLE language_servers + ALTER COLUMN capabilities DROP DEFAULT; diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index 7fca27c5c2b9580b7ef6546e4188c0aac7f73e3c..6cf3f68f54eda75ac19950c53cf535ff30a107a9 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -11,7 +11,9 @@ use crate::{ db::{User, UserId}, rpc, }; +use ::rpc::proto; use anyhow::Context as _; +use axum::extract; use axum::{ Extension, Json, Router, body::Body, @@ -23,6 +25,7 @@ use axum::{ routing::{get, post}, }; use axum_extra::response::ErasedJson; +use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use std::sync::{Arc, OnceLock}; use tower::ServiceBuilder; @@ -97,11 +100,11 @@ impl std::fmt::Display for SystemIdHeader { pub fn routes(rpc_server: Arc) -> Router<(), Body> { Router::new() - .route("/user", get(update_or_create_authenticated_user)) .route("/users/look_up", get(look_up_user)) .route("/users/:id/access_tokens", post(create_access_token)) + .route("/users/:id/refresh_llm_tokens", post(refresh_llm_tokens)) + .route("/users/:id/update_plan", post(update_plan)) .route("/rpc_server_snapshot", get(get_rpc_server_snapshot)) - .merge(billing::router()) .merge(contributors::router()) .layer( ServiceBuilder::new() @@ -141,48 +144,6 @@ pub async fn validate_api_token(req: Request, next: Next) -> impl IntoR Ok::<_, Error>(next.run(req).await) } -#[derive(Debug, Deserialize)] -struct AuthenticatedUserParams { - github_user_id: i32, - github_login: String, - github_email: Option, - github_name: Option, - github_user_created_at: chrono::DateTime, -} - -#[derive(Debug, Serialize)] -struct AuthenticatedUserResponse { - user: User, - metrics_id: String, - feature_flags: Vec, -} - -async fn update_or_create_authenticated_user( - Query(params): Query, - Extension(app): Extension>, -) -> Result> { - let initial_channel_id = app.config.auto_join_channel_id; - - let user = app - .db - .update_or_create_user_by_github_account( - ¶ms.github_login, - params.github_user_id, - params.github_email.as_deref(), - params.github_name.as_deref(), - params.github_user_created_at, - initial_channel_id, - ) - .await?; - let metrics_id = app.db.get_user_metrics_id(user.id).await?; - let feature_flags = app.db.get_user_flags(user.id).await?; - Ok(Json(AuthenticatedUserResponse { - user, - metrics_id, - feature_flags, - })) -} - #[derive(Debug, Deserialize)] struct LookUpUserParams { identifier: String, @@ -334,3 +295,90 @@ async fn create_access_token( encrypted_access_token, })) } + +#[derive(Serialize)] +struct RefreshLlmTokensResponse {} + +async fn refresh_llm_tokens( + Path(user_id): Path, + Extension(rpc_server): Extension>, +) -> Result> { + rpc_server.refresh_llm_tokens_for_user(user_id).await; + + Ok(Json(RefreshLlmTokensResponse {})) +} + +#[derive(Debug, Serialize, Deserialize)] +struct UpdatePlanBody { + pub plan: cloud_llm_client::Plan, + pub subscription_period: SubscriptionPeriod, + pub usage: cloud_llm_client::CurrentUsage, + pub trial_started_at: Option>, + pub is_usage_based_billing_enabled: bool, + pub is_account_too_young: bool, + pub has_overdue_invoices: bool, +} + +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] +struct SubscriptionPeriod { + pub started_at: DateTime, + pub ended_at: DateTime, +} + +#[derive(Serialize)] +struct UpdatePlanResponse {} + +async fn update_plan( + Path(user_id): Path, + Extension(rpc_server): Extension>, + extract::Json(body): extract::Json, +) -> Result> { + let plan = match body.plan { + cloud_llm_client::Plan::ZedFree => proto::Plan::Free, + cloud_llm_client::Plan::ZedPro => proto::Plan::ZedPro, + cloud_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial, + }; + + let update_user_plan = proto::UpdateUserPlan { + plan: plan.into(), + trial_started_at: body + .trial_started_at + .map(|trial_started_at| trial_started_at.timestamp() as u64), + is_usage_based_billing_enabled: Some(body.is_usage_based_billing_enabled), + usage: Some(proto::SubscriptionUsage { + model_requests_usage_amount: body.usage.model_requests.used, + model_requests_usage_limit: Some(usage_limit_to_proto(body.usage.model_requests.limit)), + edit_predictions_usage_amount: body.usage.edit_predictions.used, + edit_predictions_usage_limit: Some(usage_limit_to_proto( + body.usage.edit_predictions.limit, + )), + }), + subscription_period: Some(proto::SubscriptionPeriod { + started_at: body.subscription_period.started_at.timestamp() as u64, + ended_at: body.subscription_period.ended_at.timestamp() as u64, + }), + account_too_young: Some(body.is_account_too_young), + has_overdue_invoices: Some(body.has_overdue_invoices), + }; + + rpc_server + .update_plan_for_user(user_id, update_user_plan) + .await?; + + Ok(Json(UpdatePlanResponse {})) +} + +fn usage_limit_to_proto(limit: cloud_llm_client::UsageLimit) -> proto::UsageLimit { + proto::UsageLimit { + variant: Some(match limit { + cloud_llm_client::UsageLimit::Limited(limit) => { + proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { + limit: limit as u32, + }) + } + cloud_llm_client::UsageLimit::Unlimited => { + proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) + } + }), + } +} diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index c8df066cbf1bbefd0515000a093d34371842c387..a0325d14c4a1b9f4221b17b446983b17f767fcbe 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -1,1312 +1,10 @@ -use anyhow::{Context as _, bail}; -use axum::{ - Extension, Json, Router, - extract::{self, Query}, - routing::{get, post}, -}; -use chrono::{DateTime, SecondsFormat, Utc}; -use collections::HashSet; -use reqwest::StatusCode; -use sea_orm::ActiveValue; -use serde::{Deserialize, Serialize}; -use serde_json::json; -use std::{str::FromStr, sync::Arc, time::Duration}; -use stripe::{ - BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession, - CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion, - CreateBillingPortalSessionFlowDataAfterCompletionRedirect, - CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm, - CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems, - CreateBillingPortalSessionFlowDataType, CustomerId, EventObject, EventType, ListEvents, - PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus, -}; -use util::{ResultExt, maybe}; +use std::sync::Arc; +use stripe::SubscriptionStatus; -use crate::api::events::SnowflakeRow; -use crate::db::billing_subscription::{ - StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind, -}; -use crate::llm::db::subscription_usage_meter::CompletionMode; -use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND}; -use crate::rpc::{ResultExt as _, Server}; -use crate::stripe_client::{ - StripeCancellationDetailsReason, StripeClient, StripeCustomerId, StripeSubscription, - StripeSubscriptionId, UpdateCustomerParams, -}; -use crate::{AppState, Error, Result}; -use crate::{db::UserId, llm::db::LlmDatabase}; -use crate::{ - db::{ - BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams, - CreateProcessedStripeEventParams, UpdateBillingCustomerParams, - UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, billing_customer, - }, - stripe_billing::StripeBilling, -}; - -pub fn router() -> Router { - Router::new() - .route( - "/billing/preferences", - get(get_billing_preferences).put(update_billing_preferences), - ) - .route( - "/billing/subscriptions", - get(list_billing_subscriptions).post(create_billing_subscription), - ) - .route( - "/billing/subscriptions/manage", - post(manage_billing_subscription), - ) - .route( - "/billing/subscriptions/sync", - post(sync_billing_subscription), - ) - .route("/billing/usage", get(get_current_usage)) -} - -#[derive(Debug, Deserialize)] -struct GetBillingPreferencesParams { - github_user_id: i32, -} - -#[derive(Debug, Serialize)] -struct BillingPreferencesResponse { - trial_started_at: Option, - max_monthly_llm_usage_spending_in_cents: i32, - model_request_overages_enabled: bool, - model_request_overages_spend_limit_in_cents: i32, -} - -async fn get_billing_preferences( - Extension(app): Extension>, - Query(params): Query, -) -> Result> { - let user = app - .db - .get_user_by_github_user_id(params.github_user_id) - .await? - .context("user not found")?; - - let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?; - let preferences = app.db.get_billing_preferences(user.id).await?; - - Ok(Json(BillingPreferencesResponse { - trial_started_at: billing_customer - .and_then(|billing_customer| billing_customer.trial_started_at) - .map(|trial_started_at| { - trial_started_at - .and_utc() - .to_rfc3339_opts(SecondsFormat::Millis, true) - }), - max_monthly_llm_usage_spending_in_cents: preferences - .as_ref() - .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| { - preferences.max_monthly_llm_usage_spending_in_cents - }), - model_request_overages_enabled: preferences.as_ref().map_or(false, |preferences| { - preferences.model_request_overages_enabled - }), - model_request_overages_spend_limit_in_cents: preferences - .as_ref() - .map_or(0, |preferences| { - preferences.model_request_overages_spend_limit_in_cents - }), - })) -} - -#[derive(Debug, Deserialize)] -struct UpdateBillingPreferencesBody { - github_user_id: i32, - #[serde(default)] - max_monthly_llm_usage_spending_in_cents: i32, - #[serde(default)] - model_request_overages_enabled: bool, - #[serde(default)] - model_request_overages_spend_limit_in_cents: i32, -} - -async fn update_billing_preferences( - Extension(app): Extension>, - Extension(rpc_server): Extension>, - extract::Json(body): extract::Json, -) -> Result> { - let user = app - .db - .get_user_by_github_user_id(body.github_user_id) - .await? - .context("user not found")?; - - let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?; - - let max_monthly_llm_usage_spending_in_cents = - body.max_monthly_llm_usage_spending_in_cents.max(0); - let model_request_overages_spend_limit_in_cents = - body.model_request_overages_spend_limit_in_cents.max(0); - - let billing_preferences = - if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? { - app.db - .update_billing_preferences( - user.id, - &UpdateBillingPreferencesParams { - max_monthly_llm_usage_spending_in_cents: ActiveValue::set( - max_monthly_llm_usage_spending_in_cents, - ), - model_request_overages_enabled: ActiveValue::set( - body.model_request_overages_enabled, - ), - model_request_overages_spend_limit_in_cents: ActiveValue::set( - model_request_overages_spend_limit_in_cents, - ), - }, - ) - .await? - } else { - app.db - .create_billing_preferences( - user.id, - &crate::db::CreateBillingPreferencesParams { - max_monthly_llm_usage_spending_in_cents, - model_request_overages_enabled: body.model_request_overages_enabled, - model_request_overages_spend_limit_in_cents, - }, - ) - .await? - }; - - SnowflakeRow::new( - "Billing Preferences Updated", - Some(user.metrics_id), - user.admin, - None, - json!({ - "user_id": user.id, - "model_request_overages_enabled": billing_preferences.model_request_overages_enabled, - "model_request_overages_spend_limit_in_cents": billing_preferences.model_request_overages_spend_limit_in_cents, - "max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents, - }), - ) - .write(&app.kinesis_client, &app.config.kinesis_stream) - .await - .log_err(); - - rpc_server.refresh_llm_tokens_for_user(user.id).await; - - Ok(Json(BillingPreferencesResponse { - trial_started_at: billing_customer - .and_then(|billing_customer| billing_customer.trial_started_at) - .map(|trial_started_at| { - trial_started_at - .and_utc() - .to_rfc3339_opts(SecondsFormat::Millis, true) - }), - max_monthly_llm_usage_spending_in_cents: billing_preferences - .max_monthly_llm_usage_spending_in_cents, - model_request_overages_enabled: billing_preferences.model_request_overages_enabled, - model_request_overages_spend_limit_in_cents: billing_preferences - .model_request_overages_spend_limit_in_cents, - })) -} - -#[derive(Debug, Deserialize)] -struct ListBillingSubscriptionsParams { - github_user_id: i32, -} - -#[derive(Debug, Serialize)] -struct BillingSubscriptionJson { - id: BillingSubscriptionId, - name: String, - status: StripeSubscriptionStatus, - period: Option, - trial_end_at: Option, - cancel_at: Option, - /// Whether this subscription can be canceled. - is_cancelable: bool, -} - -#[derive(Debug, Serialize)] -struct BillingSubscriptionPeriodJson { - start_at: String, - end_at: String, -} - -#[derive(Debug, Serialize)] -struct ListBillingSubscriptionsResponse { - subscriptions: Vec, -} - -async fn list_billing_subscriptions( - Extension(app): Extension>, - Query(params): Query, -) -> Result> { - let user = app - .db - .get_user_by_github_user_id(params.github_user_id) - .await? - .context("user not found")?; - - let subscriptions = app.db.get_billing_subscriptions(user.id).await?; - - Ok(Json(ListBillingSubscriptionsResponse { - subscriptions: subscriptions - .into_iter() - .map(|subscription| BillingSubscriptionJson { - id: subscription.id, - name: match subscription.kind { - Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(), - Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(), - Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(), - None => "Zed LLM Usage".to_string(), - }, - status: subscription.stripe_subscription_status, - period: maybe!({ - let start_at = subscription.current_period_start_at()?; - let end_at = subscription.current_period_end_at()?; - - Some(BillingSubscriptionPeriodJson { - start_at: start_at.to_rfc3339_opts(SecondsFormat::Millis, true), - end_at: end_at.to_rfc3339_opts(SecondsFormat::Millis, true), - }) - }), - trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) { - maybe!({ - let end_at = subscription.stripe_current_period_end?; - let end_at = DateTime::from_timestamp(end_at, 0)?; - - Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true)) - }) - } else { - None - }, - cancel_at: subscription.stripe_cancel_at.map(|cancel_at| { - cancel_at - .and_utc() - .to_rfc3339_opts(SecondsFormat::Millis, true) - }), - is_cancelable: subscription.kind != Some(SubscriptionKind::ZedFree) - && subscription.stripe_subscription_status.is_cancelable() - && subscription.stripe_cancel_at.is_none(), - }) - .collect(), - })) -} - -#[derive(Debug, PartialEq, Clone, Copy, Deserialize)] -#[serde(rename_all = "snake_case")] -enum ProductCode { - ZedPro, - ZedProTrial, -} - -#[derive(Debug, Deserialize)] -struct CreateBillingSubscriptionBody { - github_user_id: i32, - product: ProductCode, -} - -#[derive(Debug, Serialize)] -struct CreateBillingSubscriptionResponse { - checkout_session_url: String, -} - -/// Initiates a Stripe Checkout session for creating a billing subscription. -async fn create_billing_subscription( - Extension(app): Extension>, - extract::Json(body): extract::Json, -) -> Result> { - let user = app - .db - .get_user_by_github_user_id(body.github_user_id) - .await? - .context("user not found")?; - - let Some(stripe_billing) = app.stripe_billing.clone() else { - log::error!("failed to retrieve Stripe billing object"); - Err(Error::http( - StatusCode::NOT_IMPLEMENTED, - "not supported".into(), - ))? - }; - - if let Some(existing_subscription) = app.db.get_active_billing_subscription(user.id).await? { - let is_checkout_allowed = body.product == ProductCode::ZedProTrial - && existing_subscription.kind == Some(SubscriptionKind::ZedFree); - - if !is_checkout_allowed { - return Err(Error::http( - StatusCode::CONFLICT, - "user already has an active subscription".into(), - )); - } - } - - let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?; - if let Some(existing_billing_customer) = &existing_billing_customer { - if existing_billing_customer.has_overdue_invoices { - return Err(Error::http( - StatusCode::PAYMENT_REQUIRED, - "user has overdue invoices".into(), - )); - } - } - - let customer_id = if let Some(existing_customer) = &existing_billing_customer { - let customer_id = StripeCustomerId(existing_customer.stripe_customer_id.clone().into()); - if let Some(email) = user.email_address.as_deref() { - stripe_billing - .client() - .update_customer(&customer_id, UpdateCustomerParams { email: Some(email) }) - .await - // Update of email address is best-effort - continue checkout even if it fails - .context("error updating stripe customer email address") - .log_err(); - } - customer_id - } else { - stripe_billing - .find_or_create_customer_by_email(user.email_address.as_deref()) - .await? - }; - - let success_url = format!( - "{}/account?checkout_complete=1", - app.config.zed_dot_dev_url() - ); - - let checkout_session_url = match body.product { - ProductCode::ZedPro => { - stripe_billing - .checkout_with_zed_pro(&customer_id, &user.github_login, &success_url) - .await? - } - ProductCode::ZedProTrial => { - if let Some(existing_billing_customer) = &existing_billing_customer { - if existing_billing_customer.trial_started_at.is_some() { - return Err(Error::http( - StatusCode::FORBIDDEN, - "user already used free trial".into(), - )); - } - } - - let feature_flags = app.db.get_user_flags(user.id).await?; - - stripe_billing - .checkout_with_zed_pro_trial( - &customer_id, - &user.github_login, - feature_flags, - &success_url, - ) - .await? - } - }; - - Ok(Json(CreateBillingSubscriptionResponse { - checkout_session_url, - })) -} - -#[derive(Debug, PartialEq, Deserialize)] -#[serde(rename_all = "snake_case")] -enum ManageSubscriptionIntent { - /// The user intends to manage their subscription. - /// - /// This will open the Stripe billing portal without putting the user in a specific flow. - ManageSubscription, - /// The user intends to update their payment method. - UpdatePaymentMethod, - /// The user intends to upgrade to Zed Pro. - UpgradeToPro, - /// The user intends to cancel their subscription. - Cancel, - /// The user intends to stop the cancellation of their subscription. - StopCancellation, -} - -#[derive(Debug, Deserialize)] -struct ManageBillingSubscriptionBody { - github_user_id: i32, - intent: ManageSubscriptionIntent, - /// The ID of the subscription to manage. - subscription_id: BillingSubscriptionId, - redirect_to: Option, -} - -#[derive(Debug, Serialize)] -struct ManageBillingSubscriptionResponse { - billing_portal_session_url: Option, -} - -/// Initiates a Stripe customer portal session for managing a billing subscription. -async fn manage_billing_subscription( - Extension(app): Extension>, - extract::Json(body): extract::Json, -) -> Result> { - let user = app - .db - .get_user_by_github_user_id(body.github_user_id) - .await? - .context("user not found")?; - - let Some(stripe_client) = app.real_stripe_client.clone() else { - log::error!("failed to retrieve Stripe client"); - Err(Error::http( - StatusCode::NOT_IMPLEMENTED, - "not supported".into(), - ))? - }; - - let Some(stripe_billing) = app.stripe_billing.clone() else { - log::error!("failed to retrieve Stripe billing object"); - Err(Error::http( - StatusCode::NOT_IMPLEMENTED, - "not supported".into(), - ))? - }; - - let customer = app - .db - .get_billing_customer_by_user_id(user.id) - .await? - .context("billing customer not found")?; - let customer_id = CustomerId::from_str(&customer.stripe_customer_id) - .context("failed to parse customer ID")?; - - let subscription = app - .db - .get_billing_subscription_by_id(body.subscription_id) - .await? - .context("subscription not found")?; - let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id) - .context("failed to parse subscription ID")?; - - if body.intent == ManageSubscriptionIntent::StopCancellation { - let updated_stripe_subscription = Subscription::update( - &stripe_client, - &subscription_id, - stripe::UpdateSubscription { - cancel_at_period_end: Some(false), - ..Default::default() - }, - ) - .await?; - - app.db - .update_billing_subscription( - subscription.id, - &UpdateBillingSubscriptionParams { - stripe_cancel_at: ActiveValue::set( - updated_stripe_subscription - .cancel_at - .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0)) - .map(|time| time.naive_utc()), - ), - ..Default::default() - }, - ) - .await?; - - return Ok(Json(ManageBillingSubscriptionResponse { - billing_portal_session_url: None, - })); - } - - let flow = match body.intent { - ManageSubscriptionIntent::ManageSubscription => None, - ManageSubscriptionIntent::UpgradeToPro => { - let zed_pro_price_id: stripe::PriceId = - stripe_billing.zed_pro_price_id().await?.try_into()?; - let zed_free_price_id: stripe::PriceId = - stripe_billing.zed_free_price_id().await?.try_into()?; - - let stripe_subscription = - Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?; - - let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing - && stripe_subscription.items.data.iter().any(|item| { - item.price - .as_ref() - .map_or(false, |price| price.id == zed_pro_price_id) - }); - if is_on_zed_pro_trial { - let payment_methods = PaymentMethod::list( - &stripe_client, - &stripe::ListPaymentMethods { - customer: Some(stripe_subscription.customer.id()), - ..Default::default() - }, - ) - .await?; - - let has_payment_method = !payment_methods.data.is_empty(); - if !has_payment_method { - return Err(Error::http( - StatusCode::BAD_REQUEST, - "missing payment method".into(), - )); - } - - // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early. - Subscription::update( - &stripe_client, - &stripe_subscription.id, - stripe::UpdateSubscription { - trial_end: Some(stripe::Scheduled::now()), - ..Default::default() - }, - ) - .await?; - - return Ok(Json(ManageBillingSubscriptionResponse { - billing_portal_session_url: None, - })); - } - - let subscription_item_to_update = stripe_subscription - .items - .data - .iter() - .find_map(|item| { - let price = item.price.as_ref()?; - - if price.id == zed_free_price_id { - Some(item.id.clone()) - } else { - None - } - }) - .context("No subscription item to update")?; - - Some(CreateBillingPortalSessionFlowData { - type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm, - subscription_update_confirm: Some( - CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm { - subscription: subscription.stripe_subscription_id, - items: vec![ - CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems { - id: subscription_item_to_update.to_string(), - price: Some(zed_pro_price_id.to_string()), - quantity: Some(1), - }, - ], - discounts: None, - }, - ), - ..Default::default() - }) - } - ManageSubscriptionIntent::UpdatePaymentMethod => Some(CreateBillingPortalSessionFlowData { - type_: CreateBillingPortalSessionFlowDataType::PaymentMethodUpdate, - after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion { - type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect, - redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect { - return_url: format!( - "{}{path}", - app.config.zed_dot_dev_url(), - path = body.redirect_to.unwrap_or_else(|| "/account".to_string()) - ), - }), - ..Default::default() - }), - ..Default::default() - }), - ManageSubscriptionIntent::Cancel => { - if subscription.kind == Some(SubscriptionKind::ZedFree) { - return Err(Error::http( - StatusCode::BAD_REQUEST, - "free subscription cannot be canceled".into(), - )); - } - - Some(CreateBillingPortalSessionFlowData { - type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel, - after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion { - type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect, - redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect { - return_url: format!("{}/account", app.config.zed_dot_dev_url()), - }), - ..Default::default() - }), - subscription_cancel: Some( - stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel { - subscription: subscription.stripe_subscription_id, - retention: None, - }, - ), - ..Default::default() - }) - } - ManageSubscriptionIntent::StopCancellation => unreachable!(), - }; - - let mut params = CreateBillingPortalSession::new(customer_id); - params.flow_data = flow; - let return_url = format!("{}/account", app.config.zed_dot_dev_url()); - params.return_url = Some(&return_url); - - let session = BillingPortalSession::create(&stripe_client, params).await?; - - Ok(Json(ManageBillingSubscriptionResponse { - billing_portal_session_url: Some(session.url), - })) -} - -#[derive(Debug, Deserialize)] -struct SyncBillingSubscriptionBody { - github_user_id: i32, -} - -#[derive(Debug, Serialize)] -struct SyncBillingSubscriptionResponse { - stripe_customer_id: String, -} - -async fn sync_billing_subscription( - Extension(app): Extension>, - extract::Json(body): extract::Json, -) -> Result> { - let Some(stripe_client) = app.stripe_client.clone() else { - log::error!("failed to retrieve Stripe client"); - Err(Error::http( - StatusCode::NOT_IMPLEMENTED, - "not supported".into(), - ))? - }; - - let user = app - .db - .get_user_by_github_user_id(body.github_user_id) - .await? - .context("user not found")?; - - let billing_customer = app - .db - .get_billing_customer_by_user_id(user.id) - .await? - .context("billing customer not found")?; - let stripe_customer_id = StripeCustomerId(billing_customer.stripe_customer_id.clone().into()); - - let subscriptions = stripe_client - .list_subscriptions_for_customer(&stripe_customer_id) - .await?; - - for subscription in subscriptions { - let subscription_id = subscription.id.clone(); - - sync_subscription(&app, &stripe_client, subscription) - .await - .with_context(|| { - format!( - "failed to sync subscription {subscription_id} for user {}", - user.id, - ) - })?; - } - - Ok(Json(SyncBillingSubscriptionResponse { - stripe_customer_id: billing_customer.stripe_customer_id.clone(), - })) -} - -/// The amount of time we wait in between each poll of Stripe events. -/// -/// This value should strike a balance between: -/// 1. Being short enough that we update quickly when something in Stripe changes -/// 2. Being long enough that we don't eat into our rate limits. -/// -/// As a point of reference, the Sequin folks say they have this at **500ms**: -/// -/// > We poll the Stripe /events endpoint every 500ms per account -/// > -/// > — https://blog.sequinstream.com/events-not-webhooks/ -const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5); - -/// The maximum number of events to return per page. -/// -/// We set this to 100 (the max) so we have to make fewer requests to Stripe. -/// -/// > Limit can range between 1 and 100, and the default is 10. -const EVENTS_LIMIT_PER_PAGE: u64 = 100; - -/// The number of pages consisting entirely of already-processed events that we -/// will see before we stop retrieving events. -/// -/// This is used to prevent over-fetching the Stripe events API for events we've -/// already seen and processed. -const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4; - -/// Polls the Stripe events API periodically to reconcile the records in our -/// database with the data in Stripe. -pub fn poll_stripe_events_periodically(app: Arc, rpc_server: Arc) { - let Some(real_stripe_client) = app.real_stripe_client.clone() else { - log::warn!("failed to retrieve Stripe client"); - return; - }; - let Some(stripe_client) = app.stripe_client.clone() else { - log::warn!("failed to retrieve Stripe client"); - return; - }; - - let executor = app.executor.clone(); - executor.spawn_detached({ - let executor = executor.clone(); - async move { - loop { - poll_stripe_events(&app, &rpc_server, &stripe_client, &real_stripe_client) - .await - .log_err(); - - executor.sleep(POLL_EVENTS_INTERVAL).await; - } - } - }); -} - -async fn poll_stripe_events( - app: &Arc, - rpc_server: &Arc, - stripe_client: &Arc, - real_stripe_client: &stripe::Client, -) -> anyhow::Result<()> { - fn event_type_to_string(event_type: EventType) -> String { - // Calling `to_string` on `stripe::EventType` members gives us a quoted string, - // so we need to unquote it. - event_type.to_string().trim_matches('"').to_string() - } - - let event_types = [ - EventType::CustomerCreated, - EventType::CustomerUpdated, - EventType::CustomerSubscriptionCreated, - EventType::CustomerSubscriptionUpdated, - EventType::CustomerSubscriptionPaused, - EventType::CustomerSubscriptionResumed, - EventType::CustomerSubscriptionDeleted, - ] - .into_iter() - .map(event_type_to_string) - .collect::>(); - - let mut pages_of_already_processed_events = 0; - let mut unprocessed_events = Vec::new(); - - log::info!( - "Stripe events: starting retrieval for {}", - event_types.join(", ") - ); - let mut params = ListEvents::new(); - params.types = Some(event_types.clone()); - params.limit = Some(EVENTS_LIMIT_PER_PAGE); - - let mut event_pages = stripe::Event::list(&real_stripe_client, ¶ms) - .await? - .paginate(params); - - loop { - let processed_event_ids = { - let event_ids = event_pages - .page - .data - .iter() - .map(|event| event.id.as_str()) - .collect::>(); - app.db - .get_processed_stripe_events_by_event_ids(&event_ids) - .await? - .into_iter() - .map(|event| event.stripe_event_id) - .collect::>() - }; - - let mut processed_events_in_page = 0; - let events_in_page = event_pages.page.data.len(); - for event in &event_pages.page.data { - if processed_event_ids.contains(&event.id.to_string()) { - processed_events_in_page += 1; - log::debug!("Stripe events: already processed '{}', skipping", event.id); - } else { - unprocessed_events.push(event.clone()); - } - } - - if processed_events_in_page == events_in_page { - pages_of_already_processed_events += 1; - } - - if event_pages.page.has_more { - if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP - { - log::info!( - "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events" - ); - break; - } else { - log::info!("Stripe events: retrieving next page"); - event_pages = event_pages.next(&real_stripe_client).await?; - } - } else { - break; - } - } - - log::info!("Stripe events: unprocessed {}", unprocessed_events.len()); - - // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred. - unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id))); - - for event in unprocessed_events { - let event_id = event.id.clone(); - let processed_event_params = CreateProcessedStripeEventParams { - stripe_event_id: event.id.to_string(), - stripe_event_type: event_type_to_string(event.type_), - stripe_event_created_timestamp: event.created, - }; - - // If the event has happened too far in the past, we don't want to - // process it and risk overwriting other more-recent updates. - // - // 1 day was chosen arbitrarily. This could be made longer or shorter. - let one_day = Duration::from_secs(24 * 60 * 60); - let a_day_ago = Utc::now() - one_day; - if a_day_ago.timestamp() > event.created { - log::info!( - "Stripe events: event '{}' is more than {one_day:?} old, marking as processed", - event_id - ); - app.db - .create_processed_stripe_event(&processed_event_params) - .await?; - - continue; - } - - let process_result = match event.type_ { - EventType::CustomerCreated | EventType::CustomerUpdated => { - handle_customer_event(app, real_stripe_client, event).await - } - EventType::CustomerSubscriptionCreated - | EventType::CustomerSubscriptionUpdated - | EventType::CustomerSubscriptionPaused - | EventType::CustomerSubscriptionResumed - | EventType::CustomerSubscriptionDeleted => { - handle_customer_subscription_event(app, rpc_server, stripe_client, event).await - } - _ => Ok(()), - }; - - if let Some(()) = process_result - .with_context(|| format!("failed to process event {event_id} successfully")) - .log_err() - { - app.db - .create_processed_stripe_event(&processed_event_params) - .await?; - } - } - - Ok(()) -} - -async fn handle_customer_event( - app: &Arc, - _stripe_client: &stripe::Client, - event: stripe::Event, -) -> anyhow::Result<()> { - let EventObject::Customer(customer) = event.data.object else { - bail!("unexpected event payload for {}", event.id); - }; - - log::info!("handling Stripe {} event: {}", event.type_, event.id); - - let Some(email) = customer.email else { - log::info!("Stripe customer has no email: skipping"); - return Ok(()); - }; - - let Some(user) = app.db.get_user_by_email(&email).await? else { - log::info!("no user found for email: skipping"); - return Ok(()); - }; - - if let Some(existing_customer) = app - .db - .get_billing_customer_by_stripe_customer_id(&customer.id) - .await? - { - app.db - .update_billing_customer( - existing_customer.id, - &UpdateBillingCustomerParams { - // For now we just leave the information as-is, as it is not - // likely to change. - ..Default::default() - }, - ) - .await?; - } else { - app.db - .create_billing_customer(&CreateBillingCustomerParams { - user_id: user.id, - stripe_customer_id: customer.id.to_string(), - }) - .await?; - } - - Ok(()) -} - -async fn sync_subscription( - app: &Arc, - stripe_client: &Arc, - subscription: StripeSubscription, -) -> anyhow::Result { - let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing { - stripe_billing - .determine_subscription_kind(&subscription) - .await - } else { - None - }; - - let billing_customer = - find_or_create_billing_customer(app, stripe_client.as_ref(), &subscription.customer) - .await? - .context("billing customer not found")?; - - if let Some(SubscriptionKind::ZedProTrial) = subscription_kind { - if subscription.status == SubscriptionStatus::Trialing { - let current_period_start = - DateTime::from_timestamp(subscription.current_period_start, 0) - .context("No trial subscription period start")?; - - app.db - .update_billing_customer( - billing_customer.id, - &UpdateBillingCustomerParams { - trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())), - ..Default::default() - }, - ) - .await?; - } - } - - let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled - && subscription - .cancellation_details - .as_ref() - .and_then(|details| details.reason) - .map_or(false, |reason| { - reason == StripeCancellationDetailsReason::PaymentFailed - }); - - if was_canceled_due_to_payment_failure { - app.db - .update_billing_customer( - billing_customer.id, - &UpdateBillingCustomerParams { - has_overdue_invoices: ActiveValue::set(true), - ..Default::default() - }, - ) - .await?; - } - - if let Some(existing_subscription) = app - .db - .get_billing_subscription_by_stripe_subscription_id(subscription.id.0.as_ref()) - .await? - { - app.db - .update_billing_subscription( - existing_subscription.id, - &UpdateBillingSubscriptionParams { - billing_customer_id: ActiveValue::set(billing_customer.id), - kind: ActiveValue::set(subscription_kind), - stripe_subscription_id: ActiveValue::set(subscription.id.to_string()), - stripe_subscription_status: ActiveValue::set(subscription.status.into()), - stripe_cancel_at: ActiveValue::set( - subscription - .cancel_at - .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0)) - .map(|time| time.naive_utc()), - ), - stripe_cancellation_reason: ActiveValue::set( - subscription - .cancellation_details - .and_then(|details| details.reason) - .map(|reason| reason.into()), - ), - stripe_current_period_start: ActiveValue::set(Some( - subscription.current_period_start, - )), - stripe_current_period_end: ActiveValue::set(Some( - subscription.current_period_end, - )), - }, - ) - .await?; - } else { - if let Some(existing_subscription) = app - .db - .get_active_billing_subscription(billing_customer.user_id) - .await? - { - if existing_subscription.kind == Some(SubscriptionKind::ZedFree) - && subscription_kind == Some(SubscriptionKind::ZedProTrial) - { - let stripe_subscription_id = StripeSubscriptionId( - existing_subscription.stripe_subscription_id.clone().into(), - ); - - stripe_client - .cancel_subscription(&stripe_subscription_id) - .await?; - } else { - // If the user already has an active billing subscription, ignore the - // event and return an `Ok` to signal that it was processed - // successfully. - // - // There is the possibility that this could cause us to not create a - // subscription in the following scenario: - // - // 1. User has an active subscription A - // 2. User cancels subscription A - // 3. User creates a new subscription B - // 4. We process the new subscription B before the cancellation of subscription A - // 5. User ends up with no subscriptions - // - // In theory this situation shouldn't arise as we try to process the events in the order they occur. - - log::info!( - "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}", - user_id = billing_customer.user_id, - subscription_id = subscription.id - ); - return Ok(billing_customer); - } - } - - app.db - .create_billing_subscription(&CreateBillingSubscriptionParams { - billing_customer_id: billing_customer.id, - kind: subscription_kind, - stripe_subscription_id: subscription.id.to_string(), - stripe_subscription_status: subscription.status.into(), - stripe_cancellation_reason: subscription - .cancellation_details - .and_then(|details| details.reason) - .map(|reason| reason.into()), - stripe_current_period_start: Some(subscription.current_period_start), - stripe_current_period_end: Some(subscription.current_period_end), - }) - .await?; - } - - if let Some(stripe_billing) = app.stripe_billing.as_ref() { - if subscription.status == SubscriptionStatus::Canceled - || subscription.status == SubscriptionStatus::Paused - { - let already_has_active_billing_subscription = app - .db - .has_active_billing_subscription(billing_customer.user_id) - .await?; - if !already_has_active_billing_subscription { - let stripe_customer_id = - StripeCustomerId(billing_customer.stripe_customer_id.clone().into()); - - stripe_billing - .subscribe_to_zed_free(stripe_customer_id) - .await?; - } - } - } - - Ok(billing_customer) -} - -async fn handle_customer_subscription_event( - app: &Arc, - rpc_server: &Arc, - stripe_client: &Arc, - event: stripe::Event, -) -> anyhow::Result<()> { - let EventObject::Subscription(subscription) = event.data.object else { - bail!("unexpected event payload for {}", event.id); - }; - - log::info!("handling Stripe {} event: {}", event.type_, event.id); - - let billing_customer = sync_subscription(app, stripe_client, subscription.into()).await?; - - // When the user's subscription changes, push down any changes to their plan. - rpc_server - .update_plan_for_user(billing_customer.user_id) - .await - .trace_err(); - - // When the user's subscription changes, we want to refresh their LLM tokens - // to either grant/revoke access. - rpc_server - .refresh_llm_tokens_for_user(billing_customer.user_id) - .await; - - Ok(()) -} - -#[derive(Debug, Deserialize)] -struct GetCurrentUsageParams { - github_user_id: i32, -} - -#[derive(Debug, Serialize)] -struct UsageCounts { - pub used: i32, - pub limit: Option, - pub remaining: Option, -} - -#[derive(Debug, Serialize)] -struct ModelRequestUsage { - pub model: String, - pub mode: CompletionMode, - pub requests: i32, -} - -#[derive(Debug, Serialize)] -struct CurrentUsage { - pub model_requests: UsageCounts, - pub model_request_usage: Vec, - pub edit_predictions: UsageCounts, -} - -#[derive(Debug, Default, Serialize)] -struct GetCurrentUsageResponse { - pub plan: String, - pub current_usage: Option, -} - -async fn get_current_usage( - Extension(app): Extension>, - Query(params): Query, -) -> Result> { - let user = app - .db - .get_user_by_github_user_id(params.github_user_id) - .await? - .context("user not found")?; - - let feature_flags = app.db.get_user_flags(user.id).await?; - let has_extended_trial = feature_flags - .iter() - .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG); - - let Some(llm_db) = app.llm_db.clone() else { - return Err(Error::http( - StatusCode::NOT_IMPLEMENTED, - "LLM database not available".into(), - )); - }; - - let Some(subscription) = app.db.get_active_billing_subscription(user.id).await? else { - return Ok(Json(GetCurrentUsageResponse::default())); - }; - - let subscription_period = maybe!({ - let period_start_at = subscription.current_period_start_at()?; - let period_end_at = subscription.current_period_end_at()?; - - Some((period_start_at, period_end_at)) - }); - - let Some((period_start_at, period_end_at)) = subscription_period else { - return Ok(Json(GetCurrentUsageResponse::default())); - }; - - let usage = llm_db - .get_subscription_usage_for_period(user.id, period_start_at, period_end_at) - .await?; - - let plan = subscription - .kind - .map(Into::into) - .unwrap_or(zed_llm_client::Plan::ZedFree); - - let model_requests_limit = match plan.model_requests_limit() { - zed_llm_client::UsageLimit::Limited(limit) => { - let limit = if plan == zed_llm_client::Plan::ZedProTrial && has_extended_trial { - 1_000 - } else { - limit - }; - - Some(limit) - } - zed_llm_client::UsageLimit::Unlimited => None, - }; - - let edit_predictions_limit = match plan.edit_predictions_limit() { - zed_llm_client::UsageLimit::Limited(limit) => Some(limit), - zed_llm_client::UsageLimit::Unlimited => None, - }; - - let Some(usage) = usage else { - return Ok(Json(GetCurrentUsageResponse { - plan: plan.as_str().to_string(), - current_usage: Some(CurrentUsage { - model_requests: UsageCounts { - used: 0, - limit: model_requests_limit, - remaining: model_requests_limit, - }, - model_request_usage: Vec::new(), - edit_predictions: UsageCounts { - used: 0, - limit: edit_predictions_limit, - remaining: edit_predictions_limit, - }, - }), - })); - }; - - let subscription_usage_meters = llm_db - .get_current_subscription_usage_meters_for_user(user.id, Utc::now()) - .await?; - - let model_request_usage = subscription_usage_meters - .into_iter() - .filter_map(|(usage_meter, _usage)| { - let model = llm_db.model_by_id(usage_meter.model_id).ok()?; - - Some(ModelRequestUsage { - model: model.name.clone(), - mode: usage_meter.mode, - requests: usage_meter.requests, - }) - }) - .collect::>(); - - Ok(Json(GetCurrentUsageResponse { - plan: plan.as_str().to_string(), - current_usage: Some(CurrentUsage { - model_requests: UsageCounts { - used: usage.model_requests, - limit: model_requests_limit, - remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)), - }, - model_request_usage, - edit_predictions: UsageCounts { - used: usage.edit_predictions, - limit: edit_predictions_limit, - remaining: edit_predictions_limit - .map(|limit| (limit - usage.edit_predictions).max(0)), - }, - }), - })) -} +use crate::AppState; +use crate::db::billing_subscription::StripeSubscriptionStatus; +use crate::db::{CreateBillingCustomerParams, billing_customer}; +use crate::stripe_client::{StripeClient, StripeCustomerId}; impl From for StripeSubscriptionStatus { fn from(value: SubscriptionStatus) -> Self { @@ -1323,16 +21,6 @@ impl From for StripeSubscriptionStatus { } } -impl From for StripeCancellationReason { - fn from(value: CancellationDetailsReason) -> Self { - match value { - CancellationDetailsReason::CancellationRequested => Self::CancellationRequested, - CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed, - CancellationDetailsReason::PaymentFailed => Self::PaymentFailed, - } - } -} - /// Finds or creates a billing customer using the provided customer. pub async fn find_or_create_billing_customer( app: &Arc, @@ -1369,152 +57,3 @@ pub async fn find_or_create_billing_customer( Ok(Some(billing_customer)) } - -const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60); - -pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc) { - let Some(stripe_billing) = app.stripe_billing.clone() else { - log::warn!("failed to retrieve Stripe billing object"); - return; - }; - let Some(llm_db) = app.llm_db.clone() else { - log::warn!("failed to retrieve LLM database"); - return; - }; - - let executor = app.executor.clone(); - executor.spawn_detached({ - let executor = executor.clone(); - async move { - loop { - sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing) - .await - .context("failed to sync LLM request usage to Stripe") - .trace_err(); - executor - .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL) - .await; - } - } - }); -} - -async fn sync_model_request_usage_with_stripe( - app: &Arc, - llm_db: &Arc, - stripe_billing: &Arc, -) -> anyhow::Result<()> { - log::info!("Stripe usage sync: Starting"); - let started_at = Utc::now(); - - let staff_users = app.db.get_staff_users().await?; - let staff_user_ids = staff_users - .iter() - .map(|user| user.id) - .collect::>(); - - let usage_meters = llm_db - .get_current_subscription_usage_meters(Utc::now()) - .await?; - let usage_meters = usage_meters - .into_iter() - .filter(|(_, usage)| !staff_user_ids.contains(&usage.user_id)) - .collect::>(); - let user_ids = usage_meters - .iter() - .map(|(_, usage)| usage.user_id) - .collect::>(); - let billing_subscriptions = app - .db - .get_active_zed_pro_billing_subscriptions(user_ids) - .await?; - - let claude_sonnet_4 = stripe_billing - .find_price_by_lookup_key("claude-sonnet-4-requests") - .await?; - let claude_sonnet_4_max = stripe_billing - .find_price_by_lookup_key("claude-sonnet-4-requests-max") - .await?; - let claude_opus_4 = stripe_billing - .find_price_by_lookup_key("claude-opus-4-requests") - .await?; - let claude_opus_4_max = stripe_billing - .find_price_by_lookup_key("claude-opus-4-requests-max") - .await?; - let claude_3_5_sonnet = stripe_billing - .find_price_by_lookup_key("claude-3-5-sonnet-requests") - .await?; - let claude_3_7_sonnet = stripe_billing - .find_price_by_lookup_key("claude-3-7-sonnet-requests") - .await?; - let claude_3_7_sonnet_max = stripe_billing - .find_price_by_lookup_key("claude-3-7-sonnet-requests-max") - .await?; - - let usage_meter_count = usage_meters.len(); - - log::info!("Stripe usage sync: Syncing {usage_meter_count} usage meters"); - - for (usage_meter, usage) in usage_meters { - maybe!(async { - let Some((billing_customer, billing_subscription)) = - billing_subscriptions.get(&usage.user_id) - else { - bail!( - "Attempted to sync usage meter for user who is not a Stripe customer: {}", - usage.user_id - ); - }; - - let stripe_customer_id = - StripeCustomerId(billing_customer.stripe_customer_id.clone().into()); - let stripe_subscription_id = - StripeSubscriptionId(billing_subscription.stripe_subscription_id.clone().into()); - - let model = llm_db.model_by_id(usage_meter.model_id)?; - - let (price, meter_event_name) = match model.name.as_str() { - "claude-opus-4" => match usage_meter.mode { - CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"), - CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"), - }, - "claude-sonnet-4" => match usage_meter.mode { - CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"), - CompletionMode::Max => (&claude_sonnet_4_max, "claude_sonnet_4/requests/max"), - }, - "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"), - "claude-3-7-sonnet" => match usage_meter.mode { - CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"), - CompletionMode::Max => { - (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max") - } - }, - model_name => { - bail!("Attempted to sync usage meter for unsupported model: {model_name:?}") - } - }; - - stripe_billing - .subscribe_to_price(&stripe_subscription_id, price) - .await?; - stripe_billing - .bill_model_request_usage( - &stripe_customer_id, - meter_event_name, - usage_meter.requests, - ) - .await?; - - Ok(()) - }) - .await - .log_err(); - } - - log::info!( - "Stripe usage sync: Synced {usage_meter_count} usage meters in {:?}", - Utc::now() - started_at - ); - - Ok(()) -} diff --git a/crates/collab/src/api/contributors.rs b/crates/collab/src/api/contributors.rs index 9296c1d4282078d73dccfe40536fc59102ec248d..8cfef0ad7e717614e23c3cf9d04852c976f1f55f 100644 --- a/crates/collab/src/api/contributors.rs +++ b/crates/collab/src/api/contributors.rs @@ -8,7 +8,6 @@ use axum::{ use chrono::{NaiveDateTime, SecondsFormat}; use serde::{Deserialize, Serialize}; -use crate::api::AuthenticatedUserParams; use crate::db::ContributorSelector; use crate::{AppState, Result}; @@ -104,9 +103,18 @@ impl RenovateBot { } } +#[derive(Debug, Deserialize)] +struct AddContributorBody { + github_user_id: i32, + github_login: String, + github_email: Option, + github_name: Option, + github_user_created_at: chrono::DateTime, +} + async fn add_contributor( Extension(app): Extension>, - extract::Json(params): extract::Json, + extract::Json(params): extract::Json, ) -> Result<()> { let initial_channel_id = app.config.auto_join_channel_id; app.db diff --git a/crates/collab/src/api/events.rs b/crates/collab/src/api/events.rs index 6ccc86c520082998c10e37a5cc4bea339a5d3a8d..2f34a843a860d9d2933a4819788d0f9285473edf 100644 --- a/crates/collab/src/api/events.rs +++ b/crates/collab/src/api/events.rs @@ -389,53 +389,58 @@ pub async fn post_panic( } } - let backtrace = if panic.backtrace.len() > 25 { - let total = panic.backtrace.len(); - format!( - "{}\n and {} more", - panic - .backtrace - .iter() - .take(20) - .cloned() - .collect::>() - .join("\n"), - total - 20 - ) - } else { - panic.backtrace.join("\n") - }; - if !report_to_slack(&panic) { return Ok(()); } - let backtrace_with_summary = panic.payload + "\n" + &backtrace; - if let Some(slack_panics_webhook) = app.config.slack_panics_webhook.clone() { + let backtrace = if panic.backtrace.len() > 25 { + let total = panic.backtrace.len(); + format!( + "{}\n and {} more", + panic + .backtrace + .iter() + .take(20) + .cloned() + .collect::>() + .join("\n"), + total - 20 + ) + } else { + panic.backtrace.join("\n") + }; + let backtrace_with_summary = panic.payload + "\n" + &backtrace; + + let version = if panic.release_channel == "nightly" + && !panic.app_version.contains("remote-server") + && let Some(sha) = panic.app_commit_sha + { + format!("Zed Nightly {}", sha.chars().take(7).collect::()) + } else { + panic.app_version + }; + let payload = slack::WebhookBody::new(|w| { w.add_section(|s| s.text(slack::Text::markdown("Panic request".to_string()))) .add_section(|s| { - s.add_field(slack::Text::markdown(format!( - "*Version:*\n {} ", - panic.app_version - ))) - .add_field({ - let hostname = app.config.blob_store_url.clone().unwrap_or_default(); - let hostname = hostname.strip_prefix("https://").unwrap_or_else(|| { - hostname.strip_prefix("http://").unwrap_or_default() - }); - - slack::Text::markdown(format!( - "*{} {}:*\n", - panic.os_name, - panic.os_version.unwrap_or_default(), - CRASH_REPORTS_BUCKET, - hostname, - incident_id, - incident_id.chars().take(8).collect::(), - )) - }) + s.add_field(slack::Text::markdown(format!("*Version:*\n {version} ",))) + .add_field({ + let hostname = app.config.blob_store_url.clone().unwrap_or_default(); + let hostname = hostname.strip_prefix("https://").unwrap_or_else(|| { + hostname.strip_prefix("http://").unwrap_or_default() + }); + + slack::Text::markdown(format!( + "*{} {}:*\n", + panic.os_name, + panic.os_version.unwrap_or_default(), + CRASH_REPORTS_BUCKET, + hostname, + incident_id, + incident_id.chars().take(8).collect::(), + )) + }) }) .add_rich_text(|r| r.add_preformatted(|p| p.add_text(backtrace_with_summary))) }); @@ -575,7 +580,7 @@ fn for_snowflake( }, serde_json::to_value(e).unwrap(), ), - Event::InlineCompletion(e) => ( + Event::EditPrediction(e) => ( format!( "Edit Prediction {}", if e.suggestion_accepted { @@ -586,7 +591,7 @@ fn for_snowflake( ), serde_json::to_value(e).unwrap(), ), - Event::InlineCompletionRating(e) => ( + Event::EditPredictionRating(e) => ( "Edit Prediction Rated".to_string(), serde_json::to_value(e).unwrap(), ), diff --git a/crates/collab/src/cents.rs b/crates/collab/src/cents.rs deleted file mode 100644 index a05971f1417339664d667665ddff63a13237f4dc..0000000000000000000000000000000000000000 --- a/crates/collab/src/cents.rs +++ /dev/null @@ -1,83 +0,0 @@ -use serde::Serialize; - -/// A number of cents. -#[derive( - Debug, - PartialEq, - Eq, - PartialOrd, - Ord, - Hash, - Clone, - Copy, - derive_more::Add, - derive_more::AddAssign, - derive_more::Sub, - derive_more::SubAssign, - Serialize, -)] -pub struct Cents(pub u32); - -impl Cents { - pub const ZERO: Self = Self(0); - - pub const fn new(cents: u32) -> Self { - Self(cents) - } - - pub const fn from_dollars(dollars: u32) -> Self { - Self(dollars * 100) - } - - pub fn saturating_sub(self, other: Cents) -> Self { - Self(self.0.saturating_sub(other.0)) - } -} - -#[cfg(test)] -mod tests { - use pretty_assertions::assert_eq; - - use super::*; - - #[test] - fn test_cents_new() { - assert_eq!(Cents::new(50), Cents(50)); - } - - #[test] - fn test_cents_from_dollars() { - assert_eq!(Cents::from_dollars(1), Cents(100)); - assert_eq!(Cents::from_dollars(5), Cents(500)); - } - - #[test] - fn test_cents_zero() { - assert_eq!(Cents::ZERO, Cents(0)); - } - - #[test] - fn test_cents_add() { - assert_eq!(Cents(50) + Cents(30), Cents(80)); - } - - #[test] - fn test_cents_add_assign() { - let mut cents = Cents(50); - cents += Cents(30); - assert_eq!(cents, Cents(80)); - } - - #[test] - fn test_cents_saturating_sub() { - assert_eq!(Cents(50).saturating_sub(Cents(30)), Cents(20)); - assert_eq!(Cents(30).saturating_sub(Cents(50)), Cents(0)); - } - - #[test] - fn test_cents_ordering() { - assert!(Cents(50) > Cents(30)); - assert!(Cents(30) < Cents(50)); - assert_eq!(Cents(50), Cents(50)); - } -} diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index cc2924569776f7be5bb2be546fa67413bbf75d4c..2c22ca206945eb02752680b6149d7796643ee938 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -42,9 +42,6 @@ pub use tests::TestDb; pub use ids::*; pub use queries::billing_customers::{CreateBillingCustomerParams, UpdateBillingCustomerParams}; -pub use queries::billing_preferences::{ - CreateBillingPreferencesParams, UpdateBillingPreferencesParams, -}; pub use queries::billing_subscriptions::{ CreateBillingSubscriptionParams, UpdateBillingSubscriptionParams, }; @@ -532,11 +529,17 @@ pub struct RejoinedProject { pub worktrees: Vec, pub updated_repositories: Vec, pub removed_repositories: Vec, - pub language_servers: Vec, + pub language_servers: Vec, } impl RejoinedProject { pub fn to_proto(&self) -> proto::RejoinedProject { + let (language_servers, language_server_capabilities) = self + .language_servers + .clone() + .into_iter() + .map(|server| (server.server, server.capabilities)) + .unzip(); proto::RejoinedProject { id: self.id.to_proto(), worktrees: self @@ -554,7 +557,8 @@ impl RejoinedProject { .iter() .map(|collaborator| collaborator.to_proto()) .collect(), - language_servers: self.language_servers.clone(), + language_servers, + language_server_capabilities, } } } @@ -601,7 +605,7 @@ pub struct Project { pub collaborators: Vec, pub worktrees: BTreeMap, pub repositories: Vec, - pub language_servers: Vec, + pub language_servers: Vec, } pub struct ProjectCollaborator { @@ -626,6 +630,12 @@ impl ProjectCollaborator { } } +#[derive(Debug, Clone)] +pub struct LanguageServer { + pub server: proto::LanguageServer, + pub capabilities: String, +} + #[derive(Debug)] pub struct LeftProject { pub id: ProjectId, diff --git a/crates/collab/src/db/queries/billing_preferences.rs b/crates/collab/src/db/queries/billing_preferences.rs index 1a6fbe946a47e5c47e5ad5c4c41db32ab25e4e7c..f370964ecd7d5c762c88e5fb572fde84ce81935d 100644 --- a/crates/collab/src/db/queries/billing_preferences.rs +++ b/crates/collab/src/db/queries/billing_preferences.rs @@ -1,21 +1,5 @@ -use anyhow::Context as _; - use super::*; -#[derive(Debug)] -pub struct CreateBillingPreferencesParams { - pub max_monthly_llm_usage_spending_in_cents: i32, - pub model_request_overages_enabled: bool, - pub model_request_overages_spend_limit_in_cents: i32, -} - -#[derive(Debug, Default)] -pub struct UpdateBillingPreferencesParams { - pub max_monthly_llm_usage_spending_in_cents: ActiveValue, - pub model_request_overages_enabled: ActiveValue, - pub model_request_overages_spend_limit_in_cents: ActiveValue, -} - impl Database { /// Returns the billing preferences for the given user, if they exist. pub async fn get_billing_preferences( @@ -30,62 +14,4 @@ impl Database { }) .await } - - /// Creates new billing preferences for the given user. - pub async fn create_billing_preferences( - &self, - user_id: UserId, - params: &CreateBillingPreferencesParams, - ) -> Result { - self.transaction(|tx| async move { - let preferences = billing_preference::Entity::insert(billing_preference::ActiveModel { - user_id: ActiveValue::set(user_id), - max_monthly_llm_usage_spending_in_cents: ActiveValue::set( - params.max_monthly_llm_usage_spending_in_cents, - ), - model_request_overages_enabled: ActiveValue::set( - params.model_request_overages_enabled, - ), - model_request_overages_spend_limit_in_cents: ActiveValue::set( - params.model_request_overages_spend_limit_in_cents, - ), - ..Default::default() - }) - .exec_with_returning(&*tx) - .await?; - - Ok(preferences) - }) - .await - } - - /// Updates the billing preferences for the given user. - pub async fn update_billing_preferences( - &self, - user_id: UserId, - params: &UpdateBillingPreferencesParams, - ) -> Result { - self.transaction(|tx| async move { - let preferences = billing_preference::Entity::update_many() - .set(billing_preference::ActiveModel { - max_monthly_llm_usage_spending_in_cents: params - .max_monthly_llm_usage_spending_in_cents - .clone(), - model_request_overages_enabled: params.model_request_overages_enabled.clone(), - model_request_overages_spend_limit_in_cents: params - .model_request_overages_spend_limit_in_cents - .clone(), - ..Default::default() - }) - .filter(billing_preference::Column::UserId.eq(user_id)) - .exec_with_returning(&*tx) - .await?; - - Ok(preferences - .into_iter() - .next() - .context("billing preferences not found")?) - }) - .await - } } diff --git a/crates/collab/src/db/queries/billing_subscriptions.rs b/crates/collab/src/db/queries/billing_subscriptions.rs index f25d0abeaaba9b303d915350d138557e268824f9..8361d6b4d07f8e6b59f9c7b39b18057e6f62b3c0 100644 --- a/crates/collab/src/db/queries/billing_subscriptions.rs +++ b/crates/collab/src/db/queries/billing_subscriptions.rs @@ -85,19 +85,6 @@ impl Database { .await } - /// Returns the billing subscription with the specified ID. - pub async fn get_billing_subscription_by_id( - &self, - id: BillingSubscriptionId, - ) -> Result> { - self.transaction(|tx| async move { - Ok(billing_subscription::Entity::find_by_id(id) - .one(&*tx) - .await?) - }) - .await - } - /// Returns the billing subscription with the specified Stripe subscription ID. pub async fn get_billing_subscription_by_stripe_subscription_id( &self, @@ -143,92 +130,6 @@ impl Database { .await } - /// Returns all of the billing subscriptions for the user with the specified ID. - /// - /// Note that this returns the subscriptions regardless of their status. - /// If you're wanting to check if a use has an active billing subscription, - /// use `get_active_billing_subscriptions` instead. - pub async fn get_billing_subscriptions( - &self, - user_id: UserId, - ) -> Result> { - self.transaction(|tx| async move { - let subscriptions = billing_subscription::Entity::find() - .inner_join(billing_customer::Entity) - .filter(billing_customer::Column::UserId.eq(user_id)) - .order_by_asc(billing_subscription::Column::Id) - .all(&*tx) - .await?; - - Ok(subscriptions) - }) - .await - } - - pub async fn get_active_billing_subscriptions( - &self, - user_ids: HashSet, - ) -> Result> { - self.transaction(|tx| { - let user_ids = user_ids.clone(); - async move { - let mut rows = billing_subscription::Entity::find() - .inner_join(billing_customer::Entity) - .select_also(billing_customer::Entity) - .filter(billing_customer::Column::UserId.is_in(user_ids)) - .filter( - billing_subscription::Column::StripeSubscriptionStatus - .eq(StripeSubscriptionStatus::Active), - ) - .filter(billing_subscription::Column::Kind.is_null()) - .order_by_asc(billing_subscription::Column::Id) - .stream(&*tx) - .await?; - - let mut subscriptions = HashMap::default(); - while let Some(row) = rows.next().await { - if let (subscription, Some(customer)) = row? { - subscriptions.insert(customer.user_id, (customer, subscription)); - } - } - Ok(subscriptions) - } - }) - .await - } - - pub async fn get_active_zed_pro_billing_subscriptions( - &self, - user_ids: HashSet, - ) -> Result> { - self.transaction(|tx| { - let user_ids = user_ids.clone(); - async move { - let mut rows = billing_subscription::Entity::find() - .inner_join(billing_customer::Entity) - .select_also(billing_customer::Entity) - .filter(billing_customer::Column::UserId.is_in(user_ids)) - .filter( - billing_subscription::Column::StripeSubscriptionStatus - .eq(StripeSubscriptionStatus::Active), - ) - .filter(billing_subscription::Column::Kind.eq(SubscriptionKind::ZedPro)) - .order_by_asc(billing_subscription::Column::Id) - .stream(&*tx) - .await?; - - let mut subscriptions = HashMap::default(); - while let Some(row) = rows.next().await { - if let (subscription, Some(customer)) = row? { - subscriptions.insert(customer.user_id, (customer, subscription)); - } - } - Ok(subscriptions) - } - }) - .await - } - /// Returns whether the user has an active billing subscription. pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result { Ok(self.count_active_billing_subscriptions(user_id).await? > 0) diff --git a/crates/collab/src/db/queries/buffers.rs b/crates/collab/src/db/queries/buffers.rs index a288a4e7ebc3a957493221255a552303c2a09fa2..2e6b4719d1c126230849ac81bc1f215092bc0b5e 100644 --- a/crates/collab/src/db/queries/buffers.rs +++ b/crates/collab/src/db/queries/buffers.rs @@ -786,6 +786,32 @@ impl Database { }) .collect()) } + + /// Update language server capabilities for a given id. + pub async fn update_server_capabilities( + &self, + project_id: ProjectId, + server_id: u64, + new_capabilities: String, + ) -> Result<()> { + self.transaction(|tx| { + let new_capabilities = new_capabilities.clone(); + async move { + Ok( + language_server::Entity::update(language_server::ActiveModel { + project_id: ActiveValue::unchanged(project_id), + id: ActiveValue::unchanged(server_id as i64), + capabilities: ActiveValue::set(new_capabilities), + ..Default::default() + }) + .exec(&*tx) + .await?, + ) + } + }) + .await?; + Ok(()) + } } fn operation_to_storage( diff --git a/crates/collab/src/db/queries/projects.rs b/crates/collab/src/db/queries/projects.rs index ba22a7b4e38fcef13b419474d0a6b97465e9ad3e..82f74d910ba0d12c1473719189e066eb9d0307eb 100644 --- a/crates/collab/src/db/queries/projects.rs +++ b/crates/collab/src/db/queries/projects.rs @@ -692,13 +692,17 @@ impl Database { project_id: ActiveValue::set(project_id), id: ActiveValue::set(server.id as i64), name: ActiveValue::set(server.name.clone()), + capabilities: ActiveValue::set(update.capabilities.clone()), }) .on_conflict( OnConflict::columns([ language_server::Column::ProjectId, language_server::Column::Id, ]) - .update_column(language_server::Column::Name) + .update_columns([ + language_server::Column::Name, + language_server::Column::Capabilities, + ]) .to_owned(), ) .exec(&*tx) @@ -1054,10 +1058,13 @@ impl Database { repositories, language_servers: language_servers .into_iter() - .map(|language_server| proto::LanguageServer { - id: language_server.id as u64, - name: language_server.name, - worktree_id: None, + .map(|language_server| LanguageServer { + server: proto::LanguageServer { + id: language_server.id as u64, + name: language_server.name, + worktree_id: None, + }, + capabilities: language_server.capabilities, }) .collect(), }; diff --git a/crates/collab/src/db/queries/rooms.rs b/crates/collab/src/db/queries/rooms.rs index cb805786dd224aa4c9edde0fff945b4b4268313c..c63d7133be2ec616a95fa73359a5050c289501bf 100644 --- a/crates/collab/src/db/queries/rooms.rs +++ b/crates/collab/src/db/queries/rooms.rs @@ -804,10 +804,13 @@ impl Database { .all(tx) .await? .into_iter() - .map(|language_server| proto::LanguageServer { - id: language_server.id as u64, - name: language_server.name, - worktree_id: None, + .map(|language_server| LanguageServer { + server: proto::LanguageServer { + id: language_server.id as u64, + name: language_server.name, + worktree_id: None, + }, + capabilities: language_server.capabilities, }) .collect::>(); diff --git a/crates/collab/src/db/tables/billing_subscription.rs b/crates/collab/src/db/tables/billing_subscription.rs index 43198f9859f004f18e944b1ccb591bbbaa6ca69b..522973dbc970b69947b8e790e370bfc9fa93aa99 100644 --- a/crates/collab/src/db/tables/billing_subscription.rs +++ b/crates/collab/src/db/tables/billing_subscription.rs @@ -95,7 +95,7 @@ pub enum SubscriptionKind { ZedFree, } -impl From for zed_llm_client::Plan { +impl From for cloud_llm_client::Plan { fn from(value: SubscriptionKind) -> Self { match value { SubscriptionKind::ZedPro => Self::ZedPro, diff --git a/crates/collab/src/db/tables/language_server.rs b/crates/collab/src/db/tables/language_server.rs index 9ff8c75fc686442ad1d0cb65af66ad881c5fb6b7..34c7514d917b313990521acf8542c31394d009fc 100644 --- a/crates/collab/src/db/tables/language_server.rs +++ b/crates/collab/src/db/tables/language_server.rs @@ -9,6 +9,7 @@ pub struct Model { #[sea_orm(primary_key)] pub id: i64, pub name: String, + pub capabilities: String, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] diff --git a/crates/collab/src/db/tests.rs b/crates/collab/src/db/tests.rs index 9404e2670c87744f210df9b16b35fe93da16466a..6c2f9dc82a88c159df1111d01a213259ab3a6c76 100644 --- a/crates/collab/src/db/tests.rs +++ b/crates/collab/src/db/tests.rs @@ -1,4 +1,3 @@ -mod billing_subscription_tests; mod buffer_tests; mod channel_tests; mod contributor_tests; diff --git a/crates/collab/src/db/tests/billing_subscription_tests.rs b/crates/collab/src/db/tests/billing_subscription_tests.rs deleted file mode 100644 index fb5f8552a366d8fe3663aa3da84e964a4f3b23d7..0000000000000000000000000000000000000000 --- a/crates/collab/src/db/tests/billing_subscription_tests.rs +++ /dev/null @@ -1,96 +0,0 @@ -use std::sync::Arc; - -use crate::db::billing_subscription::StripeSubscriptionStatus; -use crate::db::tests::new_test_user; -use crate::db::{CreateBillingCustomerParams, CreateBillingSubscriptionParams}; -use crate::test_both_dbs; - -use super::Database; - -test_both_dbs!( - test_get_active_billing_subscriptions, - test_get_active_billing_subscriptions_postgres, - test_get_active_billing_subscriptions_sqlite -); - -async fn test_get_active_billing_subscriptions(db: &Arc) { - // A user with no subscription has no active billing subscriptions. - { - let user_id = new_test_user(db, "no-subscription-user@example.com").await; - let subscription_count = db - .count_active_billing_subscriptions(user_id) - .await - .unwrap(); - - assert_eq!(subscription_count, 0); - } - - // A user with an active subscription has one active billing subscription. - { - let user_id = new_test_user(db, "active-user@example.com").await; - let customer = db - .create_billing_customer(&CreateBillingCustomerParams { - user_id, - stripe_customer_id: "cus_active_user".into(), - }) - .await - .unwrap(); - assert_eq!(customer.stripe_customer_id, "cus_active_user".to_string()); - - db.create_billing_subscription(&CreateBillingSubscriptionParams { - billing_customer_id: customer.id, - kind: None, - stripe_subscription_id: "sub_active_user".into(), - stripe_subscription_status: StripeSubscriptionStatus::Active, - stripe_cancellation_reason: None, - stripe_current_period_start: None, - stripe_current_period_end: None, - }) - .await - .unwrap(); - - let subscriptions = db.get_billing_subscriptions(user_id).await.unwrap(); - assert_eq!(subscriptions.len(), 1); - - let subscription = &subscriptions[0]; - assert_eq!( - subscription.stripe_subscription_id, - "sub_active_user".to_string() - ); - assert_eq!( - subscription.stripe_subscription_status, - StripeSubscriptionStatus::Active - ); - } - - // A user with a past-due subscription has no active billing subscriptions. - { - let user_id = new_test_user(db, "past-due-user@example.com").await; - let customer = db - .create_billing_customer(&CreateBillingCustomerParams { - user_id, - stripe_customer_id: "cus_past_due_user".into(), - }) - .await - .unwrap(); - assert_eq!(customer.stripe_customer_id, "cus_past_due_user".to_string()); - - db.create_billing_subscription(&CreateBillingSubscriptionParams { - billing_customer_id: customer.id, - kind: None, - stripe_subscription_id: "sub_past_due_user".into(), - stripe_subscription_status: StripeSubscriptionStatus::PastDue, - stripe_cancellation_reason: None, - stripe_current_period_start: None, - stripe_current_period_end: None, - }) - .await - .unwrap(); - - let subscription_count = db - .count_active_billing_subscriptions(user_id) - .await - .unwrap(); - assert_eq!(subscription_count, 0); - } -} diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 2b20c8f080e5d55eab3e81d946d7a0aaf06cffd8..905859ca6996c3593e1f13fbcb0e723531595ff6 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -1,6 +1,5 @@ pub mod api; pub mod auth; -mod cents; pub mod db; pub mod env; pub mod executor; @@ -21,7 +20,6 @@ use axum::{ http::{HeaderMap, StatusCode}, response::IntoResponse, }; -pub use cents::*; use db::{ChannelId, Database}; use executor::Executor; use llm::db::LlmDatabase; diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index cf5dec6e282662a6766edd96f3669aa096206afc..de74858168fd94ab677cee03f721a1e3fbbdfd46 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -1,8 +1,6 @@ pub mod db; mod token; -use crate::Cents; - pub use token::*; pub const AGENT_EXTENDED_TRIAL_FEATURE_FLAG: &str = "agent-extended-trial"; @@ -12,9 +10,3 @@ pub const BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG: &str = "bypass-account-age-chec /// The minimum account age an account must have in order to use the LLM service. pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30); - -/// The default value to use for maximum spend per month if the user did not -/// explicitly set a maximum spend. -/// -/// Used to prevent surprise bills. -pub const DEFAULT_MAX_MONTHLY_SPEND: Cents = Cents::from_dollars(10); diff --git a/crates/collab/src/llm/db.rs b/crates/collab/src/llm/db.rs index 6a6efca0de1aea2fb979572400b19ce4094fcfcd..18ad624dab840c47df766a55c2f59cf9a17c55e6 100644 --- a/crates/collab/src/llm/db.rs +++ b/crates/collab/src/llm/db.rs @@ -6,11 +6,11 @@ mod tables; #[cfg(test)] mod tests; +use cloud_llm_client::LanguageModelProvider; use collections::HashMap; pub use ids::*; pub use seed::*; pub use tables::*; -use zed_llm_client::LanguageModelProvider; #[cfg(test)] pub use tests::TestLlmDb; diff --git a/crates/collab/src/llm/db/queries.rs b/crates/collab/src/llm/db/queries.rs index 3565366fdd788cbe282a2d631c572efead7fd1bb..0087218b3ff9fe81850870bc8022bd81fe0ee48d 100644 --- a/crates/collab/src/llm/db/queries.rs +++ b/crates/collab/src/llm/db/queries.rs @@ -1,6 +1,5 @@ use super::*; pub mod providers; -pub mod subscription_usage_meters; pub mod subscription_usages; pub mod usages; diff --git a/crates/collab/src/llm/db/queries/subscription_usage_meters.rs b/crates/collab/src/llm/db/queries/subscription_usage_meters.rs deleted file mode 100644 index c0ce5d679bf83ca61f254d34c0090d0885a2c029..0000000000000000000000000000000000000000 --- a/crates/collab/src/llm/db/queries/subscription_usage_meters.rs +++ /dev/null @@ -1,72 +0,0 @@ -use crate::db::UserId; -use crate::llm::db::queries::subscription_usages::convert_chrono_to_time; - -use super::*; - -impl LlmDatabase { - /// Returns all current subscription usage meters as of the given timestamp. - pub async fn get_current_subscription_usage_meters( - &self, - now: DateTimeUtc, - ) -> Result> { - let now = convert_chrono_to_time(now)?; - - self.transaction(|tx| async move { - let result = subscription_usage_meter::Entity::find() - .inner_join(subscription_usage::Entity) - .filter( - subscription_usage::Column::PeriodStartAt - .lte(now) - .and(subscription_usage::Column::PeriodEndAt.gte(now)), - ) - .select_also(subscription_usage::Entity) - .all(&*tx) - .await?; - - let result = result - .into_iter() - .filter_map(|(meter, usage)| { - let usage = usage?; - Some((meter, usage)) - }) - .collect(); - - Ok(result) - }) - .await - } - - /// Returns all current subscription usage meters for the given user as of the given timestamp. - pub async fn get_current_subscription_usage_meters_for_user( - &self, - user_id: UserId, - now: DateTimeUtc, - ) -> Result> { - let now = convert_chrono_to_time(now)?; - - self.transaction(|tx| async move { - let result = subscription_usage_meter::Entity::find() - .inner_join(subscription_usage::Entity) - .filter(subscription_usage::Column::UserId.eq(user_id)) - .filter( - subscription_usage::Column::PeriodStartAt - .lte(now) - .and(subscription_usage::Column::PeriodEndAt.gte(now)), - ) - .select_also(subscription_usage::Entity) - .all(&*tx) - .await?; - - let result = result - .into_iter() - .filter_map(|(meter, usage)| { - let usage = usage?; - Some((meter, usage)) - }) - .collect(); - - Ok(result) - }) - .await - } -} diff --git a/crates/collab/src/llm/db/queries/subscription_usages.rs b/crates/collab/src/llm/db/queries/subscription_usages.rs index ee1ebf59b8cfc6bf3f4b74a81061389a160ccb4e..8a519790753099be62868e94e8b068958095d320 100644 --- a/crates/collab/src/llm/db/queries/subscription_usages.rs +++ b/crates/collab/src/llm/db/queries/subscription_usages.rs @@ -1,28 +1,7 @@ -use time::PrimitiveDateTime; - use crate::db::UserId; use super::*; -pub fn convert_chrono_to_time(datetime: DateTimeUtc) -> anyhow::Result { - use chrono::{Datelike as _, Timelike as _}; - - let date = time::Date::from_calendar_date( - datetime.year(), - time::Month::try_from(datetime.month() as u8).unwrap(), - datetime.day() as u8, - )?; - - let time = time::Time::from_hms_nano( - datetime.hour() as u8, - datetime.minute() as u8, - datetime.second() as u8, - datetime.nanosecond(), - )?; - - Ok(PrimitiveDateTime::new(date, time)) -} - impl LlmDatabase { pub async fn get_subscription_usage_for_period( &self, diff --git a/crates/collab/src/llm/db/tests/provider_tests.rs b/crates/collab/src/llm/db/tests/provider_tests.rs index 7d52964b939e7b17ca8ec9f986756c00bd0dad55..f4e1de40ec10705ed9b740619754fcf9ec5f3e1e 100644 --- a/crates/collab/src/llm/db/tests/provider_tests.rs +++ b/crates/collab/src/llm/db/tests/provider_tests.rs @@ -1,5 +1,5 @@ +use cloud_llm_client::LanguageModelProvider; use pretty_assertions::assert_eq; -use zed_llm_client::LanguageModelProvider; use crate::llm::db::LlmDatabase; use crate::test_llm_db; diff --git a/crates/collab/src/llm/token.rs b/crates/collab/src/llm/token.rs index d4566ffcb40715c11a6a714b105238f47f333969..da01c7f3bed5cab1e7dbd6cfdef8cd4d7643044c 100644 --- a/crates/collab/src/llm/token.rs +++ b/crates/collab/src/llm/token.rs @@ -4,12 +4,12 @@ use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEA use crate::{Config, db::billing_preference}; use anyhow::{Context as _, Result}; use chrono::{NaiveDateTime, Utc}; +use cloud_llm_client::Plan; use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation}; use serde::{Deserialize, Serialize}; use std::time::Duration; use thiserror::Error; use uuid::Uuid; -use zed_llm_client::Plan; #[derive(Clone, Debug, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 6a78049b3f17af97f1671137076a47b5d6ec5a84..20641cb2322a6aa10372064ca208eef091b2ae5a 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -7,8 +7,8 @@ use axum::{ routing::get, }; +use collab::ServiceMode; use collab::api::CloudflareIpCountryHeader; -use collab::api::billing::sync_llm_request_usage_with_stripe_periodically; use collab::llm::db::LlmDatabase; use collab::migrations::run_database_migrations; use collab::user_backfiller::spawn_user_backfiller; @@ -16,7 +16,6 @@ use collab::{ AppState, Config, Result, api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, rpc::ResultExt, }; -use collab::{ServiceMode, api::billing::poll_stripe_events_periodically}; use db::Database; use std::{ env::args, @@ -31,7 +30,7 @@ use tower_http::trace::TraceLayer; use tracing_subscriber::{ Layer, filter::EnvFilter, fmt::format::JsonFields, util::SubscriberInitExt, }; -use util::{ResultExt as _, maybe}; +use util::ResultExt as _; const VERSION: &str = env!("CARGO_PKG_VERSION"); const REVISION: Option<&'static str> = option_env!("GITHUB_SHA"); @@ -120,8 +119,6 @@ async fn main() -> Result<()> { let rpc_server = collab::rpc::Server::new(epoch, state.clone()); rpc_server.start().await?; - poll_stripe_events_periodically(state.clone(), rpc_server.clone()); - app = app .merge(collab::api::routes(rpc_server.clone())) .merge(collab::rpc::routes(rpc_server.clone())); @@ -133,29 +130,6 @@ async fn main() -> Result<()> { fetch_extensions_from_blob_store_periodically(state.clone()); spawn_user_backfiller(state.clone()); - let llm_db = maybe!(async { - let database_url = state - .config - .llm_database_url - .as_ref() - .context("missing LLM_DATABASE_URL")?; - let max_connections = state - .config - .llm_database_max_connections - .context("missing LLM_DATABASE_MAX_CONNECTIONS")?; - - let mut db_options = db::ConnectOptions::new(database_url); - db_options.max_connections(max_connections); - LlmDatabase::new(db_options, state.executor.clone()).await - }) - .await - .trace_err(); - - if let Some(mut llm_db) = llm_db { - llm_db.initialize().await?; - sync_llm_request_usage_with_stripe_periodically(state.clone()); - } - app = app .merge(collab::api::events::router()) .merge(collab::api::extensions::router()) diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 7a454e11cfced2fa7f9f1dc8c0263934830c7cad..ec1105b138f728de36d0798ddb6ddf22aa4bc8c8 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -23,6 +23,7 @@ use anyhow::{Context as _, anyhow, bail}; use async_tungstenite::tungstenite::{ Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame, }; +use axum::headers::UserAgent; use axum::{ Extension, Router, TypedHeader, body::Body, @@ -40,9 +41,11 @@ use chrono::Utc; use collections::{HashMap, HashSet}; pub use connection_pool::{ConnectionPool, ZedVersion}; use core::fmt::{self, Debug, Formatter}; +use futures::TryFutureExt as _; use reqwest_client::ReqwestClient; -use rpc::proto::split_repository_update; +use rpc::proto::{MultiLspQuery, split_repository_update}; use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi}; +use tracing::Span; use futures::{ FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture, @@ -93,8 +96,13 @@ const MAX_CONCURRENT_CONNECTIONS: usize = 512; static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0); +const TOTAL_DURATION_MS: &str = "total_duration_ms"; +const PROCESSING_DURATION_MS: &str = "processing_duration_ms"; +const QUEUE_DURATION_MS: &str = "queue_duration_ms"; +const HOST_WAITING_MS: &str = "host_waiting_ms"; + type MessageHandler = - Box, Session) -> BoxFuture<'static, ()>>; + Box, Session, Span) -> BoxFuture<'static, ()>>; pub struct ConnectionGuard; @@ -162,6 +170,42 @@ impl Principal { } } +#[derive(Clone)] +struct MessageContext { + session: Session, + span: tracing::Span, +} + +impl Deref for MessageContext { + type Target = Session; + + fn deref(&self) -> &Self::Target { + &self.session + } +} + +impl MessageContext { + pub fn forward_request( + &self, + receiver_id: ConnectionId, + request: T, + ) -> impl Future> { + let request_start_time = Instant::now(); + let span = self.span.clone(); + tracing::info!("start forwarding request"); + self.peer + .forward_request(self.connection_id, receiver_id, request) + .inspect(move |_| { + span.record( + HOST_WAITING_MS, + request_start_time.elapsed().as_micros() as f64 / 1000.0, + ); + }) + .inspect_err(|_| tracing::error!("error forwarding request")) + .inspect_ok(|_| tracing::info!("finished forwarding request")) + } +} + #[derive(Clone)] struct Session { principal: Principal, @@ -314,7 +358,7 @@ impl Server { .add_request_handler(forward_read_only_project_request::) .add_request_handler(forward_read_only_project_request::) .add_request_handler(forward_read_only_project_request::) - .add_request_handler(forward_find_search_candidates_request) + .add_request_handler(forward_read_only_project_request::) .add_request_handler(forward_read_only_project_request::) .add_request_handler(forward_read_only_project_request::) .add_request_handler(forward_read_only_project_request::) @@ -339,9 +383,6 @@ impl Server { .add_request_handler(forward_read_only_project_request::) .add_request_handler(forward_read_only_project_request::) .add_request_handler(forward_read_only_project_request::) - .add_request_handler( - forward_read_only_project_request::, - ) .add_request_handler(forward_read_only_project_request::) .add_request_handler( forward_mutating_project_request::, @@ -373,7 +414,7 @@ impl Server { .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) - .add_request_handler(forward_mutating_project_request::) + .add_request_handler(multi_lsp_query) .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) @@ -433,6 +474,8 @@ impl Server { .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) + .add_request_handler(forward_mutating_project_request::) + .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_read_only_project_request::) @@ -646,42 +689,37 @@ impl Server { fn add_handler(&mut self, handler: F) -> &mut Self where - F: 'static + Send + Sync + Fn(TypedEnvelope, Session) -> Fut, + F: 'static + Send + Sync + Fn(TypedEnvelope, MessageContext) -> Fut, Fut: 'static + Send + Future>, M: EnvelopedMessage, { let prev_handler = self.handlers.insert( TypeId::of::(), - Box::new(move |envelope, session| { + Box::new(move |envelope, session, span| { let envelope = envelope.into_any().downcast::>().unwrap(); let received_at = envelope.received_at; tracing::info!("message received"); let start_time = Instant::now(); - let future = (handler)(*envelope, session); + let future = (handler)( + *envelope, + MessageContext { + session, + span: span.clone(), + }, + ); async move { let result = future.await; let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0; let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0; let queue_duration_ms = total_duration_ms - processing_duration_ms; - let payload_type = M::NAME; - + span.record(TOTAL_DURATION_MS, total_duration_ms); + span.record(PROCESSING_DURATION_MS, processing_duration_ms); + span.record(QUEUE_DURATION_MS, queue_duration_ms); match result { Err(error) => { - tracing::error!( - ?error, - total_duration_ms, - processing_duration_ms, - queue_duration_ms, - payload_type, - "error handling message" - ) + tracing::error!(?error, "error handling message") } - Ok(()) => tracing::info!( - total_duration_ms, - processing_duration_ms, - queue_duration_ms, - "finished handling message" - ), + Ok(()) => tracing::info!("finished handling message"), } } .boxed() @@ -695,7 +733,7 @@ impl Server { fn add_message_handler(&mut self, handler: F) -> &mut Self where - F: 'static + Send + Sync + Fn(M, Session) -> Fut, + F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut, Fut: 'static + Send + Future>, M: EnvelopedMessage, { @@ -705,7 +743,7 @@ impl Server { fn add_request_handler(&mut self, handler: F) -> &mut Self where - F: 'static + Send + Sync + Fn(M, Response, Session) -> Fut, + F: 'static + Send + Sync + Fn(M, Response, MessageContext) -> Fut, Fut: Send + Future>, M: RequestMessage, { @@ -748,6 +786,8 @@ impl Server { address: String, principal: Principal, zed_version: ZedVersion, + release_channel: Option, + user_agent: Option, geoip_country_code: Option, system_id: Option, send_connection_id: Option>, @@ -760,9 +800,18 @@ impl Server { user_id=field::Empty, login=field::Empty, impersonator=field::Empty, - geoip_country_code=field::Empty + user_agent=field::Empty, + geoip_country_code=field::Empty, + release_channel=field::Empty, ); principal.update_span(&span); + if let Some(user_agent) = user_agent { + span.record("user_agent", user_agent); + } + if let Some(release_channel) = release_channel { + span.record("release_channel", release_channel); + } + if let Some(country_code) = geoip_country_code.as_ref() { span.record("geoip_country_code", country_code); } @@ -771,12 +820,11 @@ impl Server { async move { if *teardown.borrow() { tracing::error!("server is tearing down"); - return + return; } - let (connection_id, handle_io, mut incoming_rx) = this - .peer - .add_connection(connection, { + let (connection_id, handle_io, mut incoming_rx) = + this.peer.add_connection(connection, { let executor = executor.clone(); move |duration| executor.sleep(duration) }); @@ -793,10 +841,14 @@ impl Server { } }; - let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new( - supermaven_admin_api_key.to_string(), - http_client.clone(), - ))); + let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map( + |supermaven_admin_api_key| { + Arc::new(SupermavenAdminApi::new( + supermaven_admin_api_key.to_string(), + http_client.clone(), + )) + }, + ); let session = Session { principal: principal.clone(), @@ -811,7 +863,15 @@ impl Server { supermaven_client, }; - if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await { + if let Err(error) = this + .send_initial_client_update( + connection_id, + zed_version, + send_connection_id, + &session, + ) + .await + { tracing::error!(?error, "failed to send initial client update"); return; } @@ -828,14 +888,22 @@ impl Server { // // This arrangement ensures we will attempt to process earlier messages first, but fall // back to processing messages arrived later in the spirit of making progress. + const MAX_CONCURRENT_HANDLERS: usize = 256; let mut foreground_message_handlers = FuturesUnordered::new(); - let concurrent_handlers = Arc::new(Semaphore::new(256)); + let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS)); + let get_concurrent_handlers = { + let concurrent_handlers = concurrent_handlers.clone(); + move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits() + }; loop { let next_message = async { let permit = concurrent_handlers.clone().acquire_owned().await.unwrap(); let message = incoming_rx.next().await; - (permit, message) - }.fuse(); + // Cache the concurrent_handlers here, so that we know what the + // queue looks like as each handler starts + (permit, message, get_concurrent_handlers()) + } + .fuse(); futures::pin_mut!(next_message); futures::select_biased! { _ = teardown.changed().fuse() => return, @@ -847,21 +915,30 @@ impl Server { } _ = foreground_message_handlers.next() => {} next_message = next_message => { - let (permit, message) = next_message; + let (permit, message, concurrent_handlers) = next_message; if let Some(message) = message { let type_name = message.payload_type_name(); // note: we copy all the fields from the parent span so we can query them in the logs. // (https://github.com/tokio-rs/tracing/issues/2670). - let span = tracing::info_span!("receive message", %connection_id, %address, type_name, + let span = tracing::info_span!("receive message", + %connection_id, + %address, + type_name, + concurrent_handlers, user_id=field::Empty, login=field::Empty, impersonator=field::Empty, + multi_lsp_query_request=field::Empty, + { TOTAL_DURATION_MS }=field::Empty, + { PROCESSING_DURATION_MS }=field::Empty, + { QUEUE_DURATION_MS }=field::Empty, + { HOST_WAITING_MS }=field::Empty ); principal.update_span(&span); let span_enter = span.enter(); if let Some(handler) = this.handlers.get(&message.payload_type_id()) { let is_background = message.is_background(); - let handle_message = (handler)(message, session.clone()); + let handle_message = (handler)(message, session.clone(), span.clone()); drop(span_enter); let handle_message = async move { @@ -885,12 +962,13 @@ impl Server { } drop(foreground_message_handlers); - tracing::info!("signing out"); + let concurrent_handlers = get_concurrent_handlers(); + tracing::info!(concurrent_handlers, "signing out"); if let Err(error) = connection_lost(session, teardown, executor).await { tracing::error!(?error, "error signing out"); } - - }.instrument(span) + } + .instrument(span) } async fn send_initial_client_update( @@ -1002,7 +1080,26 @@ impl Server { Ok(()) } - pub async fn update_plan_for_user(self: &Arc, user_id: UserId) -> Result<()> { + pub async fn update_plan_for_user( + self: &Arc, + user_id: UserId, + update_user_plan: proto::UpdateUserPlan, + ) -> Result<()> { + let pool = self.connection_pool.lock(); + for connection_id in pool.user_connection_ids(user_id) { + self.peer + .send(connection_id, update_user_plan.clone()) + .trace_err(); + } + + Ok(()) + } + + /// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan` + /// message on the Collab server. + /// + /// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint. + pub async fn update_plan_for_user_legacy(self: &Arc, user_id: UserId) -> Result<()> { let user = self .app_state .db @@ -1018,14 +1115,7 @@ impl Server { ) .await?; - let pool = self.connection_pool.lock(); - for connection_id in pool.user_connection_ids(user_id) { - self.peer - .send(connection_id, update_user_plan.clone()) - .trace_err(); - } - - Ok(()) + self.update_plan_for_user(user_id, update_user_plan).await } pub async fn refresh_llm_tokens_for_user(self: &Arc, user_id: UserId) { @@ -1140,6 +1230,35 @@ impl Header for AppVersionHeader { } } +#[derive(Debug)] +pub struct ReleaseChannelHeader(String); + +impl Header for ReleaseChannelHeader { + fn name() -> &'static HeaderName { + static ZED_RELEASE_CHANNEL: OnceLock = OnceLock::new(); + ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel")) + } + + fn decode<'i, I>(values: &mut I) -> Result + where + Self: Sized, + I: Iterator, + { + Ok(Self( + values + .next() + .ok_or_else(axum::headers::Error::invalid)? + .to_str() + .map_err(|_| axum::headers::Error::invalid())? + .to_owned(), + )) + } + + fn encode>(&self, values: &mut E) { + values.extend([self.0.parse().unwrap()]); + } +} + pub fn routes(server: Arc) -> Router<(), Body> { Router::new() .route("/rpc", get(handle_websocket_request)) @@ -1155,9 +1274,11 @@ pub fn routes(server: Arc) -> Router<(), Body> { pub async fn handle_websocket_request( TypedHeader(ProtocolVersion(protocol_version)): TypedHeader, app_version_header: Option>, + release_channel_header: Option>, ConnectInfo(socket_address): ConnectInfo, Extension(server): Extension>, Extension(principal): Extension, + user_agent: Option>, country_code_header: Option>, system_id_header: Option>, ws: WebSocketUpgrade, @@ -1178,6 +1299,8 @@ pub async fn handle_websocket_request( .into_response(); }; + let release_channel = release_channel_header.map(|header| header.0.0); + if !version.can_collaborate() { return ( StatusCode::UPGRADE_REQUIRED, @@ -1213,6 +1336,8 @@ pub async fn handle_websocket_request( socket_address, principal, version, + release_channel, + user_agent.map(|header| header.to_string()), country_code_header.map(|header| header.to_string()), system_id_header.map(|header| header.to_string()), None, @@ -1305,7 +1430,11 @@ async fn connection_lost( } /// Acknowledges a ping from a client, used to keep the connection alive. -async fn ping(_: proto::Ping, response: Response, _session: Session) -> Result<()> { +async fn ping( + _: proto::Ping, + response: Response, + _session: MessageContext, +) -> Result<()> { response.send(proto::Ack {})?; Ok(()) } @@ -1314,7 +1443,7 @@ async fn ping(_: proto::Ping, response: Response, _session: Session async fn create_room( _request: proto::CreateRoom, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let livekit_room = nanoid::nanoid!(30); @@ -1354,7 +1483,7 @@ async fn create_room( async fn join_room( request: proto::JoinRoom, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let room_id = RoomId::from_proto(request.id); @@ -1421,7 +1550,7 @@ async fn join_room( async fn rejoin_room( request: proto::RejoinRoom, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let room; let channel; @@ -1549,15 +1678,15 @@ fn notify_rejoined_projects( } // Stream this worktree's diagnostics. - for summary in worktree.diagnostic_summaries { - session.peer.send( - session.connection_id, - proto::UpdateDiagnosticSummary { - project_id: project.id.to_proto(), - worktree_id: worktree.id, - summary: Some(summary), - }, - )?; + let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter(); + if let Some(summary) = worktree_diagnostics.next() { + let message = proto::UpdateDiagnosticSummary { + project_id: project.id.to_proto(), + worktree_id: worktree.id, + summary: Some(summary), + more_summaries: worktree_diagnostics.collect(), + }; + session.peer.send(session.connection_id, message)?; } for settings_file in worktree.settings_files { @@ -1598,7 +1727,7 @@ fn notify_rejoined_projects( async fn leave_room( _: proto::LeaveRoom, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { leave_room_for_session(&session, session.connection_id).await?; response.send(proto::Ack {})?; @@ -1609,7 +1738,7 @@ async fn leave_room( async fn set_room_participant_role( request: proto::SetRoomParticipantRole, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let user_id = UserId::from_proto(request.user_id); let role = ChannelRole::from(request.role()); @@ -1657,7 +1786,7 @@ async fn set_room_participant_role( async fn call( request: proto::Call, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let room_id = RoomId::from_proto(request.room_id); let calling_user_id = session.user_id(); @@ -1726,7 +1855,7 @@ async fn call( async fn cancel_call( request: proto::CancelCall, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let called_user_id = UserId::from_proto(request.called_user_id); let room_id = RoomId::from_proto(request.room_id); @@ -1761,7 +1890,7 @@ async fn cancel_call( } /// Decline an incoming call. -async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> { +async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> { let room_id = RoomId::from_proto(message.room_id); { let room = session @@ -1796,7 +1925,7 @@ async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<( async fn update_participant_location( request: proto::UpdateParticipantLocation, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let room_id = RoomId::from_proto(request.room_id); let location = request.location.context("invalid location")?; @@ -1815,7 +1944,7 @@ async fn update_participant_location( async fn share_project( request: proto::ShareProject, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let (project_id, room) = &*session .db() @@ -1836,7 +1965,7 @@ async fn share_project( } /// Unshare a project from the room. -async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> { +async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> { let project_id = ProjectId::from_proto(message.project_id); unshare_project_internal(project_id, session.connection_id, &session).await } @@ -1883,7 +2012,7 @@ async fn unshare_project_internal( async fn join_project( request: proto::JoinProject, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let project_id = ProjectId::from_proto(request.project_id); @@ -1944,12 +2073,19 @@ async fn join_project( } // First, we send the metadata associated with each worktree. + let (language_servers, language_server_capabilities) = project + .language_servers + .clone() + .into_iter() + .map(|server| (server.server, server.capabilities)) + .unzip(); response.send(proto::JoinProjectResponse { project_id: project.id.0 as u64, worktrees: worktrees.clone(), replica_id: replica_id.0 as u32, collaborators: collaborators.clone(), - language_servers: project.language_servers.clone(), + language_servers, + language_server_capabilities, role: project.role.into(), })?; @@ -1972,15 +2108,15 @@ async fn join_project( } // Stream this worktree's diagnostics. - for summary in worktree.diagnostic_summaries { - session.peer.send( - session.connection_id, - proto::UpdateDiagnosticSummary { - project_id: project_id.to_proto(), - worktree_id: worktree.id, - summary: Some(summary), - }, - )?; + let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter(); + if let Some(summary) = worktree_diagnostics.next() { + let message = proto::UpdateDiagnosticSummary { + project_id: project.id.to_proto(), + worktree_id: worktree.id, + summary: Some(summary), + more_summaries: worktree_diagnostics.collect(), + }; + session.peer.send(session.connection_id, message)?; } for settings_file in worktree.settings_files { @@ -2008,8 +2144,8 @@ async fn join_project( session.connection_id, proto::UpdateLanguageServer { project_id: project_id.to_proto(), - server_name: Some(language_server.name.clone()), - language_server_id: language_server.id, + server_name: Some(language_server.server.name.clone()), + language_server_id: language_server.server.id, variant: Some( proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated( proto::LspDiskBasedDiagnosticsUpdated {}, @@ -2023,7 +2159,7 @@ async fn join_project( } /// Leave someone elses shared project. -async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> { +async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> { let sender_id = session.connection_id; let project_id = ProjectId::from_proto(request.project_id); let db = session.db().await; @@ -2046,7 +2182,7 @@ async fn leave_project(request: proto::LeaveProject, session: Session) -> Result async fn update_project( request: proto::UpdateProject, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let project_id = ProjectId::from_proto(request.project_id); let (room, guest_connection_ids) = &*session @@ -2075,7 +2211,7 @@ async fn update_project( async fn update_worktree( request: proto::UpdateWorktree, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let guest_connection_ids = session .db() @@ -2099,7 +2235,7 @@ async fn update_worktree( async fn update_repository( request: proto::UpdateRepository, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let guest_connection_ids = session .db() @@ -2123,7 +2259,7 @@ async fn update_repository( async fn remove_repository( request: proto::RemoveRepository, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let guest_connection_ids = session .db() @@ -2147,7 +2283,7 @@ async fn remove_repository( /// Updates other participants with changes to the diagnostics async fn update_diagnostic_summary( message: proto::UpdateDiagnosticSummary, - session: Session, + session: MessageContext, ) -> Result<()> { let guest_connection_ids = session .db() @@ -2171,7 +2307,7 @@ async fn update_diagnostic_summary( /// Updates other participants with changes to the worktree settings async fn update_worktree_settings( message: proto::UpdateWorktreeSettings, - session: Session, + session: MessageContext, ) -> Result<()> { let guest_connection_ids = session .db() @@ -2195,7 +2331,7 @@ async fn update_worktree_settings( /// Notify other participants that a language server has started. async fn start_language_server( request: proto::StartLanguageServer, - session: Session, + session: MessageContext, ) -> Result<()> { let guest_connection_ids = session .db() @@ -2218,12 +2354,20 @@ async fn start_language_server( /// Notify other participants that a language server has changed. async fn update_language_server( request: proto::UpdateLanguageServer, - session: Session, + session: MessageContext, ) -> Result<()> { let project_id = ProjectId::from_proto(request.project_id); - let project_connection_ids = session - .db() - .await + let db = session.db().await; + + if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant + { + if let Some(capabilities) = update.capabilities.clone() { + db.update_server_capabilities(project_id, request.language_server_id, capabilities) + .await?; + } + } + + let project_connection_ids = db .project_connection_ids(project_id, session.connection_id, true) .await?; broadcast( @@ -2243,7 +2387,7 @@ async fn update_language_server( async fn forward_read_only_project_request( request: T, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> where T: EntityMessage + RequestMessage, @@ -2254,29 +2398,7 @@ where .await .host_for_read_only_project_request(project_id, session.connection_id) .await?; - let payload = session - .peer - .forward_request(session.connection_id, host_connection_id, request) - .await?; - response.send(payload)?; - Ok(()) -} - -async fn forward_find_search_candidates_request( - request: proto::FindSearchCandidates, - response: Response, - session: Session, -) -> Result<()> { - let project_id = ProjectId::from_proto(request.remote_entity_id()); - let host_connection_id = session - .db() - .await - .host_for_read_only_project_request(project_id, session.connection_id) - .await?; - let payload = session - .peer - .forward_request(session.connection_id, host_connection_id, request) - .await?; + let payload = session.forward_request(host_connection_id, request).await?; response.send(payload)?; Ok(()) } @@ -2286,7 +2408,7 @@ async fn forward_find_search_candidates_request( async fn forward_mutating_project_request( request: T, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> where T: EntityMessage + RequestMessage, @@ -2298,18 +2420,25 @@ where .await .host_for_mutating_project_request(project_id, session.connection_id) .await?; - let payload = session - .peer - .forward_request(session.connection_id, host_connection_id, request) - .await?; + let payload = session.forward_request(host_connection_id, request).await?; response.send(payload)?; Ok(()) } +async fn multi_lsp_query( + request: MultiLspQuery, + response: Response, + session: MessageContext, +) -> Result<()> { + tracing::Span::current().record("multi_lsp_query_request", request.request_str()); + tracing::info!("multi_lsp_query message received"); + forward_mutating_project_request(request, response, session).await +} + /// Notify other participants that a new buffer has been created async fn create_buffer_for_peer( request: proto::CreateBufferForPeer, - session: Session, + session: MessageContext, ) -> Result<()> { session .db() @@ -2331,7 +2460,7 @@ async fn create_buffer_for_peer( async fn update_buffer( request: proto::UpdateBuffer, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let project_id = ProjectId::from_proto(request.project_id); let mut capability = Capability::ReadOnly; @@ -2366,17 +2495,14 @@ async fn update_buffer( }; if host != session.connection_id { - session - .peer - .forward_request(session.connection_id, host, request.clone()) - .await?; + session.forward_request(host, request.clone()).await?; } response.send(proto::Ack {})?; Ok(()) } -async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> { +async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> { let project_id = ProjectId::from_proto(message.project_id); let operation = message.operation.as_ref().context("invalid operation")?; @@ -2421,7 +2547,7 @@ async fn update_context(message: proto::UpdateContext, session: Session) -> Resu /// Notify other participants that a project has been updated. async fn broadcast_project_message_from_host>( request: T, - session: Session, + session: MessageContext, ) -> Result<()> { let project_id = ProjectId::from_proto(request.remote_entity_id()); let project_connection_ids = session @@ -2446,7 +2572,7 @@ async fn broadcast_project_message_from_host, - session: Session, + session: MessageContext, ) -> Result<()> { let room_id = RoomId::from_proto(request.room_id); let project_id = request.project_id.map(ProjectId::from_proto); @@ -2459,10 +2585,7 @@ async fn follow( .check_room_participants(room_id, leader_id, session.connection_id) .await?; - let response_payload = session - .peer - .forward_request(session.connection_id, leader_id, request) - .await?; + let response_payload = session.forward_request(leader_id, request).await?; response.send(response_payload)?; if let Some(project_id) = project_id { @@ -2478,7 +2601,7 @@ async fn follow( } /// Stop following another user in a call. -async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> { +async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> { let room_id = RoomId::from_proto(request.room_id); let project_id = request.project_id.map(ProjectId::from_proto); let leader_id = request.leader_id.context("invalid leader id")?.into(); @@ -2507,7 +2630,7 @@ async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> { } /// Notify everyone following you of your current location. -async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> { +async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> { let room_id = RoomId::from_proto(request.room_id); let database = session.db.lock().await; @@ -2542,7 +2665,7 @@ async fn update_followers(request: proto::UpdateFollowers, session: Session) -> async fn get_users( request: proto::GetUsers, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let user_ids = request .user_ids @@ -2570,7 +2693,7 @@ async fn get_users( async fn fuzzy_search_users( request: proto::FuzzySearchUsers, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let query = request.query; let users = match query.len() { @@ -2602,7 +2725,7 @@ async fn fuzzy_search_users( async fn request_contact( request: proto::RequestContact, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let requester_id = session.user_id(); let responder_id = UserId::from_proto(request.responder_id); @@ -2649,7 +2772,7 @@ async fn request_contact( async fn respond_to_contact_request( request: proto::RespondToContactRequest, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let responder_id = session.user_id(); let requester_id = UserId::from_proto(request.requester_id); @@ -2707,7 +2830,7 @@ async fn respond_to_contact_request( async fn remove_contact( request: proto::RemoveContact, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let requester_id = session.user_id(); let responder_id = UserId::from_proto(request.user_id); @@ -2845,12 +2968,12 @@ async fn make_update_user_plan_message( } fn model_requests_limit( - plan: zed_llm_client::Plan, + plan: cloud_llm_client::Plan, feature_flags: &Vec, -) -> zed_llm_client::UsageLimit { +) -> cloud_llm_client::UsageLimit { match plan.model_requests_limit() { - zed_llm_client::UsageLimit::Limited(limit) => { - let limit = if plan == zed_llm_client::Plan::ZedProTrial + cloud_llm_client::UsageLimit::Limited(limit) => { + let limit = if plan == cloud_llm_client::Plan::ZedProTrial && feature_flags .iter() .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG) @@ -2860,9 +2983,9 @@ fn model_requests_limit( limit }; - zed_llm_client::UsageLimit::Limited(limit) + cloud_llm_client::UsageLimit::Limited(limit) } - zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited, + cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited, } } @@ -2872,21 +2995,21 @@ fn subscription_usage_to_proto( feature_flags: &Vec, ) -> proto::SubscriptionUsage { let plan = match plan { - proto::Plan::Free => zed_llm_client::Plan::ZedFree, - proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro, - proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial, + proto::Plan::Free => cloud_llm_client::Plan::ZedFree, + proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro, + proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial, }; proto::SubscriptionUsage { model_requests_usage_amount: usage.model_requests as u32, model_requests_usage_limit: Some(proto::UsageLimit { variant: Some(match model_requests_limit(plan, feature_flags) { - zed_llm_client::UsageLimit::Limited(limit) => { + cloud_llm_client::UsageLimit::Limited(limit) => { proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { limit: limit as u32, }) } - zed_llm_client::UsageLimit::Unlimited => { + cloud_llm_client::UsageLimit::Unlimited => { proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) } }), @@ -2894,12 +3017,12 @@ fn subscription_usage_to_proto( edit_predictions_usage_amount: usage.edit_predictions as u32, edit_predictions_usage_limit: Some(proto::UsageLimit { variant: Some(match plan.edit_predictions_limit() { - zed_llm_client::UsageLimit::Limited(limit) => { + cloud_llm_client::UsageLimit::Limited(limit) => { proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { limit: limit as u32, }) } - zed_llm_client::UsageLimit::Unlimited => { + cloud_llm_client::UsageLimit::Unlimited => { proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) } }), @@ -2912,21 +3035,21 @@ fn make_default_subscription_usage( feature_flags: &Vec, ) -> proto::SubscriptionUsage { let plan = match plan { - proto::Plan::Free => zed_llm_client::Plan::ZedFree, - proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro, - proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial, + proto::Plan::Free => cloud_llm_client::Plan::ZedFree, + proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro, + proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial, }; proto::SubscriptionUsage { model_requests_usage_amount: 0, model_requests_usage_limit: Some(proto::UsageLimit { variant: Some(match model_requests_limit(plan, feature_flags) { - zed_llm_client::UsageLimit::Limited(limit) => { + cloud_llm_client::UsageLimit::Limited(limit) => { proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { limit: limit as u32, }) } - zed_llm_client::UsageLimit::Unlimited => { + cloud_llm_client::UsageLimit::Unlimited => { proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) } }), @@ -2934,12 +3057,12 @@ fn make_default_subscription_usage( edit_predictions_usage_amount: 0, edit_predictions_usage_limit: Some(proto::UsageLimit { variant: Some(match plan.edit_predictions_limit() { - zed_llm_client::UsageLimit::Limited(limit) => { + cloud_llm_client::UsageLimit::Limited(limit) => { proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { limit: limit as u32, }) } - zed_llm_client::UsageLimit::Unlimited => { + cloud_llm_client::UsageLimit::Unlimited => { proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) } }), @@ -2966,7 +3089,10 @@ async fn update_user_plan(session: &Session) -> Result<()> { Ok(()) } -async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> { +async fn subscribe_to_channels( + _: proto::SubscribeToChannels, + session: MessageContext, +) -> Result<()> { subscribe_user_to_channels(session.user_id(), &session).await?; Ok(()) } @@ -2992,7 +3118,7 @@ async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Resul async fn create_channel( request: proto::CreateChannel, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; @@ -3047,7 +3173,7 @@ async fn create_channel( async fn delete_channel( request: proto::DeleteChannel, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; @@ -3075,7 +3201,7 @@ async fn delete_channel( async fn invite_channel_member( request: proto::InviteChannelMember, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3112,7 +3238,7 @@ async fn invite_channel_member( async fn remove_channel_member( request: proto::RemoveChannelMember, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3156,7 +3282,7 @@ async fn remove_channel_member( async fn set_channel_visibility( request: proto::SetChannelVisibility, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3201,7 +3327,7 @@ async fn set_channel_visibility( async fn set_channel_member_role( request: proto::SetChannelMemberRole, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3249,7 +3375,7 @@ async fn set_channel_member_role( async fn rename_channel( request: proto::RenameChannel, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3281,7 +3407,7 @@ async fn rename_channel( async fn move_channel( request: proto::MoveChannel, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let to = ChannelId::from_proto(request.to); @@ -3323,7 +3449,7 @@ async fn move_channel( async fn reorder_channel( request: proto::ReorderChannel, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let direction = request.direction(); @@ -3369,7 +3495,7 @@ async fn reorder_channel( async fn get_channel_members( request: proto::GetChannelMembers, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3389,7 +3515,7 @@ async fn get_channel_members( async fn respond_to_channel_invite( request: proto::RespondToChannelInvite, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3430,7 +3556,7 @@ async fn respond_to_channel_invite( async fn join_channel( request: proto::JoinChannel, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); join_channel_internal(channel_id, Box::new(response), session).await @@ -3453,7 +3579,7 @@ impl JoinChannelInternalResponse for Response { async fn join_channel_internal( channel_id: ChannelId, response: Box, - session: Session, + session: MessageContext, ) -> Result<()> { let joined_room = { let mut db = session.db().await; @@ -3548,7 +3674,7 @@ async fn join_channel_internal( async fn join_channel_buffer( request: proto::JoinChannelBuffer, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3579,7 +3705,7 @@ async fn join_channel_buffer( /// Edit the channel notes async fn update_channel_buffer( request: proto::UpdateChannelBuffer, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3631,7 +3757,7 @@ async fn update_channel_buffer( async fn rejoin_channel_buffers( request: proto::RejoinChannelBuffers, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let buffers = db @@ -3666,7 +3792,7 @@ async fn rejoin_channel_buffers( async fn leave_channel_buffer( request: proto::LeaveChannelBuffer, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3728,7 +3854,7 @@ fn send_notifications( async fn send_channel_message( request: proto::SendChannelMessage, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { // Validate the message body. let body = request.body.trim().to_string(); @@ -3821,7 +3947,7 @@ async fn send_channel_message( async fn remove_channel_message( request: proto::RemoveChannelMessage, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let message_id = MessageId::from_proto(request.message_id); @@ -3856,7 +3982,7 @@ async fn remove_channel_message( async fn update_channel_message( request: proto::UpdateChannelMessage, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let message_id = MessageId::from_proto(request.message_id); @@ -3940,7 +4066,7 @@ async fn update_channel_message( /// Mark a channel message as read async fn acknowledge_channel_message( request: proto::AckChannelMessage, - session: Session, + session: MessageContext, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let message_id = MessageId::from_proto(request.message_id); @@ -3960,7 +4086,7 @@ async fn acknowledge_channel_message( /// Mark a buffer version as synced async fn acknowledge_buffer_version( request: proto::AckBufferOperation, - session: Session, + session: MessageContext, ) -> Result<()> { let buffer_id = BufferId::from_proto(request.buffer_id); session @@ -3980,7 +4106,7 @@ async fn acknowledge_buffer_version( async fn get_supermaven_api_key( _request: proto::GetSupermavenApiKey, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let user_id: String = session.user_id().to_string(); if !session.is_staff() { @@ -4009,7 +4135,7 @@ async fn get_supermaven_api_key( async fn join_channel_chat( request: proto::JoinChannelChat, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); @@ -4027,7 +4153,10 @@ async fn join_channel_chat( } /// Stop receiving chat updates for a channel -async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> { +async fn leave_channel_chat( + request: proto::LeaveChannelChat, + session: MessageContext, +) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); session .db() @@ -4041,7 +4170,7 @@ async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) async fn get_channel_messages( request: proto::GetChannelMessages, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let messages = session @@ -4065,7 +4194,7 @@ async fn get_channel_messages( async fn get_channel_messages_by_id( request: proto::GetChannelMessagesById, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let message_ids = request .message_ids @@ -4088,7 +4217,7 @@ async fn get_channel_messages_by_id( async fn get_notifications( request: proto::GetNotifications, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let notifications = session .db() @@ -4110,7 +4239,7 @@ async fn get_notifications( async fn mark_notification_as_read( request: proto::MarkNotificationRead, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let database = &session.db().await; let notifications = database @@ -4132,7 +4261,7 @@ async fn mark_notification_as_read( async fn get_private_user_info( _request: proto::GetPrivateUserInfo, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; @@ -4156,7 +4285,7 @@ async fn get_private_user_info( async fn accept_terms_of_service( _request: proto::AcceptTermsOfService, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; @@ -4167,13 +4296,20 @@ async fn accept_terms_of_service( response.send(proto::AcceptTermsOfServiceResponse { accepted_tos_at: accepted_tos_at.timestamp() as u64, })?; + + // When the user accepts the terms of service, we want to refresh their LLM + // token to grant access. + session + .peer + .send(session.connection_id, proto::RefreshLlmToken {})?; + Ok(()) } async fn get_llm_api_token( _request: proto::GetLlmToken, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; diff --git a/crates/collab/src/stripe_billing.rs b/crates/collab/src/stripe_billing.rs index 8bf6c08158b9fa742f0f9e59711c7df80013614d..ef5bef3e7e5d6c687e4b963f820d5d484e6c4537 100644 --- a/crates/collab/src/stripe_billing.rs +++ b/crates/collab/src/stripe_billing.rs @@ -1,26 +1,15 @@ use std::sync::Arc; -use anyhow::{Context as _, anyhow}; -use chrono::Utc; +use anyhow::anyhow; use collections::HashMap; use stripe::SubscriptionStatus; use tokio::sync::RwLock; -use uuid::Uuid; use crate::Result; -use crate::db::billing_subscription::SubscriptionKind; -use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG; use crate::stripe_client::{ - RealStripeClient, StripeBillingAddressCollection, StripeCheckoutSessionMode, - StripeCheckoutSessionPaymentMethodCollection, StripeClient, - StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams, - StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, - StripeCreateMeterEventPayload, StripeCreateSubscriptionItems, StripeCreateSubscriptionParams, - StripeCustomerId, StripeCustomerUpdate, StripeCustomerUpdateAddress, StripeCustomerUpdateName, - StripeMeter, StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId, - StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior, - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems, - UpdateSubscriptionParams, + RealStripeClient, StripeAutomaticTax, StripeClient, StripeCreateSubscriptionItems, + StripeCreateSubscriptionParams, StripeCustomerId, StripePrice, StripePriceId, + StripeSubscription, }; pub struct StripeBilling { @@ -30,8 +19,6 @@ pub struct StripeBilling { #[derive(Default)] struct StripeBillingState { - meters_by_event_name: HashMap, - price_ids_by_meter_id: HashMap, prices_by_lookup_key: HashMap, } @@ -60,24 +47,11 @@ impl StripeBilling { let mut state = self.state.write().await; - let (meters, prices) = - futures::try_join!(self.client.list_meters(), self.client.list_prices())?; - - for meter in meters { - state - .meters_by_event_name - .insert(meter.event_name.clone(), meter); - } + let prices = self.client.list_prices().await?; for price in prices { if let Some(lookup_key) = price.lookup_key.clone() { - state.prices_by_lookup_key.insert(lookup_key, price.clone()); - } - - if let Some(recurring) = price.recurring { - if let Some(meter) = recurring.meter { - state.price_ids_by_meter_id.insert(meter, price.id); - } + state.prices_by_lookup_key.insert(lookup_key, price); } } @@ -114,30 +88,6 @@ impl StripeBilling { .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}"))) } - pub async fn determine_subscription_kind( - &self, - subscription: &StripeSubscription, - ) -> Option { - let zed_pro_price_id = self.zed_pro_price_id().await.ok()?; - let zed_free_price_id = self.zed_free_price_id().await.ok()?; - - subscription.items.iter().find_map(|item| { - let price = item.price.as_ref()?; - - if price.id == zed_pro_price_id { - Some(if subscription.status == SubscriptionStatus::Trialing { - SubscriptionKind::ZedProTrial - } else { - SubscriptionKind::ZedPro - }) - } else if price.id == zed_free_price_id { - Some(SubscriptionKind::ZedFree) - } else { - None - } - }) - } - /// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does /// not already exist. /// @@ -170,152 +120,6 @@ impl StripeBilling { Ok(customer_id) } - pub async fn subscribe_to_price( - &self, - subscription_id: &StripeSubscriptionId, - price: &StripePrice, - ) -> Result<()> { - let subscription = self.client.get_subscription(subscription_id).await?; - - if subscription_contains_price(&subscription, &price.id) { - return Ok(()); - } - - const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100; - - let price_per_unit = price.unit_amount.unwrap_or_default(); - let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit; - - self.client - .update_subscription( - subscription_id, - UpdateSubscriptionParams { - items: Some(vec![UpdateSubscriptionItems { - price: Some(price.id.clone()), - }]), - trial_settings: Some(StripeSubscriptionTrialSettings { - end_behavior: StripeSubscriptionTrialSettingsEndBehavior { - missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel - }, - }), - }, - ) - .await?; - - Ok(()) - } - - pub async fn bill_model_request_usage( - &self, - customer_id: &StripeCustomerId, - event_name: &str, - requests: i32, - ) -> Result<()> { - let timestamp = Utc::now().timestamp(); - let idempotency_key = Uuid::new_v4(); - - self.client - .create_meter_event(StripeCreateMeterEventParams { - identifier: &format!("model_requests/{}", idempotency_key), - event_name, - payload: StripeCreateMeterEventPayload { - value: requests as u64, - stripe_customer_id: customer_id, - }, - timestamp: Some(timestamp), - }) - .await?; - - Ok(()) - } - - pub async fn checkout_with_zed_pro( - &self, - customer_id: &StripeCustomerId, - github_login: &str, - success_url: &str, - ) -> Result { - let zed_pro_price_id = self.zed_pro_price_id().await?; - - let mut params = StripeCreateCheckoutSessionParams::default(); - params.mode = Some(StripeCheckoutSessionMode::Subscription); - params.customer = Some(customer_id); - params.client_reference_id = Some(github_login); - params.line_items = Some(vec![StripeCreateCheckoutSessionLineItems { - price: Some(zed_pro_price_id.to_string()), - quantity: Some(1), - }]); - params.success_url = Some(success_url); - params.billing_address_collection = Some(StripeBillingAddressCollection::Required); - params.customer_update = Some(StripeCustomerUpdate { - address: Some(StripeCustomerUpdateAddress::Auto), - name: Some(StripeCustomerUpdateName::Auto), - shipping: None, - }); - - let session = self.client.create_checkout_session(params).await?; - Ok(session.url.context("no checkout session URL")?) - } - - pub async fn checkout_with_zed_pro_trial( - &self, - customer_id: &StripeCustomerId, - github_login: &str, - feature_flags: Vec, - success_url: &str, - ) -> Result { - let zed_pro_price_id = self.zed_pro_price_id().await?; - - let eligible_for_extended_trial = feature_flags - .iter() - .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG); - - let trial_period_days = if eligible_for_extended_trial { 60 } else { 14 }; - - let mut subscription_metadata = std::collections::HashMap::new(); - if eligible_for_extended_trial { - subscription_metadata.insert( - "promo_feature_flag".to_string(), - AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string(), - ); - } - - let mut params = StripeCreateCheckoutSessionParams::default(); - params.subscription_data = Some(StripeCreateCheckoutSessionSubscriptionData { - trial_period_days: Some(trial_period_days), - trial_settings: Some(StripeSubscriptionTrialSettings { - end_behavior: StripeSubscriptionTrialSettingsEndBehavior { - missing_payment_method: - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel, - }, - }), - metadata: if !subscription_metadata.is_empty() { - Some(subscription_metadata) - } else { - None - }, - }); - params.mode = Some(StripeCheckoutSessionMode::Subscription); - params.payment_method_collection = - Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired); - params.customer = Some(customer_id); - params.client_reference_id = Some(github_login); - params.line_items = Some(vec![StripeCreateCheckoutSessionLineItems { - price: Some(zed_pro_price_id.to_string()), - quantity: Some(1), - }]); - params.success_url = Some(success_url); - params.billing_address_collection = Some(StripeBillingAddressCollection::Required); - params.customer_update = Some(StripeCustomerUpdate { - address: Some(StripeCustomerUpdateAddress::Auto), - name: Some(StripeCustomerUpdateName::Auto), - shipping: None, - }); - - let session = self.client.create_checkout_session(params).await?; - Ok(session.url.context("no checkout session URL")?) - } - pub async fn subscribe_to_zed_free( &self, customer_id: StripeCustomerId, @@ -342,6 +146,7 @@ impl StripeBilling { price: Some(zed_free_price_id), quantity: Some(1), }], + automatic_tax: Some(StripeAutomaticTax { enabled: true }), }; let subscription = self.client.create_subscription(params).await?; @@ -349,14 +154,3 @@ impl StripeBilling { Ok(subscription) } } - -fn subscription_contains_price( - subscription: &StripeSubscription, - price_id: &StripePriceId, -) -> bool { - subscription.items.iter().any(|item| { - item.price - .as_ref() - .map_or(false, |price| price.id == *price_id) - }) -} diff --git a/crates/collab/src/stripe_client.rs b/crates/collab/src/stripe_client.rs index 9ffcb2ba6c9fde13ebc84b9e7c509851158e0a1e..6e75a4d874bf41e7cb4418d4b56cfeb6040e5ff8 100644 --- a/crates/collab/src/stripe_client.rs +++ b/crates/collab/src/stripe_client.rs @@ -73,6 +73,7 @@ pub enum StripeCancellationDetailsReason { pub struct StripeCreateSubscriptionParams { pub customer: StripeCustomerId, pub items: Vec, + pub automatic_tax: Option, } #[derive(Debug)] @@ -190,6 +191,7 @@ pub struct StripeCreateCheckoutSessionParams<'a> { pub success_url: Option<&'a str>, pub billing_address_collection: Option, pub customer_update: Option, + pub tax_id_collection: Option, } #[derive(Debug, PartialEq, Eq, Clone, Copy)] @@ -218,6 +220,16 @@ pub struct StripeCreateCheckoutSessionSubscriptionData { pub trial_settings: Option, } +#[derive(Debug, PartialEq, Clone)] +pub struct StripeTaxIdCollection { + pub enabled: bool, +} + +#[derive(Debug, Clone)] +pub struct StripeAutomaticTax { + pub enabled: bool, +} + #[derive(Debug)] pub struct StripeCheckoutSession { pub url: Option, diff --git a/crates/collab/src/stripe_client/fake_stripe_client.rs b/crates/collab/src/stripe_client/fake_stripe_client.rs index 11b210dd0e7aba54148d26de0670f23415ae7cea..9bb08443ec6a5fd04ad11a8e24b1a71b03e4867b 100644 --- a/crates/collab/src/stripe_client/fake_stripe_client.rs +++ b/crates/collab/src/stripe_client/fake_stripe_client.rs @@ -14,8 +14,8 @@ use crate::stripe_client::{ StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeCustomerUpdate, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripeSubscription, - StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId, UpdateCustomerParams, - UpdateSubscriptionParams, + StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId, StripeTaxIdCollection, + UpdateCustomerParams, UpdateSubscriptionParams, }; #[derive(Debug, Clone)] @@ -38,6 +38,7 @@ pub struct StripeCreateCheckoutSessionCall { pub success_url: Option, pub billing_address_collection: Option, pub customer_update: Option, + pub tax_id_collection: Option, } pub struct FakeStripeClient { @@ -236,6 +237,7 @@ impl StripeClient for FakeStripeClient { success_url: params.success_url.map(|url| url.to_string()), billing_address_collection: params.billing_address_collection, customer_update: params.customer_update, + tax_id_collection: params.tax_id_collection, }); Ok(StripeCheckoutSession { diff --git a/crates/collab/src/stripe_client/real_stripe_client.rs b/crates/collab/src/stripe_client/real_stripe_client.rs index 7108e8d7597a3afd235c2ae48a4b05c5fc5de014..07c191ff30400ccbf4b73c4c84f09aa47e0fd9aa 100644 --- a/crates/collab/src/stripe_client/real_stripe_client.rs +++ b/crates/collab/src/stripe_client/real_stripe_client.rs @@ -10,16 +10,17 @@ use stripe::{ CreateCheckoutSessionSubscriptionData, CreateCheckoutSessionSubscriptionDataTrialSettings, CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior, CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod, - CreateCustomer, Customer, CustomerId, ListCustomers, Price, PriceId, Recurring, Subscription, - SubscriptionId, SubscriptionItem, SubscriptionItemId, UpdateCustomer, UpdateSubscriptionItems, - UpdateSubscriptionTrialSettings, UpdateSubscriptionTrialSettingsEndBehavior, + CreateCustomer, CreateSubscriptionAutomaticTax, Customer, CustomerId, ListCustomers, Price, + PriceId, Recurring, Subscription, SubscriptionId, SubscriptionItem, SubscriptionItemId, + UpdateCustomer, UpdateSubscriptionItems, UpdateSubscriptionTrialSettings, + UpdateSubscriptionTrialSettingsEndBehavior, UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, }; use crate::stripe_client::{ - CreateCustomerParams, StripeBillingAddressCollection, StripeCancellationDetails, - StripeCancellationDetailsReason, StripeCheckoutSession, StripeCheckoutSessionMode, - StripeCheckoutSessionPaymentMethodCollection, StripeClient, + CreateCustomerParams, StripeAutomaticTax, StripeBillingAddressCollection, + StripeCancellationDetails, StripeCancellationDetailsReason, StripeCheckoutSession, + StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, StripeClient, StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams, StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeCustomerUpdate, @@ -27,8 +28,8 @@ use crate::stripe_client::{ StripeMeter, StripePrice, StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId, StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior, - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateCustomerParams, - UpdateSubscriptionParams, + StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, StripeTaxIdCollection, + UpdateCustomerParams, UpdateSubscriptionParams, }; pub struct RealStripeClient { @@ -151,6 +152,7 @@ impl StripeClient for RealStripeClient { }) .collect(), ); + create_subscription.automatic_tax = params.automatic_tax.map(Into::into); let subscription = Subscription::create(&self.client, create_subscription).await?; @@ -366,6 +368,15 @@ impl From for StripeSubscriptionItem { } } +impl From for CreateSubscriptionAutomaticTax { + fn from(value: StripeAutomaticTax) -> Self { + Self { + enabled: value.enabled, + liability: None, + } + } +} + impl From for UpdateSubscriptionTrialSettings { fn from(value: StripeSubscriptionTrialSettings) -> Self { Self { @@ -448,6 +459,7 @@ impl<'a> TryFrom> for CreateCheckoutSessio success_url: value.success_url, billing_address_collection: value.billing_address_collection.map(Into::into), customer_update: value.customer_update.map(Into::into), + tax_id_collection: value.tax_id_collection.map(Into::into), ..Default::default() }) } @@ -590,3 +602,11 @@ impl From for stripe::CreateCheckoutSessionCustomerUpdate } } } + +impl From for stripe::CreateCheckoutSessionTaxIdCollection { + fn from(value: StripeTaxIdCollection) -> Self { + stripe::CreateCheckoutSessionTaxIdCollection { + enabled: value.enabled, + } + } +} diff --git a/crates/collab/src/tests.rs b/crates/collab/src/tests.rs index 19e410de5bd34a6c1984dba15367889f8c1689eb..8d5d076780733406904cd1c0431d56d6ebbc776f 100644 --- a/crates/collab/src/tests.rs +++ b/crates/collab/src/tests.rs @@ -38,12 +38,12 @@ fn room_participants(room: &Entity, cx: &mut TestAppContext) -> RoomPartic let mut remote = room .remote_participants() .values() - .map(|participant| participant.user.github_login.clone()) + .map(|participant| participant.user.github_login.clone().to_string()) .collect::>(); let mut pending = room .pending_participants() .iter() - .map(|user| user.github_login.clone()) + .map(|user| user.github_login.clone().to_string()) .collect::>(); remote.sort(); pending.sort(); diff --git a/crates/collab/src/tests/editor_tests.rs b/crates/collab/src/tests/editor_tests.rs index 2cc3ca76d1b639cc479cb44cde93a73570d5eb7f..8754b53f6eac550b465e3bc8fcc38c5363335af8 100644 --- a/crates/collab/src/tests/editor_tests.rs +++ b/crates/collab/src/tests/editor_tests.rs @@ -24,10 +24,7 @@ use language::{ }; use project::{ ProjectPath, SERVER_PROGRESS_THROTTLE_TIMEOUT, - lsp_store::{ - lsp_ext_command::{ExpandedMacro, LspExtExpandMacro}, - rust_analyzer_ext::RUST_ANALYZER_NAME, - }, + lsp_store::lsp_ext_command::{ExpandedMacro, LspExtExpandMacro}, project_settings::{InlineBlameSettings, ProjectSettings}, }; use recent_projects::disconnected_overlay::DisconnectedOverlay; @@ -296,19 +293,28 @@ async fn test_collaborating_with_completion(cx_a: &mut TestAppContext, cx_b: &mu .await; let active_call_a = cx_a.read(ActiveCall::global); + let capabilities = lsp::ServerCapabilities { + completion_provider: Some(lsp::CompletionOptions { + trigger_characters: Some(vec![".".to_string()]), + resolve_provider: Some(true), + ..lsp::CompletionOptions::default() + }), + ..lsp::ServerCapabilities::default() + }; client_a.language_registry().add(rust_lang()); let mut fake_language_servers = client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { - capabilities: lsp::ServerCapabilities { - completion_provider: Some(lsp::CompletionOptions { - trigger_characters: Some(vec![".".to_string()]), - resolve_provider: Some(true), - ..Default::default() - }), - ..Default::default() - }, - ..Default::default() + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, + ..FakeLspAdapter::default() }, ); @@ -566,11 +572,14 @@ async fn test_collaborating_with_code_actions( cx_b.update(editor::init); - // Set up a fake language server. client_a.language_registry().add(rust_lang()); let mut fake_language_servers = client_a .language_registry() .register_fake_lsp("Rust", FakeLspAdapter::default()); + client_b.language_registry().add(rust_lang()); + client_b + .language_registry() + .register_fake_lsp("Rust", FakeLspAdapter::default()); client_a .fs() @@ -775,19 +784,27 @@ async fn test_collaborating_with_renames(cx_a: &mut TestAppContext, cx_b: &mut T cx_b.update(editor::init); - // Set up a fake language server. + let capabilities = lsp::ServerCapabilities { + rename_provider: Some(lsp::OneOf::Right(lsp::RenameOptions { + prepare_provider: Some(true), + work_done_progress_options: Default::default(), + })), + ..lsp::ServerCapabilities::default() + }; client_a.language_registry().add(rust_lang()); let mut fake_language_servers = client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { - capabilities: lsp::ServerCapabilities { - rename_provider: Some(lsp::OneOf::Right(lsp::RenameOptions { - prepare_provider: Some(true), - work_done_progress_options: Default::default(), - })), - ..Default::default() - }, - ..Default::default() + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, + ..FakeLspAdapter::default() }, ); @@ -818,6 +835,8 @@ async fn test_collaborating_with_renames(cx_a: &mut TestAppContext, cx_b: &mut T .downcast::() .unwrap(); let fake_language_server = fake_language_servers.next().await.unwrap(); + cx_a.run_until_parked(); + cx_b.run_until_parked(); // Move cursor to a location that can be renamed. let prepare_rename = editor_b.update_in(cx_b, |editor, window, cx| { @@ -1055,7 +1074,7 @@ async fn test_language_server_statuses(cx_a: &mut TestAppContext, cx_b: &mut Tes project_a.read_with(cx_a, |project, cx| { let status = project.language_server_statuses(cx).next().unwrap().1; - assert_eq!(status.name, "the-language-server"); + assert_eq!(status.name.0, "the-language-server"); assert_eq!(status.pending_work.len(), 1); assert_eq!( status.pending_work["the-token"].message.as_ref().unwrap(), @@ -1072,7 +1091,7 @@ async fn test_language_server_statuses(cx_a: &mut TestAppContext, cx_b: &mut Tes project_b.read_with(cx_b, |project, cx| { let status = project.language_server_statuses(cx).next().unwrap().1; - assert_eq!(status.name, "the-language-server"); + assert_eq!(status.name.0, "the-language-server"); }); executor.advance_clock(SERVER_PROGRESS_THROTTLE_TIMEOUT); @@ -1089,7 +1108,7 @@ async fn test_language_server_statuses(cx_a: &mut TestAppContext, cx_b: &mut Tes project_a.read_with(cx_a, |project, cx| { let status = project.language_server_statuses(cx).next().unwrap().1; - assert_eq!(status.name, "the-language-server"); + assert_eq!(status.name.0, "the-language-server"); assert_eq!(status.pending_work.len(), 1); assert_eq!( status.pending_work["the-token"].message.as_ref().unwrap(), @@ -1099,7 +1118,7 @@ async fn test_language_server_statuses(cx_a: &mut TestAppContext, cx_b: &mut Tes project_b.read_with(cx_b, |project, cx| { let status = project.language_server_statuses(cx).next().unwrap().1; - assert_eq!(status.name, "the-language-server"); + assert_eq!(status.name.0, "the-language-server"); assert_eq!(status.pending_work.len(), 1); assert_eq!( status.pending_work["the-token"].message.as_ref().unwrap(), @@ -1422,18 +1441,27 @@ async fn test_on_input_format_from_guest_to_host( .await; let active_call_a = cx_a.read(ActiveCall::global); + let capabilities = lsp::ServerCapabilities { + document_on_type_formatting_provider: Some(lsp::DocumentOnTypeFormattingOptions { + first_trigger_character: ":".to_string(), + more_trigger_character: Some(vec![">".to_string()]), + }), + ..lsp::ServerCapabilities::default() + }; client_a.language_registry().add(rust_lang()); let mut fake_language_servers = client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { - capabilities: lsp::ServerCapabilities { - document_on_type_formatting_provider: Some(lsp::DocumentOnTypeFormattingOptions { - first_trigger_character: ":".to_string(), - more_trigger_character: Some(vec![">".to_string()]), - }), - ..Default::default() - }, - ..Default::default() + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, + ..FakeLspAdapter::default() }, ); @@ -1588,16 +1616,24 @@ async fn test_mutual_editor_inlay_hint_cache_update( }); }); + let capabilities = lsp::ServerCapabilities { + inlay_hint_provider: Some(lsp::OneOf::Left(true)), + ..lsp::ServerCapabilities::default() + }; client_a.language_registry().add(rust_lang()); - client_b.language_registry().add(rust_lang()); let mut fake_language_servers = client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { - capabilities: lsp::ServerCapabilities { - inlay_hint_provider: Some(lsp::OneOf::Left(true)), - ..Default::default() - }, - ..Default::default() + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, + ..FakeLspAdapter::default() }, ); @@ -1830,16 +1866,24 @@ async fn test_inlay_hint_refresh_is_forwarded( }); }); + let capabilities = lsp::ServerCapabilities { + inlay_hint_provider: Some(lsp::OneOf::Left(true)), + ..lsp::ServerCapabilities::default() + }; client_a.language_registry().add(rust_lang()); - client_b.language_registry().add(rust_lang()); let mut fake_language_servers = client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { - capabilities: lsp::ServerCapabilities { - inlay_hint_provider: Some(lsp::OneOf::Left(true)), - ..Default::default() - }, - ..Default::default() + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, + ..FakeLspAdapter::default() }, ); @@ -2004,15 +2048,23 @@ async fn test_lsp_document_color(cx_a: &mut TestAppContext, cx_b: &mut TestAppCo }); }); + let capabilities = lsp::ServerCapabilities { + color_provider: Some(lsp::ColorProviderCapability::Simple(true)), + ..lsp::ServerCapabilities::default() + }; client_a.language_registry().add(rust_lang()); - client_b.language_registry().add(rust_lang()); let mut fake_language_servers = client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { - capabilities: lsp::ServerCapabilities { - color_provider: Some(lsp::ColorProviderCapability::Simple(true)), - ..lsp::ServerCapabilities::default() - }, + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, ..FakeLspAdapter::default() }, ); @@ -2063,6 +2115,8 @@ async fn test_lsp_document_color(cx_a: &mut TestAppContext, cx_b: &mut TestAppCo .unwrap(); let fake_language_server = fake_language_servers.next().await.unwrap(); + cx_a.run_until_parked(); + cx_b.run_until_parked(); let requests_made = Arc::new(AtomicUsize::new(0)); let closure_requests_made = Arc::clone(&requests_made); @@ -2246,8 +2300,11 @@ async fn test_lsp_document_color(cx_a: &mut TestAppContext, cx_b: &mut TestAppCo }); } -#[gpui::test(iterations = 10)] -async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { +async fn test_lsp_pull_diagnostics( + should_stream_workspace_diagnostic: bool, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { let mut server = TestServer::start(cx_a.executor()).await; let executor = cx_a.executor(); let client_a = server.create_client(cx_a, "user_a").await; @@ -2261,24 +2318,32 @@ async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestApp cx_a.update(editor::init); cx_b.update(editor::init); + let capabilities = lsp::ServerCapabilities { + diagnostic_provider: Some(lsp::DiagnosticServerCapabilities::Options( + lsp::DiagnosticOptions { + identifier: Some("test-pulls".to_string()), + inter_file_dependencies: true, + workspace_diagnostics: true, + work_done_progress_options: lsp::WorkDoneProgressOptions { + work_done_progress: None, + }, + }, + )), + ..lsp::ServerCapabilities::default() + }; client_a.language_registry().add(rust_lang()); - client_b.language_registry().add(rust_lang()); let mut fake_language_servers = client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { - capabilities: lsp::ServerCapabilities { - diagnostic_provider: Some(lsp::DiagnosticServerCapabilities::Options( - lsp::DiagnosticOptions { - identifier: Some("test-pulls".to_string()), - inter_file_dependencies: true, - workspace_diagnostics: true, - work_done_progress_options: lsp::WorkDoneProgressOptions { - work_done_progress: None, - }, - }, - )), - ..lsp::ServerCapabilities::default() - }, + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, ..FakeLspAdapter::default() }, ); @@ -2331,6 +2396,8 @@ async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestApp .unwrap(); let fake_language_server = fake_language_servers.next().await.unwrap(); + cx_a.run_until_parked(); + cx_b.run_until_parked(); let expected_push_diagnostic_main_message = "pushed main diagnostic"; let expected_push_diagnostic_lib_message = "pushed lib diagnostic"; let expected_pull_diagnostic_main_message = "pulled main diagnostic"; @@ -2396,12 +2463,25 @@ async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestApp let closure_workspace_diagnostics_pulls_made = workspace_diagnostics_pulls_made.clone(); let closure_workspace_diagnostics_pulls_result_ids = workspace_diagnostics_pulls_result_ids.clone(); + let (workspace_diagnostic_cancel_tx, closure_workspace_diagnostic_cancel_rx) = + smol::channel::bounded::<()>(1); + let (closure_workspace_diagnostic_received_tx, workspace_diagnostic_received_rx) = + smol::channel::bounded::<()>(1); + let expected_workspace_diagnostic_token = lsp::ProgressToken::String(format!( + "workspace/diagnostic-{}-1", + fake_language_server.server.server_id() + )); + let closure_expected_workspace_diagnostic_token = expected_workspace_diagnostic_token.clone(); let mut workspace_diagnostics_pulls_handle = fake_language_server .set_request_handler::( move |params, _| { let workspace_requests_made = closure_workspace_diagnostics_pulls_made.clone(); let workspace_diagnostics_pulls_result_ids = closure_workspace_diagnostics_pulls_result_ids.clone(); + let workspace_diagnostic_cancel_rx = closure_workspace_diagnostic_cancel_rx.clone(); + let workspace_diagnostic_received_tx = closure_workspace_diagnostic_received_tx.clone(); + let expected_workspace_diagnostic_token = + closure_expected_workspace_diagnostic_token.clone(); async move { let workspace_request_count = workspace_requests_made.fetch_add(1, atomic::Ordering::Release) + 1; @@ -2411,6 +2491,21 @@ async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestApp .await .extend(params.previous_result_ids.into_iter().map(|id| id.value)); } + if should_stream_workspace_diagnostic && !workspace_diagnostic_cancel_rx.is_closed() + { + assert_eq!( + params.partial_result_params.partial_result_token, + Some(expected_workspace_diagnostic_token) + ); + workspace_diagnostic_received_tx.send(()).await.unwrap(); + workspace_diagnostic_cancel_rx.recv().await.unwrap(); + workspace_diagnostic_cancel_rx.close(); + // https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#partialResults + // > The final response has to be empty in terms of result values. + return Ok(lsp::WorkspaceDiagnosticReportResult::Report( + lsp::WorkspaceDiagnosticReport { items: Vec::new() }, + )); + } Ok(lsp::WorkspaceDiagnosticReportResult::Report( lsp::WorkspaceDiagnosticReport { items: vec![ @@ -2479,7 +2574,11 @@ async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestApp }, ); - workspace_diagnostics_pulls_handle.next().await.unwrap(); + if should_stream_workspace_diagnostic { + workspace_diagnostic_received_rx.recv().await.unwrap(); + } else { + workspace_diagnostics_pulls_handle.next().await.unwrap(); + } assert_eq!( 1, workspace_diagnostics_pulls_made.load(atomic::Ordering::Acquire), @@ -2503,10 +2602,10 @@ async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestApp "Expected single diagnostic, but got: {all_diagnostics:?}" ); let diagnostic = &all_diagnostics[0]; - let expected_messages = [ - expected_workspace_pull_diagnostics_main_message, - expected_pull_diagnostic_main_message, - ]; + let mut expected_messages = vec![expected_pull_diagnostic_main_message]; + if !should_stream_workspace_diagnostic { + expected_messages.push(expected_workspace_pull_diagnostics_main_message); + } assert!( expected_messages.contains(&diagnostic.diagnostic.message.as_str()), "Expected {expected_messages:?} on the host, but got: {}", @@ -2556,6 +2655,70 @@ async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestApp version: None, }, ); + + if should_stream_workspace_diagnostic { + fake_language_server.notify::(&lsp::ProgressParams { + token: expected_workspace_diagnostic_token.clone(), + value: lsp::ProgressParamsValue::WorkspaceDiagnostic( + lsp::WorkspaceDiagnosticReportResult::Report(lsp::WorkspaceDiagnosticReport { + items: vec![ + lsp::WorkspaceDocumentDiagnosticReport::Full( + lsp::WorkspaceFullDocumentDiagnosticReport { + uri: lsp::Url::from_file_path(path!("/a/main.rs")).unwrap(), + version: None, + full_document_diagnostic_report: + lsp::FullDocumentDiagnosticReport { + result_id: Some(format!( + "workspace_{}", + workspace_diagnostics_pulls_made + .fetch_add(1, atomic::Ordering::Release) + + 1 + )), + items: vec![lsp::Diagnostic { + range: lsp::Range { + start: lsp::Position { + line: 0, + character: 1, + }, + end: lsp::Position { + line: 0, + character: 2, + }, + }, + severity: Some(lsp::DiagnosticSeverity::ERROR), + message: + expected_workspace_pull_diagnostics_main_message + .to_string(), + ..lsp::Diagnostic::default() + }], + }, + }, + ), + lsp::WorkspaceDocumentDiagnosticReport::Full( + lsp::WorkspaceFullDocumentDiagnosticReport { + uri: lsp::Url::from_file_path(path!("/a/lib.rs")).unwrap(), + version: None, + full_document_diagnostic_report: + lsp::FullDocumentDiagnosticReport { + result_id: Some(format!( + "workspace_{}", + workspace_diagnostics_pulls_made + .fetch_add(1, atomic::Ordering::Release) + + 1 + )), + items: Vec::new(), + }, + }, + ), + ], + }), + ), + }); + }; + + let mut workspace_diagnostic_start_count = + workspace_diagnostics_pulls_made.load(atomic::Ordering::Acquire); + executor.run_until_parked(); editor_a_main.update(cx_a, |editor, cx| { let snapshot = editor.buffer().read(cx).snapshot(cx); @@ -2590,6 +2753,7 @@ async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestApp .unwrap() .downcast::() .unwrap(); + cx_b.run_until_parked(); pull_diagnostics_handle.next().await.unwrap(); assert_eq!( @@ -2599,7 +2763,7 @@ async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestApp ); executor.run_until_parked(); assert_eq!( - 1, + workspace_diagnostic_start_count, workspace_diagnostics_pulls_made.load(atomic::Ordering::Acquire), "Workspace diagnostics should not be changed as the remote client does not initialize the workspace diagnostics pull" ); @@ -2646,7 +2810,7 @@ async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestApp ); executor.run_until_parked(); assert_eq!( - 1, + workspace_diagnostic_start_count, workspace_diagnostics_pulls_made.load(atomic::Ordering::Acquire), "The remote client still did not anything to trigger the workspace diagnostics pull" ); @@ -2673,6 +2837,75 @@ async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestApp ); } }); + + if should_stream_workspace_diagnostic { + fake_language_server.notify::(&lsp::ProgressParams { + token: expected_workspace_diagnostic_token.clone(), + value: lsp::ProgressParamsValue::WorkspaceDiagnostic( + lsp::WorkspaceDiagnosticReportResult::Report(lsp::WorkspaceDiagnosticReport { + items: vec![lsp::WorkspaceDocumentDiagnosticReport::Full( + lsp::WorkspaceFullDocumentDiagnosticReport { + uri: lsp::Url::from_file_path(path!("/a/lib.rs")).unwrap(), + version: None, + full_document_diagnostic_report: lsp::FullDocumentDiagnosticReport { + result_id: Some(format!( + "workspace_{}", + workspace_diagnostics_pulls_made + .fetch_add(1, atomic::Ordering::Release) + + 1 + )), + items: vec![lsp::Diagnostic { + range: lsp::Range { + start: lsp::Position { + line: 0, + character: 1, + }, + end: lsp::Position { + line: 0, + character: 2, + }, + }, + severity: Some(lsp::DiagnosticSeverity::ERROR), + message: expected_workspace_pull_diagnostics_lib_message + .to_string(), + ..lsp::Diagnostic::default() + }], + }, + }, + )], + }), + ), + }); + workspace_diagnostic_start_count = + workspace_diagnostics_pulls_made.load(atomic::Ordering::Acquire); + workspace_diagnostic_cancel_tx.send(()).await.unwrap(); + workspace_diagnostics_pulls_handle.next().await.unwrap(); + executor.run_until_parked(); + editor_b_lib.update(cx_b, |editor, cx| { + let snapshot = editor.buffer().read(cx).snapshot(cx); + let all_diagnostics = snapshot + .diagnostics_in_range(0..snapshot.len()) + .collect::>(); + let expected_messages = [ + expected_workspace_pull_diagnostics_lib_message, + // TODO bug: the pushed diagnostics are not being sent to the client when they open the corresponding buffer. + // expected_push_diagnostic_lib_message, + ]; + assert_eq!( + all_diagnostics.len(), + 1, + "Expected pull diagnostics, but got: {all_diagnostics:?}" + ); + for diagnostic in all_diagnostics { + assert!( + expected_messages.contains(&diagnostic.diagnostic.message.as_str()), + "The client should get both push and pull messages: {expected_messages:?}, but got: {}", + diagnostic.diagnostic.message + ); + } + }); + }; + { assert!( diagnostics_pulls_result_ids.lock().await.len() > 0, @@ -2701,7 +2934,7 @@ async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestApp ); workspace_diagnostics_pulls_handle.next().await.unwrap(); assert_eq!( - 2, + workspace_diagnostic_start_count + 1, workspace_diagnostics_pulls_made.load(atomic::Ordering::Acquire), "After client lib.rs edits, the workspace diagnostics request should follow" ); @@ -2720,7 +2953,7 @@ async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestApp ); workspace_diagnostics_pulls_handle.next().await.unwrap(); assert_eq!( - 3, + workspace_diagnostic_start_count + 2, workspace_diagnostics_pulls_made.load(atomic::Ordering::Acquire), "After client main.rs edits, the workspace diagnostics pull should follow" ); @@ -2739,7 +2972,7 @@ async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestApp ); workspace_diagnostics_pulls_handle.next().await.unwrap(); assert_eq!( - 4, + workspace_diagnostic_start_count + 3, workspace_diagnostics_pulls_made.load(atomic::Ordering::Acquire), "After host main.rs edits, the workspace diagnostics pull should follow" ); @@ -2769,7 +3002,7 @@ async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestApp ); workspace_diagnostics_pulls_handle.next().await.unwrap(); assert_eq!( - 5, + workspace_diagnostic_start_count + 4, workspace_diagnostics_pulls_made.load(atomic::Ordering::Acquire), "Another workspace diagnostics pull should happen after the diagnostics refresh server request" ); @@ -2840,6 +3073,19 @@ async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestApp }); } +#[gpui::test(iterations = 10)] +async fn test_non_streamed_lsp_pull_diagnostics( + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + test_lsp_pull_diagnostics(false, cx_a, cx_b).await; +} + +#[gpui::test(iterations = 10)] +async fn test_streamed_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { + test_lsp_pull_diagnostics(true, cx_a, cx_b).await; +} + #[gpui::test(iterations = 10)] async fn test_git_blame_is_forwarded(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { let mut server = TestServer::start(cx_a.executor()).await; @@ -3537,11 +3783,18 @@ async fn test_client_can_query_lsp_ext(cx_a: &mut TestAppContext, cx_b: &mut Tes cx_b.update(editor::init); client_a.language_registry().add(rust_lang()); - client_b.language_registry().add(rust_lang()); let mut fake_language_servers = client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { - name: RUST_ANALYZER_NAME, + name: "rust-analyzer", + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + name: "rust-analyzer", ..FakeLspAdapter::default() }, ); diff --git a/crates/collab/src/tests/following_tests.rs b/crates/collab/src/tests/following_tests.rs index a77112213f195190e613c2382300bfbbeca70066..d9fd8ffeb2a6c693c3570409070f7a0fbfe33ea2 100644 --- a/crates/collab/src/tests/following_tests.rs +++ b/crates/collab/src/tests/following_tests.rs @@ -439,7 +439,7 @@ async fn test_basic_following( editor_a1.item_id() ); - #[cfg(all(not(target_os = "macos"), not(target_os = "windows")))] + // #[cfg(all(not(target_os = "macos"), not(target_os = "windows")))] { use crate::rpc::RECONNECT_TIMEOUT; use gpui::TestScreenCaptureSource; @@ -456,11 +456,19 @@ async fn test_basic_following( .await .unwrap(); cx_b.set_screen_capture_sources(vec![display]); + let source = cx_b + .read(|cx| cx.screen_capture_sources()) + .await + .unwrap() + .unwrap() + .into_iter() + .next() + .unwrap(); active_call_b .update(cx_b, |call, cx| { call.room() .unwrap() - .update(cx, |room, cx| room.share_screen(cx)) + .update(cx, |room, cx| room.share_screen(source, cx)) }) .await .unwrap(); @@ -1013,7 +1021,7 @@ async fn test_peers_following_each_other(cx_a: &mut TestAppContext, cx_b: &mut T // and some of which were originally opened by client B. workspace_b.update_in(cx_b, |workspace, window, cx| { workspace.active_pane().update(cx, |pane, cx| { - pane.close_inactive_items(&Default::default(), window, cx) + pane.close_other_items(&Default::default(), None, window, cx) .detach(); }); }); diff --git a/crates/collab/src/tests/integration_tests.rs b/crates/collab/src/tests/integration_tests.rs index d1099a327a4d090dcd26fff8d5308e36922a49b6..5a2c40b890cfe32510347c33a1257af4cbea0768 100644 --- a/crates/collab/src/tests/integration_tests.rs +++ b/crates/collab/src/tests/integration_tests.rs @@ -277,11 +277,19 @@ async fn test_basic_calls( let events_b = active_call_events(cx_b); let events_c = active_call_events(cx_c); cx_a.set_screen_capture_sources(vec![display]); + let screen_a = cx_a + .update(|cx| cx.screen_capture_sources()) + .await + .unwrap() + .unwrap() + .into_iter() + .next() + .unwrap(); active_call_a .update(cx_a, |call, cx| { call.room() .unwrap() - .update(cx, |room, cx| room.share_screen(cx)) + .update(cx, |room, cx| room.share_screen(screen_a, cx)) }) .await .unwrap(); @@ -834,7 +842,7 @@ async fn test_client_disconnecting_from_room( // Allow user A to reconnect to the server. server.allow_connections(); - executor.advance_clock(RECEIVE_TIMEOUT); + executor.advance_clock(RECONNECT_TIMEOUT); // Call user B again from client A. active_call_a @@ -1278,7 +1286,7 @@ async fn test_calls_on_multiple_connections( client_b1.disconnect(&cx_b1.to_async()); executor.advance_clock(RECEIVE_TIMEOUT); client_b1 - .authenticate_and_connect(false, &cx_b1.to_async()) + .connect(false, &cx_b1.to_async()) .await .into_response() .unwrap(); @@ -1350,7 +1358,7 @@ async fn test_calls_on_multiple_connections( // User A reconnects automatically, then calls user B again. server.allow_connections(); - executor.advance_clock(RECEIVE_TIMEOUT); + executor.advance_clock(RECONNECT_TIMEOUT); active_call_a .update(cx_a, |call, cx| { call.invite(client_b1.user_id().unwrap(), None, cx) @@ -1659,7 +1667,7 @@ async fn test_project_reconnect( // Client A reconnects. Their project is re-shared, and client B re-joins it. server.allow_connections(); client_a - .authenticate_and_connect(false, &cx_a.to_async()) + .connect(false, &cx_a.to_async()) .await .into_response() .unwrap(); @@ -1788,7 +1796,7 @@ async fn test_project_reconnect( // Client B reconnects. They re-join the room and the remaining shared project. server.allow_connections(); client_b - .authenticate_and_connect(false, &cx_b.to_async()) + .connect(false, &cx_b.to_async()) .await .into_response() .unwrap(); @@ -1873,7 +1881,7 @@ async fn test_active_call_events( vec![room::Event::RemoteProjectShared { owner: Arc::new(User { id: client_a.user_id().unwrap(), - github_login: "user_a".to_string(), + github_login: "user_a".into(), avatar_uri: "avatar_a".into(), name: None, }), @@ -1892,7 +1900,7 @@ async fn test_active_call_events( vec![room::Event::RemoteProjectShared { owner: Arc::new(User { id: client_b.user_id().unwrap(), - github_login: "user_b".to_string(), + github_login: "user_b".into(), avatar_uri: "avatar_b".into(), name: None, }), @@ -4770,10 +4778,27 @@ async fn test_definition( .await; let active_call_a = cx_a.read(ActiveCall::global); - let mut fake_language_servers = client_a - .language_registry() - .register_fake_lsp("Rust", Default::default()); + let capabilities = lsp::ServerCapabilities { + definition_provider: Some(OneOf::Left(true)), + type_definition_provider: Some(lsp::TypeDefinitionProviderCapability::Simple(true)), + ..lsp::ServerCapabilities::default() + }; client_a.language_registry().add(rust_lang()); + let mut fake_language_servers = client_a.language_registry().register_fake_lsp( + "Rust", + FakeLspAdapter { + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, + ..FakeLspAdapter::default() + }, + ); client_a .fs() @@ -4819,13 +4844,19 @@ async fn test_definition( ))) }, ); + cx_a.run_until_parked(); + cx_b.run_until_parked(); let definitions_1 = project_b .update(cx_b, |p, cx| p.definitions(&buffer_b, 23, cx)) .await .unwrap(); cx_b.read(|cx| { - assert_eq!(definitions_1.len(), 1); + assert_eq!( + definitions_1.len(), + 1, + "Unexpected definitions: {definitions_1:?}" + ); assert_eq!(project_b.read(cx).worktrees(cx).count(), 2); let target_buffer = definitions_1[0].target.buffer.read(cx); assert_eq!( @@ -4893,7 +4924,11 @@ async fn test_definition( .await .unwrap(); cx_b.read(|cx| { - assert_eq!(type_definitions.len(), 1); + assert_eq!( + type_definitions.len(), + 1, + "Unexpected type definitions: {type_definitions:?}" + ); let target_buffer = type_definitions[0].target.buffer.read(cx); assert_eq!(target_buffer.text(), "type T2 = usize;"); assert_eq!( @@ -4917,16 +4952,26 @@ async fn test_references( .await; let active_call_a = cx_a.read(ActiveCall::global); + let capabilities = lsp::ServerCapabilities { + references_provider: Some(lsp::OneOf::Left(true)), + ..lsp::ServerCapabilities::default() + }; client_a.language_registry().add(rust_lang()); let mut fake_language_servers = client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { name: "my-fake-lsp-adapter", - capabilities: lsp::ServerCapabilities { - references_provider: Some(lsp::OneOf::Left(true)), - ..Default::default() - }, - ..Default::default() + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + name: "my-fake-lsp-adapter", + capabilities: capabilities, + ..FakeLspAdapter::default() }, ); @@ -4981,6 +5026,8 @@ async fn test_references( } } }); + cx_a.run_until_parked(); + cx_b.run_until_parked(); let references = project_b.update(cx_b, |p, cx| p.references(&buffer_b, 7, cx)); @@ -4988,7 +5035,7 @@ async fn test_references( executor.run_until_parked(); project_b.read_with(cx_b, |project, cx| { let status = project.language_server_statuses(cx).next().unwrap().1; - assert_eq!(status.name, "my-fake-lsp-adapter"); + assert_eq!(status.name.0, "my-fake-lsp-adapter"); assert_eq!( status.pending_work.values().next().unwrap().message, Some("Finding references...".into()) @@ -5046,7 +5093,7 @@ async fn test_references( executor.run_until_parked(); project_b.read_with(cx_b, |project, cx| { let status = project.language_server_statuses(cx).next().unwrap().1; - assert_eq!(status.name, "my-fake-lsp-adapter"); + assert_eq!(status.name.0, "my-fake-lsp-adapter"); assert_eq!( status.pending_work.values().next().unwrap().message, Some("Finding references...".into()) @@ -5196,10 +5243,26 @@ async fn test_document_highlights( ) .await; - let mut fake_language_servers = client_a - .language_registry() - .register_fake_lsp("Rust", Default::default()); client_a.language_registry().add(rust_lang()); + let capabilities = lsp::ServerCapabilities { + document_highlight_provider: Some(lsp::OneOf::Left(true)), + ..lsp::ServerCapabilities::default() + }; + let mut fake_language_servers = client_a.language_registry().register_fake_lsp( + "Rust", + FakeLspAdapter { + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, + ..FakeLspAdapter::default() + }, + ); let (project_a, worktree_id) = client_a.build_local_project(path!("/root-1"), cx_a).await; let project_id = active_call_a @@ -5248,6 +5311,8 @@ async fn test_document_highlights( ])) }, ); + cx_a.run_until_parked(); + cx_b.run_until_parked(); let highlights = project_b .update(cx_b, |p, cx| p.document_highlights(&buffer_b, 34, cx)) @@ -5298,30 +5363,49 @@ async fn test_lsp_hover( client_a.language_registry().add(rust_lang()); let language_server_names = ["rust-analyzer", "CrabLang-ls"]; + let capabilities_1 = lsp::ServerCapabilities { + hover_provider: Some(lsp::HoverProviderCapability::Simple(true)), + ..lsp::ServerCapabilities::default() + }; + let capabilities_2 = lsp::ServerCapabilities { + hover_provider: Some(lsp::HoverProviderCapability::Simple(true)), + ..lsp::ServerCapabilities::default() + }; let mut language_servers = [ client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { - name: "rust-analyzer", - capabilities: lsp::ServerCapabilities { - hover_provider: Some(lsp::HoverProviderCapability::Simple(true)), - ..lsp::ServerCapabilities::default() - }, + name: language_server_names[0], + capabilities: capabilities_1.clone(), ..FakeLspAdapter::default() }, ), client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { - name: "CrabLang-ls", - capabilities: lsp::ServerCapabilities { - hover_provider: Some(lsp::HoverProviderCapability::Simple(true)), - ..lsp::ServerCapabilities::default() - }, + name: language_server_names[1], + capabilities: capabilities_2.clone(), ..FakeLspAdapter::default() }, ), ]; + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + name: language_server_names[0], + capabilities: capabilities_1, + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + name: language_server_names[1], + capabilities: capabilities_2, + ..FakeLspAdapter::default() + }, + ); let (project_a, worktree_id) = client_a.build_local_project(path!("/root-1"), cx_a).await; let project_id = active_call_a @@ -5415,6 +5499,8 @@ async fn test_lsp_hover( unexpected => panic!("Unexpected server name: {unexpected}"), } } + cx_a.run_until_parked(); + cx_b.run_until_parked(); // Request hover information as the guest. let mut hovers = project_b @@ -5597,10 +5683,26 @@ async fn test_open_buffer_while_getting_definition_pointing_to_it( .await; let active_call_a = cx_a.read(ActiveCall::global); + let capabilities = lsp::ServerCapabilities { + definition_provider: Some(OneOf::Left(true)), + ..lsp::ServerCapabilities::default() + }; client_a.language_registry().add(rust_lang()); - let mut fake_language_servers = client_a - .language_registry() - .register_fake_lsp("Rust", Default::default()); + let mut fake_language_servers = client_a.language_registry().register_fake_lsp( + "Rust", + FakeLspAdapter { + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, + ..FakeLspAdapter::default() + }, + ); client_a .fs() @@ -5641,6 +5743,8 @@ async fn test_open_buffer_while_getting_definition_pointing_to_it( let definitions; let buffer_b2; if rng.r#gen() { + cx_a.run_until_parked(); + cx_b.run_until_parked(); definitions = project_b.update(cx_b, |p, cx| p.definitions(&buffer_b1, 23, cx)); (buffer_b2, _) = project_b .update(cx_b, |p, cx| { @@ -5655,11 +5759,17 @@ async fn test_open_buffer_while_getting_definition_pointing_to_it( }) .await .unwrap(); + cx_a.run_until_parked(); + cx_b.run_until_parked(); definitions = project_b.update(cx_b, |p, cx| p.definitions(&buffer_b1, 23, cx)); } let definitions = definitions.await.unwrap(); - assert_eq!(definitions.len(), 1); + assert_eq!( + definitions.len(), + 1, + "Unexpected definitions: {definitions:?}" + ); assert_eq!(definitions[0].target.buffer, buffer_b2); } @@ -5730,7 +5840,7 @@ async fn test_contacts( server.allow_connections(); client_c - .authenticate_and_connect(false, &cx_c.to_async()) + .connect(false, &cx_c.to_async()) .await .into_response() .unwrap(); @@ -6071,7 +6181,7 @@ async fn test_contacts( .iter() .map(|contact| { ( - contact.user.github_login.clone(), + contact.user.github_login.clone().to_string(), if contact.online { "online" } else { "offline" }, if contact.busy { "busy" } else { "free" }, ) @@ -6261,7 +6371,7 @@ async fn test_contact_requests( client.disconnect(&cx.to_async()); client.clear_contacts(cx).await; client - .authenticate_and_connect(false, &cx.to_async()) + .connect(false, &cx.to_async()) .await .into_response() .unwrap(); @@ -6312,11 +6422,20 @@ async fn test_join_call_after_screen_was_shared( // User A shares their screen let display = gpui::TestScreenCaptureSource::new(); cx_a.set_screen_capture_sources(vec![display]); + let screen_a = cx_a + .update(|cx| cx.screen_capture_sources()) + .await + .unwrap() + .unwrap() + .into_iter() + .next() + .unwrap(); + active_call_a .update(cx_a, |call, cx| { call.room() .unwrap() - .update(cx, |room, cx| room.share_screen(cx)) + .update(cx, |room, cx| room.share_screen(screen_a, cx)) }) .await .unwrap(); diff --git a/crates/collab/src/tests/notification_tests.rs b/crates/collab/src/tests/notification_tests.rs index 4e64b5526bc3554d2ee15a3acf0fb79d9859166f..9bf906694ef84e66f529269fdadfd1fe6ad1fd08 100644 --- a/crates/collab/src/tests/notification_tests.rs +++ b/crates/collab/src/tests/notification_tests.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use gpui::{BackgroundExecutor, TestAppContext}; use notifications::NotificationEvent; use parking_lot::Mutex; +use pretty_assertions::assert_eq; use rpc::{Notification, proto}; use crate::tests::TestServer; @@ -17,6 +18,9 @@ async fn test_notifications( let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; + // Wait for authentication/connection to Collab to be established. + executor.run_until_parked(); + let notification_events_a = Arc::new(Mutex::new(Vec::new())); let notification_events_b = Arc::new(Mutex::new(Vec::new())); client_a.notification_store().update(cx_a, |_, cx| { diff --git a/crates/collab/src/tests/remote_editing_collaboration_tests.rs b/crates/collab/src/tests/remote_editing_collaboration_tests.rs index 7aeb381c02beeb6165e44ccd5bbd72f5744cc964..8ab6e6910c88880bc8b6451d972e39b5c2315812 100644 --- a/crates/collab/src/tests/remote_editing_collaboration_tests.rs +++ b/crates/collab/src/tests/remote_editing_collaboration_tests.rs @@ -2,6 +2,7 @@ use crate::tests::TestServer; use call::ActiveCall; use collections::{HashMap, HashSet}; +use dap::{Capabilities, adapters::DebugTaskDefinition, transport::RequestHandling}; use debugger_ui::debugger_panel::DebugPanel; use extension::ExtensionHostProxy; use fs::{FakeFs, Fs as _, RemoveOptions}; @@ -22,6 +23,7 @@ use language::{ use node_runtime::NodeRuntime; use project::{ ProjectPath, + debugger::session::ThreadId, lsp_store::{FormatTrigger, LspFormatTarget}, }; use remote::SshRemoteClient; @@ -29,7 +31,11 @@ use remote_server::{HeadlessAppState, HeadlessProject}; use rpc::proto; use serde_json::json; use settings::SettingsStore; -use std::{path::Path, sync::Arc}; +use std::{ + path::Path, + sync::{Arc, atomic::AtomicUsize}, +}; +use task::TcpArgumentsTemplate; use util::path; #[gpui::test(iterations = 10)] @@ -688,3 +694,162 @@ async fn test_remote_server_debugger( shutdown_session.await.unwrap(); } + +#[gpui::test] +async fn test_slow_adapter_startup_retries( + cx_a: &mut TestAppContext, + server_cx: &mut TestAppContext, + executor: BackgroundExecutor, +) { + cx_a.update(|cx| { + release_channel::init(SemanticVersion::default(), cx); + command_palette_hooks::init(cx); + zlog::init_test(); + dap_adapters::init(cx); + }); + server_cx.update(|cx| { + release_channel::init(SemanticVersion::default(), cx); + dap_adapters::init(cx); + }); + let (opts, server_ssh) = SshRemoteClient::fake_server(cx_a, server_cx); + let remote_fs = FakeFs::new(server_cx.executor()); + remote_fs + .insert_tree( + path!("/code"), + json!({ + "lib.rs": "fn one() -> usize { 1 }" + }), + ) + .await; + + // User A connects to the remote project via SSH. + server_cx.update(HeadlessProject::init); + let remote_http_client = Arc::new(BlockedHttpClient); + let node = NodeRuntime::unavailable(); + let languages = Arc::new(LanguageRegistry::new(server_cx.executor())); + let _headless_project = server_cx.new(|cx| { + client::init_settings(cx); + HeadlessProject::new( + HeadlessAppState { + session: server_ssh, + fs: remote_fs.clone(), + http_client: remote_http_client, + node_runtime: node, + languages, + extension_host_proxy: Arc::new(ExtensionHostProxy::new()), + }, + cx, + ) + }); + + let client_ssh = SshRemoteClient::fake_client(opts, cx_a).await; + let mut server = TestServer::start(server_cx.executor()).await; + let client_a = server.create_client(cx_a, "user_a").await; + cx_a.update(|cx| { + debugger_ui::init(cx); + command_palette_hooks::init(cx); + }); + let (project_a, _) = client_a + .build_ssh_project(path!("/code"), client_ssh.clone(), cx_a) + .await; + + let (workspace, cx_a) = client_a.build_workspace(&project_a, cx_a); + + let debugger_panel = workspace + .update_in(cx_a, |_workspace, window, cx| { + cx.spawn_in(window, DebugPanel::load) + }) + .await + .unwrap(); + + workspace.update_in(cx_a, |workspace, window, cx| { + workspace.add_panel(debugger_panel, window, cx); + }); + + cx_a.run_until_parked(); + let debug_panel = workspace + .update(cx_a, |workspace, cx| workspace.panel::(cx)) + .unwrap(); + + let workspace_window = cx_a + .window_handle() + .downcast::() + .unwrap(); + + let count = Arc::new(AtomicUsize::new(0)); + let session = debugger_ui::tests::start_debug_session_with( + &workspace_window, + cx_a, + DebugTaskDefinition { + adapter: "fake-adapter".into(), + label: "test".into(), + config: json!({ + "request": "launch" + }), + tcp_connection: Some(TcpArgumentsTemplate { + port: None, + host: None, + timeout: None, + }), + }, + move |client| { + let count = count.clone(); + client.on_request_ext::(move |_seq, _request| { + if count.fetch_add(1, std::sync::atomic::Ordering::SeqCst) < 5 { + return RequestHandling::Exit; + } + RequestHandling::Respond(Ok(Capabilities::default())) + }); + }, + ) + .unwrap(); + cx_a.run_until_parked(); + + let client = session.update(cx_a, |session, _| session.adapter_client().unwrap()); + client + .fake_event(dap::messages::Events::Stopped(dap::StoppedEvent { + reason: dap::StoppedEventReason::Pause, + description: None, + thread_id: Some(1), + preserve_focus_hint: None, + text: None, + all_threads_stopped: None, + hit_breakpoint_ids: None, + })) + .await; + + cx_a.run_until_parked(); + + let active_session = debug_panel + .update(cx_a, |this, _| this.active_session()) + .unwrap(); + + let running_state = active_session.update(cx_a, |active_session, _| { + active_session.running_state().clone() + }); + + assert_eq!( + client.id(), + running_state.read_with(cx_a, |running_state, _| running_state.session_id()) + ); + assert_eq!( + ThreadId(1), + running_state.read_with(cx_a, |running_state, _| running_state + .selected_thread_id() + .unwrap()) + ); + + let shutdown_session = workspace.update(cx_a, |workspace, cx| { + workspace.project().update(cx, |project, cx| { + project.dap_store().update(cx, |dap_store, cx| { + dap_store.shutdown_session(session.read(cx).session_id(), cx) + }) + }) + }); + + client_ssh.update(cx_a, |a, _| { + a.shutdown_processes(Some(proto::ShutdownRemoteServer {}), executor) + }); + + shutdown_session.await.unwrap(); +} diff --git a/crates/collab/src/tests/stripe_billing_tests.rs b/crates/collab/src/tests/stripe_billing_tests.rs index c19eb0a23432fb835b99007b0ebca2e4a5a8f2e6..bb84bedfcfc1fb4f95724f60bbd80707b12c215a 100644 --- a/crates/collab/src/tests/stripe_billing_tests.rs +++ b/crates/collab/src/tests/stripe_billing_tests.rs @@ -1,20 +1,9 @@ use std::sync::Arc; -use chrono::{Duration, Utc}; use pretty_assertions::assert_eq; -use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG; use crate::stripe_billing::StripeBilling; -use crate::stripe_client::{ - FakeStripeClient, StripeBillingAddressCollection, StripeCheckoutSessionMode, - StripeCheckoutSessionPaymentMethodCollection, StripeCreateCheckoutSessionLineItems, - StripeCreateCheckoutSessionSubscriptionData, StripeCustomerId, StripeCustomerUpdate, - StripeCustomerUpdateAddress, StripeCustomerUpdateName, StripeMeter, StripeMeterId, StripePrice, - StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId, - StripeSubscriptionItem, StripeSubscriptionItemId, StripeSubscriptionTrialSettings, - StripeSubscriptionTrialSettingsEndBehavior, - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems, -}; +use crate::stripe_client::{FakeStripeClient, StripePrice, StripePriceId, StripePriceRecurring}; fn make_stripe_billing() -> (StripeBilling, Arc) { let stripe_client = Arc::new(FakeStripeClient::new()); @@ -27,24 +16,6 @@ fn make_stripe_billing() -> (StripeBilling, Arc) { async fn test_initialize() { let (stripe_billing, stripe_client) = make_stripe_billing(); - // Add test meters - let meter1 = StripeMeter { - id: StripeMeterId("meter_1".into()), - event_name: "event_1".to_string(), - }; - let meter2 = StripeMeter { - id: StripeMeterId("meter_2".into()), - event_name: "event_2".to_string(), - }; - stripe_client - .meters - .lock() - .insert(meter1.id.clone(), meter1); - stripe_client - .meters - .lock() - .insert(meter2.id.clone(), meter2); - // Add test prices let price1 = StripePrice { id: StripePriceId("price_1".into()), @@ -150,454 +121,3 @@ async fn test_find_or_create_customer_by_email() { assert_eq!(customer.email.as_deref(), Some(email)); } } - -#[gpui::test] -async fn test_subscribe_to_price() { - let (stripe_billing, stripe_client) = make_stripe_billing(); - - let price = StripePrice { - id: StripePriceId("price_test".into()), - unit_amount: Some(2000), - lookup_key: Some("test-price".to_string()), - recurring: None, - }; - stripe_client - .prices - .lock() - .insert(price.id.clone(), price.clone()); - - let now = Utc::now(); - let subscription = StripeSubscription { - id: StripeSubscriptionId("sub_test".into()), - customer: StripeCustomerId("cus_test".into()), - status: stripe::SubscriptionStatus::Active, - current_period_start: now.timestamp(), - current_period_end: (now + Duration::days(30)).timestamp(), - items: vec![], - cancel_at: None, - cancellation_details: None, - }; - stripe_client - .subscriptions - .lock() - .insert(subscription.id.clone(), subscription.clone()); - - stripe_billing - .subscribe_to_price(&subscription.id, &price) - .await - .unwrap(); - - let update_subscription_calls = stripe_client - .update_subscription_calls - .lock() - .iter() - .map(|(id, params)| (id.clone(), params.clone())) - .collect::>(); - assert_eq!(update_subscription_calls.len(), 1); - assert_eq!(update_subscription_calls[0].0, subscription.id); - assert_eq!( - update_subscription_calls[0].1.items, - Some(vec![UpdateSubscriptionItems { - price: Some(price.id.clone()) - }]) - ); - - // Subscribing to a price that is already on the subscription is a no-op. - { - let now = Utc::now(); - let subscription = StripeSubscription { - id: StripeSubscriptionId("sub_test".into()), - customer: StripeCustomerId("cus_test".into()), - status: stripe::SubscriptionStatus::Active, - current_period_start: now.timestamp(), - current_period_end: (now + Duration::days(30)).timestamp(), - items: vec![StripeSubscriptionItem { - id: StripeSubscriptionItemId("si_test".into()), - price: Some(price.clone()), - }], - cancel_at: None, - cancellation_details: None, - }; - stripe_client - .subscriptions - .lock() - .insert(subscription.id.clone(), subscription.clone()); - - stripe_billing - .subscribe_to_price(&subscription.id, &price) - .await - .unwrap(); - - assert_eq!(stripe_client.update_subscription_calls.lock().len(), 1); - } -} - -#[gpui::test] -async fn test_subscribe_to_zed_free() { - let (stripe_billing, stripe_client) = make_stripe_billing(); - - let zed_pro_price = StripePrice { - id: StripePriceId("price_1".into()), - unit_amount: Some(0), - lookup_key: Some("zed-pro".to_string()), - recurring: None, - }; - stripe_client - .prices - .lock() - .insert(zed_pro_price.id.clone(), zed_pro_price.clone()); - let zed_free_price = StripePrice { - id: StripePriceId("price_2".into()), - unit_amount: Some(0), - lookup_key: Some("zed-free".to_string()), - recurring: None, - }; - stripe_client - .prices - .lock() - .insert(zed_free_price.id.clone(), zed_free_price.clone()); - - stripe_billing.initialize().await.unwrap(); - - // Customer is subscribed to Zed Free when not already subscribed to a plan. - { - let customer_id = StripeCustomerId("cus_no_plan".into()); - - let subscription = stripe_billing - .subscribe_to_zed_free(customer_id) - .await - .unwrap(); - - assert_eq!(subscription.items[0].price.as_ref(), Some(&zed_free_price)); - } - - // Customer is not subscribed to Zed Free when they already have an active subscription. - { - let customer_id = StripeCustomerId("cus_active_subscription".into()); - - let now = Utc::now(); - let existing_subscription = StripeSubscription { - id: StripeSubscriptionId("sub_existing_active".into()), - customer: customer_id.clone(), - status: stripe::SubscriptionStatus::Active, - current_period_start: now.timestamp(), - current_period_end: (now + Duration::days(30)).timestamp(), - items: vec![StripeSubscriptionItem { - id: StripeSubscriptionItemId("si_test".into()), - price: Some(zed_pro_price.clone()), - }], - cancel_at: None, - cancellation_details: None, - }; - stripe_client.subscriptions.lock().insert( - existing_subscription.id.clone(), - existing_subscription.clone(), - ); - - let subscription = stripe_billing - .subscribe_to_zed_free(customer_id) - .await - .unwrap(); - - assert_eq!(subscription, existing_subscription); - } - - // Customer is not subscribed to Zed Free when they already have a trial subscription. - { - let customer_id = StripeCustomerId("cus_trial_subscription".into()); - - let now = Utc::now(); - let existing_subscription = StripeSubscription { - id: StripeSubscriptionId("sub_existing_trial".into()), - customer: customer_id.clone(), - status: stripe::SubscriptionStatus::Trialing, - current_period_start: now.timestamp(), - current_period_end: (now + Duration::days(14)).timestamp(), - items: vec![StripeSubscriptionItem { - id: StripeSubscriptionItemId("si_test".into()), - price: Some(zed_pro_price.clone()), - }], - cancel_at: None, - cancellation_details: None, - }; - stripe_client.subscriptions.lock().insert( - existing_subscription.id.clone(), - existing_subscription.clone(), - ); - - let subscription = stripe_billing - .subscribe_to_zed_free(customer_id) - .await - .unwrap(); - - assert_eq!(subscription, existing_subscription); - } -} - -#[gpui::test] -async fn test_bill_model_request_usage() { - let (stripe_billing, stripe_client) = make_stripe_billing(); - - let customer_id = StripeCustomerId("cus_test".into()); - - stripe_billing - .bill_model_request_usage(&customer_id, "some_model/requests", 73) - .await - .unwrap(); - - let create_meter_event_calls = stripe_client - .create_meter_event_calls - .lock() - .iter() - .cloned() - .collect::>(); - assert_eq!(create_meter_event_calls.len(), 1); - assert!( - create_meter_event_calls[0] - .identifier - .starts_with("model_requests/") - ); - assert_eq!(create_meter_event_calls[0].stripe_customer_id, customer_id); - assert_eq!( - create_meter_event_calls[0].event_name.as_ref(), - "some_model/requests" - ); - assert_eq!(create_meter_event_calls[0].value, 73); -} - -#[gpui::test] -async fn test_checkout_with_zed_pro() { - let (stripe_billing, stripe_client) = make_stripe_billing(); - - let customer_id = StripeCustomerId("cus_test".into()); - let github_login = "zeduser1"; - let success_url = "https://example.com/success"; - - // It returns an error when the Zed Pro price doesn't exist. - { - let result = stripe_billing - .checkout_with_zed_pro(&customer_id, github_login, success_url) - .await; - - assert!(result.is_err()); - assert_eq!( - result.err().unwrap().to_string(), - r#"no price ID found for "zed-pro""# - ); - } - - // Successful checkout. - { - let price = StripePrice { - id: StripePriceId("price_1".into()), - unit_amount: Some(2000), - lookup_key: Some("zed-pro".to_string()), - recurring: None, - }; - stripe_client - .prices - .lock() - .insert(price.id.clone(), price.clone()); - - stripe_billing.initialize().await.unwrap(); - - let checkout_url = stripe_billing - .checkout_with_zed_pro(&customer_id, github_login, success_url) - .await - .unwrap(); - - assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay")); - - let create_checkout_session_calls = stripe_client - .create_checkout_session_calls - .lock() - .drain(..) - .collect::>(); - assert_eq!(create_checkout_session_calls.len(), 1); - let call = create_checkout_session_calls.into_iter().next().unwrap(); - assert_eq!(call.customer, Some(customer_id)); - assert_eq!(call.client_reference_id.as_deref(), Some(github_login)); - assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription)); - assert_eq!( - call.line_items, - Some(vec![StripeCreateCheckoutSessionLineItems { - price: Some(price.id.to_string()), - quantity: Some(1) - }]) - ); - assert_eq!(call.payment_method_collection, None); - assert_eq!(call.subscription_data, None); - assert_eq!(call.success_url.as_deref(), Some(success_url)); - assert_eq!( - call.billing_address_collection, - Some(StripeBillingAddressCollection::Required) - ); - assert_eq!( - call.customer_update, - Some(StripeCustomerUpdate { - address: Some(StripeCustomerUpdateAddress::Auto), - name: Some(StripeCustomerUpdateName::Auto), - shipping: None, - }) - ); - } -} - -#[gpui::test] -async fn test_checkout_with_zed_pro_trial() { - let (stripe_billing, stripe_client) = make_stripe_billing(); - - let customer_id = StripeCustomerId("cus_test".into()); - let github_login = "zeduser1"; - let success_url = "https://example.com/success"; - - // It returns an error when the Zed Pro price doesn't exist. - { - let result = stripe_billing - .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url) - .await; - - assert!(result.is_err()); - assert_eq!( - result.err().unwrap().to_string(), - r#"no price ID found for "zed-pro""# - ); - } - - let price = StripePrice { - id: StripePriceId("price_1".into()), - unit_amount: Some(2000), - lookup_key: Some("zed-pro".to_string()), - recurring: None, - }; - stripe_client - .prices - .lock() - .insert(price.id.clone(), price.clone()); - - stripe_billing.initialize().await.unwrap(); - - // Successful checkout. - { - let checkout_url = stripe_billing - .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url) - .await - .unwrap(); - - assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay")); - - let create_checkout_session_calls = stripe_client - .create_checkout_session_calls - .lock() - .drain(..) - .collect::>(); - assert_eq!(create_checkout_session_calls.len(), 1); - let call = create_checkout_session_calls.into_iter().next().unwrap(); - assert_eq!(call.customer.as_ref(), Some(&customer_id)); - assert_eq!(call.client_reference_id.as_deref(), Some(github_login)); - assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription)); - assert_eq!( - call.line_items, - Some(vec![StripeCreateCheckoutSessionLineItems { - price: Some(price.id.to_string()), - quantity: Some(1) - }]) - ); - assert_eq!( - call.payment_method_collection, - Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired) - ); - assert_eq!( - call.subscription_data, - Some(StripeCreateCheckoutSessionSubscriptionData { - trial_period_days: Some(14), - trial_settings: Some(StripeSubscriptionTrialSettings { - end_behavior: StripeSubscriptionTrialSettingsEndBehavior { - missing_payment_method: - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel, - }, - }), - metadata: None, - }) - ); - assert_eq!(call.success_url.as_deref(), Some(success_url)); - assert_eq!( - call.billing_address_collection, - Some(StripeBillingAddressCollection::Required) - ); - assert_eq!( - call.customer_update, - Some(StripeCustomerUpdate { - address: Some(StripeCustomerUpdateAddress::Auto), - name: Some(StripeCustomerUpdateName::Auto), - shipping: None, - }) - ); - } - - // Successful checkout with extended trial. - { - let checkout_url = stripe_billing - .checkout_with_zed_pro_trial( - &customer_id, - github_login, - vec![AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string()], - success_url, - ) - .await - .unwrap(); - - assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay")); - - let create_checkout_session_calls = stripe_client - .create_checkout_session_calls - .lock() - .drain(..) - .collect::>(); - assert_eq!(create_checkout_session_calls.len(), 1); - let call = create_checkout_session_calls.into_iter().next().unwrap(); - assert_eq!(call.customer, Some(customer_id)); - assert_eq!(call.client_reference_id.as_deref(), Some(github_login)); - assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription)); - assert_eq!( - call.line_items, - Some(vec![StripeCreateCheckoutSessionLineItems { - price: Some(price.id.to_string()), - quantity: Some(1) - }]) - ); - assert_eq!( - call.payment_method_collection, - Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired) - ); - assert_eq!( - call.subscription_data, - Some(StripeCreateCheckoutSessionSubscriptionData { - trial_period_days: Some(60), - trial_settings: Some(StripeSubscriptionTrialSettings { - end_behavior: StripeSubscriptionTrialSettingsEndBehavior { - missing_payment_method: - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel, - }, - }), - metadata: Some(std::collections::HashMap::from_iter([( - "promo_feature_flag".into(), - AGENT_EXTENDED_TRIAL_FEATURE_FLAG.into() - )])), - }) - ); - assert_eq!(call.success_url.as_deref(), Some(success_url)); - assert_eq!( - call.billing_address_collection, - Some(StripeBillingAddressCollection::Required) - ); - assert_eq!( - call.customer_update, - Some(StripeCustomerUpdate { - address: Some(StripeCustomerUpdateAddress::Auto), - name: Some(StripeCustomerUpdateName::Auto), - shipping: None, - }) - ); - } -} diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index ab84e02b190443787aa0165ada558382a5d08da9..f5a0e8ea81f0befbb3bae44ab516a7b8f4b04b52 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -8,6 +8,7 @@ use crate::{ use anyhow::anyhow; use call::ActiveCall; use channel::{ChannelBuffer, ChannelStore}; +use client::test::{make_get_authenticated_user_response, parse_authorization_header}; use client::{ self, ChannelId, Client, Connection, Credentials, EstablishConnectionError, UserStore, proto::PeerId, @@ -20,7 +21,7 @@ use fs::FakeFs; use futures::{StreamExt as _, channel::oneshot}; use git::GitHostingProviderRegistry; use gpui::{AppContext as _, BackgroundExecutor, Entity, Task, TestAppContext, VisualTestContext}; -use http_client::FakeHttpClient; +use http_client::{FakeHttpClient, Method}; use language::LanguageRegistry; use node_runtime::NodeRuntime; use notifications::NotificationStore; @@ -161,6 +162,8 @@ impl TestServer { } pub async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient { + const ACCESS_TOKEN: &str = "the-token"; + let fs = FakeFs::new(cx.executor()); cx.update(|cx| { @@ -175,7 +178,7 @@ impl TestServer { }); let clock = Arc::new(FakeSystemClock::new()); - let http = FakeHttpClient::with_404_response(); + let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await { user.id @@ -197,6 +200,47 @@ impl TestServer { .expect("creating user failed") .user_id }; + + let http = FakeHttpClient::create({ + let name = name.to_string(); + move |req| { + let name = name.clone(); + async move { + match (req.method(), req.uri().path()) { + (&Method::GET, "/client/users/me") => { + let credentials = parse_authorization_header(&req); + if credentials + != Some(Credentials { + user_id: user_id.to_proto(), + access_token: ACCESS_TOKEN.into(), + }) + { + return Ok(http_client::Response::builder() + .status(401) + .body("Unauthorized".into()) + .unwrap()); + } + + Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&make_get_authenticated_user_response( + user_id.0, name, + )) + .unwrap() + .into(), + ) + .unwrap()) + } + _ => Ok(http_client::Response::builder() + .status(404) + .body("Not Found".into()) + .unwrap()), + } + } + } + }); + let client_name = name.to_string(); let mut client = cx.update(|cx| Client::new(clock, http.clone(), cx)); let server = self.server.clone(); @@ -208,11 +252,10 @@ impl TestServer { .unwrap() .set_id(user_id.to_proto()) .override_authenticate(move |cx| { - let access_token = "the-token".to_string(); cx.spawn(async move |_| { Ok(Credentials { user_id: user_id.to_proto(), - access_token, + access_token: ACCESS_TOKEN.into(), }) }) }) @@ -221,7 +264,7 @@ impl TestServer { credentials, &Credentials { user_id: user_id.0 as u64, - access_token: "the-token".into() + access_token: ACCESS_TOKEN.into(), } ); @@ -254,6 +297,8 @@ impl TestServer { client_name, Principal::User(user), ZedVersion(SemanticVersion::new(1, 0, 0)), + Some("test".to_string()), + None, None, None, Some(connection_id_tx), @@ -318,7 +363,7 @@ impl TestServer { }); client - .authenticate_and_connect(false, &cx.to_async()) + .connect(false, &cx.to_async()) .await .into_response() .unwrap(); @@ -691,17 +736,17 @@ impl TestClient { current: store .contacts() .iter() - .map(|contact| contact.user.github_login.clone()) + .map(|contact| contact.user.github_login.clone().to_string()) .collect(), outgoing_requests: store .outgoing_contact_requests() .iter() - .map(|user| user.github_login.clone()) + .map(|user| user.github_login.clone().to_string()) .collect(), incoming_requests: store .incoming_contact_requests() .iter() - .map(|user| user.github_login.clone()) + .map(|user| user.github_login.clone().to_string()) .collect(), }) } diff --git a/crates/collab_ui/src/chat_panel.rs b/crates/collab_ui/src/chat_panel.rs index 3e2d813f1ba6474dc9e089d1fde2d48b26c7a31a..51d9f003f813212d40ff8e0716c86b1439fd4de6 100644 --- a/crates/collab_ui/src/chat_panel.rs +++ b/crates/collab_ui/src/chat_panel.rs @@ -103,28 +103,16 @@ impl ChatPanel { }); cx.new(|cx| { - let entity = cx.entity().downgrade(); - let message_list = ListState::new( - 0, - gpui::ListAlignment::Bottom, - px(1000.), - move |ix, window, cx| { - if let Some(entity) = entity.upgrade() { - entity.update(cx, |this: &mut Self, cx| { - this.render_message(ix, window, cx).into_any_element() - }) - } else { - div().into_any() + let message_list = ListState::new(0, gpui::ListAlignment::Bottom, px(1000.)); + + message_list.set_scroll_handler(cx.listener( + |this: &mut Self, event: &ListScrollEvent, _, cx| { + if event.visible_range.start < MESSAGE_LOADING_THRESHOLD { + this.load_more_messages(cx); } + this.is_scrolled_to_bottom = !event.is_scrolled; }, - ); - - message_list.set_scroll_handler(cx.listener(|this, event: &ListScrollEvent, _, cx| { - if event.visible_range.start < MESSAGE_LOADING_THRESHOLD { - this.load_more_messages(cx); - } - this.is_scrolled_to_bottom = !event.is_scrolled; - })); + )); let local_offset = chrono::Local::now().offset().local_minus_utc(); let mut this = Self { @@ -399,7 +387,7 @@ impl ChatPanel { ix: usize, window: &mut Window, cx: &mut Context, - ) -> impl IntoElement { + ) -> AnyElement { let active_chat = &self.active_chat.as_ref().unwrap().0; let (message, is_continuation_from_previous, is_admin) = active_chat.update(cx, |active_chat, cx| { @@ -582,6 +570,7 @@ impl ChatPanel { self.render_popover_buttons(message_id, can_delete_message, can_edit_message, cx) .mt_neg_2p5(), ) + .into_any_element() } fn has_open_menu(&self, message_id: Option) -> bool { @@ -979,7 +968,13 @@ impl Render for ChatPanel { ) .child(div().flex_grow().px_2().map(|this| { if self.active_chat.is_some() { - this.child(list(self.message_list.clone()).size_full()) + this.child( + list( + self.message_list.clone(), + cx.processor(Self::render_message), + ) + .size_full(), + ) } else { this.child( div() @@ -1162,7 +1157,7 @@ impl Panel for ChatPanel { } fn icon(&self, _window: &Window, cx: &App) -> Option { - self.enabled(cx).then(|| ui::IconName::MessageBubbles) + self.enabled(cx).then(|| ui::IconName::Chat) } fn icon_tooltip(&self, _: &Window, _: &App) -> Option<&'static str> { diff --git a/crates/collab_ui/src/collab_panel.rs b/crates/collab_ui/src/collab_panel.rs index ec23e2c3f536dc38db05f448f0d239d243a15756..51e4ff8965ee253f2d6bcb858ec6881b8a1cadc4 100644 --- a/crates/collab_ui/src/collab_panel.rs +++ b/crates/collab_ui/src/collab_panel.rs @@ -144,10 +144,22 @@ pub fn init(cx: &mut App) { if let Some(room) = room { window.defer(cx, move |_window, cx| { room.update(cx, |room, cx| { - if room.is_screen_sharing() { - room.unshare_screen(cx).ok(); + if room.is_sharing_screen() { + room.unshare_screen(true, cx).ok(); } else { - room.share_screen(cx).detach_and_log_err(cx); + let sources = cx.screen_capture_sources(); + + cx.spawn(async move |room, cx| { + let sources = sources.await??; + let first = sources.into_iter().next(); + if let Some(first) = first { + room.update(cx, |room, cx| room.share_screen(first, cx))? + .await + } else { + Ok(()) + } + }) + .detach_and_log_err(cx); }; }); }); @@ -312,20 +324,6 @@ impl CollabPanel { ) .detach(); - let entity = cx.entity().downgrade(); - let list_state = ListState::new( - 0, - gpui::ListAlignment::Top, - px(1000.), - move |ix, window, cx| { - if let Some(entity) = entity.upgrade() { - entity.update(cx, |this, cx| this.render_list_entry(ix, window, cx)) - } else { - div().into_any() - } - }, - ); - let mut this = Self { width: None, focus_handle: cx.focus_handle(), @@ -333,7 +331,7 @@ impl CollabPanel { fs: workspace.app_state().fs.clone(), pending_serialization: Task::ready(None), context_menu: None, - list_state, + list_state: ListState::new(0, gpui::ListAlignment::Top, px(1000.)), channel_name_editor, filter_editor, entries: Vec::default(), @@ -528,10 +526,10 @@ impl CollabPanel { project_id: project.id, worktree_root_names: project.worktree_root_names.clone(), host_user_id: user_id, - is_last: projects.peek().is_none() && !room.is_screen_sharing(), + is_last: projects.peek().is_none() && !room.is_sharing_screen(), }); } - if room.is_screen_sharing() { + if room.is_sharing_screen() { self.entries.push(ListEntry::ParticipantScreen { peer_id: None, is_last: true, @@ -928,7 +926,7 @@ impl CollabPanel { room.read(cx).local_participant().role == proto::ChannelRole::Admin }); - ListItem::new(SharedString::from(user.github_login.clone())) + ListItem::new(user.github_login.clone()) .start_slot(Avatar::new(user.avatar_uri.clone())) .child(Label::new(user.github_login.clone())) .toggle_state(is_selected) @@ -1112,7 +1110,7 @@ impl CollabPanel { .relative() .gap_1() .child(render_tree_branch(false, false, window, cx)) - .child(IconButton::new(0, IconName::MessageBubbles)) + .child(IconButton::new(0, IconName::Chat)) .children(has_messages_notification.then(|| { div() .w_1p5() @@ -2319,7 +2317,7 @@ impl CollabPanel { let client = this.client.clone(); cx.spawn_in(window, async move |_, cx| { client - .authenticate_and_connect(true, &cx) + .connect(true, &cx) .await .into_response() .notify_async_err(cx); @@ -2419,7 +2417,13 @@ impl CollabPanel { }); v_flex() .size_full() - .child(list(self.list_state.clone()).size_full()) + .child( + list( + self.list_state.clone(), + cx.processor(Self::render_list_entry), + ) + .size_full(), + ) .child( v_flex() .child(div().mx_2().border_primary(cx).border_t_1()) @@ -2571,7 +2575,7 @@ impl CollabPanel { ) -> impl IntoElement { let online = contact.online; let busy = contact.busy || calling; - let github_login = SharedString::from(contact.user.github_login.clone()); + let github_login = contact.user.github_login.clone(); let item = ListItem::new(github_login.clone()) .indent_level(1) .indent_step_size(px(20.)) @@ -2593,7 +2597,7 @@ impl CollabPanel { let contact = contact.clone(); move |this, event: &ClickEvent, window, cx| { this.deploy_contact_context_menu( - event.down.position, + event.position(), contact.clone(), window, cx, @@ -2650,7 +2654,7 @@ impl CollabPanel { is_selected: bool, cx: &mut Context, ) -> impl IntoElement { - let github_login = SharedString::from(user.github_login.clone()); + let github_login = user.github_login.clone(); let user_id = user.id; let is_response_pending = self.user_store.read(cx).is_contact_request_pending(user); let color = if is_response_pending { @@ -2911,7 +2915,7 @@ impl CollabPanel { .gap_1() .px_1() .child( - IconButton::new("channel_chat", IconName::MessageBubbles) + IconButton::new("channel_chat", IconName::Chat) .style(ButtonStyle::Filled) .shape(ui::IconButtonShape::Square) .icon_size(IconSize::Small) @@ -2927,7 +2931,7 @@ impl CollabPanel { .visible_on_hover(""), ) .child( - IconButton::new("channel_notes", IconName::File) + IconButton::new("channel_notes", IconName::FileText) .style(ButtonStyle::Filled) .shape(ui::IconButtonShape::Square) .icon_size(IconSize::Small) @@ -3049,7 +3053,7 @@ impl Render for CollabPanel { .on_action(cx.listener(CollabPanel::move_channel_down)) .track_focus(&self.focus_handle) .size_full() - .child(if self.user_store.read(cx).current_user().is_none() { + .child(if !self.client.status().borrow().is_connected() { self.render_signed_out(cx) } else { self.render_signed_in(window, cx) diff --git a/crates/collab_ui/src/notification_panel.rs b/crates/collab_ui/src/notification_panel.rs index fba8f66c2d19153a0288148b02e593ee37078fb0..3a280ff6677c9a5f9598d5ecaf473af232a8fed1 100644 --- a/crates/collab_ui/src/notification_panel.rs +++ b/crates/collab_ui/src/notification_panel.rs @@ -118,16 +118,7 @@ impl NotificationPanel { }) .detach(); - let entity = cx.entity().downgrade(); - let notification_list = - ListState::new(0, ListAlignment::Top, px(1000.), move |ix, window, cx| { - entity - .upgrade() - .and_then(|entity| { - entity.update(cx, |this, cx| this.render_notification(ix, window, cx)) - }) - .unwrap_or_else(|| div().into_any()) - }); + let notification_list = ListState::new(0, ListAlignment::Top, px(1000.)); notification_list.set_scroll_handler(cx.listener( |this, event: &ListScrollEvent, _, cx| { if event.count.saturating_sub(event.visible_range.end) < LOADING_THRESHOLD { @@ -634,13 +625,13 @@ impl Render for NotificationPanel { .child(Icon::new(IconName::Envelope)), ) .map(|this| { - if self.client.user_id().is_none() { + if !self.client.status().borrow().is_connected() { this.child( v_flex() .gap_2() .p_4() .child( - Button::new("sign_in_prompt_button", "Sign in") + Button::new("connect_prompt_button", "Connect") .icon_color(Color::Muted) .icon(IconName::Github) .icon_position(IconPosition::Start) @@ -652,10 +643,7 @@ impl Render for NotificationPanel { let client = client.clone(); window .spawn(cx, async move |cx| { - match client - .authenticate_and_connect(true, &cx) - .await - { + match client.connect(true, &cx).await { util::ConnectionResult::Timeout => { log::error!("Connection timeout"); } @@ -673,7 +661,7 @@ impl Render for NotificationPanel { ) .child( div().flex().w_full().items_center().child( - Label::new("Sign in to view notifications.") + Label::new("Connect to view notifications.") .color(Color::Muted) .size(LabelSize::Small), ), @@ -690,7 +678,16 @@ impl Render for NotificationPanel { ), ) } else { - this.child(list(self.notification_list.clone()).size_full()) + this.child( + list( + self.notification_list.clone(), + cx.processor(|this, ix, window, cx| { + this.render_notification(ix, window, cx) + .unwrap_or_else(|| div().into_any()) + }), + ) + .size_full(), + ) } }) } diff --git a/crates/command_palette/src/command_palette.rs b/crates/command_palette/src/command_palette.rs index abb8978d5a103fb66f862af6c5ee69beee0f6251..b8800ff91284e6f105c029f7fffe9b4b83b6bcd1 100644 --- a/crates/command_palette/src/command_palette.rs +++ b/crates/command_palette/src/command_palette.rs @@ -136,7 +136,10 @@ impl Focusable for CommandPalette { impl Render for CommandPalette { fn render(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { - v_flex().w(rems(34.)).child(self.picker.clone()) + v_flex() + .key_context("CommandPalette") + .w(rems(34.)) + .child(self.picker.clone()) } } @@ -242,7 +245,7 @@ impl CommandPaletteDelegate { self.selected_ix = cmp::min(self.selected_ix, self.matches.len() - 1); } } - /// + /// Hit count for each command in the palette. /// We only account for commands triggered directly via command palette and not by e.g. keystrokes because /// if a user already knows a keystroke for a command, they are unlikely to use a command palette to look for it. diff --git a/crates/component/src/component.rs b/crates/component/src/component.rs index 02840cc3cb922f2e8a37c5985db529f66d7791b0..0c05ba4a97f4598e9f7982cbc294831a955f1fc6 100644 --- a/crates/component/src/component.rs +++ b/crates/component/src/component.rs @@ -318,8 +318,10 @@ pub enum ComponentScope { Notification, #[strum(serialize = "Overlays & Layering")] Overlays, + Onboarding, Status, Typography, + Utilities, #[strum(serialize = "Version Control")] VersionControl, } diff --git a/crates/component/src/component_layout.rs b/crates/component/src/component_layout.rs index b749ea20eab8b347b83bf34e35c33ec4ef5c614f..58bf1d8f0c85533a4a06bd38c07f840c08cc6de3 100644 --- a/crates/component/src/component_layout.rs +++ b/crates/component/src/component_layout.rs @@ -48,20 +48,20 @@ impl RenderOnce for ComponentExample { ) .child( div() - .flex() - .w_full() - .rounded_xl() .min_h(px(100.)) - .justify_center() + .w_full() .p_8() + .flex() + .items_center() + .justify_center() + .rounded_xl() .border_1() .border_color(cx.theme().colors().border.opacity(0.5)) .bg(pattern_slash( - cx.theme().colors().surface_background.opacity(0.5), + cx.theme().colors().surface_background.opacity(0.25), 12.0, 12.0, )) - .shadow_xs() .child(self.element), ) .into_any_element() @@ -118,8 +118,8 @@ impl RenderOnce for ComponentExampleGroup { .flex() .items_center() .gap_3() - .pb_1() - .child(div().h_px().w_4().bg(cx.theme().colors().border)) + .mt_4() + .mb_1() .child( div() .flex_none() diff --git a/crates/context_server/Cargo.toml b/crates/context_server/Cargo.toml index 96bb9e071f42dd1f6f7fa0782ed8ca425e1cd379..5e4f8369c45f0edb58efda1618bf8fe0aad55749 100644 --- a/crates/context_server/Cargo.toml +++ b/crates/context_server/Cargo.toml @@ -21,12 +21,14 @@ collections.workspace = true futures.workspace = true gpui.workspace = true log.workspace = true +net.workspace = true parking_lot.workspace = true postage.workspace = true schemars.workspace = true serde.workspace = true serde_json.workspace = true smol.workspace = true +tempfile.workspace = true url = { workspace = true, features = ["serde"] } util.workspace = true workspace-hack.workspace = true diff --git a/crates/context_server/src/client.rs b/crates/context_server/src/client.rs index 83d815432da6bba4ac6e077e0378a31739655548..65283afa87d94fae3ec51f8a89574713080bded2 100644 --- a/crates/context_server/src/client.rs +++ b/crates/context_server/src/client.rs @@ -1,6 +1,6 @@ use anyhow::{Context as _, Result, anyhow}; use collections::HashMap; -use futures::{FutureExt, StreamExt, channel::oneshot, select}; +use futures::{FutureExt, StreamExt, channel::oneshot, future, select}; use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Task}; use parking_lot::Mutex; use postage::barrier; @@ -10,15 +10,19 @@ use smol::channel; use std::{ fmt, path::PathBuf, + pin::pin, sync::{ Arc, atomic::{AtomicI32, Ordering::SeqCst}, }, time::{Duration, Instant}, }; -use util::TryFutureExt; +use util::{ResultExt, TryFutureExt}; -use crate::transport::{StdioTransport, Transport}; +use crate::{ + transport::{StdioTransport, Transport}, + types::{CancelledParams, ClientNotification, Notification as _, notifications::Cancelled}, +}; const JSON_RPC_VERSION: &str = "2.0"; const REQUEST_TIMEOUT: Duration = Duration::from_secs(60); @@ -32,6 +36,7 @@ pub const INTERNAL_ERROR: i32 = -32603; type ResponseHandler = Box)>; type NotificationHandler = Box; +type RequestHandler = Box; #[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] #[serde(untagged)] @@ -70,12 +75,21 @@ fn is_null_value(value: &T) -> bool { } #[derive(Serialize, Deserialize)] -struct Request<'a, T> { - jsonrpc: &'static str, - id: RequestId, - method: &'a str, +pub struct Request<'a, T> { + pub jsonrpc: &'static str, + pub id: RequestId, + pub method: &'a str, #[serde(skip_serializing_if = "is_null_value")] - params: T, + pub params: T, +} + +#[derive(Serialize, Deserialize)] +pub struct AnyRequest<'a> { + pub jsonrpc: &'a str, + pub id: RequestId, + pub method: &'a str, + #[serde(skip_serializing_if = "is_null_value")] + pub params: Option<&'a RawValue>, } #[derive(Serialize, Deserialize)] @@ -88,18 +102,18 @@ struct AnyResponse<'a> { result: Option<&'a RawValue>, } -#[derive(Deserialize)] +#[derive(Serialize, Deserialize)] #[allow(dead_code)] -struct Response { - jsonrpc: &'static str, - id: RequestId, +pub(crate) struct Response { + pub jsonrpc: &'static str, + pub id: RequestId, #[serde(flatten)] - value: CspResult, + pub value: CspResult, } -#[derive(Deserialize)] +#[derive(Serialize, Deserialize)] #[serde(rename_all = "snake_case")] -enum CspResult { +pub(crate) enum CspResult { #[serde(rename = "result")] Ok(Option), #[allow(dead_code)] @@ -123,8 +137,9 @@ struct AnyNotification<'a> { } #[derive(Debug, Serialize, Deserialize)] -struct Error { - message: String, +pub(crate) struct Error { + pub message: String, + pub code: i32, } #[derive(Debug, Clone, Deserialize)] @@ -143,6 +158,7 @@ impl Client { pub fn stdio( server_id: ContextServerId, binary: ModelContextServerBinary, + working_directory: &Option, cx: AsyncApp, ) -> Result { log::info!( @@ -157,7 +173,7 @@ impl Client { .map(|name| name.to_string_lossy().to_string()) .unwrap_or_else(String::new); - let transport = Arc::new(StdioTransport::new(binary, &cx)?); + let transport = Arc::new(StdioTransport::new(binary, working_directory, &cx)?); Self::new(server_id, server_name.into(), transport, cx) } @@ -175,15 +191,23 @@ impl Client { Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default())); let response_handlers = Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default()))); + let request_handlers = Arc::new(Mutex::new(HashMap::<_, RequestHandler>::default())); let receive_input_task = cx.spawn({ let notification_handlers = notification_handlers.clone(); let response_handlers = response_handlers.clone(); + let request_handlers = request_handlers.clone(); let transport = transport.clone(); async move |cx| { - Self::handle_input(transport, notification_handlers, response_handlers, cx) - .log_err() - .await + Self::handle_input( + transport, + notification_handlers, + request_handlers, + response_handlers, + cx, + ) + .log_err() + .await } }); let receive_err_task = cx.spawn({ @@ -229,13 +253,24 @@ impl Client { async fn handle_input( transport: Arc, notification_handlers: Arc>>, + request_handlers: Arc>>, response_handlers: Arc>>>, cx: &mut AsyncApp, ) -> anyhow::Result<()> { let mut receiver = transport.receive(); while let Some(message) = receiver.next().await { - if let Ok(response) = serde_json::from_str::(&message) { + log::trace!("recv: {}", &message); + if let Ok(request) = serde_json::from_str::(&message) { + let mut request_handlers = request_handlers.lock(); + if let Some(handler) = request_handlers.get_mut(request.method) { + handler( + request.id, + request.params.unwrap_or(RawValue::NULL), + cx.clone(), + ); + } + } else if let Ok(response) = serde_json::from_str::(&message) { if let Some(handlers) = response_handlers.lock().as_mut() { if let Some(handler) = handlers.remove(&response.id) { handler(Ok(message.to_string())); @@ -246,6 +281,8 @@ impl Client { if let Some(handler) = notification_handlers.get_mut(notification.method.as_str()) { handler(notification.params.unwrap_or(Value::Null), cx.clone()); } + } else { + log::error!("Unhandled JSON from context_server: {}", message); } } @@ -293,6 +330,17 @@ impl Client { &self, method: &str, params: impl Serialize, + ) -> Result { + self.request_with(method, params, None, Some(REQUEST_TIMEOUT)) + .await + } + + pub async fn request_with( + &self, + method: &str, + params: impl Serialize, + cancel_rx: Option>, + timeout: Option, ) -> Result { let id = self.next_id.fetch_add(1, SeqCst); let request = serde_json::to_string(&Request { @@ -328,7 +376,23 @@ impl Client { handle_response?; send?; - let mut timeout = executor.timer(REQUEST_TIMEOUT).fuse(); + let mut timeout_fut = pin!( + match timeout { + Some(timeout) => future::Either::Left(executor.timer(timeout)), + None => future::Either::Right(future::pending()), + } + .fuse() + ); + let mut cancel_fut = pin!( + match cancel_rx { + Some(rx) => future::Either::Left(async { + rx.await.log_err(); + }), + None => future::Either::Right(future::pending()), + } + .fuse() + ); + select! { response = rx.fuse() => { let elapsed = started.elapsed(); @@ -347,8 +411,18 @@ impl Client { Err(_) => anyhow::bail!("cancelled") } } - _ = timeout => { - log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", REQUEST_TIMEOUT); + _ = cancel_fut => { + self.notify( + Cancelled::METHOD, + ClientNotification::Cancelled(CancelledParams { + request_id: RequestId::Int(id), + reason: None + }) + ).log_err(); + anyhow::bail!(RequestCanceled) + } + _ = timeout_fut => { + log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", timeout.unwrap()); anyhow::bail!("Context server request timeout"); } } @@ -367,14 +441,23 @@ impl Client { Ok(()) } - #[allow(unused)] - pub fn on_notification(&self, method: &'static str, f: F) - where - F: 'static + Send + FnMut(Value, AsyncApp), - { - self.notification_handlers - .lock() - .insert(method, Box::new(f)); + pub fn on_notification( + &self, + method: &'static str, + f: Box, + ) { + self.notification_handlers.lock().insert(method, f); + } +} + +#[derive(Debug)] +pub struct RequestCanceled; + +impl std::error::Error for RequestCanceled {} + +impl std::fmt::Display for RequestCanceled { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("Context server request was canceled") } } diff --git a/crates/context_server/src/context_server.rs b/crates/context_server/src/context_server.rs index 905435fcce57dc8ce8719e5056b28118168e9a04..34fa29678d5d68f864de7d9df3bef82d4c667f05 100644 --- a/crates/context_server/src/context_server.rs +++ b/crates/context_server/src/context_server.rs @@ -1,13 +1,14 @@ pub mod client; +pub mod listener; pub mod protocol; #[cfg(any(test, feature = "test-support"))] pub mod test; pub mod transport; pub mod types; -use std::fmt::Display; use std::path::Path; use std::sync::Arc; +use std::{fmt::Display, path::PathBuf}; use anyhow::Result; use client::Client; @@ -30,7 +31,7 @@ impl Display for ContextServerId { #[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema)] pub struct ContextServerCommand { #[serde(rename = "command")] - pub path: String, + pub path: PathBuf, pub args: Vec, pub env: Option>, } @@ -52,7 +53,7 @@ impl std::fmt::Debug for ContextServerCommand { } enum ContextServerTransport { - Stdio(ContextServerCommand), + Stdio(ContextServerCommand, Option), Custom(Arc), } @@ -63,11 +64,18 @@ pub struct ContextServer { } impl ContextServer { - pub fn stdio(id: ContextServerId, command: ContextServerCommand) -> Self { + pub fn stdio( + id: ContextServerId, + command: ContextServerCommand, + working_directory: Option>, + ) -> Self { Self { id, client: RwLock::new(None), - configuration: ContextServerTransport::Stdio(command), + configuration: ContextServerTransport::Stdio( + command, + working_directory.map(|directory| directory.to_path_buf()), + ), } } @@ -87,15 +95,36 @@ impl ContextServer { self.client.read().clone() } - pub async fn start(self: Arc, cx: &AsyncApp) -> Result<()> { - let client = match &self.configuration { - ContextServerTransport::Stdio(command) => Client::stdio( + pub async fn start(&self, cx: &AsyncApp) -> Result<()> { + self.initialize(self.new_client(cx)?).await + } + + /// Starts the context server, making sure handlers are registered before initialization happens + pub async fn start_with_handlers( + &self, + notification_handlers: Vec<( + &'static str, + Box, + )>, + cx: &AsyncApp, + ) -> Result<()> { + let client = self.new_client(cx)?; + for (method, handler) in notification_handlers { + client.on_notification(method, handler); + } + self.initialize(client).await + } + + fn new_client(&self, cx: &AsyncApp) -> Result { + Ok(match &self.configuration { + ContextServerTransport::Stdio(command, working_directory) => Client::stdio( client::ContextServerId(self.id.0.clone()), client::ModelContextServerBinary { executable: Path::new(&command.path).to_path_buf(), args: command.args.clone(), env: command.env.clone(), }, + working_directory, cx.clone(), )?, ContextServerTransport::Custom(transport) => Client::new( @@ -104,8 +133,7 @@ impl ContextServer { transport.clone(), cx.clone(), )?, - }; - self.initialize(client).await + }) } async fn initialize(&self, client: Client) -> Result<()> { diff --git a/crates/context_server/src/listener.rs b/crates/context_server/src/listener.rs new file mode 100644 index 0000000000000000000000000000000000000000..0e85fb21292739ab0a92d0898fc449a31efe6f29 --- /dev/null +++ b/crates/context_server/src/listener.rs @@ -0,0 +1,443 @@ +use ::serde::{Deserialize, Serialize}; +use anyhow::{Context as _, Result}; +use collections::HashMap; +use futures::{ + AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt, + channel::mpsc::{UnboundedReceiver, UnboundedSender, unbounded}, + io::BufReader, + select_biased, +}; +use gpui::{App, AppContext, AsyncApp, Task}; +use net::async_net::{UnixListener, UnixStream}; +use schemars::JsonSchema; +use serde::de::DeserializeOwned; +use serde_json::{json, value::RawValue}; +use smol::stream::StreamExt; +use std::{ + cell::RefCell, + path::{Path, PathBuf}, + rc::Rc, +}; +use util::ResultExt; + +use crate::{ + client::{CspResult, RequestId, Response}, + types::{ + CallToolParams, CallToolResponse, ListToolsResponse, Request, Tool, ToolAnnotations, + ToolResponseContent, + requests::{CallTool, ListTools}, + }, +}; + +pub struct McpServer { + socket_path: PathBuf, + tools: Rc>>, + handlers: Rc>>, + _server_task: Task<()>, +} + +struct RegisteredTool { + tool: Tool, + handler: ToolHandler, +} + +type ToolHandler = Box< + dyn Fn( + Option, + &mut AsyncApp, + ) -> Task>>, +>; +type RequestHandler = Box>, &App) -> Task>; + +impl McpServer { + pub fn new(cx: &AsyncApp) -> Task> { + let task = cx.background_spawn(async move { + let temp_dir = tempfile::Builder::new().prefix("zed-mcp").tempdir()?; + let socket_path = temp_dir.path().join("mcp.sock"); + let listener = UnixListener::bind(&socket_path).context("creating mcp socket")?; + + anyhow::Ok((temp_dir, socket_path, listener)) + }); + + cx.spawn(async move |cx| { + let (temp_dir, socket_path, listener) = task.await?; + let tools = Rc::new(RefCell::new(HashMap::default())); + let handlers = Rc::new(RefCell::new(HashMap::default())); + let server_task = cx.spawn({ + let tools = tools.clone(); + let handlers = handlers.clone(); + async move |cx| { + while let Ok((stream, _)) = listener.accept().await { + Self::serve_connection(stream, tools.clone(), handlers.clone(), cx); + } + drop(temp_dir) + } + }); + Ok(Self { + socket_path, + _server_task: server_task, + tools, + handlers: handlers, + }) + }) + } + + pub fn add_tool(&mut self, tool: T) { + let mut settings = schemars::generate::SchemaSettings::draft07(); + settings.inline_subschemas = true; + let mut generator = settings.into_generator(); + + let output_schema = generator.root_schema_for::(); + let unit_schema = generator.root_schema_for::(); + + let registered_tool = RegisteredTool { + tool: Tool { + name: T::NAME.into(), + description: Some(tool.description().into()), + input_schema: generator.root_schema_for::().into(), + output_schema: if output_schema == unit_schema { + None + } else { + Some(output_schema.into()) + }, + annotations: Some(tool.annotations()), + }, + handler: Box::new({ + let tool = tool.clone(); + move |input_value, cx| { + let input = match input_value { + Some(input) => serde_json::from_value(input), + None => serde_json::from_value(serde_json::Value::Null), + }; + + let tool = tool.clone(); + match input { + Ok(input) => cx.spawn(async move |cx| { + let output = tool.run(input, cx).await?; + + Ok(ToolResponse { + content: output.content, + structured_content: serde_json::to_value(output.structured_content) + .unwrap_or_default(), + }) + }), + Err(err) => Task::ready(Err(err.into())), + } + } + }), + }; + + self.tools.borrow_mut().insert(T::NAME, registered_tool); + } + + pub fn handle_request( + &mut self, + f: impl Fn(R::Params, &App) -> Task> + 'static, + ) { + let f = Box::new(f); + self.handlers.borrow_mut().insert( + R::METHOD, + Box::new(move |req_id, opt_params, cx| { + let result = match opt_params { + Some(params) => serde_json::from_str(params.get()), + None => serde_json::from_value(serde_json::Value::Null), + }; + + let params: R::Params = match result { + Ok(params) => params, + Err(e) => { + return Task::ready( + serde_json::to_string(&Response:: { + jsonrpc: "2.0", + id: req_id, + value: CspResult::Error(Some(crate::client::Error { + message: format!("{e}"), + code: -32700, + })), + }) + .unwrap(), + ); + } + }; + let task = f(params, cx); + cx.background_spawn(async move { + match task.await { + Ok(result) => serde_json::to_string(&Response { + jsonrpc: "2.0", + id: req_id, + value: CspResult::Ok(Some(result)), + }) + .unwrap(), + Err(e) => serde_json::to_string(&Response { + jsonrpc: "2.0", + id: req_id, + value: CspResult::Error::(Some(crate::client::Error { + message: format!("{e}"), + code: -32603, + })), + }) + .unwrap(), + } + }) + }), + ); + } + + pub fn socket_path(&self) -> &Path { + &self.socket_path + } + + fn serve_connection( + stream: UnixStream, + tools: Rc>>, + handlers: Rc>>, + cx: &mut AsyncApp, + ) { + let (read, write) = smol::io::split(stream); + let (incoming_tx, mut incoming_rx) = unbounded(); + let (outgoing_tx, outgoing_rx) = unbounded(); + + cx.background_spawn(Self::handle_io(outgoing_rx, incoming_tx, write, read)) + .detach(); + + cx.spawn(async move |cx| { + while let Some(request) = incoming_rx.next().await { + let Some(request_id) = request.id.clone() else { + continue; + }; + + if request.method == CallTool::METHOD { + Self::handle_call_tool(request_id, request.params, &tools, &outgoing_tx, cx) + .await; + } else if request.method == ListTools::METHOD { + Self::handle_list_tools(request.id.unwrap(), &tools, &outgoing_tx); + } else if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) { + let outgoing_tx = outgoing_tx.clone(); + + if let Some(task) = cx + .update(|cx| handler(request_id, request.params, cx)) + .log_err() + { + cx.spawn(async move |_| { + let response = task.await; + outgoing_tx.unbounded_send(response).ok(); + }) + .detach(); + } + } else { + Self::send_err( + request_id, + format!("unhandled method {}", request.method), + &outgoing_tx, + ); + } + } + }) + .detach(); + } + + fn handle_list_tools( + request_id: RequestId, + tools: &Rc>>, + outgoing_tx: &UnboundedSender, + ) { + let response = ListToolsResponse { + tools: tools.borrow().values().map(|t| t.tool.clone()).collect(), + next_cursor: None, + meta: None, + }; + + outgoing_tx + .unbounded_send( + serde_json::to_string(&Response { + jsonrpc: "2.0", + id: request_id, + value: CspResult::Ok(Some(response)), + }) + .unwrap_or_default(), + ) + .ok(); + } + + async fn handle_call_tool( + request_id: RequestId, + params: Option>, + tools: &Rc>>, + outgoing_tx: &UnboundedSender, + cx: &mut AsyncApp, + ) { + let result: Result = match params.as_ref() { + Some(params) => serde_json::from_str(params.get()), + None => serde_json::from_value(serde_json::Value::Null), + }; + + match result { + Ok(params) => { + if let Some(tool) = tools.borrow().get(¶ms.name.as_ref()) { + let outgoing_tx = outgoing_tx.clone(); + + let task = (tool.handler)(params.arguments, cx); + cx.spawn(async move |_| { + let response = match task.await { + Ok(result) => CallToolResponse { + content: result.content, + is_error: Some(false), + meta: None, + structured_content: if result.structured_content.is_null() { + None + } else { + Some(result.structured_content) + }, + }, + Err(err) => CallToolResponse { + content: vec![ToolResponseContent::Text { + text: err.to_string(), + }], + is_error: Some(true), + meta: None, + structured_content: None, + }, + }; + + outgoing_tx + .unbounded_send( + serde_json::to_string(&Response { + jsonrpc: "2.0", + id: request_id, + value: CspResult::Ok(Some(response)), + }) + .unwrap_or_default(), + ) + .ok(); + }) + .detach(); + } else { + Self::send_err( + request_id, + format!("Tool not found: {}", params.name), + &outgoing_tx, + ); + } + } + Err(err) => { + Self::send_err(request_id, err.to_string(), &outgoing_tx); + } + } + } + + fn send_err( + request_id: RequestId, + message: impl Into, + outgoing_tx: &UnboundedSender, + ) { + outgoing_tx + .unbounded_send( + serde_json::to_string(&Response::<()> { + jsonrpc: "2.0", + id: request_id, + value: CspResult::Error(Some(crate::client::Error { + message: message.into(), + code: -32601, + })), + }) + .unwrap(), + ) + .ok(); + } + + async fn handle_io( + mut outgoing_rx: UnboundedReceiver, + incoming_tx: UnboundedSender, + mut outgoing_bytes: impl Unpin + AsyncWrite, + incoming_bytes: impl Unpin + AsyncRead, + ) -> Result<()> { + let mut output_reader = BufReader::new(incoming_bytes); + let mut incoming_line = String::new(); + loop { + select_biased! { + message = outgoing_rx.next().fuse() => { + if let Some(message) = message { + log::trace!("send: {}", &message); + outgoing_bytes.write_all(message.as_bytes()).await?; + outgoing_bytes.write_all(&[b'\n']).await?; + } else { + break; + } + } + bytes_read = output_reader.read_line(&mut incoming_line).fuse() => { + if bytes_read? == 0 { + break + } + log::trace!("recv: {}", &incoming_line); + match serde_json::from_str(&incoming_line) { + Ok(message) => { + incoming_tx.unbounded_send(message).log_err(); + } + Err(error) => { + outgoing_bytes.write_all(serde_json::to_string(&json!({ + "jsonrpc": "2.0", + "error": json!({ + "code": -32603, + "message": format!("Failed to parse: {error}"), + }), + }))?.as_bytes()).await?; + outgoing_bytes.write_all(&[b'\n']).await?; + log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}"); + } + } + incoming_line.clear(); + } + } + } + Ok(()) + } +} + +pub trait McpServerTool { + type Input: DeserializeOwned + JsonSchema; + type Output: Serialize + JsonSchema; + + const NAME: &'static str; + + fn description(&self) -> &'static str; + + fn annotations(&self) -> ToolAnnotations { + ToolAnnotations { + title: None, + read_only_hint: None, + destructive_hint: None, + idempotent_hint: None, + open_world_hint: None, + } + } + + fn run( + &self, + input: Self::Input, + cx: &mut AsyncApp, + ) -> impl Future>>; +} + +pub struct ToolResponse { + pub content: Vec, + pub structured_content: T, +} + +#[derive(Debug, Serialize, Deserialize)] +struct RawRequest { + #[serde(skip_serializing_if = "Option::is_none")] + id: Option, + method: String, + #[serde(skip_serializing_if = "Option::is_none")] + params: Option>, +} + +#[derive(Serialize, Deserialize)] +struct RawResponse { + jsonrpc: &'static str, + id: RequestId, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, + #[serde(skip_serializing_if = "Option::is_none")] + result: Option>, +} diff --git a/crates/context_server/src/protocol.rs b/crates/context_server/src/protocol.rs index d8bbac60d616268dcb771d653cf02ee3adc59122..5355f20f620b5bed76bf945e863fdb5cbcc2ff43 100644 --- a/crates/context_server/src/protocol.rs +++ b/crates/context_server/src/protocol.rs @@ -5,7 +5,12 @@ //! read/write messages and the types from types.rs for serialization/deserialization //! of messages. +use std::time::Duration; + use anyhow::Result; +use futures::channel::oneshot; +use gpui::AsyncApp; +use serde_json::Value; use crate::client::Client; use crate::types::{self, Notification, Request}; @@ -95,7 +100,26 @@ impl InitializedContextServerProtocol { self.inner.request(T::METHOD, params).await } + pub async fn request_with( + &self, + params: T::Params, + cancel_rx: Option>, + timeout: Option, + ) -> Result { + self.inner + .request_with(T::METHOD, params, cancel_rx, timeout) + .await + } + pub fn notify(&self, params: T::Params) -> Result<()> { self.inner.notify(T::METHOD, params) } + + pub fn on_notification( + &self, + method: &'static str, + f: Box, + ) { + self.inner.on_notification(method, f); + } } diff --git a/crates/context_server/src/transport/stdio_transport.rs b/crates/context_server/src/transport/stdio_transport.rs index 56d0240fa5e86149091c59102d277fca3580a970..443b8c16f160394f4bede9a72315b4e80c652726 100644 --- a/crates/context_server/src/transport/stdio_transport.rs +++ b/crates/context_server/src/transport/stdio_transport.rs @@ -1,3 +1,4 @@ +use std::path::PathBuf; use std::pin::Pin; use anyhow::{Context as _, Result}; @@ -22,7 +23,11 @@ pub struct StdioTransport { } impl StdioTransport { - pub fn new(binary: ModelContextServerBinary, cx: &AsyncApp) -> Result { + pub fn new( + binary: ModelContextServerBinary, + working_directory: &Option, + cx: &AsyncApp, + ) -> Result { let mut command = util::command::new_smol_command(&binary.executable); command .args(&binary.args) @@ -32,6 +37,10 @@ impl StdioTransport { .stderr(std::process::Stdio::piped()) .kill_on_drop(true); + if let Some(working_directory) = working_directory { + command.current_dir(working_directory); + } + let mut server = command.spawn().with_context(|| { format!( "failed to spawn command. (path={:?}, args={:?})", diff --git a/crates/context_server/src/types.rs b/crates/context_server/src/types.rs index 8e3daf9e222a29cf373ba7a3bb37d83c2950acf7..5fa2420a3d40ce04ee97b4f88c1105711dea8793 100644 --- a/crates/context_server/src/types.rs +++ b/crates/context_server/src/types.rs @@ -3,6 +3,8 @@ use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use url::Url; +use crate::client::RequestId; + pub const LATEST_PROTOCOL_VERSION: &str = "2025-03-26"; pub const VERSION_2024_11_05: &str = "2024-11-05"; @@ -100,6 +102,7 @@ pub mod notifications { notification!("notifications/initialized", Initialized, ()); notification!("notifications/progress", Progress, ProgressParams); notification!("notifications/message", Message, MessageParams); + notification!("notifications/cancelled", Cancelled, CancelledParams); notification!( "notifications/resources/updated", ResourcesUpdated, @@ -153,7 +156,7 @@ pub struct InitializeParams { pub struct CallToolParams { pub name: String, #[serde(skip_serializing_if = "Option::is_none")] - pub arguments: Option>, + pub arguments: Option, #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] pub meta: Option>, } @@ -492,18 +495,20 @@ pub struct RootsCapabilities { pub list_changed: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct Tool { pub name: String, #[serde(skip_serializing_if = "Option::is_none")] pub description: Option, pub input_schema: serde_json::Value, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub output_schema: Option, #[serde(skip_serializing_if = "Option::is_none")] pub annotations: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ToolAnnotations { /// A human-readable title for the tool. @@ -617,11 +622,15 @@ pub enum ClientNotification { Initialized, Progress(ProgressParams), RootsListChanged, - Cancelled { - request_id: String, - #[serde(skip_serializing_if = "Option::is_none")] - reason: Option, - }, + Cancelled(CancelledParams), +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CancelledParams { + pub request_id: RequestId, + #[serde(skip_serializing_if = "Option::is_none")] + pub reason: Option, } #[derive(Debug, Serialize, Deserialize)] @@ -673,6 +682,20 @@ pub struct CallToolResponse { pub is_error: Option, #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] pub meta: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub structured_content: Option, +} + +impl CallToolResponse { + pub fn text_contents(&self) -> String { + let mut text = String::new(); + for chunk in &self.content { + if let ToolResponseContent::Text { text: chunk } = chunk { + text.push_str(&chunk) + }; + } + text + } } #[derive(Debug, Serialize, Deserialize)] diff --git a/crates/copilot/Cargo.toml b/crates/copilot/Cargo.toml index 234875d420bd035a1c2d0955a4a68d165ef702d1..0fc119f31125f4ef3925799fd98fd47cac7ca9da 100644 --- a/crates/copilot/Cargo.toml +++ b/crates/copilot/Cargo.toml @@ -34,7 +34,7 @@ fs.workspace = true futures.workspace = true gpui.workspace = true http_client.workspace = true -inline_completion.workspace = true +edit_prediction.workspace = true language.workspace = true log.workspace = true lsp.workspace = true @@ -46,6 +46,7 @@ project.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true +sum_tree.workspace = true task.workspace = true ui.workspace = true util.workspace = true diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index e4370d2e67cef9c5c4db68123edfb7dca5d7fa00..49ae2b9d9c92c5deba00c54c51d48deb82d03dcc 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -23,8 +23,10 @@ use language::{ use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId, LanguageServerName}; use node_runtime::NodeRuntime; use parking_lot::Mutex; +use project::DisableAiSettings; use request::StatusNotification; use serde_json::json; +use settings::Settings; use settings::SettingsStore; use sign_in::{reinstall_and_sign_in_within_workspace, sign_out_within_workspace}; use std::collections::hash_map::Entry; @@ -37,6 +39,7 @@ use std::{ path::{Path, PathBuf}, sync::Arc, }; +use sum_tree::Dimensions; use util::{ResultExt, fs::remove_matching}; use workspace::Workspace; @@ -83,37 +86,13 @@ pub fn init( move |cx| Copilot::start(new_server_id, fs, node_runtime, cx) }); Copilot::set_global(copilot.clone(), cx); - cx.observe(&copilot, |handle, cx| { - let copilot_action_types = [ - TypeId::of::(), - TypeId::of::(), - TypeId::of::(), - TypeId::of::(), - ]; - let copilot_auth_action_types = [TypeId::of::()]; - let copilot_no_auth_action_types = [TypeId::of::()]; - let status = handle.read(cx).status(); - let filter = CommandPaletteFilter::global_mut(cx); - - match status { - Status::Disabled => { - filter.hide_action_types(&copilot_action_types); - filter.hide_action_types(&copilot_auth_action_types); - filter.hide_action_types(&copilot_no_auth_action_types); - } - Status::Authorized => { - filter.hide_action_types(&copilot_no_auth_action_types); - filter.show_action_types( - copilot_action_types - .iter() - .chain(&copilot_auth_action_types), - ); - } - _ => { - filter.hide_action_types(&copilot_action_types); - filter.hide_action_types(&copilot_auth_action_types); - filter.show_action_types(copilot_no_auth_action_types.iter()); - } + cx.observe(&copilot, |copilot, cx| { + copilot.update(cx, |copilot, cx| copilot.update_action_visibilities(cx)); + }) + .detach(); + cx.observe_global::(|cx| { + if let Some(copilot) = Copilot::global(cx) { + copilot.update(cx, |copilot, cx| copilot.update_action_visibilities(cx)); } }) .detach(); @@ -209,8 +188,14 @@ impl Status { matches!(self, Status::Authorized) } - pub fn is_disabled(&self) -> bool { - matches!(self, Status::Disabled) + pub fn is_configured(&self) -> bool { + matches!( + self, + Status::Starting { .. } + | Status::Error(_) + | Status::SigningIn { .. } + | Status::Authorized + ) } } @@ -255,7 +240,7 @@ impl RegisteredBuffer { let new_snapshot = new_snapshot.clone(); async move { new_snapshot - .edits_since::<(PointUtf16, usize)>(&old_version) + .edits_since::>(&old_version) .map(|edit| { let edit_start = edit.new.start.0; let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0); @@ -1115,6 +1100,44 @@ impl Copilot { cx.notify(); } } + + fn update_action_visibilities(&self, cx: &mut App) { + let signed_in_actions = [ + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + ]; + let auth_actions = [TypeId::of::()]; + let no_auth_actions = [TypeId::of::()]; + let status = self.status(); + + let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai; + let filter = CommandPaletteFilter::global_mut(cx); + + if is_ai_disabled { + filter.hide_action_types(&signed_in_actions); + filter.hide_action_types(&auth_actions); + filter.hide_action_types(&no_auth_actions); + } else { + match status { + Status::Disabled => { + filter.hide_action_types(&signed_in_actions); + filter.hide_action_types(&auth_actions); + filter.hide_action_types(&no_auth_actions); + } + Status::Authorized => { + filter.hide_action_types(&no_auth_actions); + filter.show_action_types(signed_in_actions.iter().chain(&auth_actions)); + } + _ => { + filter.hide_action_types(&signed_in_actions); + filter.hide_action_types(&auth_actions); + filter.show_action_types(no_auth_actions.iter()); + } + } + } + } } fn id_for_language(language: Option<&Arc>) -> String { diff --git a/crates/copilot/src/copilot_completion_provider.rs b/crates/copilot/src/copilot_completion_provider.rs index 8dc04622f9020c2fe175304764157b409c7936c1..2fd6df27b9e15d4247d85edca4d8836c35b23df1 100644 --- a/crates/copilot/src/copilot_completion_provider.rs +++ b/crates/copilot/src/copilot_completion_provider.rs @@ -1,7 +1,7 @@ use crate::{Completion, Copilot}; use anyhow::Result; +use edit_prediction::{Direction, EditPrediction, EditPredictionProvider}; use gpui::{App, Context, Entity, EntityId, Task}; -use inline_completion::{Direction, EditPredictionProvider, InlineCompletion}; use language::{Buffer, OffsetRangeExt, ToOffset, language_settings::AllLanguageSettings}; use project::Project; use settings::Settings; @@ -58,11 +58,19 @@ impl EditPredictionProvider for CopilotCompletionProvider { } fn show_completions_in_menu() -> bool { + true + } + + fn show_tab_accept_marker() -> bool { + true + } + + fn supports_jump_to_edit() -> bool { false } fn is_refreshing(&self) -> bool { - self.pending_refresh.is_some() + self.pending_refresh.is_some() && self.completions.is_empty() } fn is_enabled( @@ -210,7 +218,7 @@ impl EditPredictionProvider for CopilotCompletionProvider { buffer: &Entity, cursor_position: language::Anchor, cx: &mut Context, - ) -> Option { + ) -> Option { let buffer_id = buffer.entity_id(); let buffer = buffer.read(cx); let completion = self.active_completion()?; @@ -241,7 +249,7 @@ impl EditPredictionProvider for CopilotCompletionProvider { None } else { let position = cursor_position.bias_right(buffer); - Some(InlineCompletion { + Some(EditPrediction { id: None, edits: vec![(position..position, completion_text.into())], edit_preview: None, @@ -343,8 +351,8 @@ mod tests { executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, window, cx| { assert!(editor.context_menu_visible()); - assert!(!editor.has_active_inline_completion()); - // Since we have both, the copilot suggestion is not shown inline + assert!(editor.has_active_edit_prediction()); + // Since we have both, the copilot suggestion is existing but does not show up as ghost text assert_eq!(editor.text(cx), "one.\ntwo\nthree\n"); assert_eq!(editor.display_text(cx), "one.\ntwo\nthree\n"); @@ -355,7 +363,7 @@ mod tests { .unwrap() .detach(); assert!(!editor.context_menu_visible()); - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); assert_eq!(editor.text(cx), "one.completion_a\ntwo\nthree\n"); assert_eq!(editor.display_text(cx), "one.completion_a\ntwo\nthree\n"); }); @@ -389,7 +397,7 @@ mod tests { executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, _, cx| { assert!(!editor.context_menu_visible()); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); // Since only the copilot is available, it's shown inline assert_eq!(editor.display_text(cx), "one.copilot1\ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.\ntwo\nthree\n"); @@ -400,7 +408,7 @@ mod tests { executor.run_until_parked(); cx.update_editor(|editor, _, cx| { assert!(!editor.context_menu_visible()); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one.copilot1\ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.c\ntwo\nthree\n"); }); @@ -418,25 +426,25 @@ mod tests { executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, window, cx| { assert!(!editor.context_menu_visible()); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one.copilot2\ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.c\ntwo\nthree\n"); // Canceling should remove the active Copilot suggestion. editor.cancel(&Default::default(), window, cx); - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one.c\ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.c\ntwo\nthree\n"); // After canceling, tabbing shouldn't insert the previously shown suggestion. editor.tab(&Default::default(), window, cx); - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one.c \ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.c \ntwo\nthree\n"); // When undoing the previously active suggestion is shown again. editor.undo(&Default::default(), window, cx); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one.copilot2\ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.c\ntwo\nthree\n"); }); @@ -444,25 +452,25 @@ mod tests { // If an edit occurs outside of this editor, the suggestion is still correctly interpolated. cx.update_buffer(|buffer, cx| buffer.edit([(5..5, "o")], None, cx)); cx.update_editor(|editor, window, cx| { - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one.copilot2\ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.co\ntwo\nthree\n"); // AcceptEditPrediction when there is an active suggestion inserts it. editor.accept_edit_prediction(&Default::default(), window, cx); - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one.copilot2\ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.copilot2\ntwo\nthree\n"); // When undoing the previously active suggestion is shown again. editor.undo(&Default::default(), window, cx); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one.copilot2\ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.co\ntwo\nthree\n"); // Hide suggestion. editor.cancel(&Default::default(), window, cx); - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one.co\ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.co\ntwo\nthree\n"); }); @@ -471,7 +479,7 @@ mod tests { // we won't make it visible. cx.update_buffer(|buffer, cx| buffer.edit([(6..6, "p")], None, cx)); cx.update_editor(|editor, _, cx| { - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one.cop\ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.cop\ntwo\nthree\n"); }); @@ -498,19 +506,19 @@ mod tests { }); executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, window, cx| { - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "fn foo() {\n let x = 4;\n}"); assert_eq!(editor.text(cx), "fn foo() {\n \n}"); // Tabbing inside of leading whitespace inserts indentation without accepting the suggestion. editor.tab(&Default::default(), window, cx); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.text(cx), "fn foo() {\n \n}"); assert_eq!(editor.display_text(cx), "fn foo() {\n let x = 4;\n}"); // Using AcceptEditPrediction again accepts the suggestion. editor.accept_edit_prediction(&Default::default(), window, cx); - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); assert_eq!(editor.text(cx), "fn foo() {\n let x = 4;\n}"); assert_eq!(editor.display_text(cx), "fn foo() {\n let x = 4;\n}"); }); @@ -575,17 +583,17 @@ mod tests { ); executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, window, cx| { - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); // Accepting the first word of the suggestion should only accept the first word and still show the rest. - editor.accept_partial_inline_completion(&Default::default(), window, cx); - assert!(editor.has_active_inline_completion()); + editor.accept_partial_edit_prediction(&Default::default(), window, cx); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.text(cx), "one.copilot\ntwo\nthree\n"); assert_eq!(editor.display_text(cx), "one.copilot1\ntwo\nthree\n"); // Accepting next word should accept the non-word and copilot suggestion should be gone - editor.accept_partial_inline_completion(&Default::default(), window, cx); - assert!(!editor.has_active_inline_completion()); + editor.accept_partial_edit_prediction(&Default::default(), window, cx); + assert!(!editor.has_active_edit_prediction()); assert_eq!(editor.text(cx), "one.copilot1\ntwo\nthree\n"); assert_eq!(editor.display_text(cx), "one.copilot1\ntwo\nthree\n"); }); @@ -617,11 +625,11 @@ mod tests { ); executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, window, cx| { - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); // Accepting the first word (non-word) of the suggestion should only accept the first word and still show the rest. - editor.accept_partial_inline_completion(&Default::default(), window, cx); - assert!(editor.has_active_inline_completion()); + editor.accept_partial_edit_prediction(&Default::default(), window, cx); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.text(cx), "one.123. \ntwo\nthree\n"); assert_eq!( editor.display_text(cx), @@ -629,8 +637,8 @@ mod tests { ); // Accepting next word should accept the next word and copilot suggestion should still exist - editor.accept_partial_inline_completion(&Default::default(), window, cx); - assert!(editor.has_active_inline_completion()); + editor.accept_partial_edit_prediction(&Default::default(), window, cx); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.text(cx), "one.123. copilot\ntwo\nthree\n"); assert_eq!( editor.display_text(cx), @@ -638,8 +646,8 @@ mod tests { ); // Accepting the whitespace should accept the non-word/whitespaces with newline and copilot suggestion should be gone - editor.accept_partial_inline_completion(&Default::default(), window, cx); - assert!(!editor.has_active_inline_completion()); + editor.accept_partial_edit_prediction(&Default::default(), window, cx); + assert!(!editor.has_active_edit_prediction()); assert_eq!(editor.text(cx), "one.123. copilot\n 456\ntwo\nthree\n"); assert_eq!( editor.display_text(cx), @@ -692,29 +700,29 @@ mod tests { }); executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, window, cx| { - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one\ntwo.foo()\nthree\n"); assert_eq!(editor.text(cx), "one\ntw\nthree\n"); editor.backspace(&Default::default(), window, cx); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one\ntwo.foo()\nthree\n"); assert_eq!(editor.text(cx), "one\nt\nthree\n"); editor.backspace(&Default::default(), window, cx); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one\ntwo.foo()\nthree\n"); assert_eq!(editor.text(cx), "one\n\nthree\n"); // Deleting across the original suggestion range invalidates it. editor.backspace(&Default::default(), window, cx); - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one\nthree\n"); assert_eq!(editor.text(cx), "one\nthree\n"); // Undoing the deletion restores the suggestion. editor.undo(&Default::default(), window, cx); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one\ntwo.foo()\nthree\n"); assert_eq!(editor.text(cx), "one\n\nthree\n"); }); @@ -775,7 +783,7 @@ mod tests { }); executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); _ = editor.update(cx, |editor, _, cx| { - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!( editor.display_text(cx), "\n\na = 1\nb = 2 + a\n\n\n\nc = 3\nd = 4\n" @@ -797,7 +805,7 @@ mod tests { editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { s.select_ranges([Point::new(4, 5)..Point::new(4, 5)]) }); - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); assert_eq!( editor.display_text(cx), "\n\na = 1\nb = 2\n\n\n\nc = 3\nd = 4\n" @@ -806,7 +814,7 @@ mod tests { // Type a character, ensuring we don't even try to interpolate the previous suggestion. editor.handle_input(" ", window, cx); - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); assert_eq!( editor.display_text(cx), "\n\na = 1\nb = 2\n\n\n\nc = 3\nd = 4 \n" @@ -817,7 +825,7 @@ mod tests { // Ensure the new suggestion is displayed when the debounce timeout expires. executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); _ = editor.update(cx, |editor, _, cx| { - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!( editor.display_text(cx), "\n\na = 1\nb = 2\n\n\n\nc = 3\nd = 4 + c\n" @@ -880,7 +888,7 @@ mod tests { executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, _, cx| { assert!(!editor.context_menu_visible()); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one\ntwo.foo()\nthree\n"); assert_eq!(editor.text(cx), "one\ntw\nthree\n"); }); @@ -907,7 +915,7 @@ mod tests { executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, _, cx| { assert!(!editor.context_menu_visible()); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one\ntwo.foo()\nthree\n"); assert_eq!(editor.text(cx), "one\ntwo\nthree\n"); }); @@ -934,8 +942,9 @@ mod tests { executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, _, cx| { assert!(editor.context_menu_visible()); - assert!(!editor.has_active_inline_completion(),); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.text(cx), "one\ntwo.\nthree\n"); + assert_eq!(editor.display_text(cx), "one\ntwo.\nthree\n"); }); } @@ -1023,7 +1032,7 @@ mod tests { editor.change_selections(SelectionEffects::no_scroll(), window, cx, |selections| { selections.select_ranges([Point::new(0, 0)..Point::new(0, 0)]) }); - editor.refresh_inline_completion(true, false, window, cx); + editor.refresh_edit_prediction(true, false, window, cx); }); executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); @@ -1033,7 +1042,7 @@ mod tests { editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { s.select_ranges([Point::new(5, 0)..Point::new(5, 0)]) }); - editor.refresh_inline_completion(true, false, window, cx); + editor.refresh_edit_prediction(true, false, window, cx); }); executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); @@ -1077,8 +1086,6 @@ mod tests { vec![complete_from_marker.clone(), replace_range_marker.clone()], ); - let complete_from_position = - cx.to_lsp(marked_ranges.remove(&complete_from_marker).unwrap()[0].start); let replace_range = cx.to_lsp_range(marked_ranges.remove(&replace_range_marker).unwrap()[0].clone()); @@ -1087,10 +1094,6 @@ mod tests { let completions = completions.clone(); async move { assert_eq!(params.text_document_position.text_document.uri, url.clone()); - assert_eq!( - params.text_document_position.position, - complete_from_position - ); Ok(Some(lsp::CompletionResponse::Array( completions .iter() diff --git a/crates/crashes/Cargo.toml b/crates/crashes/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..641a97765a70f75edbfe478d3a493322b9c443eb --- /dev/null +++ b/crates/crashes/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "crashes" +version = "0.1.0" +publish.workspace = true +edition.workspace = true +license = "GPL-3.0-or-later" + +[dependencies] +crash-handler.workspace = true +log.workspace = true +minidumper.workspace = true +paths.workspace = true +smol.workspace = true +workspace-hack.workspace = true + +[lints] +workspace = true + +[lib] +path = "src/crashes.rs" diff --git a/crates/crashes/LICENSE-GPL b/crates/crashes/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/crashes/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/crashes/src/crashes.rs b/crates/crashes/src/crashes.rs new file mode 100644 index 0000000000000000000000000000000000000000..cfb4b57d5dfdf9303fa799fbe0a4d200657612c0 --- /dev/null +++ b/crates/crashes/src/crashes.rs @@ -0,0 +1,172 @@ +use crash_handler::CrashHandler; +use log::info; +use minidumper::{Client, LoopAction, MinidumpBinary}; + +use std::{ + env, + fs::File, + io, + path::{Path, PathBuf}, + process::{self, Command}, + sync::{ + OnceLock, + atomic::{AtomicBool, Ordering}, + }, + thread, + time::Duration, +}; + +// set once the crash handler has initialized and the client has connected to it +pub static CRASH_HANDLER: AtomicBool = AtomicBool::new(false); +// set when the first minidump request is made to avoid generating duplicate crash reports +pub static REQUESTED_MINIDUMP: AtomicBool = AtomicBool::new(false); +const CRASH_HANDLER_TIMEOUT: Duration = Duration::from_secs(60); + +pub async fn init(id: String) { + let exe = env::current_exe().expect("unable to find ourselves"); + let zed_pid = process::id(); + // TODO: we should be able to get away with using 1 crash-handler process per machine, + // but for now we append the PID of the current process which makes it unique per remote + // server or interactive zed instance. This solves an issue where occasionally the socket + // used by the crash handler isn't destroyed correctly which causes it to stay on the file + // system and block further attempts to initialize crash handlers with that socket path. + let socket_name = paths::temp_dir().join(format!("zed-crash-handler-{zed_pid}")); + #[allow(unused)] + let server_pid = Command::new(exe) + .arg("--crash-handler") + .arg(&socket_name) + .spawn() + .expect("unable to spawn server process") + .id(); + info!("spawning crash handler process"); + + let mut elapsed = Duration::ZERO; + let retry_frequency = Duration::from_millis(100); + let mut maybe_client = None; + while maybe_client.is_none() { + if let Ok(client) = Client::with_name(socket_name.as_path()) { + maybe_client = Some(client); + info!("connected to crash handler process after {elapsed:?}"); + break; + } + elapsed += retry_frequency; + smol::Timer::after(retry_frequency).await; + } + let client = maybe_client.unwrap(); + client.send_message(1, id).unwrap(); // set session id on the server + + let client = std::sync::Arc::new(client); + let handler = crash_handler::CrashHandler::attach(unsafe { + let client = client.clone(); + crash_handler::make_crash_event(move |crash_context: &crash_handler::CrashContext| { + // only request a minidump once + let res = if REQUESTED_MINIDUMP + .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) + .is_ok() + { + client.send_message(2, "mistakes were made").unwrap(); + client.ping().unwrap(); + client.request_dump(crash_context).is_ok() + } else { + true + }; + crash_handler::CrashEventResult::Handled(res) + }) + }) + .expect("failed to attach signal handler"); + + #[cfg(target_os = "linux")] + { + handler.set_ptracer(Some(server_pid)); + } + CRASH_HANDLER.store(true, Ordering::Release); + std::mem::forget(handler); + info!("crash handler registered"); + + loop { + client.ping().ok(); + smol::Timer::after(Duration::from_secs(10)).await; + } +} + +pub struct CrashServer { + session_id: OnceLock, +} + +impl minidumper::ServerHandler for CrashServer { + fn create_minidump_file(&self) -> Result<(File, PathBuf), io::Error> { + let err_message = "Need to send a message with the ID upon starting the crash handler"; + let dump_path = paths::logs_dir() + .join(self.session_id.get().expect(err_message)) + .with_extension("dmp"); + let file = File::create(&dump_path)?; + Ok((file, dump_path)) + } + + fn on_minidump_created(&self, result: Result) -> LoopAction { + match result { + Ok(mut md_bin) => { + use io::Write; + let _ = md_bin.file.flush(); + info!("wrote minidump to disk {:?}", md_bin.path); + } + Err(e) => { + info!("failed to write minidump: {:#}", e); + } + } + LoopAction::Exit + } + + fn on_message(&self, kind: u32, buffer: Vec) { + let message = String::from_utf8(buffer).expect("invalid utf-8"); + info!("kind: {kind}, message: {message}",); + if kind == 1 { + self.session_id + .set(message) + .expect("session id already initialized"); + } + } + + fn on_client_disconnected(&self, clients: usize) -> LoopAction { + info!("client disconnected, {clients} remaining"); + if clients == 0 { + LoopAction::Exit + } else { + LoopAction::Continue + } + } +} + +pub fn handle_panic() { + // wait 500ms for the crash handler process to start up + // if it's still not there just write panic info and no minidump + let retry_frequency = Duration::from_millis(100); + for _ in 0..5 { + if CRASH_HANDLER.load(Ordering::Acquire) { + log::error!("triggering a crash to generate a minidump..."); + #[cfg(target_os = "linux")] + CrashHandler.simulate_signal(crash_handler::Signal::Trap as u32); + #[cfg(not(target_os = "linux"))] + CrashHandler.simulate_exception(None); + break; + } + thread::sleep(retry_frequency); + } +} + +pub fn crash_server(socket: &Path) { + let Ok(mut server) = minidumper::Server::with_name(socket) else { + log::info!("Couldn't create socket, there may already be a running crash server"); + return; + }; + let ab = AtomicBool::new(false); + server + .run( + Box::new(CrashServer { + session_id: OnceLock::new(), + }), + &ab, + Some(CRASH_HANDLER_TIMEOUT), + ) + .expect("failed to run server"); +} diff --git a/crates/dap/src/adapters.rs b/crates/dap/src/adapters.rs index d9f26b3b348985f2e52423cb217b1c1446960bbf..687305ae94da3bc1ddd72e9e9f4594f4f4a19ee4 100644 --- a/crates/dap/src/adapters.rs +++ b/crates/dap/src/adapters.rs @@ -74,6 +74,12 @@ impl Borrow for DebugAdapterName { } } +impl Borrow for DebugAdapterName { + fn borrow(&self) -> &SharedString { + &self.0 + } +} + impl std::fmt::Display for DebugAdapterName { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(&self.0, f) @@ -378,6 +384,14 @@ pub trait DebugAdapter: 'static + Send + Sync { fn label_for_child_session(&self, _args: &StartDebuggingRequestArguments) -> Option { None } + + fn compact_child_session(&self) -> bool { + false + } + + fn prefer_thread_name(&self) -> bool { + false + } } #[cfg(any(test, feature = "test-support"))] @@ -442,10 +456,18 @@ impl DebugAdapter for FakeAdapter { _: Option>, _: &mut AsyncApp, ) -> Result { + let connection = task_definition + .tcp_connection + .as_ref() + .map(|connection| TcpArguments { + host: connection.host(), + port: connection.port.unwrap_or(17), + timeout: connection.timeout, + }); Ok(DebugAdapterBinary { command: Some("command".into()), arguments: vec![], - connection: None, + connection, envs: HashMap::default(), cwd: None, request_args: StartDebuggingRequestArguments { diff --git a/crates/dap/src/client.rs b/crates/dap/src/client.rs index ff082e3b765b0baac294cf310a50b54534ae9bd1..7b791450ecba3b09b6571ac84fbebdf92fff57b8 100644 --- a/crates/dap/src/client.rs +++ b/crates/dap/src/client.rs @@ -2,7 +2,7 @@ use crate::{ adapters::DebugAdapterBinary, transport::{IoKind, LogKind, TransportDelegate}, }; -use anyhow::{Context as _, Result}; +use anyhow::Result; use dap_types::{ messages::{Message, Response}, requests::Request, @@ -110,9 +110,7 @@ impl DebugAdapterClient { self.transport_delegate .pending_requests .lock() - .as_mut() - .context("client is closed")? - .insert(sequence_id, callback_tx); + .insert(sequence_id, callback_tx)?; log::debug!( "Client {} send `{}` request with sequence_id: {}", @@ -170,6 +168,7 @@ impl DebugAdapterClient { pub fn kill(&self) { log::debug!("Killing DAP process"); self.transport_delegate.transport.lock().kill(); + self.transport_delegate.pending_requests.lock().shutdown(); } pub fn has_adapter_logs(&self) -> bool { @@ -184,11 +183,34 @@ impl DebugAdapterClient { } #[cfg(any(test, feature = "test-support"))] - pub fn on_request(&self, handler: F) + pub fn on_request(&self, mut handler: F) where F: 'static + Send + FnMut(u64, R::Arguments) -> Result, + { + use crate::transport::RequestHandling; + + self.transport_delegate + .transport + .lock() + .as_fake() + .on_request::(move |seq, request| { + RequestHandling::Respond(handler(seq, request)) + }); + } + + #[cfg(any(test, feature = "test-support"))] + pub fn on_request_ext(&self, handler: F) + where + F: 'static + + Send + + FnMut( + u64, + R::Arguments, + ) -> crate::transport::RequestHandling< + Result, + >, { self.transport_delegate .transport @@ -273,7 +295,7 @@ mod tests { request: dap_types::StartDebuggingRequestArgumentsRequest::Launch, }, }, - Box::new(|_| panic!("Did not expect to hit this code path")), + Box::new(|_| {}), &mut cx.to_async(), ) .await diff --git a/crates/dap/src/registry.rs b/crates/dap/src/registry.rs index 9435b16b924e43406d5ed99c864df78c179f27b1..212fa2bc239bb5180274ae482f3d39082a16dd3f 100644 --- a/crates/dap/src/registry.rs +++ b/crates/dap/src/registry.rs @@ -46,6 +46,7 @@ impl DapRegistry { let name = adapter.name(); let _previous_value = self.0.write().adapters.insert(name, adapter); } + pub fn add_locator(&self, locator: Arc) { self.0.write().locators.insert(locator.name(), locator); } @@ -86,7 +87,7 @@ impl DapRegistry { self.0.read().adapters.get(name).cloned() } - pub fn enumerate_adapters(&self) -> Vec { + pub fn enumerate_adapters>(&self) -> B { self.0.read().adapters.keys().cloned().collect() } } diff --git a/crates/dap/src/transport.rs b/crates/dap/src/transport.rs index 14370f66e458309e3551a769bd79d29529a0cf3d..f9fbbfc84295bfba946ad96b5eb701d13c6aa52c 100644 --- a/crates/dap/src/transport.rs +++ b/crates/dap/src/transport.rs @@ -49,6 +49,12 @@ pub enum IoKind { StdErr, } +#[cfg(any(test, feature = "test-support"))] +pub enum RequestHandling { + Respond(T), + Exit, +} + type LogHandlers = Arc>>; pub trait Transport: Send + Sync { @@ -76,7 +82,11 @@ async fn start( ) -> Result> { #[cfg(any(test, feature = "test-support"))] if cfg!(any(test, feature = "test-support")) { - return Ok(Box::new(FakeTransport::start(cx).await?)); + if let Some(connection) = binary.connection.clone() { + return Ok(Box::new(FakeTransport::start_tcp(connection, cx).await?)); + } else { + return Ok(Box::new(FakeTransport::start_stdio(cx).await?)); + } } if binary.connection.is_some() { @@ -90,11 +100,57 @@ async fn start( } } +pub(crate) struct PendingRequests { + inner: Option>>>, +} + +impl PendingRequests { + fn new() -> Self { + Self { + inner: Some(HashMap::default()), + } + } + + fn flush(&mut self, e: anyhow::Error) { + let Some(inner) = self.inner.as_mut() else { + return; + }; + for (_, sender) in inner.drain() { + sender.send(Err(e.cloned())).ok(); + } + } + + pub(crate) fn insert( + &mut self, + sequence_id: u64, + callback_tx: oneshot::Sender>, + ) -> anyhow::Result<()> { + let Some(inner) = self.inner.as_mut() else { + bail!("client is closed") + }; + inner.insert(sequence_id, callback_tx); + Ok(()) + } + + pub(crate) fn remove( + &mut self, + sequence_id: u64, + ) -> anyhow::Result>>> { + let Some(inner) = self.inner.as_mut() else { + bail!("client is closed"); + }; + Ok(inner.remove(&sequence_id)) + } + + pub(crate) fn shutdown(&mut self) { + self.flush(anyhow!("transport shutdown")); + self.inner = None; + } +} + pub(crate) struct TransportDelegate { log_handlers: LogHandlers, - // TODO this should really be some kind of associative channel - pub(crate) pending_requests: - Arc>>>>>, + pub(crate) pending_requests: Arc>, pub(crate) transport: Mutex>, pub(crate) server_tx: smol::lock::Mutex>>, tasks: Mutex>>, @@ -108,7 +164,7 @@ impl TransportDelegate { transport: Mutex::new(transport), log_handlers, server_tx: Default::default(), - pending_requests: Arc::new(Mutex::new(Some(HashMap::default()))), + pending_requests: Arc::new(Mutex::new(PendingRequests::new())), tasks: Default::default(), }) } @@ -151,24 +207,10 @@ impl TransportDelegate { Ok(()) => { pending_requests .lock() - .take() - .into_iter() - .flatten() - .for_each(|(_, request)| { - request - .send(Err(anyhow!("debugger shutdown unexpectedly"))) - .ok(); - }); + .flush(anyhow!("debugger shutdown unexpectedly")); } Err(e) => { - pending_requests - .lock() - .take() - .into_iter() - .flatten() - .for_each(|(_, request)| { - request.send(Err(e.cloned())).ok(); - }); + pending_requests.lock().flush(e); } } })); @@ -286,7 +328,7 @@ impl TransportDelegate { async fn recv_from_server( server_stdout: Stdout, mut message_handler: DapMessageHandler, - pending_requests: Arc>>>>>, + pending_requests: Arc>, log_handlers: Option, ) -> Result<()> where @@ -303,14 +345,10 @@ impl TransportDelegate { ConnectionResult::Timeout => anyhow::bail!("Timed out when connecting to debugger"), ConnectionResult::ConnectionReset => { log::info!("Debugger closed the connection"); - break Ok(()); + return Ok(()); } ConnectionResult::Result(Ok(Message::Response(res))) => { - let tx = pending_requests - .lock() - .as_mut() - .context("client is closed")? - .remove(&res.request_seq); + let tx = pending_requests.lock().remove(res.request_seq)?; if let Some(tx) = tx { if let Err(e) = tx.send(Self::process_response(res)) { log::trace!("Did not send response `{:?}` for a cancelled", e); @@ -704,8 +742,7 @@ impl Drop for StdioTransport { } #[cfg(any(test, feature = "test-support"))] -type RequestHandler = - Box dap_types::messages::Response>; +type RequestHandler = Box RequestHandling>; #[cfg(any(test, feature = "test-support"))] type ResponseHandler = Box; @@ -716,23 +753,38 @@ pub struct FakeTransport { request_handlers: Arc>>, // for reverse request responses response_handlers: Arc>>, - - stdin_writer: Option, - stdout_reader: Option, message_handler: Option>>, + kind: FakeTransportKind, +} + +#[cfg(any(test, feature = "test-support"))] +pub enum FakeTransportKind { + Stdio { + stdin_writer: Option, + stdout_reader: Option, + }, + Tcp { + connection: TcpArguments, + executor: BackgroundExecutor, + }, } #[cfg(any(test, feature = "test-support"))] impl FakeTransport { pub fn on_request(&self, mut handler: F) where - F: 'static + Send + FnMut(u64, R::Arguments) -> Result, + F: 'static + + Send + + FnMut(u64, R::Arguments) -> RequestHandling>, { self.request_handlers.lock().insert( R::COMMAND, Box::new(move |seq, args| { let result = handler(seq, serde_json::from_value(args).unwrap()); - let response = match result { + let RequestHandling::Respond(response) = result else { + return RequestHandling::Exit; + }; + let response = match response { Ok(response) => Response { seq: seq + 1, request_seq: seq, @@ -750,7 +802,7 @@ impl FakeTransport { message: None, }, }; - response + RequestHandling::Respond(response) }), ); } @@ -764,86 +816,76 @@ impl FakeTransport { .insert(R::COMMAND, Box::new(handler)); } - async fn start(cx: &mut AsyncApp) -> Result { - use dap_types::requests::{Request, RunInTerminal, StartDebugging}; - use serde_json::json; - - let (stdin_writer, stdin_reader) = async_pipe::pipe(); - let (stdout_writer, stdout_reader) = async_pipe::pipe(); - - let mut this = Self { + async fn start_tcp(connection: TcpArguments, cx: &mut AsyncApp) -> Result { + Ok(Self { request_handlers: Arc::new(Mutex::new(HashMap::default())), response_handlers: Arc::new(Mutex::new(HashMap::default())), - stdin_writer: Some(stdin_writer), - stdout_reader: Some(stdout_reader), message_handler: None, - }; + kind: FakeTransportKind::Tcp { + connection, + executor: cx.background_executor().clone(), + }, + }) + } - let request_handlers = this.request_handlers.clone(); - let response_handlers = this.response_handlers.clone(); + async fn handle_messages( + request_handlers: Arc>>, + response_handlers: Arc>>, + stdin_reader: PipeReader, + stdout_writer: PipeWriter, + ) -> Result<()> { + use dap_types::requests::{Request, RunInTerminal, StartDebugging}; + use serde_json::json; + + let mut reader = BufReader::new(stdin_reader); let stdout_writer = Arc::new(smol::lock::Mutex::new(stdout_writer)); + let mut buffer = String::new(); - this.message_handler = Some(cx.background_spawn(async move { - let mut reader = BufReader::new(stdin_reader); - let mut buffer = String::new(); + loop { + match TransportDelegate::receive_server_message(&mut reader, &mut buffer, None).await { + ConnectionResult::Timeout => { + anyhow::bail!("Timed out when connecting to debugger"); + } + ConnectionResult::ConnectionReset => { + log::info!("Debugger closed the connection"); + break Ok(()); + } + ConnectionResult::Result(Err(e)) => break Err(e), + ConnectionResult::Result(Ok(message)) => { + match message { + Message::Request(request) => { + // redirect reverse requests to stdout writer/reader + if request.command == RunInTerminal::COMMAND + || request.command == StartDebugging::COMMAND + { + let message = + serde_json::to_string(&Message::Request(request)).unwrap(); - loop { - match TransportDelegate::receive_server_message(&mut reader, &mut buffer, None) - .await - { - ConnectionResult::Timeout => { - anyhow::bail!("Timed out when connecting to debugger"); - } - ConnectionResult::ConnectionReset => { - log::info!("Debugger closed the connection"); - break Ok(()); - } - ConnectionResult::Result(Err(e)) => break Err(e), - ConnectionResult::Result(Ok(message)) => { - match message { - Message::Request(request) => { - // redirect reverse requests to stdout writer/reader - if request.command == RunInTerminal::COMMAND - || request.command == StartDebugging::COMMAND + let mut writer = stdout_writer.lock().await; + writer + .write_all( + TransportDelegate::build_rpc_message(message).as_bytes(), + ) + .await + .unwrap(); + writer.flush().await.unwrap(); + } else { + let response = if let Some(handle) = + request_handlers.lock().get_mut(request.command.as_str()) { - let message = - serde_json::to_string(&Message::Request(request)).unwrap(); - - let mut writer = stdout_writer.lock().await; - writer - .write_all( - TransportDelegate::build_rpc_message(message) - .as_bytes(), - ) - .await - .unwrap(); - writer.flush().await.unwrap(); + handle(request.seq, request.arguments.unwrap_or(json!({}))) } else { - let response = if let Some(handle) = - request_handlers.lock().get_mut(request.command.as_str()) - { - handle(request.seq, request.arguments.unwrap_or(json!({}))) - } else { - panic!("No request handler for {}", request.command); - }; - let message = - serde_json::to_string(&Message::Response(response)) - .unwrap(); - - let mut writer = stdout_writer.lock().await; - writer - .write_all( - TransportDelegate::build_rpc_message(message) - .as_bytes(), - ) - .await - .unwrap(); - writer.flush().await.unwrap(); - } - } - Message::Event(event) => { + panic!("No request handler for {}", request.command); + }; + let response = match response { + RequestHandling::Respond(response) => response, + RequestHandling::Exit => { + break Err(anyhow!("exit in response to request")); + } + }; + let success = response.success; let message = - serde_json::to_string(&Message::Event(event)).unwrap(); + serde_json::to_string(&Message::Response(response)).unwrap(); let mut writer = stdout_writer.lock().await; writer @@ -852,22 +894,77 @@ impl FakeTransport { ) .await .unwrap(); - writer.flush().await.unwrap(); - } - Message::Response(response) => { - if let Some(handle) = - response_handlers.lock().get(response.command.as_str()) + + if request.command == dap_types::requests::Initialize::COMMAND + && success { - handle(response); - } else { - log::error!("No response handler for {}", response.command); + let message = serde_json::to_string(&Message::Event(Box::new( + dap_types::messages::Events::Initialized(Some( + Default::default(), + )), + ))) + .unwrap(); + writer + .write_all( + TransportDelegate::build_rpc_message(message) + .as_bytes(), + ) + .await + .unwrap(); } + + writer.flush().await.unwrap(); + } + } + Message::Event(event) => { + let message = serde_json::to_string(&Message::Event(event)).unwrap(); + + let mut writer = stdout_writer.lock().await; + writer + .write_all(TransportDelegate::build_rpc_message(message).as_bytes()) + .await + .unwrap(); + writer.flush().await.unwrap(); + } + Message::Response(response) => { + if let Some(handle) = + response_handlers.lock().get(response.command.as_str()) + { + handle(response); + } else { + log::error!("No response handler for {}", response.command); } } } } } - })); + } + } + + async fn start_stdio(cx: &mut AsyncApp) -> Result { + let (stdin_writer, stdin_reader) = async_pipe::pipe(); + let (stdout_writer, stdout_reader) = async_pipe::pipe(); + let kind = FakeTransportKind::Stdio { + stdin_writer: Some(stdin_writer), + stdout_reader: Some(stdout_reader), + }; + + let mut this = Self { + request_handlers: Arc::new(Mutex::new(HashMap::default())), + response_handlers: Arc::new(Mutex::new(HashMap::default())), + message_handler: None, + kind, + }; + + let request_handlers = this.request_handlers.clone(); + let response_handlers = this.response_handlers.clone(); + + this.message_handler = Some(cx.background_spawn(Self::handle_messages( + request_handlers, + response_handlers, + stdin_reader, + stdout_writer, + ))); Ok(this) } @@ -876,7 +973,10 @@ impl FakeTransport { #[cfg(any(test, feature = "test-support"))] impl Transport for FakeTransport { fn tcp_arguments(&self) -> Option { - None + match &self.kind { + FakeTransportKind::Stdio { .. } => None, + FakeTransportKind::Tcp { connection, .. } => Some(connection.clone()), + } } fn connect( @@ -887,12 +987,33 @@ impl Transport for FakeTransport { Box, )>, > { - let result = util::maybe!({ - Ok(( - Box::new(self.stdin_writer.take().context("Cannot reconnect")?) as _, - Box::new(self.stdout_reader.take().context("Cannot reconnect")?) as _, - )) - }); + let result = match &mut self.kind { + FakeTransportKind::Stdio { + stdin_writer, + stdout_reader, + } => util::maybe!({ + Ok(( + Box::new(stdin_writer.take().context("Cannot reconnect")?) as _, + Box::new(stdout_reader.take().context("Cannot reconnect")?) as _, + )) + }), + FakeTransportKind::Tcp { executor, .. } => { + let (stdin_writer, stdin_reader) = async_pipe::pipe(); + let (stdout_writer, stdout_reader) = async_pipe::pipe(); + + let request_handlers = self.request_handlers.clone(); + let response_handlers = self.response_handlers.clone(); + + self.message_handler = Some(executor.spawn(Self::handle_messages( + request_handlers, + response_handlers, + stdin_reader, + stdout_writer, + ))); + + Ok((Box::new(stdin_writer) as _, Box::new(stdout_reader) as _)) + } + }; Task::ready(result) } diff --git a/crates/dap_adapters/Cargo.toml b/crates/dap_adapters/Cargo.toml index 65544fbb6a1b7565c4fe641058e4e6c725b21016..e7366785c810077ef2bdc3669dd5b340859c97a6 100644 --- a/crates/dap_adapters/Cargo.toml +++ b/crates/dap_adapters/Cargo.toml @@ -36,6 +36,7 @@ paths.workspace = true serde.workspace = true serde_json.workspace = true shlex.workspace = true +smol.workspace = true task.workspace = true util.workspace = true workspace-hack.workspace = true diff --git a/crates/dap_adapters/src/dap_adapters.rs b/crates/dap_adapters/src/dap_adapters.rs index a147861f8dc965c7924a70d884004d594d59a949..a4e6beb2495ebe1eec9f08ddb8394b498c0ae410 100644 --- a/crates/dap_adapters/src/dap_adapters.rs +++ b/crates/dap_adapters/src/dap_adapters.rs @@ -13,7 +13,6 @@ use dap::{ DapRegistry, adapters::{ self, AdapterVersion, DapDelegate, DebugAdapter, DebugAdapterBinary, DebugAdapterName, - GithubRepo, }, configure_tcp_connection, }; diff --git a/crates/dap_adapters/src/go.rs b/crates/dap_adapters/src/go.rs index d32f5cbf3426f1b669132e74e389862e7944267b..22d8262b93e36b17e548ae4dcc9bb725da8ca7cb 100644 --- a/crates/dap_adapters/src/go.rs +++ b/crates/dap_adapters/src/go.rs @@ -547,6 +547,7 @@ async fn handle_envs( } }; + let mut env_vars = HashMap::default(); for path in env_files { let Some(path) = path .and_then(|s| PathBuf::from_str(s).ok()) @@ -556,13 +557,33 @@ async fn handle_envs( }; if let Ok(file) = fs.open_sync(&path).await { - envs.extend(dotenvy::from_read_iter(file).filter_map(Result::ok)) + let file_envs: HashMap = dotenvy::from_read_iter(file) + .filter_map(Result::ok) + .collect(); + envs.extend(file_envs.iter().map(|(k, v)| (k.clone(), v.clone()))); + env_vars.extend(file_envs); } else { warn!("While starting Go debug session: failed to read env file {path:?}"); }; } + let mut env_obj: serde_json::Map = serde_json::Map::new(); + + for (k, v) in env_vars { + env_obj.insert(k, Value::String(v)); + } + + if let Some(existing_env) = config.get("env").and_then(|v| v.as_object()) { + for (k, v) in existing_env { + env_obj.insert(k.clone(), v.clone()); + } + } + + if !env_obj.is_empty() { + config.insert("env".to_string(), Value::Object(env_obj)); + } + // remove envFile now that it's been handled - config.remove("entry"); + config.remove("envFile"); Some(()) } diff --git a/crates/dap_adapters/src/javascript.rs b/crates/dap_adapters/src/javascript.rs index 76c1d1fb7bb3b2b3a534293957b43919a079a888..2d19921a0f0c979fe53ede5860ac0c4d26b510c3 100644 --- a/crates/dap_adapters/src/javascript.rs +++ b/crates/dap_adapters/src/javascript.rs @@ -54,20 +54,6 @@ impl JsDebugAdapter { user_args: Option>, _: &mut AsyncApp, ) -> Result { - let adapter_path = if let Some(user_installed_path) = user_installed_path { - user_installed_path - } else { - let adapter_path = paths::debug_adapters_dir().join(self.name().as_ref()); - - let file_name_prefix = format!("{}_", self.name()); - - util::fs::find_file_name_in_dir(adapter_path.as_path(), |file_name| { - file_name.starts_with(&file_name_prefix) - }) - .await - .context("Couldn't find JavaScript dap directory")? - }; - let tcp_connection = task_definition.tcp_connection.clone().unwrap_or_default(); let (host, port, timeout) = crate::configure_tcp_connection(tcp_connection).await?; @@ -136,21 +122,27 @@ impl JsDebugAdapter { .or_insert(true.into()); } + let adapter_path = if let Some(user_installed_path) = user_installed_path { + user_installed_path + } else { + let adapter_path = paths::debug_adapters_dir().join(self.name().as_ref()); + + let file_name_prefix = format!("{}_", self.name()); + + util::fs::find_file_name_in_dir(adapter_path.as_path(), |file_name| { + file_name.starts_with(&file_name_prefix) + }) + .await + .context("Couldn't find JavaScript dap directory")? + .join(Self::ADAPTER_PATH) + }; + let arguments = if let Some(mut args) = user_args { - args.insert( - 0, - adapter_path - .join(Self::ADAPTER_PATH) - .to_string_lossy() - .to_string(), - ); + args.insert(0, adapter_path.to_string_lossy().to_string()); args } else { vec![ - adapter_path - .join(Self::ADAPTER_PATH) - .to_string_lossy() - .to_string(), + adapter_path.to_string_lossy().to_string(), port.to_string(), host.to_string(), ] @@ -534,6 +526,14 @@ impl DebugAdapter for JsDebugAdapter { .filter(|name| !name.is_empty())?; Some(label.to_owned()) } + + fn compact_child_session(&self) -> bool { + true + } + + fn prefer_thread_name(&self) -> bool { + true + } } fn normalize_task_type(task_type: &mut Value) { diff --git a/crates/dap_adapters/src/python.rs b/crates/dap_adapters/src/python.rs index dc3d15e124578e183ba5ed09b80aee7d6dda54c8..461ce6fbb3176508d74524f8159eb6a0cc448932 100644 --- a/crates/dap_adapters/src/python.rs +++ b/crates/dap_adapters/src/python.rs @@ -1,31 +1,37 @@ use crate::*; use anyhow::Context as _; -use dap::adapters::latest_github_release; use dap::{DebugRequest, StartDebuggingRequestArguments, adapters::DebugTaskDefinition}; -use gpui::{AppContext, AsyncApp, SharedString}; +use fs::RemoveOptions; +use futures::{StreamExt, TryStreamExt}; +use gpui::http_client::AsyncBody; +use gpui::{AsyncApp, SharedString}; use json_dotpath::DotPaths; -use language::{LanguageName, Toolchain}; +use language::LanguageName; +use paths::debug_adapters_dir; use serde_json::Value; +use smol::fs::File; +use smol::io::AsyncReadExt; +use smol::lock::OnceCell; +use std::ffi::OsString; use std::net::Ipv4Addr; +use std::str::FromStr; use std::{ collections::HashMap, ffi::OsStr, path::{Path, PathBuf}, - sync::OnceLock, }; -use util::ResultExt; +use util::{ResultExt, maybe}; #[derive(Default)] pub(crate) struct PythonDebugAdapter { - checked: OnceLock<()>, + debugpy_whl_base_path: OnceCell, String>>, } impl PythonDebugAdapter { const ADAPTER_NAME: &'static str = "Debugpy"; const DEBUG_ADAPTER_NAME: DebugAdapterName = DebugAdapterName(SharedString::new_static(Self::ADAPTER_NAME)); - const ADAPTER_PACKAGE_NAME: &'static str = "debugpy"; - const ADAPTER_PATH: &'static str = "src/debugpy/adapter"; + const LANGUAGE_NAME: &'static str = "Python"; async fn generate_debugpy_arguments( @@ -33,43 +39,22 @@ impl PythonDebugAdapter { port: u16, user_installed_path: Option<&Path>, user_args: Option>, - installed_in_venv: bool, ) -> Result> { let mut args = if let Some(user_installed_path) = user_installed_path { log::debug!( "Using user-installed debugpy adapter from: {}", user_installed_path.display() ); - vec![ - user_installed_path - .join(Self::ADAPTER_PATH) - .to_string_lossy() - .to_string(), - ] - } else if installed_in_venv { - log::debug!("Using venv-installed debugpy"); - vec!["-m".to_string(), "debugpy.adapter".to_string()] + vec![user_installed_path.to_string_lossy().to_string()] } else { let adapter_path = paths::debug_adapters_dir().join(Self::DEBUG_ADAPTER_NAME.as_ref()); - let file_name_prefix = format!("{}_", Self::ADAPTER_NAME); - - let debugpy_dir = - util::fs::find_file_name_in_dir(adapter_path.as_path(), |file_name| { - file_name.starts_with(&file_name_prefix) - }) - .await - .context("Debugpy directory not found")?; - - log::debug!( - "Using GitHub-downloaded debugpy adapter from: {}", - debugpy_dir.display() - ); - vec![ - debugpy_dir - .join(Self::ADAPTER_PATH) - .to_string_lossy() - .to_string(), - ] + let path = adapter_path + .join("debugpy") + .join("adapter") + .to_string_lossy() + .into_owned(); + log::debug!("Using pip debugpy adapter from: {path}"); + vec![path] }; args.extend(if let Some(args) = user_args { @@ -105,44 +90,144 @@ impl PythonDebugAdapter { request, }) } - async fn fetch_latest_adapter_version( - &self, - delegate: &Arc, - ) -> Result { - let github_repo = GithubRepo { - repo_name: Self::ADAPTER_PACKAGE_NAME.into(), - repo_owner: "microsoft".into(), - }; - fetch_latest_adapter_version_from_github(github_repo, delegate.as_ref()).await - } + async fn fetch_wheel(delegate: &Arc) -> Result, String> { + let system_python = Self::system_python_name(delegate) + .await + .ok_or_else(|| String::from("Could not find a Python installation"))?; + let command: &OsStr = system_python.as_ref(); + let download_dir = debug_adapters_dir().join(Self::ADAPTER_NAME).join("wheels"); + std::fs::create_dir_all(&download_dir).map_err(|e| e.to_string())?; + let installation_succeeded = util::command::new_smol_command(command) + .args([ + "-m", + "pip", + "download", + "debugpy", + "--only-binary=:all:", + "-d", + download_dir.to_string_lossy().as_ref(), + ]) + .output() + .await + .map_err(|e| format!("{e}"))? + .status + .success(); + if !installation_succeeded { + return Err("debugpy installation failed".into()); + } + + let wheel_path = std::fs::read_dir(&download_dir) + .map_err(|e| e.to_string())? + .find_map(|entry| { + entry.ok().filter(|e| { + e.file_type().is_ok_and(|typ| typ.is_file()) + && Path::new(&e.file_name()).extension() == Some("whl".as_ref()) + }) + }) + .ok_or_else(|| String::from("Did not find a .whl in {download_dir}"))?; - async fn install_binary( - adapter_name: DebugAdapterName, - version: AdapterVersion, - delegate: Arc, - ) -> Result<()> { - let version_path = adapters::download_adapter_from_github( - adapter_name, - version, - adapters::DownloadedFileType::GzipTar, - delegate.as_ref(), + util::archive::extract_zip( + &debug_adapters_dir().join(Self::ADAPTER_NAME), + File::open(&wheel_path.path()) + .await + .map_err(|e| e.to_string())?, ) - .await?; - // only needed when you install the latest version for the first time - if let Some(debugpy_dir) = - util::fs::find_file_name_in_dir(version_path.as_path(), |file_name| { - file_name.starts_with("microsoft-debugpy-") + .await + .map_err(|e| e.to_string())?; + + Ok(Arc::from(wheel_path.path())) + } + + async fn maybe_fetch_new_wheel(delegate: &Arc) { + let latest_release = delegate + .http_client() + .get( + "https://pypi.org/pypi/debugpy/json", + AsyncBody::empty(), + false, + ) + .await + .log_err(); + maybe!(async move { + let response = latest_release.filter(|response| response.status().is_success())?; + + let mut output = String::new(); + response + .into_body() + .read_to_string(&mut output) + .await + .ok()?; + let as_json = serde_json::Value::from_str(&output).ok()?; + let latest_version = as_json.get("info").and_then(|info| { + info.get("version") + .and_then(|version| version.as_str()) + .map(ToOwned::to_owned) + })?; + let dist_info_dirname: OsString = format!("debugpy-{latest_version}.dist-info").into(); + let is_up_to_date = delegate + .fs() + .read_dir(&debug_adapters_dir().join(Self::ADAPTER_NAME)) + .await + .ok()? + .into_stream() + .any(async |entry| { + entry.is_ok_and(|e| e.file_name().is_some_and(|name| name == dist_info_dirname)) + }) + .await; + + if !is_up_to_date { + delegate + .fs() + .remove_dir( + &debug_adapters_dir().join(Self::ADAPTER_NAME), + RemoveOptions { + recursive: true, + ignore_if_not_exists: true, + }, + ) + .await + .ok()?; + Self::fetch_wheel(delegate).await.ok()?; + } + Some(()) + }) + .await; + } + + async fn fetch_debugpy_whl( + &self, + delegate: &Arc, + ) -> Result, String> { + self.debugpy_whl_base_path + .get_or_init(|| async move { + Self::maybe_fetch_new_wheel(delegate).await; + Ok(Arc::from( + debug_adapters_dir() + .join(Self::ADAPTER_NAME) + .join("debugpy") + .join("adapter") + .as_ref(), + )) }) .await - { - // TODO Debugger: Rename folder instead of moving all files to another folder - // We're doing unnecessary IO work right now - util::fs::move_folder_files_to_folder(debugpy_dir.as_path(), version_path.as_path()) - .await?; - } + .clone() + } - Ok(()) + async fn system_python_name(delegate: &Arc) -> Option { + const BINARY_NAMES: [&str; 3] = ["python3", "python", "py"]; + let mut name = None; + + for cmd in BINARY_NAMES { + name = delegate + .which(OsStr::new(cmd)) + .await + .map(|path| path.to_string_lossy().to_string()); + if name.is_some() { + break; + } + } + name } async fn get_installed_binary( @@ -151,28 +236,15 @@ impl PythonDebugAdapter { config: &DebugTaskDefinition, user_installed_path: Option, user_args: Option>, - toolchain: Option, - installed_in_venv: bool, + python_from_toolchain: Option, ) -> Result { - const BINARY_NAMES: [&str; 3] = ["python3", "python", "py"]; let tcp_connection = config.tcp_connection.clone().unwrap_or_default(); let (host, port, timeout) = crate::configure_tcp_connection(tcp_connection).await?; - let python_path = if let Some(toolchain) = toolchain { - Some(toolchain.path.to_string()) + let python_path = if let Some(toolchain) = python_from_toolchain { + Some(toolchain) } else { - let mut name = None; - - for cmd in BINARY_NAMES { - name = delegate - .which(OsStr::new(cmd)) - .await - .map(|path| path.to_string_lossy().to_string()); - if name.is_some() { - break; - } - } - name + Self::system_python_name(delegate).await }; let python_command = python_path.context("failed to find binary path for Python")?; @@ -183,7 +255,6 @@ impl PythonDebugAdapter { port, user_installed_path.as_deref(), user_args, - installed_in_venv, ) .await?; @@ -605,59 +676,52 @@ impl DebugAdapter for PythonDebugAdapter { local_path.display() ); return self - .get_installed_binary( - delegate, - &config, - Some(local_path.clone()), - user_args, - None, - false, - ) + .get_installed_binary(delegate, &config, Some(local_path.clone()), user_args, None) .await; } + let base_path = config + .config + .get("cwd") + .and_then(|cwd| { + cwd.as_str() + .map(Path::new)? + .strip_prefix(delegate.worktree_root_path()) + .ok() + }) + .unwrap_or_else(|| "".as_ref()) + .into(); let toolchain = delegate .toolchain_store() .active_toolchain( delegate.worktree_id(), - Arc::from("".as_ref()), + base_path, language::LanguageName::new(Self::LANGUAGE_NAME), cx, ) .await; + let debugpy_path = self + .fetch_debugpy_whl(delegate) + .await + .map_err(|e| anyhow::anyhow!("{e}"))?; if let Some(toolchain) = &toolchain { - if let Some(path) = Path::new(&toolchain.path.to_string()).parent() { - let debugpy_path = path.join("debugpy"); - if delegate.fs().is_file(&debugpy_path).await { - log::debug!( - "Found debugpy in toolchain environment: {}", - debugpy_path.display() - ); - return self - .get_installed_binary( - delegate, - &config, - None, - user_args, - Some(toolchain.clone()), - true, - ) - .await; - } - } - } - - if self.checked.set(()).is_ok() { - delegate.output_to_console(format!("Checking latest version of {}...", self.name())); - if let Some(version) = self.fetch_latest_adapter_version(delegate).await.log_err() { - cx.background_spawn(Self::install_binary(self.name(), version, delegate.clone())) - .await - .context("Failed to install debugpy")?; - } + log::debug!( + "Found debugpy in toolchain environment: {}", + debugpy_path.display() + ); + return self + .get_installed_binary( + delegate, + &config, + None, + user_args, + Some(toolchain.path.to_string()), + ) + .await; } - self.get_installed_binary(delegate, &config, None, user_args, toolchain, false) + self.get_installed_binary(delegate, &config, None, user_args, None) .await } @@ -671,26 +735,10 @@ impl DebugAdapter for PythonDebugAdapter { } } -async fn fetch_latest_adapter_version_from_github( - github_repo: GithubRepo, - delegate: &dyn DapDelegate, -) -> Result { - let release = latest_github_release( - &format!("{}/{}", github_repo.repo_owner, github_repo.repo_name), - false, - false, - delegate.http_client(), - ) - .await?; - - Ok(AdapterVersion { - tag_name: release.tag_name, - url: release.tarball_url, - }) -} - #[cfg(test)] mod tests { + use util::path; + use super::*; use std::{net::Ipv4Addr, path::PathBuf}; @@ -700,31 +748,25 @@ mod tests { let port = 5678; // Case 1: User-defined debugpy path (highest precedence) - let user_path = PathBuf::from("/custom/path/to/debugpy"); - let user_args = PythonDebugAdapter::generate_debugpy_arguments( - &host, - port, - Some(&user_path), - None, - false, - ) - .await - .unwrap(); - - // Case 2: Venv-installed debugpy (uses -m debugpy.adapter) - let venv_args = - PythonDebugAdapter::generate_debugpy_arguments(&host, port, None, None, true) + let user_path = PathBuf::from("/custom/path/to/debugpy/src/debugpy/adapter"); + let user_args = + PythonDebugAdapter::generate_debugpy_arguments(&host, port, Some(&user_path), None) .await .unwrap(); - assert!(user_args[0].ends_with("src/debugpy/adapter")); + // Case 2: Venv-installed debugpy (uses -m debugpy.adapter) + let venv_args = PythonDebugAdapter::generate_debugpy_arguments(&host, port, None, None) + .await + .unwrap(); + + assert_eq!(user_args[0], "/custom/path/to/debugpy/src/debugpy/adapter"); assert_eq!(user_args[1], "--host=127.0.0.1"); assert_eq!(user_args[2], "--port=5678"); - assert_eq!(venv_args[0], "-m"); - assert_eq!(venv_args[1], "debugpy.adapter"); - assert_eq!(venv_args[2], "--host=127.0.0.1"); - assert_eq!(venv_args[3], "--port=5678"); + let expected_suffix = path!("debug_adapters/Debugpy/debugpy/adapter"); + assert!(venv_args[0].ends_with(expected_suffix)); + assert_eq!(venv_args[1], "--host=127.0.0.1"); + assert_eq!(venv_args[2], "--port=5678"); // The same cases, with arguments overridden by the user let user_args = PythonDebugAdapter::generate_debugpy_arguments( @@ -732,7 +774,6 @@ mod tests { port, Some(&user_path), Some(vec!["foo".into()]), - false, ) .await .unwrap(); @@ -741,7 +782,6 @@ mod tests { port, None, Some(vec!["foo".into()]), - true, ) .await .unwrap(); @@ -749,9 +789,8 @@ mod tests { assert!(user_args[0].ends_with("src/debugpy/adapter")); assert_eq!(user_args[1], "foo"); - assert_eq!(venv_args[0], "-m"); - assert_eq!(venv_args[1], "debugpy.adapter"); - assert_eq!(venv_args[2], "foo"); + assert!(venv_args[0].ends_with(expected_suffix)); + assert_eq!(venv_args[1], "foo"); // Note: Case 3 (GitHub-downloaded debugpy) is not tested since this requires mocking the Github API. } diff --git a/crates/debugger_tools/src/dap_log.rs b/crates/debugger_tools/src/dap_log.rs index f2f193cad451772146f6fd39e13a75f29f13292b..b806381d251c6595a5dd12022dc3d1df8b71739f 100644 --- a/crates/debugger_tools/src/dap_log.rs +++ b/crates/debugger_tools/src/dap_log.rs @@ -32,12 +32,19 @@ use workspace::{ ui::{Button, Clickable, ContextMenu, Label, LabelCommon, PopoverMenu, h_flex}, }; +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum View { + AdapterLogs, + RpcMessages, + InitializationSequence, +} + struct DapLogView { editor: Entity, focus_handle: FocusHandle, log_store: Entity, editor_subscriptions: Vec, - current_view: Option<(SessionId, LogKind)>, + current_view: Option<(SessionId, View)>, project: Entity, _subscriptions: Vec, } @@ -77,6 +84,7 @@ struct DebugAdapterState { id: SessionId, log_messages: VecDeque, rpc_messages: RpcMessages, + session_label: SharedString, adapter_name: DebugAdapterName, has_adapter_logs: bool, is_terminated: bool, @@ -121,12 +129,18 @@ impl MessageKind { } impl DebugAdapterState { - fn new(id: SessionId, adapter_name: DebugAdapterName, has_adapter_logs: bool) -> Self { + fn new( + id: SessionId, + adapter_name: DebugAdapterName, + session_label: SharedString, + has_adapter_logs: bool, + ) -> Self { Self { id, log_messages: VecDeque::new(), rpc_messages: RpcMessages::new(), adapter_name, + session_label, has_adapter_logs, is_terminated: false, } @@ -371,18 +385,22 @@ impl LogStore { return None; }; - let (adapter_name, has_adapter_logs) = session.read_with(cx, |session, _| { - ( - session.adapter(), - session - .adapter_client() - .map_or(false, |client| client.has_adapter_logs()), - ) - }); + let (adapter_name, session_label, has_adapter_logs) = + session.read_with(cx, |session, _| { + ( + session.adapter(), + session.label(), + session + .adapter_client() + .map_or(false, |client| client.has_adapter_logs()), + ) + }); state.insert(DebugAdapterState::new( id.session_id, adapter_name, + session_label + .unwrap_or_else(|| format!("Session {} (child)", id.session_id.0).into()), has_adapter_logs, )); @@ -506,12 +524,13 @@ impl Render for DapLogToolbarItemView { current_client .map(|sub_item| { Cow::Owned(format!( - "{} ({}) - {}", + "{} - {} - {}", sub_item.adapter_name, - sub_item.session_id.0, + sub_item.session_label, match sub_item.selected_entry { - LogKind::Adapter => ADAPTER_LOGS, - LogKind::Rpc => RPC_MESSAGES, + View::AdapterLogs => ADAPTER_LOGS, + View::RpcMessages => RPC_MESSAGES, + View::InitializationSequence => INITIALIZATION_SEQUENCE, } )) }) @@ -529,8 +548,8 @@ impl Render for DapLogToolbarItemView { .pl_2() .child( Label::new(format!( - "{}. {}", - row.session_id.0, row.adapter_name, + "{} - {}", + row.adapter_name, row.session_label )) .color(workspace::ui::Color::Muted), ) @@ -669,9 +688,16 @@ impl DapLogView { let events_subscriptions = cx.subscribe(&log_store, |log_view, _, event, cx| match event { Event::NewLogEntry { id, entry, kind } => { - if log_view.current_view == Some((id.session_id, *kind)) - && log_view.project == *id.project - { + let is_current_view = match (log_view.current_view, *kind) { + (Some((i, View::AdapterLogs)), LogKind::Adapter) + | (Some((i, View::RpcMessages)), LogKind::Rpc) + if i == id.session_id => + { + log_view.project == *id.project + } + _ => false, + }; + if is_current_view { log_view.editor.update(cx, |editor, cx| { editor.set_read_only(false); let last_point = editor.buffer().read(cx).len(cx); @@ -768,10 +794,11 @@ impl DapLogView { .map(|state| DapMenuItem { session_id: state.id, adapter_name: state.adapter_name.clone(), + session_label: state.session_label.clone(), has_adapter_logs: state.has_adapter_logs, selected_entry: self .current_view - .map_or(LogKind::Adapter, |(_, kind)| kind), + .map_or(View::AdapterLogs, |(_, kind)| kind), }) .collect::>() }) @@ -789,7 +816,7 @@ impl DapLogView { .map(|state| log_contents(state.iter().cloned())) }); if let Some(rpc_log) = rpc_log { - self.current_view = Some((id.session_id, LogKind::Rpc)); + self.current_view = Some((id.session_id, View::RpcMessages)); let (editor, editor_subscriptions) = Self::editor_for_logs(rpc_log, window, cx); let language = self.project.read(cx).languages().language_for_name("JSON"); editor @@ -830,7 +857,7 @@ impl DapLogView { .map(|state| log_contents(state.iter().cloned())) }); if let Some(message_log) = message_log { - self.current_view = Some((id.session_id, LogKind::Adapter)); + self.current_view = Some((id.session_id, View::AdapterLogs)); let (editor, editor_subscriptions) = Self::editor_for_logs(message_log, window, cx); editor .read(cx) @@ -859,7 +886,7 @@ impl DapLogView { .map(|state| log_contents(state.iter().cloned())) }); if let Some(rpc_log) = rpc_log { - self.current_view = Some((id.session_id, LogKind::Rpc)); + self.current_view = Some((id.session_id, View::InitializationSequence)); let (editor, editor_subscriptions) = Self::editor_for_logs(rpc_log, window, cx); let language = self.project.read(cx).languages().language_for_name("JSON"); editor @@ -899,11 +926,12 @@ fn log_contents(lines: impl Iterator) -> String { } #[derive(Clone, PartialEq)] -pub(crate) struct DapMenuItem { - pub session_id: SessionId, - pub adapter_name: DebugAdapterName, - pub has_adapter_logs: bool, - pub selected_entry: LogKind, +struct DapMenuItem { + session_id: SessionId, + session_label: SharedString, + adapter_name: DebugAdapterName, + has_adapter_logs: bool, + selected_entry: View, } const ADAPTER_LOGS: &str = "Adapter Logs"; diff --git a/crates/debugger_ui/Cargo.toml b/crates/debugger_ui/Cargo.toml index fe9640b7b9e2276ab16066672d267cb4ce432c43..df4125860f4ab79ce3a55d6b5b4fbb8f8fc64e5e 100644 --- a/crates/debugger_ui/Cargo.toml +++ b/crates/debugger_ui/Cargo.toml @@ -35,22 +35,27 @@ command_palette_hooks.workspace = true dap.workspace = true dap_adapters = { workspace = true, optional = true } db.workspace = true +debugger_tools.workspace = true editor.workspace = true file_icons.workspace = true futures.workspace = true fuzzy.workspace = true gpui.workspace = true +hex.workspace = true indoc.workspace = true itertools.workspace = true language.workspace = true log.workspace = true menu.workspace = true +notifications.workspace = true parking_lot.workspace = true +parse_int.workspace = true paths.workspace = true picker.workspace = true pretty_assertions.workspace = true project.workspace = true rpc.workspace = true +schemars.workspace = true serde.workspace = true serde_json.workspace = true serde_json_lenient.workspace = true @@ -63,14 +68,13 @@ telemetry.workspace = true terminal_view.workspace = true text.workspace = true theme.workspace = true -tree-sitter.workspace = true tree-sitter-json.workspace = true +tree-sitter.workspace = true ui.workspace = true +unindent = { workspace = true, optional = true } util.workspace = true -workspace.workspace = true workspace-hack.workspace = true -debugger_tools.workspace = true -unindent = { workspace = true, optional = true } +workspace.workspace = true zed_actions.workspace = true [dev-dependencies] @@ -80,8 +84,8 @@ debugger_tools = { workspace = true, features = ["test-support"] } editor = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] } project = { workspace = true, features = ["test-support"] } +tree-sitter-go.workspace = true unindent.workspace = true util = { workspace = true, features = ["test-support"] } workspace = { workspace = true, features = ["test-support"] } zlog.workspace = true -tree-sitter-go.workspace = true diff --git a/crates/debugger_ui/src/debugger_panel.rs b/crates/debugger_ui/src/debugger_panel.rs index c90a2878e925ceb7d44a80c87bc8d9a7945b4fa0..0ac419580bd7db2d17b07b002b133163497a1e9d 100644 --- a/crates/debugger_ui/src/debugger_panel.rs +++ b/crates/debugger_ui/src/debugger_panel.rs @@ -2,6 +2,7 @@ use crate::persistence::DebuggerPaneItem; use crate::session::DebugSession; use crate::session::running::RunningState; use crate::session::running::breakpoint_list::BreakpointList; + use crate::{ ClearAllBreakpoints, Continue, CopyDebugAdapterArguments, Detach, FocusBreakpointList, FocusConsole, FocusFrames, FocusLoadedSources, FocusModules, FocusTerminal, FocusVariables, @@ -9,6 +10,7 @@ use crate::{ ToggleExpandItem, ToggleSessionPicker, ToggleThreadPicker, persistence, spawn_task_or_modal, }; use anyhow::{Context as _, Result, anyhow}; +use collections::IndexMap; use dap::adapters::DebugAdapterName; use dap::debugger_settings::DebugPanelDockPosition; use dap::{ @@ -26,7 +28,7 @@ use text::ToPoint as _; use itertools::Itertools as _; use language::Buffer; -use project::debugger::session::{Session, SessionStateEvent}; +use project::debugger::session::{Session, SessionQuirks, SessionState, SessionStateEvent}; use project::{DebugScenarioContext, Fs, ProjectPath, TaskSourceKind, WorktreeId}; use project::{Project, debugger::session::ThreadStatus}; use rpc::proto::{self}; @@ -35,7 +37,7 @@ use std::sync::{Arc, LazyLock}; use task::{DebugScenario, TaskContext}; use tree_sitter::{Query, StreamingIterator as _}; use ui::{ContextMenu, Divider, PopoverMenuHandle, Tooltip, prelude::*}; -use util::{ResultExt, maybe}; +use util::{ResultExt, debug_panic, maybe}; use workspace::SplitDirection; use workspace::item::SaveOptions; use workspace::{ @@ -63,13 +65,14 @@ pub enum DebugPanelEvent { pub struct DebugPanel { size: Pixels, - sessions: Vec>, active_session: Option>, project: Entity, workspace: WeakEntity, focus_handle: FocusHandle, context_menu: Option<(Entity, Point, Subscription)>, debug_scenario_scheduled_last: bool, + pub(crate) sessions_with_children: + IndexMap, Vec>>, pub(crate) thread_picker_menu_handle: PopoverMenuHandle, pub(crate) session_picker_menu_handle: PopoverMenuHandle, fs: Arc, @@ -100,7 +103,7 @@ impl DebugPanel { Self { size: px(300.), - sessions: vec![], + sessions_with_children: Default::default(), active_session: None, focus_handle, breakpoint_list: BreakpointList::new( @@ -138,8 +141,9 @@ impl DebugPanel { }); } - pub(crate) fn sessions(&self) -> Vec> { - self.sessions.clone() + #[cfg(test)] + pub(crate) fn sessions(&self) -> impl Iterator> { + self.sessions_with_children.keys().cloned() } pub fn active_session(&self) -> Option> { @@ -185,12 +189,20 @@ impl DebugPanel { cx: &mut Context, ) { let dap_store = self.project.read(cx).dap_store(); + let Some(adapter) = DapRegistry::global(cx).adapter(&scenario.adapter) else { + return; + }; + let quirks = SessionQuirks { + compact: adapter.compact_child_session(), + prefer_thread_name: adapter.prefer_thread_name(), + }; let session = dap_store.update(cx, |dap_store, cx| { dap_store.new_session( - scenario.label.clone(), + Some(scenario.label.clone()), DebugAdapterName(scenario.adapter.clone()), task_context.clone(), None, + quirks, cx, ) }); @@ -267,22 +279,34 @@ impl DebugPanel { } }); - cx.spawn(async move |_, cx| { - if let Err(error) = task.await { - log::error!("{error}"); - session - .update(cx, |session, cx| { - session - .console_output(cx) - .unbounded_send(format!("error: {}", error)) - .ok(); - session.shutdown(cx) - })? - .await; + let boot_task = cx.spawn({ + let session = session.clone(); + + async move |_, cx| { + if let Err(error) = task.await { + log::error!("{error}"); + session + .update(cx, |session, cx| { + session + .console_output(cx) + .unbounded_send(format!("error: {}", error)) + .ok(); + session.shutdown(cx) + })? + .await; + } + anyhow::Ok(()) } - anyhow::Ok(()) - }) - .detach_and_log_err(cx); + }); + + session.update(cx, |session, _| match &mut session.mode { + SessionState::Booting(state_task) => { + *state_task = Some(boot_task); + } + SessionState::Running(_) => { + debug_panic!("Session state should be in building because we are just starting it"); + } + }); } pub(crate) fn rerun_last_session( @@ -363,14 +387,15 @@ impl DebugPanel { }; let dap_store_handle = self.project.read(cx).dap_store().clone(); - let label = curr_session.read(cx).label().clone(); + let label = curr_session.read(cx).label(); + let quirks = curr_session.read(cx).quirks(); let adapter = curr_session.read(cx).adapter().clone(); let binary = curr_session.read(cx).binary().cloned().unwrap(); let task_context = curr_session.read(cx).task_context().clone(); let curr_session_id = curr_session.read(cx).session_id(); - self.sessions - .retain(|session| session.read(cx).session_id(cx) != curr_session_id); + self.sessions_with_children + .retain(|session, _| session.read(cx).session_id(cx) != curr_session_id); let task = dap_store_handle.update(cx, |dap_store, cx| { dap_store.shutdown_session(curr_session_id, cx) }); @@ -379,7 +404,7 @@ impl DebugPanel { task.await.log_err(); let (session, task) = dap_store_handle.update(cx, |dap_store, cx| { - let session = dap_store.new_session(label, adapter, task_context, None, cx); + let session = dap_store.new_session(label, adapter, task_context, None, quirks, cx); let task = session.update(cx, |session, cx| { session.boot(binary, worktree, dap_store_handle.downgrade(), cx) @@ -425,6 +450,7 @@ impl DebugPanel { let dap_store_handle = self.project.read(cx).dap_store().clone(); let label = self.label_for_child_session(&parent_session, request, cx); let adapter = parent_session.read(cx).adapter().clone(); + let quirks = parent_session.read(cx).quirks(); let Some(mut binary) = parent_session.read(cx).binary().cloned() else { log::error!("Attempted to start a child-session without a binary"); return; @@ -438,6 +464,7 @@ impl DebugPanel { adapter, task_context, Some(parent_session.clone()), + quirks, cx, ); @@ -463,8 +490,8 @@ impl DebugPanel { cx: &mut Context, ) { let Some(session) = self - .sessions - .iter() + .sessions_with_children + .keys() .find(|other| entity_id == other.entity_id()) .cloned() else { @@ -498,15 +525,14 @@ impl DebugPanel { } session.update(cx, |session, cx| session.shutdown(cx)).ok(); this.update(cx, |this, cx| { - this.sessions.retain(|other| entity_id != other.entity_id()); - + this.retain_sessions(|other| entity_id != other.entity_id()); if let Some(active_session_id) = this .active_session .as_ref() .map(|session| session.entity_id()) { if active_session_id == entity_id { - this.active_session = this.sessions.first().cloned(); + this.active_session = this.sessions_with_children.keys().next().cloned(); } } cx.notify() @@ -813,13 +839,24 @@ impl DebugPanel { .on_click(window.listener_for( &running_state, |this, _, _window, cx| { - this.stop_thread(cx); + if this.session().read(cx).is_building() { + this.session().update(cx, |session, cx| { + session.shutdown(cx).detach() + }); + } else { + this.stop_thread(cx); + } + }, + )) + .disabled(active_session.as_ref().is_none_or( + |session| { + session + .read(cx) + .session(cx) + .read(cx) + .is_terminated() }, )) - .disabled( - thread_status != ThreadStatus::Stopped - && thread_status != ThreadStatus::Running, - ) .tooltip({ let focus_handle = focus_handle.clone(); let label = if capabilities @@ -976,8 +1013,8 @@ impl DebugPanel { cx: &mut Context, ) { if let Some(session) = self - .sessions - .iter() + .sessions_with_children + .keys() .find(|session| session.read(cx).session_id(cx) == session_id) { self.activate_session(session.clone(), window, cx); @@ -990,7 +1027,7 @@ impl DebugPanel { window: &mut Window, cx: &mut Context, ) { - debug_assert!(self.sessions.contains(&session_item)); + debug_assert!(self.sessions_with_children.contains_key(&session_item)); session_item.focus_handle(cx).focus(window); session_item.update(cx, |this, cx| { this.running_state().update(cx, |this, cx| { @@ -1261,18 +1298,27 @@ impl DebugPanel { parent_session: &Entity, request: &StartDebuggingRequestArguments, cx: &mut Context<'_, Self>, - ) -> SharedString { + ) -> Option { let adapter = parent_session.read(cx).adapter(); if let Some(adapter) = DapRegistry::global(cx).adapter(&adapter) { if let Some(label) = adapter.label_for_child_session(request) { - return label.into(); + return Some(label.into()); } } - let mut label = parent_session.read(cx).label().clone(); - if !label.ends_with("(child)") { - label = format!("{label} (child)").into(); + None + } + + fn retain_sessions(&mut self, keep: impl Fn(&Entity) -> bool) { + self.sessions_with_children + .retain(|session, _| keep(session)); + for children in self.sessions_with_children.values_mut() { + children.retain(|child| { + let Some(child) = child.upgrade() else { + return false; + }; + keep(&child) + }); } - label } } @@ -1302,11 +1348,11 @@ async fn register_session_inner( let serialized_layout = persistence::get_serialized_layout(adapter_name).await; let debug_session = this.update_in(cx, |this, window, cx| { let parent_session = this - .sessions - .iter() + .sessions_with_children + .keys() .find(|p| Some(p.read(cx).session_id(cx)) == session.read(cx).parent_id(cx)) .cloned(); - this.sessions.retain(|session| { + this.retain_sessions(|session| { !session .read(cx) .running_state() @@ -1337,13 +1383,23 @@ async fn register_session_inner( ) .detach(); let insert_position = this - .sessions - .iter() + .sessions_with_children + .keys() .position(|session| Some(session) == parent_session.as_ref()) .map(|position| position + 1) - .unwrap_or(this.sessions.len()); + .unwrap_or(this.sessions_with_children.len()); // Maintain topological sort order of sessions - this.sessions.insert(insert_position, debug_session.clone()); + let (_, old) = this.sessions_with_children.insert_before( + insert_position, + debug_session.clone(), + Default::default(), + ); + debug_assert!(old.is_none()); + if let Some(parent_session) = parent_session { + this.sessions_with_children + .entry(parent_session) + .and_modify(|children| children.push(debug_session.downgrade())); + } debug_session })?; @@ -1383,7 +1439,7 @@ impl Panel for DebugPanel { cx: &mut Context, ) { if position.axis() != self.position(window, cx).axis() { - self.sessions.iter().for_each(|session_item| { + self.sessions_with_children.keys().for_each(|session_item| { session_item.update(cx, |item, cx| { item.running_state() .update(cx, |state, _| state.invert_axies()) @@ -1704,6 +1760,7 @@ impl Render for DebugPanel { category_filter: Some( zed_actions::ExtensionCategoryFilter::DebugAdapters, ), + id: None, } .boxed_clone(), cx, @@ -1749,6 +1806,7 @@ impl Render for DebugPanel { .child(breakpoint_list) .child(Divider::vertical()) .child(welcome_experience) + .child(Divider::vertical()) } else { this.items_end() .child(welcome_experience) diff --git a/crates/debugger_ui/src/debugger_ui.rs b/crates/debugger_ui/src/debugger_ui.rs index 2056232e9bd6912bbd1b4b7da7b51b769a47e63a..5f5dfd1a1e6a543cdb7a4d87e1b8e9984c4ecba9 100644 --- a/crates/debugger_ui/src/debugger_ui.rs +++ b/crates/debugger_ui/src/debugger_ui.rs @@ -3,10 +3,12 @@ use std::any::TypeId; use dap::debugger_settings::DebuggerSettings; use debugger_panel::DebugPanel; use editor::Editor; -use gpui::{App, DispatchPhase, EntityInputHandler, actions}; +use gpui::{Action, App, DispatchPhase, EntityInputHandler, actions}; use new_process_modal::{NewProcessModal, NewProcessMode}; use onboarding_modal::DebuggerOnboardingModal; use project::debugger::{self, breakpoint_store::SourceBreakpoint, session::ThreadStatus}; +use schemars::JsonSchema; +use serde::Deserialize; use session::DebugSession; use settings::Settings; use stack_trace_view::StackTraceView; @@ -86,6 +88,20 @@ actions!( ] ); +/// Extends selection down by a specified number of lines. +#[derive(PartialEq, Clone, Deserialize, Default, JsonSchema, Action)] +#[action(namespace = debugger)] +#[serde(deny_unknown_fields)] +/// Set a data breakpoint on the selected variable or memory region. +pub struct ToggleDataBreakpoint { + /// The type of data breakpoint + /// Read & Write + /// Read + /// Write + #[serde(default)] + pub access_type: Option, +} + actions!( dev, [ @@ -283,59 +299,76 @@ pub fn init(cx: &mut App) { else { return; }; + + let session = active_session + .read(cx) + .running_state + .read(cx) + .session() + .read(cx); + + if session.is_terminated() { + return; + } + let editor = cx.entity().downgrade(); - window.on_action(TypeId::of::(), { - let editor = editor.clone(); - let active_session = active_session.clone(); - move |_, phase, _, cx| { - if phase != DispatchPhase::Bubble { - return; - } - maybe!({ - let (buffer, position, _) = editor - .update(cx, |editor, cx| { - let cursor_point: language::Point = - editor.selections.newest(cx).head(); - editor - .buffer() - .read(cx) - .point_to_buffer_point(cursor_point, cx) - }) - .ok()??; + window.on_action_when( + session.any_stopped_thread(), + TypeId::of::(), + { + let editor = editor.clone(); + let active_session = active_session.clone(); + move |_, phase, _, cx| { + if phase != DispatchPhase::Bubble { + return; + } + maybe!({ + let (buffer, position, _) = editor + .update(cx, |editor, cx| { + let cursor_point: language::Point = + editor.selections.newest(cx).head(); - let path = + editor + .buffer() + .read(cx) + .point_to_buffer_point(cursor_point, cx) + }) + .ok()??; + + let path = debugger::breakpoint_store::BreakpointStore::abs_path_from_buffer( &buffer, cx, )?; - let source_breakpoint = SourceBreakpoint { - row: position.row, - path, - message: None, - condition: None, - hit_condition: None, - state: debugger::breakpoint_store::BreakpointState::Enabled, - }; + let source_breakpoint = SourceBreakpoint { + row: position.row, + path, + message: None, + condition: None, + hit_condition: None, + state: debugger::breakpoint_store::BreakpointState::Enabled, + }; - active_session.update(cx, |session, cx| { - session.running_state().update(cx, |state, cx| { - if let Some(thread_id) = state.selected_thread_id() { - state.session().update(cx, |session, cx| { - session.run_to_position( - source_breakpoint, - thread_id, - cx, - ); - }) - } + active_session.update(cx, |session, cx| { + session.running_state().update(cx, |state, cx| { + if let Some(thread_id) = state.selected_thread_id() { + state.session().update(cx, |session, cx| { + session.run_to_position( + source_breakpoint, + thread_id, + cx, + ); + }) + } + }); }); - }); - Some(()) - }); - } - }); + Some(()) + }); + } + }, + ); window.on_action( TypeId::of::(), diff --git a/crates/debugger_ui/src/dropdown_menus.rs b/crates/debugger_ui/src/dropdown_menus.rs index f93aceae094db9a75b9550021c97bb9723ad6811..dca15eb0527cfc78bd137889a1910e6b32abf98c 100644 --- a/crates/debugger_ui/src/dropdown_menus.rs +++ b/crates/debugger_ui/src/dropdown_menus.rs @@ -1,16 +1,82 @@ -use std::time::Duration; +use std::{rc::Rc, time::Duration}; use collections::HashMap; -use gpui::{Animation, AnimationExt as _, Entity, Transformation, percentage}; +use gpui::{Animation, AnimationExt as _, Entity, Transformation, WeakEntity, percentage}; use project::debugger::session::{ThreadId, ThreadStatus}; use ui::{ContextMenu, DropdownMenu, DropdownStyle, Indicator, prelude::*}; -use util::truncate_and_trailoff; +use util::{maybe, truncate_and_trailoff}; use crate::{ debugger_panel::DebugPanel, session::{DebugSession, running::RunningState}, }; +struct SessionListEntry { + ancestors: Vec>, + leaf: Entity, +} + +impl SessionListEntry { + pub(crate) fn label_element(&self, depth: usize, cx: &mut App) -> AnyElement { + const MAX_LABEL_CHARS: usize = 150; + + let mut label = String::new(); + for ancestor in &self.ancestors { + label.push_str(&ancestor.update(cx, |ancestor, cx| { + ancestor.label(cx).unwrap_or("(child)".into()) + })); + label.push_str(" » "); + } + label.push_str( + &self + .leaf + .update(cx, |leaf, cx| leaf.label(cx).unwrap_or("(child)".into())), + ); + let label = truncate_and_trailoff(&label, MAX_LABEL_CHARS); + + let is_terminated = self + .leaf + .read(cx) + .running_state + .read(cx) + .session() + .read(cx) + .is_terminated(); + let icon = { + if is_terminated { + Some(Indicator::dot().color(Color::Error)) + } else { + match self + .leaf + .read(cx) + .running_state + .read(cx) + .thread_status(cx) + .unwrap_or_default() + { + project::debugger::session::ThreadStatus::Stopped => { + Some(Indicator::dot().color(Color::Conflict)) + } + _ => Some(Indicator::dot().color(Color::Success)), + } + } + }; + + h_flex() + .id("session-label") + .ml(depth * px(16.0)) + .gap_2() + .when_some(icon, |this, indicator| this.child(indicator)) + .justify_between() + .child( + Label::new(label) + .size(LabelSize::Small) + .when(is_terminated, |this| this.strikethrough()), + ) + .into_any_element() + } +} + impl DebugPanel { fn dropdown_label(label: impl Into) -> Label { const MAX_LABEL_CHARS: usize = 50; @@ -25,145 +91,205 @@ impl DebugPanel { window: &mut Window, cx: &mut Context, ) -> Option { - if let Some(running_state) = running_state { - let sessions = self.sessions().clone(); - let weak = cx.weak_entity(); - let running_state = running_state.read(cx); - let label = if let Some(active_session) = active_session.clone() { - active_session.read(cx).session(cx).read(cx).label() - } else { - SharedString::new_static("Unknown Session") - }; + let running_state = running_state?; + + let mut session_entries = Vec::with_capacity(self.sessions_with_children.len() * 3); + let mut sessions_with_children = self.sessions_with_children.iter().peekable(); - let is_terminated = running_state.session().read(cx).is_terminated(); - let is_started = active_session - .is_some_and(|session| session.read(cx).session(cx).read(cx).is_started()); - - let session_state_indicator = if is_terminated { - Indicator::dot().color(Color::Error).into_any_element() - } else if !is_started { - Icon::new(IconName::ArrowCircle) - .size(IconSize::Small) - .color(Color::Muted) - .with_animation( - "arrow-circle", - Animation::new(Duration::from_secs(2)).repeat(), - |icon, delta| icon.transform(Transformation::rotate(percentage(delta))), - ) - .into_any_element() + while let Some((root, children)) = sessions_with_children.next() { + let root_entry = if let Ok([single_child]) = <&[_; 1]>::try_from(children.as_slice()) + && let Some(single_child) = single_child.upgrade() + && single_child.read(cx).quirks.compact + { + sessions_with_children.next(); + SessionListEntry { + leaf: single_child.clone(), + ancestors: vec![root.clone()], + } } else { - match running_state.thread_status(cx).unwrap_or_default() { - ThreadStatus::Stopped => { - Indicator::dot().color(Color::Conflict).into_any_element() - } - _ => Indicator::dot().color(Color::Success).into_any_element(), + SessionListEntry { + leaf: root.clone(), + ancestors: Vec::new(), } }; + session_entries.push(root_entry); + + session_entries.extend( + sessions_with_children + .by_ref() + .take_while(|(session, _)| { + session + .read(cx) + .session(cx) + .read(cx) + .parent_id(cx) + .is_some() + }) + .map(|(session, _)| SessionListEntry { + leaf: session.clone(), + ancestors: vec![], + }), + ); + } - let trigger = h_flex() - .gap_2() - .child(session_state_indicator) - .justify_between() - .child( - DebugPanel::dropdown_label(label) - .when(is_terminated, |this| this.strikethrough()), + let weak = cx.weak_entity(); + let trigger_label = if let Some(active_session) = active_session.clone() { + active_session.update(cx, |active_session, cx| { + active_session.label(cx).unwrap_or("(child)".into()) + }) + } else { + SharedString::new_static("Unknown Session") + }; + let running_state = running_state.read(cx); + + let is_terminated = running_state.session().read(cx).is_terminated(); + let is_started = active_session + .is_some_and(|session| session.read(cx).session(cx).read(cx).is_started()); + + let session_state_indicator = if is_terminated { + Indicator::dot().color(Color::Error).into_any_element() + } else if !is_started { + Icon::new(IconName::ArrowCircle) + .size(IconSize::Small) + .color(Color::Muted) + .with_animation( + "arrow-circle", + Animation::new(Duration::from_secs(2)).repeat(), + |icon, delta| icon.transform(Transformation::rotate(percentage(delta))), ) - .into_any_element(); + .into_any_element() + } else { + match running_state.thread_status(cx).unwrap_or_default() { + ThreadStatus::Stopped => Indicator::dot().color(Color::Conflict).into_any_element(), + _ => Indicator::dot().color(Color::Success).into_any_element(), + } + }; - Some( - DropdownMenu::new_with_element( - "debugger-session-list", - trigger, - ContextMenu::build(window, cx, move |mut this, _, cx| { - let context_menu = cx.weak_entity(); - let mut session_depths = HashMap::default(); - for session in sessions.into_iter() { - let weak_session = session.downgrade(); - let weak_session_id = weak_session.entity_id(); - let session_id = session.read(cx).session_id(cx); - let parent_depth = session - .read(cx) - .session(cx) - .read(cx) - .parent_id(cx) - .and_then(|parent_id| session_depths.get(&parent_id).cloned()); - let self_depth = - *session_depths.entry(session_id).or_insert_with(|| { - parent_depth.map(|depth| depth + 1).unwrap_or(0usize) - }); - this = this.custom_entry( - { - let weak = weak.clone(); - let context_menu = context_menu.clone(); - move |_, cx| { - weak_session - .read_with(cx, |session, cx| { - let context_menu = context_menu.clone(); - - let id: SharedString = - format!("debug-session-{}", session_id.0) - .into(); - - h_flex() - .w_full() - .group(id.clone()) - .justify_between() - .child(session.label_element(self_depth, cx)) - .child( - IconButton::new( - "close-debug-session", - IconName::Close, - ) - .visible_on_hover(id.clone()) - .icon_size(IconSize::Small) - .on_click({ - let weak = weak.clone(); - move |_, window, cx| { - weak.update(cx, |panel, cx| { - panel.close_session( - weak_session_id, - window, - cx, - ); - }) - .ok(); - context_menu - .update(cx, |this, cx| { - this.cancel( - &Default::default(), - window, - cx, - ); - }) - .ok(); - } - }), - ) - .into_any_element() - }) - .unwrap_or_else(|_| div().into_any_element()) - } - }, - { - let weak = weak.clone(); - move |window, cx| { - weak.update(cx, |panel, cx| { - panel.activate_session(session.clone(), window, cx); - }) - .ok(); - } - }, - ); + let trigger = h_flex() + .gap_2() + .child(session_state_indicator) + .justify_between() + .child( + DebugPanel::dropdown_label(trigger_label) + .when(is_terminated, |this| this.strikethrough()), + ) + .into_any_element(); + + let menu = DropdownMenu::new_with_element( + "debugger-session-list", + trigger, + ContextMenu::build(window, cx, move |mut this, _, cx| { + let context_menu = cx.weak_entity(); + let mut session_depths = HashMap::default(); + for session_entry in session_entries { + let session_id = session_entry.leaf.read(cx).session_id(cx); + let parent_depth = session_entry + .ancestors + .first() + .unwrap_or(&session_entry.leaf) + .read(cx) + .session(cx) + .read(cx) + .parent_id(cx) + .and_then(|parent_id| session_depths.get(&parent_id).cloned()); + let self_depth = *session_depths + .entry(session_id) + .or_insert_with(|| parent_depth.map(|depth| depth + 1).unwrap_or(0usize)); + this = this.custom_entry( + { + let weak = weak.clone(); + let context_menu = context_menu.clone(); + let ancestors: Rc<[_]> = session_entry + .ancestors + .iter() + .map(|session| session.downgrade()) + .collect(); + let leaf = session_entry.leaf.downgrade(); + move |window, cx| { + Self::render_session_menu_entry( + weak.clone(), + context_menu.clone(), + ancestors.clone(), + leaf.clone(), + self_depth, + window, + cx, + ) + } + }, + { + let weak = weak.clone(); + let leaf = session_entry.leaf.clone(); + move |window, cx| { + weak.update(cx, |panel, cx| { + panel.activate_session(leaf.clone(), window, cx); + }) + .ok(); + } + }, + ); + } + this + }), + ) + .style(DropdownStyle::Ghost) + .handle(self.session_picker_menu_handle.clone()); + + Some(menu) + } + + fn render_session_menu_entry( + weak: WeakEntity, + context_menu: WeakEntity, + ancestors: Rc<[WeakEntity]>, + leaf: WeakEntity, + self_depth: usize, + _window: &mut Window, + cx: &mut App, + ) -> AnyElement { + let Some(session_entry) = maybe!({ + let ancestors = ancestors + .iter() + .map(|ancestor| ancestor.upgrade()) + .collect::>>()?; + let leaf = leaf.upgrade()?; + Some(SessionListEntry { ancestors, leaf }) + }) else { + return div().into_any_element(); + }; + + let id: SharedString = format!( + "debug-session-{}", + session_entry.leaf.read(cx).session_id(cx).0 + ) + .into(); + let session_entity_id = session_entry.leaf.entity_id(); + + h_flex() + .w_full() + .group(id.clone()) + .justify_between() + .child(session_entry.label_element(self_depth, cx)) + .child( + IconButton::new("close-debug-session", IconName::Close) + .visible_on_hover(id.clone()) + .icon_size(IconSize::Small) + .on_click({ + let weak = weak.clone(); + move |_, window, cx| { + weak.update(cx, |panel, cx| { + panel.close_session(session_entity_id, window, cx); + }) + .ok(); + context_menu + .update(cx, |this, cx| { + this.cancel(&Default::default(), window, cx); + }) + .ok(); } - this }), - ) - .style(DropdownStyle::Ghost) - .handle(self.session_picker_menu_handle.clone()), ) - } else { - None - } + .into_any_element() } pub(crate) fn render_thread_dropdown( diff --git a/crates/debugger_ui/src/new_process_modal.rs b/crates/debugger_ui/src/new_process_modal.rs index 6d7fa244a2e2bfaaaa82f1321d446627e2b0c343..4ac8e371a15052a00ed962480a9f694a8802007c 100644 --- a/crates/debugger_ui/src/new_process_modal.rs +++ b/crates/debugger_ui/src/new_process_modal.rs @@ -1,5 +1,5 @@ use anyhow::{Context as _, bail}; -use collections::{FxHashMap, HashMap}; +use collections::{FxHashMap, HashMap, HashSet}; use language::LanguageRegistry; use std::{ borrow::Cow, @@ -450,7 +450,7 @@ impl NewProcessModal { .and_then(|buffer| buffer.read(cx).language()) .cloned(); - let mut available_adapters = workspace + let mut available_adapters: Vec<_> = workspace .update(cx, |_, cx| DapRegistry::global(cx).enumerate_adapters()) .unwrap_or_default(); if let Some(language) = active_buffer_language { @@ -766,14 +766,7 @@ impl Render for NewProcessModal { )) .child( h_flex() - .child(div().child(self.adapter_drop_down_menu(window, cx))) - .child( - Button::new("debugger-spawn", "Start") - .on_click(cx.listener(|this, _, window, cx| { - this.start_new_session(window, cx) - })) - .disabled(disabled), - ), + .child(div().child(self.adapter_drop_down_menu(window, cx))), ) }), NewProcessMode::Debug => el, @@ -1022,15 +1015,13 @@ impl DebugDelegate { let language_names = languages.language_names(); let language = dap_registry .adapter_language(&scenario.adapter) - .map(|language| TaskSourceKind::Language { - name: language.into(), - }); + .map(|language| TaskSourceKind::Language { name: language.0 }); let language = language.or_else(|| { scenario.label.split_whitespace().find_map(|word| { language_names .iter() - .find(|name| name.eq_ignore_ascii_case(word)) + .find(|name| name.as_ref().eq_ignore_ascii_case(word)) .map(|name| TaskSourceKind::Language { name: name.to_owned().into(), }) @@ -1063,6 +1054,9 @@ impl DebugDelegate { }) }) }); + + let valid_adapters: HashSet<_> = cx.global::().enumerate_adapters(); + cx.spawn(async move |this, cx| { let (recent, scenarios) = if let Some(task) = task { task.await @@ -1103,6 +1097,7 @@ impl DebugDelegate { } => !(hide_vscode && dir.ends_with(".vscode")), _ => true, }) + .filter(|(_, scenario)| valid_adapters.contains(&scenario.adapter)) .map(|(kind, scenario)| { let (language, scenario) = Self::get_scenario_kind(&languages, &dap_registry, scenario); diff --git a/crates/debugger_ui/src/persistence.rs b/crates/debugger_ui/src/persistence.rs index d15244c3496b5cb42bf1b4151e075f624862aec0..3a0ad7a40e60d4dc28f2086b94a0a43186978542 100644 --- a/crates/debugger_ui/src/persistence.rs +++ b/crates/debugger_ui/src/persistence.rs @@ -11,7 +11,7 @@ use workspace::{Member, Pane, PaneAxis, Workspace}; use crate::session::running::{ self, DebugTerminal, RunningState, SubView, breakpoint_list::BreakpointList, console::Console, - loaded_source_list::LoadedSourceList, module_list::ModuleList, + loaded_source_list::LoadedSourceList, memory_view::MemoryView, module_list::ModuleList, stack_frame_list::StackFrameList, variable_list::VariableList, }; @@ -24,6 +24,7 @@ pub(crate) enum DebuggerPaneItem { Modules, LoadedSources, Terminal, + MemoryView, } impl DebuggerPaneItem { @@ -36,6 +37,7 @@ impl DebuggerPaneItem { DebuggerPaneItem::Modules, DebuggerPaneItem::LoadedSources, DebuggerPaneItem::Terminal, + DebuggerPaneItem::MemoryView, ]; VARIANTS } @@ -43,6 +45,9 @@ impl DebuggerPaneItem { pub(crate) fn is_supported(&self, capabilities: &Capabilities) -> bool { match self { DebuggerPaneItem::Modules => capabilities.supports_modules_request.unwrap_or_default(), + DebuggerPaneItem::MemoryView => capabilities + .supports_read_memory_request + .unwrap_or_default(), DebuggerPaneItem::LoadedSources => capabilities .supports_loaded_sources_request .unwrap_or_default(), @@ -59,6 +64,7 @@ impl DebuggerPaneItem { DebuggerPaneItem::Modules => SharedString::new_static("Modules"), DebuggerPaneItem::LoadedSources => SharedString::new_static("Sources"), DebuggerPaneItem::Terminal => SharedString::new_static("Terminal"), + DebuggerPaneItem::MemoryView => SharedString::new_static("Memory View"), } } pub(crate) fn tab_tooltip(self) -> SharedString { @@ -80,6 +86,7 @@ impl DebuggerPaneItem { DebuggerPaneItem::Terminal => { "Provides an interactive terminal session within the debugging environment." } + DebuggerPaneItem::MemoryView => "Allows inspection of memory contents.", }; SharedString::new_static(tooltip) } @@ -204,6 +211,7 @@ pub(crate) fn deserialize_pane_layout( breakpoint_list: &Entity, loaded_sources: &Entity, terminal: &Entity, + memory_view: &Entity, subscriptions: &mut HashMap, window: &mut Window, cx: &mut Context, @@ -228,6 +236,7 @@ pub(crate) fn deserialize_pane_layout( breakpoint_list, loaded_sources, terminal, + memory_view, subscriptions, window, cx, @@ -298,6 +307,12 @@ pub(crate) fn deserialize_pane_layout( DebuggerPaneItem::Terminal, cx, )), + DebuggerPaneItem::MemoryView => Box::new(SubView::new( + memory_view.focus_handle(cx), + memory_view.clone().into(), + DebuggerPaneItem::MemoryView, + cx, + )), }) .collect(); diff --git a/crates/debugger_ui/src/session.rs b/crates/debugger_ui/src/session.rs index 482297b13671a969c166e154b43a6c854f231e5c..73cfef78cc6410196441ff974f09b5abe3d86916 100644 --- a/crates/debugger_ui/src/session.rs +++ b/crates/debugger_ui/src/session.rs @@ -5,14 +5,13 @@ use dap::client::SessionId; use gpui::{ App, Axis, Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity, }; -use project::Project; use project::debugger::session::Session; use project::worktree_store::WorktreeStore; +use project::{Project, debugger::session::SessionQuirks}; use rpc::proto; use running::RunningState; -use std::{cell::OnceCell, sync::OnceLock}; -use ui::{Indicator, Tooltip, prelude::*}; -use util::truncate_and_trailoff; +use std::cell::OnceCell; +use ui::prelude::*; use workspace::{ CollaboratorId, FollowableItem, ViewId, Workspace, item::{self, Item}, @@ -20,8 +19,8 @@ use workspace::{ pub struct DebugSession { remote_id: Option, - running_state: Entity, - label: OnceLock, + pub(crate) running_state: Entity, + pub(crate) quirks: SessionQuirks, stack_trace_view: OnceCell>, _worktree_store: WeakEntity, workspace: WeakEntity, @@ -57,6 +56,7 @@ impl DebugSession { cx, ) }); + let quirks = session.read(cx).quirks(); cx.new(|cx| Self { _subscriptions: [cx.subscribe(&running_state, |_, _, _, cx| { @@ -64,7 +64,7 @@ impl DebugSession { })], remote_id: None, running_state, - label: OnceLock::new(), + quirks, stack_trace_view: OnceCell::new(), _worktree_store: project.read(cx).worktree_store().downgrade(), workspace, @@ -110,65 +110,28 @@ impl DebugSession { .update(cx, |state, cx| state.shutdown(cx)); } - pub(crate) fn label(&self, cx: &App) -> SharedString { - if let Some(label) = self.label.get() { - return label.clone(); - } - - let session = self.running_state.read(cx).session(); - - self.label - .get_or_init(|| session.read(cx).label()) - .to_owned() - } - - pub(crate) fn running_state(&self) -> &Entity { - &self.running_state - } - - pub(crate) fn label_element(&self, depth: usize, cx: &App) -> AnyElement { - const MAX_LABEL_CHARS: usize = 150; - - let label = self.label(cx); - let label = truncate_and_trailoff(&label, MAX_LABEL_CHARS); - - let is_terminated = self - .running_state - .read(cx) - .session() - .read(cx) - .is_terminated(); - let icon = { - if is_terminated { - Some(Indicator::dot().color(Color::Error)) - } else { - match self - .running_state - .read(cx) - .thread_status(cx) - .unwrap_or_default() - { - project::debugger::session::ThreadStatus::Stopped => { - Some(Indicator::dot().color(Color::Conflict)) - } - _ => Some(Indicator::dot().color(Color::Success)), + pub(crate) fn label(&self, cx: &mut App) -> Option { + let session = self.running_state.read(cx).session().clone(); + session.update(cx, |session, cx| { + let session_label = session.label(); + let quirks = session.quirks(); + let mut single_thread_name = || { + let threads = session.threads(cx); + match threads.as_slice() { + [(thread, _)] => Some(SharedString::from(&thread.name)), + _ => None, } + }; + if quirks.prefer_thread_name { + single_thread_name().or(session_label) + } else { + session_label.or_else(single_thread_name) } - }; + }) + } - h_flex() - .id("session-label") - .tooltip(Tooltip::text(format!("Session {}", self.session_id(cx).0,))) - .ml(depth * px(16.0)) - .gap_2() - .when_some(icon, |this, indicator| this.child(indicator)) - .justify_between() - .child( - Label::new(label) - .size(LabelSize::Small) - .when(is_terminated, |this| this.strikethrough()), - ) - .into_any_element() + pub fn running_state(&self) -> &Entity { + &self.running_state } } diff --git a/crates/debugger_ui/src/session/running.rs b/crates/debugger_ui/src/session/running.rs index af8c14aef77d0886071dfd899d8de5adff0d3ed6..f2f9e17d8981d8c8b73ccd4e3caad51f5278eb7d 100644 --- a/crates/debugger_ui/src/session/running.rs +++ b/crates/debugger_ui/src/session/running.rs @@ -1,16 +1,17 @@ pub(crate) mod breakpoint_list; pub(crate) mod console; pub(crate) mod loaded_source_list; +pub(crate) mod memory_view; pub(crate) mod module_list; pub mod stack_frame_list; pub mod variable_list; - use std::{any::Any, ops::ControlFlow, path::PathBuf, sync::Arc, time::Duration}; use crate::{ ToggleExpandItem, new_process_modal::resolve_path, persistence::{self, DebuggerPaneItem, SerializedLayout}, + session::running::memory_view::MemoryView, }; use super::DebugPanelItemEvent; @@ -34,7 +35,7 @@ use loaded_source_list::LoadedSourceList; use module_list::ModuleList; use project::{ DebugScenarioContext, Project, WorktreeId, - debugger::session::{Session, SessionEvent, ThreadId, ThreadStatus}, + debugger::session::{self, Session, SessionEvent, SessionStateEvent, ThreadId, ThreadStatus}, terminals::TerminalKind, }; use rpc::proto::ViewId; @@ -81,6 +82,7 @@ pub struct RunningState { _schedule_serialize: Option>, pub(crate) scenario: Option, pub(crate) scenario_context: Option, + memory_view: Entity, } impl RunningState { @@ -676,14 +678,36 @@ impl RunningState { let session_id = session.read(cx).session_id(); let weak_state = cx.weak_entity(); let stack_frame_list = cx.new(|cx| { - StackFrameList::new(workspace.clone(), session.clone(), weak_state, window, cx) + StackFrameList::new( + workspace.clone(), + session.clone(), + weak_state.clone(), + window, + cx, + ) }); let debug_terminal = parent_terminal.unwrap_or_else(|| cx.new(|cx| DebugTerminal::empty(window, cx))); - - let variable_list = - cx.new(|cx| VariableList::new(session.clone(), stack_frame_list.clone(), window, cx)); + let memory_view = cx.new(|cx| { + MemoryView::new( + session.clone(), + workspace.clone(), + stack_frame_list.downgrade(), + window, + cx, + ) + }); + let variable_list = cx.new(|cx| { + VariableList::new( + session.clone(), + stack_frame_list.clone(), + memory_view.clone(), + weak_state.clone(), + window, + cx, + ) + }); let module_list = cx.new(|cx| ModuleList::new(session.clone(), workspace.clone(), cx)); @@ -770,6 +794,15 @@ impl RunningState { cx.on_focus_out(&focus_handle, window, |this, _, window, cx| { this.serialize_layout(window, cx); }), + cx.subscribe( + &session, + |this, session, event: &SessionStateEvent, cx| match event { + SessionStateEvent::Shutdown if session.read(cx).is_building() => { + this.shutdown(cx); + } + _ => {} + }, + ), ]; let mut pane_close_subscriptions = HashMap::default(); @@ -786,6 +819,7 @@ impl RunningState { &breakpoint_list, &loaded_source_list, &debug_terminal, + &memory_view, &mut pane_close_subscriptions, window, cx, @@ -814,6 +848,7 @@ impl RunningState { let active_pane = panes.first_pane(); Self { + memory_view, session, workspace, focus_handle, @@ -884,6 +919,7 @@ impl RunningState { let weak_project = project.downgrade(); let weak_workspace = workspace.downgrade(); let is_local = project.read(cx).is_local(); + cx.spawn_in(window, async move |this, cx| { let DebugScenario { adapter, @@ -1224,6 +1260,12 @@ impl RunningState { item_kind, cx, )), + DebuggerPaneItem::MemoryView => Box::new(SubView::new( + self.memory_view.focus_handle(cx), + self.memory_view.clone().into(), + item_kind, + cx, + )), } } @@ -1408,7 +1450,14 @@ impl RunningState { &self.module_list } - pub(crate) fn activate_item(&self, item: DebuggerPaneItem, window: &mut Window, cx: &mut App) { + pub(crate) fn activate_item( + &mut self, + item: DebuggerPaneItem, + window: &mut Window, + cx: &mut Context, + ) { + self.ensure_pane_item(item, window, cx); + let (variable_list_position, pane) = self .panes .panes() @@ -1420,9 +1469,10 @@ impl RunningState { .map(|view| (view, pane)) }) .unwrap(); + pane.update(cx, |this, cx| { this.activate_item(variable_list_position, true, true, window, cx); - }) + }); } #[cfg(test)] @@ -1459,7 +1509,7 @@ impl RunningState { } } - pub(crate) fn selected_thread_id(&self) -> Option { + pub fn selected_thread_id(&self) -> Option { self.thread_id } @@ -1599,9 +1649,21 @@ impl RunningState { }) .log_err(); - self.session.update(cx, |session, cx| { + let is_building = self.session.update(cx, |session, cx| { session.shutdown(cx).detach(); - }) + matches!(session.mode, session::SessionState::Booting(_)) + }); + + if is_building { + self.debug_terminal.update(cx, |terminal, cx| { + if let Some(view) = terminal.terminal.as_ref() { + view.update(cx, |view, cx| { + view.terminal() + .update(cx, |terminal, _| terminal.kill_active_task()) + }) + } + }) + } } pub fn stop_thread(&self, cx: &mut Context) { diff --git a/crates/debugger_ui/src/session/running/breakpoint_list.rs b/crates/debugger_ui/src/session/running/breakpoint_list.rs index 78c87db2e6f2a1f9d54368b875d1e86b3ac5789f..a6defbbf35cc103025c286b9875285b924032ed0 100644 --- a/crates/debugger_ui/src/session/running/breakpoint_list.rs +++ b/crates/debugger_ui/src/session/running/breakpoint_list.rs @@ -24,12 +24,11 @@ use project::{ }; use ui::{ ActiveTheme, AnyElement, App, ButtonCommon, Clickable, Color, Context, Disableable, Div, - Divider, FluentBuilder as _, Icon, IconButton, IconName, IconSize, Indicator, - InteractiveElement, IntoElement, Label, LabelCommon, LabelSize, ListItem, ParentElement, - Render, RenderOnce, Scrollbar, ScrollbarState, SharedString, StatefulInteractiveElement, - Styled, Toggleable, Tooltip, Window, div, h_flex, px, v_flex, + Divider, FluentBuilder as _, Icon, IconButton, IconName, IconSize, InteractiveElement, + IntoElement, Label, LabelCommon, LabelSize, ListItem, ParentElement, Render, RenderOnce, + Scrollbar, ScrollbarState, SharedString, StatefulInteractiveElement, Styled, Toggleable, + Tooltip, Window, div, h_flex, px, v_flex, }; -use util::ResultExt; use workspace::Workspace; use zed_actions::{ToggleEnableBreakpoint, UnsetBreakpoint}; @@ -46,6 +45,7 @@ actions!( pub(crate) enum SelectedBreakpointKind { Source, Exception, + Data, } pub(crate) struct BreakpointList { workspace: WeakEntity, @@ -55,8 +55,6 @@ pub(crate) struct BreakpointList { scrollbar_state: ScrollbarState, breakpoints: Vec, session: Option>, - hide_scrollbar_task: Option>, - show_scrollbar: bool, focus_handle: FocusHandle, scroll_handle: UniformListScrollHandle, selected_ix: Option, @@ -102,8 +100,6 @@ impl BreakpointList { worktree_store, scrollbar_state, breakpoints: Default::default(), - hide_scrollbar_task: None, - show_scrollbar: false, workspace, session, focus_handle, @@ -188,6 +184,9 @@ impl BreakpointList { BreakpointEntryKind::ExceptionBreakpoint(bp) => { (SelectedBreakpointKind::Exception, bp.is_enabled) } + BreakpointEntryKind::DataBreakpoint(bp) => { + (SelectedBreakpointKind::Data, bp.0.is_enabled) + } }) }) } @@ -391,7 +390,8 @@ impl BreakpointList { let row = line_breakpoint.breakpoint.row; self.go_to_line_breakpoint(path, row, window, cx); } - BreakpointEntryKind::ExceptionBreakpoint(_) => {} + BreakpointEntryKind::DataBreakpoint(_) + | BreakpointEntryKind::ExceptionBreakpoint(_) => {} } } @@ -421,6 +421,10 @@ impl BreakpointList { let id = exception_breakpoint.id.clone(); self.toggle_exception_breakpoint(&id, cx); } + BreakpointEntryKind::DataBreakpoint(data_breakpoint) => { + let id = data_breakpoint.0.dap.data_id.clone(); + self.toggle_data_breakpoint(&id, cx); + } } cx.notify(); } @@ -441,7 +445,7 @@ impl BreakpointList { let row = line_breakpoint.breakpoint.row; self.edit_line_breakpoint(path, row, BreakpointEditAction::Toggle, cx); } - BreakpointEntryKind::ExceptionBreakpoint(_) => {} + _ => {} } cx.notify(); } @@ -490,6 +494,14 @@ impl BreakpointList { cx.notify(); } + fn toggle_data_breakpoint(&mut self, id: &str, cx: &mut Context) { + if let Some(session) = &self.session { + session.update(cx, |this, cx| { + this.toggle_data_breakpoint(&id, cx); + }); + } + } + fn toggle_exception_breakpoint(&mut self, id: &str, cx: &mut Context) { if let Some(session) = &self.session { session.update(cx, |this, cx| { @@ -548,21 +560,6 @@ impl BreakpointList { Ok(()) } - fn hide_scrollbar(&mut self, window: &mut Window, cx: &mut Context) { - const SCROLLBAR_SHOW_INTERVAL: Duration = Duration::from_secs(1); - self.hide_scrollbar_task = Some(cx.spawn_in(window, async move |panel, cx| { - cx.background_executor() - .timer(SCROLLBAR_SHOW_INTERVAL) - .await; - panel - .update(cx, |panel, cx| { - panel.show_scrollbar = false; - cx.notify(); - }) - .log_err(); - })) - } - fn render_list(&mut self, cx: &mut Context) -> impl IntoElement { let selected_ix = self.selected_ix; let focus_handle = self.focus_handle.clone(); @@ -597,43 +594,39 @@ impl BreakpointList { .flex_grow() } - fn render_vertical_scrollbar(&self, cx: &mut Context) -> Option> { - if !(self.show_scrollbar || self.scrollbar_state.is_dragging()) { - return None; - } - Some( - div() - .occlude() - .id("breakpoint-list-vertical-scrollbar") - .on_mouse_move(cx.listener(|_, _, _, cx| { - cx.notify(); - cx.stop_propagation() - })) - .on_hover(|_, _, cx| { - cx.stop_propagation(); - }) - .on_any_mouse_down(|_, _, cx| { + fn render_vertical_scrollbar(&self, cx: &mut Context) -> Stateful
{ + div() + .occlude() + .id("breakpoint-list-vertical-scrollbar") + .on_mouse_move(cx.listener(|_, _, _, cx| { + cx.notify(); + cx.stop_propagation() + })) + .on_hover(|_, _, cx| { + cx.stop_propagation(); + }) + .on_any_mouse_down(|_, _, cx| { + cx.stop_propagation(); + }) + .on_mouse_up( + MouseButton::Left, + cx.listener(|_, _, _, cx| { cx.stop_propagation(); - }) - .on_mouse_up( - MouseButton::Left, - cx.listener(|_, _, _, cx| { - cx.stop_propagation(); - }), - ) - .on_scroll_wheel(cx.listener(|_, _, _, cx| { - cx.notify(); - })) - .h_full() - .absolute() - .right_1() - .top_1() - .bottom_0() - .w(px(12.)) - .cursor_default() - .children(Scrollbar::vertical(self.scrollbar_state.clone())), - ) + }), + ) + .on_scroll_wheel(cx.listener(|_, _, _, cx| { + cx.notify(); + })) + .h_full() + .absolute() + .right_1() + .top_1() + .bottom_0() + .w(px(12.)) + .cursor_default() + .children(Scrollbar::vertical(self.scrollbar_state.clone()).map(|s| s.auto_hide(cx))) } + pub(crate) fn render_control_strip(&self) -> AnyElement { let selection_kind = self.selection_kind(); let focus_handle = self.focus_handle.clone(); @@ -642,6 +635,7 @@ impl BreakpointList { SelectedBreakpointKind::Exception => { "Exception Breakpoints cannot be removed from the breakpoint list" } + SelectedBreakpointKind::Data => "Remove data breakpoint from a breakpoint list", }); let toggle_label = selection_kind.map(|(_, is_enabled)| { if is_enabled { @@ -783,21 +777,24 @@ impl Render for BreakpointList { weak: weak.clone(), }) }); - self.breakpoints - .extend(breakpoints.chain(exception_breakpoints)); + let data_breakpoints = self.session.as_ref().into_iter().flat_map(|session| { + session + .read(cx) + .data_breakpoints() + .map(|state| BreakpointEntry { + kind: BreakpointEntryKind::DataBreakpoint(DataBreakpoint(state.clone())), + weak: weak.clone(), + }) + }); + self.breakpoints.extend( + breakpoints + .chain(data_breakpoints) + .chain(exception_breakpoints), + ); v_flex() .id("breakpoint-list") .key_context("BreakpointList") .track_focus(&self.focus_handle) - .on_hover(cx.listener(|this, hovered, window, cx| { - if *hovered { - this.show_scrollbar = true; - this.hide_scrollbar_task.take(); - cx.notify(); - } else if !this.focus_handle.contains_focused(window, cx) { - this.hide_scrollbar(window, cx); - } - })) .on_action(cx.listener(Self::select_next)) .on_action(cx.listener(Self::select_previous)) .on_action(cx.listener(Self::select_first)) @@ -814,7 +811,7 @@ impl Render for BreakpointList { v_flex() .size_full() .child(self.render_list(cx)) - .children(self.render_vertical_scrollbar(cx)), + .child(self.render_vertical_scrollbar(cx)), ) .when_some(self.strip_mode, |this, _| { this.child(Divider::horizontal()).child( @@ -905,7 +902,11 @@ impl LineBreakpoint { .ok(); } }) - .child(Indicator::icon(Icon::new(icon_name)).color(Color::Debugger)) + .child( + Icon::new(icon_name) + .color(Color::Debugger) + .size(IconSize::XSmall), + ) .on_mouse_down(MouseButton::Left, move |_, _, _| {}); ListItem::new(SharedString::from(format!( @@ -996,6 +997,103 @@ struct ExceptionBreakpoint { data: ExceptionBreakpointsFilter, is_enabled: bool, } +#[derive(Clone, Debug)] +struct DataBreakpoint(project::debugger::session::DataBreakpointState); + +impl DataBreakpoint { + fn render( + &self, + props: SupportedBreakpointProperties, + strip_mode: Option, + ix: usize, + is_selected: bool, + focus_handle: FocusHandle, + list: WeakEntity, + ) -> ListItem { + let color = if self.0.is_enabled { + Color::Debugger + } else { + Color::Muted + }; + let is_enabled = self.0.is_enabled; + let id = self.0.dap.data_id.clone(); + ListItem::new(SharedString::from(format!( + "data-breakpoint-ui-item-{}", + self.0.dap.data_id + ))) + .rounded() + .start_slot( + div() + .id(SharedString::from(format!( + "data-breakpoint-ui-item-{}-click-handler", + self.0.dap.data_id + ))) + .tooltip({ + let focus_handle = focus_handle.clone(); + move |window, cx| { + Tooltip::for_action_in( + if is_enabled { + "Disable Data Breakpoint" + } else { + "Enable Data Breakpoint" + }, + &ToggleEnableBreakpoint, + &focus_handle, + window, + cx, + ) + } + }) + .on_click({ + let list = list.clone(); + move |_, _, cx| { + list.update(cx, |this, cx| { + this.toggle_data_breakpoint(&id, cx); + }) + .ok(); + } + }) + .cursor_pointer() + .child( + Icon::new(IconName::Binary) + .color(color) + .size(IconSize::Small), + ), + ) + .child( + h_flex() + .w_full() + .mr_4() + .py_0p5() + .justify_between() + .child( + v_flex() + .py_1() + .gap_1() + .min_h(px(26.)) + .justify_center() + .id(("data-breakpoint-label", ix)) + .child( + Label::new(self.0.context.human_readable_label()) + .size(LabelSize::Small) + .line_height_style(ui::LineHeightStyle::UiLabel), + ), + ) + .child(BreakpointOptionsStrip { + props, + breakpoint: BreakpointEntry { + kind: BreakpointEntryKind::DataBreakpoint(self.clone()), + weak: list, + }, + is_selected, + focus_handle, + strip_mode, + index: ix, + }), + ) + .toggle_state(is_selected) + } +} impl ExceptionBreakpoint { fn render( @@ -1062,7 +1160,11 @@ impl ExceptionBreakpoint { } }) .cursor_pointer() - .child(Indicator::icon(Icon::new(IconName::Flame)).color(color)), + .child( + Icon::new(IconName::Flame) + .color(color) + .size(IconSize::Small), + ), ) .child( h_flex() @@ -1105,6 +1207,7 @@ impl ExceptionBreakpoint { enum BreakpointEntryKind { LineBreakpoint(LineBreakpoint), ExceptionBreakpoint(ExceptionBreakpoint), + DataBreakpoint(DataBreakpoint), } #[derive(Clone, Debug)] @@ -1140,6 +1243,14 @@ impl BreakpointEntry { focus_handle, self.weak.clone(), ), + BreakpointEntryKind::DataBreakpoint(data_breakpoint) => data_breakpoint.render( + props.for_data_breakpoints(), + strip_mode, + ix, + is_selected, + focus_handle, + self.weak.clone(), + ), } } @@ -1155,6 +1266,11 @@ impl BreakpointEntry { exception_breakpoint.id ) .into(), + BreakpointEntryKind::DataBreakpoint(data_breakpoint) => format!( + "data-breakpoint-control-strip--{}", + data_breakpoint.0.dap.data_id + ) + .into(), } } @@ -1172,8 +1288,8 @@ impl BreakpointEntry { BreakpointEntryKind::LineBreakpoint(line_breakpoint) => { line_breakpoint.breakpoint.condition.is_some() } - // We don't support conditions on exception breakpoints - BreakpointEntryKind::ExceptionBreakpoint(_) => false, + // We don't support conditions on exception/data breakpoints + _ => false, } } @@ -1225,6 +1341,10 @@ impl SupportedBreakpointProperties { // TODO: we don't yet support conditions for exception breakpoints at the data layer, hence all props are disabled here. Self::empty() } + fn for_data_breakpoints(self) -> Self { + // TODO: we don't yet support conditions for data breakpoints at the data layer, hence all props are disabled here. + Self::empty() + } } #[derive(IntoElement)] struct BreakpointOptionsStrip { diff --git a/crates/debugger_ui/src/session/running/console.rs b/crates/debugger_ui/src/session/running/console.rs index 9375c8820b0eb335f1d36534f219f339ec587df1..1385bec54ef77222485cd642174d50aa60fa289a 100644 --- a/crates/debugger_ui/src/session/running/console.rs +++ b/crates/debugger_ui/src/session/running/console.rs @@ -12,7 +12,7 @@ use gpui::{ Action as _, AppContext, Context, Corner, Entity, FocusHandle, Focusable, HighlightStyle, Hsla, Render, Subscription, Task, TextStyle, WeakEntity, actions, }; -use language::{Buffer, CodeLabel, ToOffset}; +use language::{Anchor, Buffer, CodeLabel, TextBufferSnapshot, ToOffset}; use menu::{Confirm, SelectNext, SelectPrevious}; use project::{ Completion, CompletionResponse, @@ -637,27 +637,13 @@ impl ConsoleQueryBarCompletionProvider { }); let snapshot = buffer.read(cx).text_snapshot(); - let query = snapshot.text(); - let replace_range = { - let buffer_offset = buffer_position.to_offset(&snapshot); - let reversed_chars = snapshot.reversed_chars_for_range(0..buffer_offset); - let mut word_len = 0; - for ch in reversed_chars { - if ch.is_alphanumeric() || ch == '_' { - word_len += 1; - } else { - break; - } - } - let word_start_offset = buffer_offset - word_len; - let start_anchor = snapshot.anchor_at(word_start_offset, Bias::Left); - start_anchor..buffer_position - }; + let buffer_text = snapshot.text(); + cx.spawn(async move |_, cx| { const LIMIT: usize = 10; let matches = fuzzy::match_strings( &string_matches, - &query, + &buffer_text, true, true, LIMIT, @@ -672,7 +658,12 @@ impl ConsoleQueryBarCompletionProvider { let variable_value = variables.get(&string_match.string)?; Some(project::Completion { - replace_range: replace_range.clone(), + replace_range: Self::replace_range_for_completion( + &buffer_text, + buffer_position, + string_match.string.as_bytes(), + &snapshot, + ), new_text: string_match.string.clone(), label: CodeLabel { filter_range: 0..string_match.string.len(), @@ -697,6 +688,28 @@ impl ConsoleQueryBarCompletionProvider { }) } + fn replace_range_for_completion( + buffer_text: &String, + buffer_position: Anchor, + new_bytes: &[u8], + snapshot: &TextBufferSnapshot, + ) -> Range { + let buffer_offset = buffer_position.to_offset(&snapshot); + let buffer_bytes = &buffer_text.as_bytes()[0..buffer_offset]; + + let mut prefix_len = 0; + for i in (0..new_bytes.len()).rev() { + if buffer_bytes.ends_with(&new_bytes[0..i]) { + prefix_len = i; + break; + } + } + + let start = snapshot.clip_offset(buffer_offset - prefix_len, Bias::Left); + + snapshot.anchor_before(start)..buffer_position + } + const fn completion_type_score(completion_type: CompletionItemType) -> usize { match completion_type { CompletionItemType::Field | CompletionItemType::Property => 0, @@ -744,6 +757,8 @@ impl ConsoleQueryBarCompletionProvider { cx.background_executor().spawn(async move { let completions = completion_task.await?; + let buffer_text = snapshot.text(); + let completions = completions .into_iter() .map(|completion| { @@ -753,26 +768,14 @@ impl ConsoleQueryBarCompletionProvider { .as_ref() .unwrap_or(&completion.label) .to_owned(); - let buffer_text = snapshot.text(); - let buffer_bytes = buffer_text.as_bytes(); - let new_bytes = new_text.as_bytes(); - - let mut prefix_len = 0; - for i in (0..new_bytes.len()).rev() { - if buffer_bytes.ends_with(&new_bytes[0..i]) { - prefix_len = i; - break; - } - } - - let buffer_offset = buffer_position.to_offset(&snapshot); - let start = buffer_offset - prefix_len; - let start = snapshot.clip_offset(start, Bias::Left); - let start = snapshot.anchor_before(start); - let replace_range = start..buffer_position; project::Completion { - replace_range, + replace_range: Self::replace_range_for_completion( + &buffer_text, + buffer_position, + new_text.as_bytes(), + &snapshot, + ), new_text, label: CodeLabel { filter_range: 0..completion.label.len(), @@ -944,3 +947,64 @@ fn color_fetcher(color: ansi::Color) -> fn(&Theme) -> Hsla { }; color_fetcher } + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::init_test; + use editor::test::editor_test_context::EditorTestContext; + use gpui::TestAppContext; + use language::Point; + + #[track_caller] + fn assert_completion_range( + input: &str, + expect: &str, + replacement: &str, + cx: &mut EditorTestContext, + ) { + cx.set_state(input); + + let buffer_position = + cx.editor(|editor, _, cx| editor.selections.newest::(cx).start); + + let snapshot = &cx.buffer_snapshot(); + + let replace_range = ConsoleQueryBarCompletionProvider::replace_range_for_completion( + &cx.buffer_text(), + snapshot.anchor_before(buffer_position), + replacement.as_bytes(), + &snapshot, + ); + + cx.update_editor(|editor, _, cx| { + editor.edit( + vec![( + snapshot.offset_for_anchor(&replace_range.start) + ..snapshot.offset_for_anchor(&replace_range.end), + replacement, + )], + cx, + ); + }); + + pretty_assertions::assert_eq!(expect, cx.display_text()); + } + + #[gpui::test] + async fn test_determine_completion_replace_range(cx: &mut TestAppContext) { + init_test(cx); + + let mut cx = EditorTestContext::new(cx).await; + + assert_completion_range("resˇ", "result", "result", &mut cx); + assert_completion_range("print(resˇ)", "print(result)", "result", &mut cx); + assert_completion_range("$author->nˇ", "$author->name", "$author->name", &mut cx); + assert_completion_range( + "$author->books[ˇ", + "$author->books[0]", + "$author->books[0]", + &mut cx, + ); + } +} diff --git a/crates/debugger_ui/src/session/running/loaded_source_list.rs b/crates/debugger_ui/src/session/running/loaded_source_list.rs index dd5487e0426ac8386a6af04d27e30d786bb29eaf..6b376bb892e1ea5aae64a1d5873b91487e65f3c2 100644 --- a/crates/debugger_ui/src/session/running/loaded_source_list.rs +++ b/crates/debugger_ui/src/session/running/loaded_source_list.rs @@ -13,22 +13,8 @@ pub(crate) struct LoadedSourceList { impl LoadedSourceList { pub fn new(session: Entity, cx: &mut Context) -> Self { - let weak_entity = cx.weak_entity(); let focus_handle = cx.focus_handle(); - - let list = ListState::new( - 0, - gpui::ListAlignment::Top, - px(1000.), - move |ix, _window, cx| { - weak_entity - .upgrade() - .map(|loaded_sources| { - loaded_sources.update(cx, |this, cx| this.render_entry(ix, cx)) - }) - .unwrap_or(div().into_any()) - }, - ); + let list = ListState::new(0, gpui::ListAlignment::Top, px(1000.)); let _subscription = cx.subscribe(&session, |this, _, event, cx| match event { SessionEvent::Stopped(_) | SessionEvent::LoadedSources => { @@ -98,6 +84,12 @@ impl Render for LoadedSourceList { .track_focus(&self.focus_handle) .size_full() .p_1() - .child(list(self.list.clone()).size_full()) + .child( + list( + self.list.clone(), + cx.processor(|this, ix, _window, cx| this.render_entry(ix, cx)), + ) + .size_full(), + ) } } diff --git a/crates/debugger_ui/src/session/running/memory_view.rs b/crates/debugger_ui/src/session/running/memory_view.rs new file mode 100644 index 0000000000000000000000000000000000000000..75b8938371555374fccdb5d77a6a8cc07ebae0e0 --- /dev/null +++ b/crates/debugger_ui/src/session/running/memory_view.rs @@ -0,0 +1,951 @@ +use std::{ + cell::LazyCell, + fmt::Write, + ops::RangeInclusive, + sync::{Arc, LazyLock}, + time::Duration, +}; + +use editor::{Editor, EditorElement, EditorStyle}; +use gpui::{ + Action, AppContext, DismissEvent, DragMoveEvent, Empty, Entity, FocusHandle, Focusable, + MouseButton, Point, ScrollStrategy, ScrollWheelEvent, Stateful, Subscription, Task, TextStyle, + UniformList, UniformListScrollHandle, WeakEntity, actions, anchored, deferred, point, + uniform_list, +}; +use notifications::status_toast::{StatusToast, ToastIcon}; +use project::debugger::{MemoryCell, dap_command::DataBreakpointContext, session::Session}; +use settings::Settings; +use theme::ThemeSettings; +use ui::{ + ActiveTheme, AnyElement, App, Color, Context, ContextMenu, Div, Divider, DropdownMenu, Element, + FluentBuilder, Icon, IconName, InteractiveElement, IntoElement, Label, LabelCommon, + ParentElement, Pixels, PopoverMenuHandle, Render, Scrollbar, ScrollbarState, SharedString, + StatefulInteractiveElement, Styled, TextSize, Tooltip, Window, div, h_flex, px, v_flex, +}; +use workspace::Workspace; + +use crate::{ToggleDataBreakpoint, session::running::stack_frame_list::StackFrameList}; + +actions!(debugger, [GoToSelectedAddress]); + +pub(crate) struct MemoryView { + workspace: WeakEntity, + scroll_handle: UniformListScrollHandle, + scroll_state: ScrollbarState, + stack_frame_list: WeakEntity, + focus_handle: FocusHandle, + view_state: ViewState, + query_editor: Entity, + session: Entity, + width_picker_handle: PopoverMenuHandle, + is_writing_memory: bool, + open_context_menu: Option<(Entity, Point, Subscription)>, +} + +impl Focusable for MemoryView { + fn focus_handle(&self, _: &ui::App) -> FocusHandle { + self.focus_handle.clone() + } +} +#[derive(Clone, Debug)] +struct Drag { + start_address: u64, + end_address: u64, +} + +impl Drag { + fn contains(&self, address: u64) -> bool { + let range = self.memory_range(); + range.contains(&address) + } + + fn memory_range(&self) -> RangeInclusive { + if self.start_address < self.end_address { + self.start_address..=self.end_address + } else { + self.end_address..=self.start_address + } + } +} +#[derive(Clone, Debug)] +enum SelectedMemoryRange { + DragUnderway(Drag), + DragComplete(Drag), +} + +impl SelectedMemoryRange { + fn contains(&self, address: u64) -> bool { + match self { + SelectedMemoryRange::DragUnderway(drag) => drag.contains(address), + SelectedMemoryRange::DragComplete(drag) => drag.contains(address), + } + } + fn is_dragging(&self) -> bool { + matches!(self, SelectedMemoryRange::DragUnderway(_)) + } + fn drag(&self) -> &Drag { + match self { + SelectedMemoryRange::DragUnderway(drag) => drag, + SelectedMemoryRange::DragComplete(drag) => drag, + } + } +} + +#[derive(Clone)] +struct ViewState { + /// Uppermost row index + base_row: u64, + /// How many cells per row do we have? + line_width: ViewWidth, + selection: Option, +} + +impl ViewState { + fn new(base_row: u64, line_width: ViewWidth) -> Self { + Self { + base_row, + line_width, + selection: None, + } + } + fn row_count(&self) -> u64 { + // This was picked fully arbitrarily. There's no incentive for us to care about page sizes other than the fact that it seems to be a good + // middle ground for data size. + const PAGE_SIZE: u64 = 4096; + PAGE_SIZE / self.line_width.width as u64 + } + fn schedule_scroll_down(&mut self) { + self.base_row = self.base_row.saturating_add(1) + } + fn schedule_scroll_up(&mut self) { + self.base_row = self.base_row.saturating_sub(1); + } +} + +struct ScrollbarDragging; + +static HEX_BYTES_MEMOIZED: LazyLock<[SharedString; 256]> = + LazyLock::new(|| std::array::from_fn(|byte| SharedString::from(format!("{byte:02X}")))); +static UNKNOWN_BYTE: SharedString = SharedString::new_static("??"); +impl MemoryView { + pub(crate) fn new( + session: Entity, + workspace: WeakEntity, + stack_frame_list: WeakEntity, + window: &mut Window, + cx: &mut Context, + ) -> Self { + let view_state = ViewState::new(0, WIDTHS[4].clone()); + let scroll_handle = UniformListScrollHandle::default(); + + let query_editor = cx.new(|cx| Editor::single_line(window, cx)); + + let scroll_state = ScrollbarState::new(scroll_handle.clone()); + let mut this = Self { + workspace, + scroll_state, + scroll_handle, + stack_frame_list, + focus_handle: cx.focus_handle(), + view_state, + query_editor, + session, + width_picker_handle: Default::default(), + is_writing_memory: true, + open_context_menu: None, + }; + this.change_query_bar_mode(false, window, cx); + cx.on_focus_out(&this.focus_handle, window, |this, _, window, cx| { + this.change_query_bar_mode(false, window, cx); + cx.notify(); + }) + .detach(); + this + } + + fn render_vertical_scrollbar(&self, cx: &mut Context) -> Stateful
{ + div() + .occlude() + .id("memory-view-vertical-scrollbar") + .on_drag_move(cx.listener(|this, evt, _, cx| { + let did_handle = this.handle_scroll_drag(evt); + cx.notify(); + if did_handle { + cx.stop_propagation() + } + })) + .on_drag(ScrollbarDragging, |_, _, _, cx| cx.new(|_| Empty)) + .on_hover(|_, _, cx| { + cx.stop_propagation(); + }) + .on_any_mouse_down(|_, _, cx| { + cx.stop_propagation(); + }) + .on_mouse_up( + MouseButton::Left, + cx.listener(|_, _, _, cx| { + cx.stop_propagation(); + }), + ) + .on_scroll_wheel(cx.listener(|_, _, _, cx| { + cx.notify(); + })) + .h_full() + .absolute() + .right_1() + .top_1() + .bottom_0() + .w(px(12.)) + .cursor_default() + .children(Scrollbar::vertical(self.scroll_state.clone()).map(|s| s.auto_hide(cx))) + } + + fn render_memory(&self, cx: &mut Context) -> UniformList { + let weak = cx.weak_entity(); + let session = self.session.clone(); + let view_state = self.view_state.clone(); + uniform_list( + "debugger-memory-view", + self.view_state.row_count() as usize, + move |range, _, cx| { + let mut line_buffer = Vec::with_capacity(view_state.line_width.width as usize); + let memory_start = + (view_state.base_row + range.start as u64) * view_state.line_width.width as u64; + let memory_end = (view_state.base_row + range.end as u64) + * view_state.line_width.width as u64 + - 1; + let mut memory = session.update(cx, |this, cx| { + this.read_memory(memory_start..=memory_end, cx) + }); + let mut rows = Vec::with_capacity(range.end - range.start); + for ix in range { + line_buffer.extend((&mut memory).take(view_state.line_width.width as usize)); + rows.push(render_single_memory_view_line( + &line_buffer, + ix as u64, + weak.clone(), + cx, + )); + line_buffer.clear(); + } + rows + }, + ) + .track_scroll(self.scroll_handle.clone()) + .on_scroll_wheel(cx.listener(|this, evt: &ScrollWheelEvent, window, _| { + let delta = evt.delta.pixel_delta(window.line_height()); + let scroll_handle = this.scroll_state.scroll_handle(); + let size = scroll_handle.content_size(); + let viewport = scroll_handle.viewport(); + let current_offset = scroll_handle.offset(); + let first_entry_offset_boundary = size.height / this.view_state.row_count() as f32; + let last_entry_offset_boundary = size.height - first_entry_offset_boundary; + if first_entry_offset_boundary + viewport.size.height > current_offset.y.abs() { + // The topmost entry is visible, hence if we're scrolling up, we need to load extra lines. + this.view_state.schedule_scroll_up(); + } else if last_entry_offset_boundary < current_offset.y.abs() + viewport.size.height { + this.view_state.schedule_scroll_down(); + } + scroll_handle.set_offset(current_offset + point(px(0.), delta.y)); + })) + } + fn render_query_bar(&self, cx: &Context) -> impl IntoElement { + EditorElement::new( + &self.query_editor, + Self::editor_style(&self.query_editor, cx), + ) + } + pub(super) fn go_to_memory_reference( + &mut self, + memory_reference: &str, + evaluate_name: Option<&str>, + stack_frame_id: Option, + cx: &mut Context, + ) { + use parse_int::parse; + let Ok(as_address) = parse::(&memory_reference) else { + return; + }; + let access_size = evaluate_name + .map(|typ| { + self.session.update(cx, |this, cx| { + this.data_access_size(stack_frame_id, typ, cx) + }) + }) + .unwrap_or_else(|| Task::ready(None)); + cx.spawn(async move |this, cx| { + let access_size = access_size.await.unwrap_or(1); + this.update(cx, |this, cx| { + this.view_state.selection = Some(SelectedMemoryRange::DragComplete(Drag { + start_address: as_address, + end_address: as_address + access_size - 1, + })); + this.jump_to_address(as_address, cx); + }) + .ok(); + }) + .detach(); + } + + fn handle_memory_drag(&mut self, evt: &DragMoveEvent) { + if !self + .view_state + .selection + .as_ref() + .is_some_and(|selection| selection.is_dragging()) + { + return; + } + let row_count = self.view_state.row_count(); + debug_assert!(row_count > 1); + let scroll_handle = self.scroll_state.scroll_handle(); + let viewport = scroll_handle.viewport(); + + if viewport.bottom() < evt.event.position.y { + self.view_state.schedule_scroll_down(); + } else if viewport.top() > evt.event.position.y { + self.view_state.schedule_scroll_up(); + } + } + + fn handle_scroll_drag(&mut self, evt: &DragMoveEvent) -> bool { + if !self.scroll_state.is_dragging() { + return false; + } + let row_count = self.view_state.row_count(); + debug_assert!(row_count > 1); + let scroll_handle = self.scroll_state.scroll_handle(); + let viewport = scroll_handle.viewport(); + + if viewport.bottom() < evt.event.position.y { + self.view_state.schedule_scroll_down(); + true + } else if viewport.top() > evt.event.position.y { + self.view_state.schedule_scroll_up(); + true + } else { + false + } + } + + fn editor_style(editor: &Entity, cx: &Context) -> EditorStyle { + let is_read_only = editor.read(cx).read_only(cx); + let settings = ThemeSettings::get_global(cx); + let theme = cx.theme(); + let text_style = TextStyle { + color: if is_read_only { + theme.colors().text_muted + } else { + theme.colors().text + }, + font_family: settings.buffer_font.family.clone(), + font_features: settings.buffer_font.features.clone(), + font_size: TextSize::Small.rems(cx).into(), + font_weight: settings.buffer_font.weight, + + ..Default::default() + }; + EditorStyle { + background: theme.colors().editor_background, + local_player: theme.players().local(), + text: text_style, + ..Default::default() + } + } + + fn render_width_picker(&self, window: &mut Window, cx: &mut Context) -> DropdownMenu { + let weak = cx.weak_entity(); + let selected_width = self.view_state.line_width.clone(); + DropdownMenu::new( + "memory-view-width-picker", + selected_width.label.clone(), + ContextMenu::build(window, cx, |mut this, window, cx| { + for width in &WIDTHS { + let weak = weak.clone(); + let width = width.clone(); + this = this.entry(width.label.clone(), None, move |_, cx| { + _ = weak.update(cx, |this, _| { + // Convert base ix between 2 line widths to keep the shown memory address roughly the same. + // All widths are powers of 2, so the conversion should be lossless. + match this.view_state.line_width.width.cmp(&width.width) { + std::cmp::Ordering::Less => { + // We're converting up. + let shift = width.width.trailing_zeros() + - this.view_state.line_width.width.trailing_zeros(); + this.view_state.base_row >>= shift; + } + std::cmp::Ordering::Greater => { + // We're converting down. + let shift = this.view_state.line_width.width.trailing_zeros() + - width.width.trailing_zeros(); + this.view_state.base_row <<= shift; + } + _ => {} + } + this.view_state.line_width = width.clone(); + }); + }); + } + if let Some(ix) = WIDTHS + .iter() + .position(|width| width.width == selected_width.width) + { + for _ in 0..=ix { + this.select_next(&Default::default(), window, cx); + } + } + this + }), + ) + .handle(self.width_picker_handle.clone()) + } + + fn page_down(&mut self, _: &menu::SelectLast, _: &mut Window, cx: &mut Context) { + self.view_state.base_row = self + .view_state + .base_row + .overflowing_add(self.view_state.row_count()) + .0; + cx.notify(); + } + fn page_up(&mut self, _: &menu::SelectFirst, _: &mut Window, cx: &mut Context) { + self.view_state.base_row = self + .view_state + .base_row + .overflowing_sub(self.view_state.row_count()) + .0; + cx.notify(); + } + + fn change_query_bar_mode( + &mut self, + is_writing_memory: bool, + window: &mut Window, + cx: &mut Context, + ) { + if is_writing_memory == self.is_writing_memory { + return; + } + if !self.is_writing_memory { + self.query_editor.update(cx, |this, cx| { + this.clear(window, cx); + this.set_placeholder_text("Write to Selected Memory Range", cx); + }); + self.is_writing_memory = true; + self.query_editor.focus_handle(cx).focus(window); + } else { + self.query_editor.update(cx, |this, cx| { + this.clear(window, cx); + this.set_placeholder_text("Go to Memory Address / Expression", cx); + }); + self.is_writing_memory = false; + } + } + + fn toggle_data_breakpoint( + &mut self, + _: &crate::ToggleDataBreakpoint, + _: &mut Window, + cx: &mut Context, + ) { + let Some(SelectedMemoryRange::DragComplete(selection)) = self.view_state.selection.clone() + else { + return; + }; + let range = selection.memory_range(); + let context = Arc::new(DataBreakpointContext::Address { + address: range.start().to_string(), + bytes: Some(*range.end() - *range.start()), + }); + + self.session.update(cx, |this, cx| { + let data_breakpoint_info = this.data_breakpoint_info(context.clone(), None, cx); + cx.spawn(async move |this, cx| { + if let Some(info) = data_breakpoint_info.await { + let Some(data_id) = info.data_id.clone() else { + return; + }; + _ = this.update(cx, |this, cx| { + this.create_data_breakpoint( + context, + data_id.clone(), + dap::DataBreakpoint { + data_id, + access_type: None, + condition: None, + hit_condition: None, + }, + cx, + ); + }); + } + }) + .detach(); + }) + } + + fn confirm(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { + if let Some(SelectedMemoryRange::DragComplete(drag)) = &self.view_state.selection { + // Go into memory writing mode. + if !self.is_writing_memory { + let should_return = self.session.update(cx, |session, cx| { + if !session + .capabilities() + .supports_write_memory_request + .unwrap_or_default() + { + let adapter_name = session.adapter(); + // We cannot write memory with this adapter. + _ = self.workspace.update(cx, |this, cx| { + this.toggle_status_toast( + StatusToast::new(format!( + "Debug Adapter `{adapter_name}` does not support writing to memory" + ), cx, |this, cx| { + cx.spawn(async move |this, cx| { + cx.background_executor().timer(Duration::from_secs(2)).await; + _ = this.update(cx, |_, cx| { + cx.emit(DismissEvent) + }); + }).detach(); + this.icon(ToastIcon::new(IconName::XCircle).color(Color::Error)) + }), + cx, + ); + }); + true + } else { + false + } + }); + if should_return { + return; + } + + self.change_query_bar_mode(true, window, cx); + } else if self.query_editor.focus_handle(cx).is_focused(window) { + let mut text = self.query_editor.read(cx).text(cx); + if text.chars().any(|c| !c.is_ascii_hexdigit()) { + // Interpret this text as a string and oh-so-conveniently convert it. + text = text.bytes().map(|byte| format!("{:02x}", byte)).collect(); + } + self.session.update(cx, |this, cx| { + let range = drag.memory_range(); + + if let Ok(as_hex) = hex::decode(text) { + this.write_memory(*range.start(), &as_hex, cx); + } + }); + self.change_query_bar_mode(false, window, cx); + } + + cx.notify(); + return; + } + // Just change the currently viewed address. + if !self.query_editor.focus_handle(cx).is_focused(window) { + return; + } + self.jump_to_query_bar_address(cx); + } + + fn jump_to_query_bar_address(&mut self, cx: &mut Context) { + use parse_int::parse; + let text = self.query_editor.read(cx).text(cx); + + let Ok(as_address) = parse::(&text) else { + return self.jump_to_expression(text, cx); + }; + self.jump_to_address(as_address, cx); + } + + fn jump_to_address(&mut self, address: u64, cx: &mut Context) { + self.view_state.base_row = (address & !0xfff) / self.view_state.line_width.width as u64; + let line_ix = (address & 0xfff) / self.view_state.line_width.width as u64; + self.scroll_handle + .scroll_to_item(line_ix as usize, ScrollStrategy::Center); + cx.notify(); + } + + fn jump_to_expression(&mut self, expr: String, cx: &mut Context) { + let Ok(selected_frame) = self + .stack_frame_list + .update(cx, |this, _| this.opened_stack_frame_id()) + else { + return; + }; + let expr = format!("?${{{expr}}}"); + let reference = self.session.update(cx, |this, cx| { + this.memory_reference_of_expr(selected_frame, expr, cx) + }); + cx.spawn(async move |this, cx| { + if let Some((reference, typ)) = reference.await { + _ = this.update(cx, |this, cx| { + let sizeof_expr = if typ.as_ref().is_some_and(|t| { + t.chars() + .all(|c| c.is_whitespace() || c.is_alphabetic() || c == '*') + }) { + typ.as_deref() + } else { + None + }; + this.go_to_memory_reference(&reference, sizeof_expr, selected_frame, cx); + }); + } + }) + .detach(); + } + + fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context) { + self.view_state.selection = None; + cx.notify(); + } + + /// Jump to memory pointed to by selected memory range. + fn go_to_address( + &mut self, + _: &GoToSelectedAddress, + window: &mut Window, + cx: &mut Context, + ) { + let Some(SelectedMemoryRange::DragComplete(drag)) = self.view_state.selection.clone() + else { + return; + }; + let range = drag.memory_range(); + let Some(memory): Option> = self.session.update(cx, |this, cx| { + this.read_memory(range, cx).map(|cell| cell.0).collect() + }) else { + return; + }; + if memory.len() > 8 { + return; + } + let zeros_to_write = 8 - memory.len(); + let mut acc = String::from("0x"); + acc.extend(std::iter::repeat("00").take(zeros_to_write)); + let as_query = memory.into_iter().rev().fold(acc, |mut acc, byte| { + _ = write!(&mut acc, "{:02x}", byte); + acc + }); + self.query_editor.update(cx, |this, cx| { + this.set_text(as_query, window, cx); + }); + self.jump_to_query_bar_address(cx); + } + + fn deploy_memory_context_menu( + &mut self, + range: RangeInclusive, + position: Point, + window: &mut Window, + cx: &mut Context, + ) { + let session = self.session.clone(); + let context_menu = ContextMenu::build(window, cx, |menu, _, cx| { + let range_too_large = range.end() - range.start() > std::mem::size_of::() as u64; + let caps = session.read(cx).capabilities(); + let supports_data_breakpoints = caps.supports_data_breakpoints.unwrap_or_default() + && caps.supports_data_breakpoint_bytes.unwrap_or_default(); + let memory_unreadable = LazyCell::new(|| { + session.update(cx, |this, cx| { + this.read_memory(range.clone(), cx) + .any(|cell| cell.0.is_none()) + }) + }); + + let mut menu = menu.action_disabled_when( + range_too_large || *memory_unreadable, + "Go To Selected Address", + GoToSelectedAddress.boxed_clone(), + ); + + if supports_data_breakpoints { + menu = menu.action_disabled_when( + *memory_unreadable, + "Set Data Breakpoint", + ToggleDataBreakpoint { access_type: None }.boxed_clone(), + ); + } + menu.context(self.focus_handle.clone()) + }); + + cx.focus_view(&context_menu, window); + let subscription = cx.subscribe_in( + &context_menu, + window, + |this, _, _: &DismissEvent, window, cx| { + if this.open_context_menu.as_ref().is_some_and(|context_menu| { + context_menu.0.focus_handle(cx).contains_focused(window, cx) + }) { + cx.focus_self(window); + } + this.open_context_menu.take(); + cx.notify(); + }, + ); + + self.open_context_menu = Some((context_menu, position, subscription)); + } +} + +#[derive(Clone)] +struct ViewWidth { + width: u8, + label: SharedString, +} + +impl ViewWidth { + const fn new(width: u8, label: &'static str) -> Self { + Self { + width, + label: SharedString::new_static(label), + } + } +} + +static WIDTHS: [ViewWidth; 7] = [ + ViewWidth::new(1, "1 byte"), + ViewWidth::new(2, "2 bytes"), + ViewWidth::new(4, "4 bytes"), + ViewWidth::new(8, "8 bytes"), + ViewWidth::new(16, "16 bytes"), + ViewWidth::new(32, "32 bytes"), + ViewWidth::new(64, "64 bytes"), +]; + +fn render_single_memory_view_line( + memory: &[MemoryCell], + ix: u64, + weak: gpui::WeakEntity, + cx: &mut App, +) -> AnyElement { + let Ok(view_state) = weak.update(cx, |this, _| this.view_state.clone()) else { + return div().into_any(); + }; + let base_address = (view_state.base_row + ix) * view_state.line_width.width as u64; + + h_flex() + .id(( + "memory-view-row-full", + ix * view_state.line_width.width as u64, + )) + .size_full() + .gap_x_2() + .child( + div() + .child( + Label::new(format!("{:016X}", base_address)) + .buffer_font(cx) + .size(ui::LabelSize::Small) + .color(Color::Muted), + ) + .px_1() + .border_r_1() + .border_color(Color::Muted.color(cx)), + ) + .child( + h_flex() + .id(( + "memory-view-row-raw-memory", + ix * view_state.line_width.width as u64, + )) + .px_1() + .children(memory.iter().enumerate().map(|(cell_ix, cell)| { + let weak = weak.clone(); + div() + .id(("memory-view-row-raw-memory-cell", cell_ix as u64)) + .px_0p5() + .when_some(view_state.selection.as_ref(), |this, selection| { + this.when(selection.contains(base_address + cell_ix as u64), |this| { + let weak = weak.clone(); + + this.bg(Color::Selected.color(cx).opacity(0.2)).when( + !selection.is_dragging(), + |this| { + let selection = selection.drag().memory_range(); + this.on_mouse_down( + MouseButton::Right, + move |click, window, cx| { + _ = weak.update(cx, |this, cx| { + this.deploy_memory_context_menu( + selection.clone(), + click.position, + window, + cx, + ) + }); + cx.stop_propagation(); + }, + ) + }, + ) + }) + }) + .child( + Label::new( + cell.0 + .map(|val| HEX_BYTES_MEMOIZED[val as usize].clone()) + .unwrap_or_else(|| UNKNOWN_BYTE.clone()), + ) + .buffer_font(cx) + .when(cell.0.is_none(), |this| this.color(Color::Muted)) + .size(ui::LabelSize::Small), + ) + .on_drag( + Drag { + start_address: base_address + cell_ix as u64, + end_address: base_address + cell_ix as u64, + }, + { + let weak = weak.clone(); + move |drag, _, _, cx| { + _ = weak.update(cx, |this, _| { + this.view_state.selection = + Some(SelectedMemoryRange::DragUnderway(drag.clone())); + }); + + cx.new(|_| Empty) + } + }, + ) + .on_drop({ + let weak = weak.clone(); + move |drag: &Drag, _, cx| { + _ = weak.update(cx, |this, _| { + this.view_state.selection = + Some(SelectedMemoryRange::DragComplete(Drag { + start_address: drag.start_address, + end_address: base_address + cell_ix as u64, + })); + }); + } + }) + .drag_over(move |style, drag: &Drag, _, cx| { + _ = weak.update(cx, |this, _| { + this.view_state.selection = + Some(SelectedMemoryRange::DragUnderway(Drag { + start_address: drag.start_address, + end_address: base_address + cell_ix as u64, + })); + }); + + style + }) + })), + ) + .child( + h_flex() + .id(( + "memory-view-row-ascii-memory", + ix * view_state.line_width.width as u64, + )) + .h_full() + .px_1() + .mr_4() + // .gap_x_1p5() + .border_x_1() + .border_color(Color::Muted.color(cx)) + .children(memory.iter().enumerate().map(|(ix, cell)| { + let as_character = char::from(cell.0.unwrap_or(0)); + let as_visible = if as_character.is_ascii_graphic() { + as_character + } else { + '·' + }; + div() + .px_0p5() + .when_some(view_state.selection.as_ref(), |this, selection| { + this.when(selection.contains(base_address + ix as u64), |this| { + this.bg(Color::Selected.color(cx).opacity(0.2)) + }) + }) + .child( + Label::new(format!("{as_visible}")) + .buffer_font(cx) + .when(cell.0.is_none(), |this| this.color(Color::Muted)) + .size(ui::LabelSize::Small), + ) + })), + ) + .into_any() +} + +impl Render for MemoryView { + fn render( + &mut self, + window: &mut ui::Window, + cx: &mut ui::Context, + ) -> impl ui::IntoElement { + let (icon, tooltip_text) = if self.is_writing_memory { + (IconName::Pencil, "Edit memory at a selected address") + } else { + ( + IconName::LocationEdit, + "Change address of currently viewed memory", + ) + }; + v_flex() + .id("Memory-view") + .on_action(cx.listener(Self::cancel)) + .on_action(cx.listener(Self::go_to_address)) + .p_1() + .on_action(cx.listener(Self::confirm)) + .on_action(cx.listener(Self::toggle_data_breakpoint)) + .on_action(cx.listener(Self::page_down)) + .on_action(cx.listener(Self::page_up)) + .size_full() + .track_focus(&self.focus_handle) + .child( + h_flex() + .w_full() + .mb_0p5() + .gap_1() + .child( + h_flex() + .w_full() + .rounded_md() + .border_1() + .gap_x_2() + .px_2() + .py_0p5() + .mb_0p5() + .bg(cx.theme().colors().editor_background) + .when_else( + self.query_editor + .focus_handle(cx) + .contains_focused(window, cx), + |this| this.border_color(cx.theme().colors().border_focused), + |this| this.border_color(cx.theme().colors().border_transparent), + ) + .child( + div() + .id("memory-view-editor-icon") + .child(Icon::new(icon).size(ui::IconSize::XSmall)) + .tooltip(Tooltip::text(tooltip_text)), + ) + .child(self.render_query_bar(cx)), + ) + .child(self.render_width_picker(window, cx)), + ) + .child(Divider::horizontal()) + .child( + v_flex() + .size_full() + .on_drag_move(cx.listener(|this, evt, _, _| { + this.handle_memory_drag(&evt); + })) + .child(self.render_memory(cx).size_full()) + .children(self.open_context_menu.as_ref().map(|(menu, position, _)| { + deferred( + anchored() + .position(*position) + .anchor(gpui::Corner::TopLeft) + .child(menu.clone()), + ) + .with_priority(1) + })) + .child(self.render_vertical_scrollbar(cx)), + ) + } +} diff --git a/crates/debugger_ui/src/session/running/stack_frame_list.rs b/crates/debugger_ui/src/session/running/stack_frame_list.rs index da3674c8e2eedf18be05fb7ecc1381521e9735b2..2149502f4a5774478555ef66139f163a85a37b3d 100644 --- a/crates/debugger_ui/src/session/running/stack_frame_list.rs +++ b/crates/debugger_ui/src/session/running/stack_frame_list.rs @@ -70,13 +70,7 @@ impl StackFrameList { _ => {} }); - let list_state = ListState::new(0, gpui::ListAlignment::Top, px(1000.), { - let this = cx.weak_entity(); - move |ix, _window, cx| { - this.update(cx, |this, cx| this.render_entry(ix, cx)) - .unwrap_or(div().into_any()) - } - }); + let list_state = ListState::new(0, gpui::ListAlignment::Top, px(1000.)); let scrollbar_state = ScrollbarState::new(list_state.clone()); let mut this = Self { @@ -708,11 +702,14 @@ impl StackFrameList { self.activate_selected_entry(window, cx); } - fn render_list(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { - div() - .p_1() - .size_full() - .child(list(self.list_state.clone()).size_full()) + fn render_list(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + div().p_1().size_full().child( + list( + self.list_state.clone(), + cx.processor(|this, ix, _window, cx| this.render_entry(ix, cx)), + ) + .size_full(), + ) } } diff --git a/crates/debugger_ui/src/session/running/variable_list.rs b/crates/debugger_ui/src/session/running/variable_list.rs index bdb095bde3e4295bf96cff7d02012e4a4ea9d5bd..efbc72e8cfc9099a5d699493898440d17fbf615b 100644 --- a/crates/debugger_ui/src/session/running/variable_list.rs +++ b/crates/debugger_ui/src/session/running/variable_list.rs @@ -1,3 +1,5 @@ +use crate::session::running::{RunningState, memory_view::MemoryView}; + use super::stack_frame_list::{StackFrameList, StackFrameListEvent}; use dap::{ ScopePresentationHint, StackFrameId, VariablePresentationHint, VariablePresentationHintKind, @@ -7,13 +9,17 @@ use editor::Editor; use gpui::{ Action, AnyElement, ClickEvent, ClipboardItem, Context, DismissEvent, Empty, Entity, FocusHandle, Focusable, Hsla, MouseButton, MouseDownEvent, Point, Stateful, Subscription, - TextStyleRefinement, UniformListScrollHandle, actions, anchored, deferred, uniform_list, + TextStyleRefinement, UniformListScrollHandle, WeakEntity, actions, anchored, deferred, + uniform_list, }; use menu::{SelectFirst, SelectLast, SelectNext, SelectPrevious}; -use project::debugger::session::{Session, SessionEvent, Watcher}; +use project::debugger::{ + dap_command::DataBreakpointContext, + session::{Session, SessionEvent, Watcher}, +}; use std::{collections::HashMap, ops::Range, sync::Arc}; use ui::{ContextMenu, ListItem, ScrollableHandle, Scrollbar, ScrollbarState, Tooltip, prelude::*}; -use util::debug_panic; +use util::{debug_panic, maybe}; actions!( variable_list, @@ -32,6 +38,8 @@ actions!( AddWatch, /// Removes the selected variable from the watch list. RemoveWatch, + /// Jump to variable's memory location. + GoToMemory, ] ); @@ -86,30 +94,30 @@ impl EntryPath { } #[derive(Debug, Clone, PartialEq)] -enum EntryKind { +enum DapEntry { Watcher(Watcher), Variable(dap::Variable), Scope(dap::Scope), } -impl EntryKind { +impl DapEntry { fn as_watcher(&self) -> Option<&Watcher> { match self { - EntryKind::Watcher(watcher) => Some(watcher), + DapEntry::Watcher(watcher) => Some(watcher), _ => None, } } fn as_variable(&self) -> Option<&dap::Variable> { match self { - EntryKind::Variable(dap) => Some(dap), + DapEntry::Variable(dap) => Some(dap), _ => None, } } fn as_scope(&self) -> Option<&dap::Scope> { match self { - EntryKind::Scope(dap) => Some(dap), + DapEntry::Scope(dap) => Some(dap), _ => None, } } @@ -117,38 +125,38 @@ impl EntryKind { #[cfg(test)] fn name(&self) -> &str { match self { - EntryKind::Watcher(watcher) => &watcher.expression, - EntryKind::Variable(dap) => &dap.name, - EntryKind::Scope(dap) => &dap.name, + DapEntry::Watcher(watcher) => &watcher.expression, + DapEntry::Variable(dap) => &dap.name, + DapEntry::Scope(dap) => &dap.name, } } } #[derive(Debug, Clone, PartialEq)] struct ListEntry { - dap_kind: EntryKind, + entry: DapEntry, path: EntryPath, } impl ListEntry { fn as_watcher(&self) -> Option<&Watcher> { - self.dap_kind.as_watcher() + self.entry.as_watcher() } fn as_variable(&self) -> Option<&dap::Variable> { - self.dap_kind.as_variable() + self.entry.as_variable() } fn as_scope(&self) -> Option<&dap::Scope> { - self.dap_kind.as_scope() + self.entry.as_scope() } fn item_id(&self) -> ElementId { use std::fmt::Write; - let mut id = match &self.dap_kind { - EntryKind::Watcher(watcher) => format!("watcher-{}", watcher.expression), - EntryKind::Variable(dap) => format!("variable-{}", dap.name), - EntryKind::Scope(dap) => format!("scope-{}", dap.name), + let mut id = match &self.entry { + DapEntry::Watcher(watcher) => format!("watcher-{}", watcher.expression), + DapEntry::Variable(dap) => format!("variable-{}", dap.name), + DapEntry::Scope(dap) => format!("scope-{}", dap.name), }; for name in self.path.indices.iter() { _ = write!(id, "-{}", name); @@ -158,10 +166,10 @@ impl ListEntry { fn item_value_id(&self) -> ElementId { use std::fmt::Write; - let mut id = match &self.dap_kind { - EntryKind::Watcher(watcher) => format!("watcher-{}", watcher.expression), - EntryKind::Variable(dap) => format!("variable-{}", dap.name), - EntryKind::Scope(dap) => format!("scope-{}", dap.name), + let mut id = match &self.entry { + DapEntry::Watcher(watcher) => format!("watcher-{}", watcher.expression), + DapEntry::Variable(dap) => format!("variable-{}", dap.name), + DapEntry::Scope(dap) => format!("scope-{}", dap.name), }; for name in self.path.indices.iter() { _ = write!(id, "-{}", name); @@ -188,13 +196,17 @@ pub struct VariableList { focus_handle: FocusHandle, edited_path: Option<(EntryPath, Entity)>, disabled: bool, + memory_view: Entity, + weak_running: WeakEntity, _subscriptions: Vec, } impl VariableList { - pub fn new( + pub(crate) fn new( session: Entity, stack_frame_list: Entity, + memory_view: Entity, + weak_running: WeakEntity, window: &mut Window, cx: &mut Context, ) -> Self { @@ -211,6 +223,7 @@ impl VariableList { SessionEvent::Variables | SessionEvent::Watchers => { this.build_entries(cx); } + _ => {} }), cx.on_focus_out(&focus_handle, window, |this, _, _, cx| { @@ -234,6 +247,8 @@ impl VariableList { edited_path: None, entries: Default::default(), entry_states: Default::default(), + weak_running, + memory_view, } } @@ -284,7 +299,7 @@ impl VariableList { scope.variables_reference, scope.variables_reference, EntryPath::for_scope(&scope.name), - EntryKind::Scope(scope), + DapEntry::Scope(scope), ) }) .collect::>(); @@ -298,7 +313,7 @@ impl VariableList { watcher.variables_reference, watcher.variables_reference, EntryPath::for_watcher(watcher.expression.clone()), - EntryKind::Watcher(watcher.clone()), + DapEntry::Watcher(watcher.clone()), ) }) .collect::>(), @@ -309,9 +324,9 @@ impl VariableList { while let Some((container_reference, variables_reference, mut path, dap_kind)) = stack.pop() { match &dap_kind { - EntryKind::Watcher(watcher) => path = path.with_child(watcher.expression.clone()), - EntryKind::Variable(dap) => path = path.with_name(dap.name.clone().into()), - EntryKind::Scope(dap) => path = path.with_child(dap.name.clone().into()), + DapEntry::Watcher(watcher) => path = path.with_child(watcher.expression.clone()), + DapEntry::Variable(dap) => path = path.with_name(dap.name.clone().into()), + DapEntry::Scope(dap) => path = path.with_child(dap.name.clone().into()), } let var_state = self @@ -336,7 +351,7 @@ impl VariableList { }); entries.push(ListEntry { - dap_kind, + entry: dap_kind, path: path.clone(), }); @@ -349,7 +364,7 @@ impl VariableList { variables_reference, child.variables_reference, path.with_child(child.name.clone().into()), - EntryKind::Variable(child), + DapEntry::Variable(child), ) })); } @@ -380,9 +395,9 @@ impl VariableList { pub fn completion_variables(&self, _cx: &mut Context) -> Vec { self.entries .iter() - .filter_map(|entry| match &entry.dap_kind { - EntryKind::Variable(dap) => Some(dap.clone()), - EntryKind::Scope(_) | EntryKind::Watcher { .. } => None, + .filter_map(|entry| match &entry.entry { + DapEntry::Variable(dap) => Some(dap.clone()), + DapEntry::Scope(_) | DapEntry::Watcher { .. } => None, }) .collect() } @@ -400,12 +415,12 @@ impl VariableList { .get(ix) .and_then(|entry| Some(entry).zip(self.entry_states.get(&entry.path)))?; - match &entry.dap_kind { - EntryKind::Watcher { .. } => { + match &entry.entry { + DapEntry::Watcher { .. } => { Some(self.render_watcher(entry, *state, window, cx)) } - EntryKind::Variable(_) => Some(self.render_variable(entry, *state, window, cx)), - EntryKind::Scope(_) => Some(self.render_scope(entry, *state, cx)), + DapEntry::Variable(_) => Some(self.render_variable(entry, *state, window, cx)), + DapEntry::Scope(_) => Some(self.render_scope(entry, *state, cx)), } }) .collect() @@ -562,6 +577,51 @@ impl VariableList { } } + fn jump_to_variable_memory( + &mut self, + _: &GoToMemory, + window: &mut Window, + cx: &mut Context, + ) { + _ = maybe!({ + let selection = self.selection.as_ref()?; + let entry = self.entries.iter().find(|entry| &entry.path == selection)?; + let var = entry.entry.as_variable()?; + let memory_reference = var.memory_reference.as_deref()?; + + let sizeof_expr = if var.type_.as_ref().is_some_and(|t| { + t.chars() + .all(|c| c.is_whitespace() || c.is_alphabetic() || c == '*') + }) { + var.type_.as_deref() + } else { + var.evaluate_name + .as_deref() + .map(|name| name.strip_prefix("/nat ").unwrap_or_else(|| name)) + }; + self.memory_view.update(cx, |this, cx| { + this.go_to_memory_reference( + memory_reference, + sizeof_expr, + self.selected_stack_frame_id, + cx, + ); + }); + let weak_panel = self.weak_running.clone(); + + window.defer(cx, move |window, cx| { + _ = weak_panel.update(cx, |this, cx| { + this.activate_item( + crate::persistence::DebuggerPaneItem::MemoryView, + window, + cx, + ); + }); + }); + Some(()) + }); + } + fn deploy_list_entry_context_menu( &mut self, entry: ListEntry, @@ -569,49 +629,197 @@ impl VariableList { window: &mut Window, cx: &mut Context, ) { - let supports_set_variable = self - .session - .read(cx) - .capabilities() - .supports_set_variable - .unwrap_or_default(); - - let context_menu = ContextMenu::build(window, cx, |menu, _, _| { - menu.when(entry.as_variable().is_some(), |menu| { - menu.action("Copy Name", CopyVariableName.boxed_clone()) - .action("Copy Value", CopyVariableValue.boxed_clone()) - .when(supports_set_variable, |menu| { - menu.action("Edit Value", EditVariable.boxed_clone()) + let (supports_set_variable, supports_data_breakpoints, supports_go_to_memory) = + self.session.read_with(cx, |session, _| { + ( + session + .capabilities() + .supports_set_variable + .unwrap_or_default(), + session + .capabilities() + .supports_data_breakpoints + .unwrap_or_default(), + session + .capabilities() + .supports_read_memory_request + .unwrap_or_default(), + ) + }); + let can_toggle_data_breakpoint = entry + .as_variable() + .filter(|_| supports_data_breakpoints) + .and_then(|variable| { + let variables_reference = self + .entry_states + .get(&entry.path) + .map(|state| state.parent_reference)?; + Some(self.session.update(cx, |session, cx| { + session.data_breakpoint_info( + Arc::new(DataBreakpointContext::Variable { + variables_reference, + name: variable.name.clone(), + bytes: None, + }), + None, + cx, + ) + })) + }); + + let focus_handle = self.focus_handle.clone(); + cx.spawn_in(window, async move |this, cx| { + let can_toggle_data_breakpoint = if let Some(task) = can_toggle_data_breakpoint { + task.await + } else { + None + }; + cx.update(|window, cx| { + let context_menu = ContextMenu::build(window, cx, |menu, _, _| { + menu.when_some(entry.as_variable(), |menu, _| { + menu.action("Copy Name", CopyVariableName.boxed_clone()) + .action("Copy Value", CopyVariableValue.boxed_clone()) + .when(supports_set_variable, |menu| { + menu.action("Edit Value", EditVariable.boxed_clone()) + }) + .when(supports_go_to_memory, |menu| { + menu.action("Go To Memory", GoToMemory.boxed_clone()) + }) + .action("Watch Variable", AddWatch.boxed_clone()) + .when_some(can_toggle_data_breakpoint, |mut menu, data_info| { + menu = menu.separator(); + if let Some(access_types) = data_info.access_types { + for access in access_types { + menu = menu.action( + format!( + "Toggle {} Data Breakpoint", + match access { + dap::DataBreakpointAccessType::Read => "Read", + dap::DataBreakpointAccessType::Write => "Write", + dap::DataBreakpointAccessType::ReadWrite => + "Read/Write", + } + ), + crate::ToggleDataBreakpoint { + access_type: Some(access), + } + .boxed_clone(), + ); + } + + menu + } else { + menu.action( + "Toggle Data Breakpoint", + crate::ToggleDataBreakpoint { access_type: None } + .boxed_clone(), + ) + } + }) }) - .action("Watch Variable", AddWatch.boxed_clone()) - }) - .when(entry.as_watcher().is_some(), |menu| { - menu.action("Copy Name", CopyVariableName.boxed_clone()) - .action("Copy Value", CopyVariableValue.boxed_clone()) - .when(supports_set_variable, |menu| { - menu.action("Edit Value", EditVariable.boxed_clone()) + .when(entry.as_watcher().is_some(), |menu| { + menu.action("Copy Name", CopyVariableName.boxed_clone()) + .action("Copy Value", CopyVariableValue.boxed_clone()) + .when(supports_set_variable, |menu| { + menu.action("Edit Value", EditVariable.boxed_clone()) + }) + .action("Remove Watch", RemoveWatch.boxed_clone()) }) - .action("Remove Watch", RemoveWatch.boxed_clone()) + .context(focus_handle.clone()) + }); + + _ = this.update(cx, |this, cx| { + cx.focus_view(&context_menu, window); + let subscription = cx.subscribe_in( + &context_menu, + window, + |this, _, _: &DismissEvent, window, cx| { + if this.open_context_menu.as_ref().is_some_and(|context_menu| { + context_menu.0.focus_handle(cx).contains_focused(window, cx) + }) { + cx.focus_self(window); + } + this.open_context_menu.take(); + cx.notify(); + }, + ); + + this.open_context_menu = Some((context_menu, position, subscription)); + }); }) - .context(self.focus_handle.clone()) + }) + .detach(); + } + + fn toggle_data_breakpoint( + &mut self, + data_info: &crate::ToggleDataBreakpoint, + _window: &mut Window, + cx: &mut Context, + ) { + let Some(entry) = self + .selection + .as_ref() + .and_then(|selection| self.entries.iter().find(|entry| &entry.path == selection)) + else { + return; + }; + + let Some((name, var_ref)) = entry.as_variable().map(|var| &var.name).zip( + self.entry_states + .get(&entry.path) + .map(|state| state.parent_reference), + ) else { + return; + }; + + let context = Arc::new(DataBreakpointContext::Variable { + variables_reference: var_ref, + name: name.clone(), + bytes: None, + }); + let data_breakpoint = self.session.update(cx, |session, cx| { + session.data_breakpoint_info(context.clone(), None, cx) }); - cx.focus_view(&context_menu, window); - let subscription = cx.subscribe_in( - &context_menu, - window, - |this, _, _: &DismissEvent, window, cx| { - if this.open_context_menu.as_ref().is_some_and(|context_menu| { - context_menu.0.focus_handle(cx).contains_focused(window, cx) - }) { - cx.focus_self(window); + let session = self.session.downgrade(); + let access_type = data_info.access_type; + cx.spawn(async move |_, cx| { + let Some((data_id, access_types)) = data_breakpoint + .await + .and_then(|info| Some((info.data_id?, info.access_types))) + else { + return; + }; + + // Because user's can manually add this action to the keymap + // we check if access type is supported + let access_type = match access_types { + None => None, + Some(access_types) => { + if access_type.is_some_and(|access_type| access_types.contains(&access_type)) { + access_type + } else { + None + } } - this.open_context_menu.take(); + }; + _ = session.update(cx, |session, cx| { + session.create_data_breakpoint( + context, + data_id.clone(), + dap::DataBreakpoint { + data_id, + access_type, + condition: None, + hit_condition: None, + }, + cx, + ); cx.notify(); - }, - ); - - self.open_context_menu = Some((context_menu, position, subscription)); + }); + }) + .detach(); } fn copy_variable_name( @@ -628,10 +836,10 @@ impl VariableList { return; }; - let variable_name = match &entry.dap_kind { - EntryKind::Variable(dap) => dap.name.clone(), - EntryKind::Watcher(watcher) => watcher.expression.to_string(), - EntryKind::Scope(_) => return, + let variable_name = match &entry.entry { + DapEntry::Variable(dap) => dap.name.clone(), + DapEntry::Watcher(watcher) => watcher.expression.to_string(), + DapEntry::Scope(_) => return, }; cx.write_to_clipboard(ClipboardItem::new_string(variable_name)); @@ -651,10 +859,10 @@ impl VariableList { return; }; - let variable_value = match &entry.dap_kind { - EntryKind::Variable(dap) => dap.value.clone(), - EntryKind::Watcher(watcher) => watcher.value.to_string(), - EntryKind::Scope(_) => return, + let variable_value = match &entry.entry { + DapEntry::Variable(dap) => dap.value.clone(), + DapEntry::Watcher(watcher) => watcher.value.to_string(), + DapEntry::Scope(_) => return, }; cx.write_to_clipboard(ClipboardItem::new_string(variable_value)); @@ -669,10 +877,10 @@ impl VariableList { return; }; - let variable_value = match &entry.dap_kind { - EntryKind::Watcher(watcher) => watcher.value.to_string(), - EntryKind::Variable(variable) => variable.value.clone(), - EntryKind::Scope(_) => return, + let variable_value = match &entry.entry { + DapEntry::Watcher(watcher) => watcher.value.to_string(), + DapEntry::Variable(variable) => variable.value.clone(), + DapEntry::Scope(_) => return, }; let editor = Self::create_variable_editor(&variable_value, window, cx); @@ -753,7 +961,7 @@ impl VariableList { "{}{} {}{}", INDENT.repeat(state.depth - 1), if state.is_expanded { "v" } else { ">" }, - entry.dap_kind.name(), + entry.entry.name(), if self.selection.as_ref() == Some(&entry.path) { " <=== selected" } else { @@ -770,8 +978,8 @@ impl VariableList { pub(crate) fn scopes(&self) -> Vec { self.entries .iter() - .filter_map(|entry| match &entry.dap_kind { - EntryKind::Scope(scope) => Some(scope), + .filter_map(|entry| match &entry.entry { + DapEntry::Scope(scope) => Some(scope), _ => None, }) .cloned() @@ -785,10 +993,10 @@ impl VariableList { let mut idx = 0; for entry in self.entries.iter() { - match &entry.dap_kind { - EntryKind::Watcher { .. } => continue, - EntryKind::Variable(dap) => scopes[idx].1.push(dap.clone()), - EntryKind::Scope(scope) => { + match &entry.entry { + DapEntry::Watcher { .. } => continue, + DapEntry::Variable(dap) => scopes[idx].1.push(dap.clone()), + DapEntry::Scope(scope) => { if scopes.len() > 0 { idx += 1; } @@ -806,8 +1014,8 @@ impl VariableList { pub(crate) fn variables(&self) -> Vec { self.entries .iter() - .filter_map(|entry| match &entry.dap_kind { - EntryKind::Variable(variable) => Some(variable), + .filter_map(|entry| match &entry.entry { + DapEntry::Variable(variable) => Some(variable), _ => None, }) .cloned() @@ -899,7 +1107,7 @@ impl VariableList { let variable_value = value.clone(); this.on_click(cx.listener( move |this, click: &ClickEvent, window, cx| { - if click.down.click_count < 2 { + if click.click_count() < 2 { return; } let editor = Self::create_variable_editor( @@ -1358,6 +1566,8 @@ impl Render for VariableList { .on_action(cx.listener(Self::edit_variable)) .on_action(cx.listener(Self::add_watcher)) .on_action(cx.listener(Self::remove_watcher)) + .on_action(cx.listener(Self::toggle_data_breakpoint)) + .on_action(cx.listener(Self::jump_to_variable_memory)) .child( uniform_list( "variable-list", diff --git a/crates/debugger_ui/src/tests/debugger_panel.rs b/crates/debugger_ui/src/tests/debugger_panel.rs index 05bca8131ac9734b1635a90c22026424f1c5cf2e..6180831ea9dccfb3c1ee861daac099e54b2242c3 100644 --- a/crates/debugger_ui/src/tests/debugger_panel.rs +++ b/crates/debugger_ui/src/tests/debugger_panel.rs @@ -427,7 +427,7 @@ async fn test_handle_start_debugging_request( let sessions = workspace .update(cx, |workspace, _window, cx| { let debug_panel = workspace.panel::(cx).unwrap(); - debug_panel.read(cx).sessions() + debug_panel.read(cx).sessions().collect::>() }) .unwrap(); assert_eq!(sessions.len(), 1); @@ -451,7 +451,7 @@ async fn test_handle_start_debugging_request( .unwrap() .read(cx) .session(cx); - let current_sessions = debug_panel.read(cx).sessions(); + let current_sessions = debug_panel.read(cx).sessions().collect::>(); assert_eq!(active_session, current_sessions[1].read(cx).session(cx)); assert_eq!( active_session.read(cx).parent_session(), @@ -918,7 +918,7 @@ async fn test_debug_panel_item_thread_status_reset_on_failure( .unwrap(); let client = session.update(cx, |session, _| session.adapter_client().unwrap()); - const THREAD_ID_NUM: u64 = 1; + const THREAD_ID_NUM: i64 = 1; client.on_request::(move |_, _| { Ok(dap::ThreadsResponse { @@ -1796,7 +1796,7 @@ async fn test_debug_adapters_shutdown_on_app_quit( let panel = workspace.panel::(cx).unwrap(); panel.read_with(cx, |panel, _| { assert!( - !panel.sessions().is_empty(), + panel.sessions().next().is_some(), "Debug session should be active" ); }); diff --git a/crates/debugger_ui/src/tests/inline_values.rs b/crates/debugger_ui/src/tests/inline_values.rs index 45cab2a3063a8741d01efb54059667026a646879..9f921ec969debc5247d531469c5132e8485c163b 100644 --- a/crates/debugger_ui/src/tests/inline_values.rs +++ b/crates/debugger_ui/src/tests/inline_values.rs @@ -2241,3 +2241,34 @@ func main() { ) .await; } + +#[gpui::test] +async fn test_trim_multi_line_inline_value(executor: BackgroundExecutor, cx: &mut TestAppContext) { + let variables = [("y", "hello\n world")]; + + let before = r#" +fn main() { + let y = "hello\n world"; +} +"# + .unindent(); + + let after = r#" +fn main() { + let y: hello… = "hello\n world"; +} +"# + .unindent(); + + test_inline_values_util( + &variables, + &[], + &before, + &after, + None, + rust_lang(), + executor, + cx, + ) + .await; +} diff --git a/crates/debugger_ui/src/tests/module_list.rs b/crates/debugger_ui/src/tests/module_list.rs index 49cfd6fcf88339c7d040d56d575dafce50f8d0f2..09c90cbc4a3af71aa9fb7273cf3535e9f7ece592 100644 --- a/crates/debugger_ui/src/tests/module_list.rs +++ b/crates/debugger_ui/src/tests/module_list.rs @@ -111,7 +111,6 @@ async fn test_module_list(executor: BackgroundExecutor, cx: &mut TestAppContext) }); running_state.update_in(cx, |this, window, cx| { - this.ensure_pane_item(DebuggerPaneItem::Modules, window, cx); this.activate_item(DebuggerPaneItem::Modules, window, cx); cx.refresh_windows(); }); diff --git a/crates/debugger_ui/src/tests/new_process_modal.rs b/crates/debugger_ui/src/tests/new_process_modal.rs index 0805060bf4413a16d4b7242d133e635bdf4d7cd4..d6b0dfa00429f9487eafbe38dca5f072ed547779 100644 --- a/crates/debugger_ui/src/tests/new_process_modal.rs +++ b/crates/debugger_ui/src/tests/new_process_modal.rs @@ -298,7 +298,7 @@ async fn test_dap_adapter_config_conversion_and_validation(cx: &mut TestAppConte let adapter_names = cx.update(|cx| { let registry = DapRegistry::global(cx); - registry.enumerate_adapters() + registry.enumerate_adapters::>() }); let zed_config = ZedDebugConfig { diff --git a/crates/diagnostics/src/diagnostic_renderer.rs b/crates/diagnostics/src/diagnostic_renderer.rs index 77bb249733f612ede3017e1cff592927b40e8d43..ce7b253702a01e24e7f4a457ac418572e0fa2729 100644 --- a/crates/diagnostics/src/diagnostic_renderer.rs +++ b/crates/diagnostics/src/diagnostic_renderer.rs @@ -144,7 +144,6 @@ impl editor::DiagnosticRenderer for DiagnosticRenderer { style: BlockStyle::Flex, render: Arc::new(move |bcx| block.render_block(editor.clone(), bcx)), priority: 1, - render_in_minimap: false, } }) .collect() diff --git a/crates/diagnostics/src/diagnostics.rs b/crates/diagnostics/src/diagnostics.rs index 1daa9025b64f2a783409ba5ebe10214ed55c362b..e7660920da30ddcc088c2bbee6bfb1cf05d51d58 100644 --- a/crates/diagnostics/src/diagnostics.rs +++ b/crates/diagnostics/src/diagnostics.rs @@ -80,6 +80,7 @@ pub(crate) struct ProjectDiagnosticsEditor { include_warnings: bool, update_excerpts_task: Option>>, cargo_diagnostics_fetch: CargoDiagnosticsFetchState, + diagnostic_summary_update: Task<()>, _subscription: Subscription, } @@ -176,16 +177,25 @@ impl ProjectDiagnosticsEditor { } project::Event::DiagnosticsUpdated { language_server_id, - path, + paths, } => { - this.paths_to_update.insert(path.clone()); - this.summary = project.read(cx).diagnostic_summary(false, cx); + this.paths_to_update.extend(paths.clone()); + let project = project.clone(); + this.diagnostic_summary_update = cx.spawn(async move |this, cx| { + cx.background_executor() + .timer(Duration::from_millis(30)) + .await; + this.update(cx, |this, cx| { + this.summary = project.read(cx).diagnostic_summary(false, cx); + }) + .log_err(); + }); cx.emit(EditorEvent::TitleChanged); if this.editor.focus_handle(cx).contains_focused(window, cx) || this.focus_handle.contains_focused(window, cx) { - log::debug!("diagnostics updated for server {language_server_id}, path {path:?}. recording change"); + log::debug!("diagnostics updated for server {language_server_id}, paths {paths:?}. recording change"); } else { - log::debug!("diagnostics updated for server {language_server_id}, path {path:?}. updating excerpts"); + log::debug!("diagnostics updated for server {language_server_id}, paths {paths:?}. updating excerpts"); this.update_stale_excerpts(window, cx); } } @@ -276,6 +286,7 @@ impl ProjectDiagnosticsEditor { cancel_task: None, diagnostic_sources: Arc::new(Vec::new()), }, + diagnostic_summary_update: Task::ready(()), _subscription: project_event_subscription, }; this.update_all_diagnostics(true, window, cx); @@ -656,7 +667,6 @@ impl ProjectDiagnosticsEditor { block.render_block(editor.clone(), bcx) }), priority: 1, - render_in_minimap: false, } }); let block_ids = this.editor.update(cx, |editor, cx| { diff --git a/crates/diagnostics/src/diagnostics_tests.rs b/crates/diagnostics/src/diagnostics_tests.rs index 0d47eaf367d6e28708cdf34258fc6080ba500c86..8fb223b2cbfcc7db817059dd92bf1ff869846645 100644 --- a/crates/diagnostics/src/diagnostics_tests.rs +++ b/crates/diagnostics/src/diagnostics_tests.rs @@ -14,7 +14,10 @@ use indoc::indoc; use language::{DiagnosticSourceKind, Rope}; use lsp::LanguageServerId; use pretty_assertions::assert_eq; -use project::FakeFs; +use project::{ + FakeFs, + project_settings::{GoToDiagnosticSeverity, GoToDiagnosticSeverityFilter}, +}; use rand::{Rng, rngs::StdRng, seq::IteratorRandom as _}; use serde_json::json; use settings::SettingsStore; @@ -870,10 +873,10 @@ async fn test_random_diagnostics_with_inlays(cx: &mut TestAppContext, mut rng: S editor.splice_inlays( &[], - vec![Inlay::inline_completion( + vec![Inlay::edit_prediction( post_inc(&mut next_inlay_id), snapshot.buffer_snapshot.anchor_before(position), - format!("Test inlay {next_inlay_id}"), + Rope::from_iter(["Test inlay ", "next_inlay_id"]), )], cx, ); @@ -1005,7 +1008,7 @@ async fn active_diagnostics_dismiss_after_invalidation(cx: &mut TestAppContext) cx.run_until_parked(); cx.update_editor(|editor, window, cx| { - editor.go_to_diagnostic(&GoToDiagnostic, window, cx); + editor.go_to_diagnostic(&GoToDiagnostic::default(), window, cx); assert_eq!( editor .active_diagnostic_group() @@ -1047,7 +1050,7 @@ async fn active_diagnostics_dismiss_after_invalidation(cx: &mut TestAppContext) "}); cx.update_editor(|editor, window, cx| { - editor.go_to_diagnostic(&GoToDiagnostic, window, cx); + editor.go_to_diagnostic(&GoToDiagnostic::default(), window, cx); assert_eq!(editor.active_diagnostic_group(), None); }); cx.assert_editor_state(indoc! {" @@ -1126,7 +1129,7 @@ async fn cycle_through_same_place_diagnostics(cx: &mut TestAppContext) { // Fourth diagnostic cx.update_editor(|editor, window, cx| { - editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic, window, cx); + editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic::default(), window, cx); }); cx.assert_editor_state(indoc! {" fn func(abc def: i32) -> ˇu32 { @@ -1135,7 +1138,7 @@ async fn cycle_through_same_place_diagnostics(cx: &mut TestAppContext) { // Third diagnostic cx.update_editor(|editor, window, cx| { - editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic, window, cx); + editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic::default(), window, cx); }); cx.assert_editor_state(indoc! {" fn func(abc ˇdef: i32) -> u32 { @@ -1144,7 +1147,7 @@ async fn cycle_through_same_place_diagnostics(cx: &mut TestAppContext) { // Second diagnostic, same place cx.update_editor(|editor, window, cx| { - editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic, window, cx); + editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic::default(), window, cx); }); cx.assert_editor_state(indoc! {" fn func(abc ˇdef: i32) -> u32 { @@ -1153,7 +1156,7 @@ async fn cycle_through_same_place_diagnostics(cx: &mut TestAppContext) { // First diagnostic cx.update_editor(|editor, window, cx| { - editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic, window, cx); + editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic::default(), window, cx); }); cx.assert_editor_state(indoc! {" fn func(abcˇ def: i32) -> u32 { @@ -1162,7 +1165,7 @@ async fn cycle_through_same_place_diagnostics(cx: &mut TestAppContext) { // Wrapped over, fourth diagnostic cx.update_editor(|editor, window, cx| { - editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic, window, cx); + editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic::default(), window, cx); }); cx.assert_editor_state(indoc! {" fn func(abc def: i32) -> ˇu32 { @@ -1181,7 +1184,7 @@ async fn cycle_through_same_place_diagnostics(cx: &mut TestAppContext) { // First diagnostic cx.update_editor(|editor, window, cx| { - editor.go_to_diagnostic(&GoToDiagnostic, window, cx); + editor.go_to_diagnostic(&GoToDiagnostic::default(), window, cx); }); cx.assert_editor_state(indoc! {" fn func(abcˇ def: i32) -> u32 { @@ -1190,7 +1193,7 @@ async fn cycle_through_same_place_diagnostics(cx: &mut TestAppContext) { // Second diagnostic cx.update_editor(|editor, window, cx| { - editor.go_to_diagnostic(&GoToDiagnostic, window, cx); + editor.go_to_diagnostic(&GoToDiagnostic::default(), window, cx); }); cx.assert_editor_state(indoc! {" fn func(abc ˇdef: i32) -> u32 { @@ -1199,7 +1202,7 @@ async fn cycle_through_same_place_diagnostics(cx: &mut TestAppContext) { // Third diagnostic, same place cx.update_editor(|editor, window, cx| { - editor.go_to_diagnostic(&GoToDiagnostic, window, cx); + editor.go_to_diagnostic(&GoToDiagnostic::default(), window, cx); }); cx.assert_editor_state(indoc! {" fn func(abc ˇdef: i32) -> u32 { @@ -1208,7 +1211,7 @@ async fn cycle_through_same_place_diagnostics(cx: &mut TestAppContext) { // Fourth diagnostic cx.update_editor(|editor, window, cx| { - editor.go_to_diagnostic(&GoToDiagnostic, window, cx); + editor.go_to_diagnostic(&GoToDiagnostic::default(), window, cx); }); cx.assert_editor_state(indoc! {" fn func(abc def: i32) -> ˇu32 { @@ -1217,7 +1220,7 @@ async fn cycle_through_same_place_diagnostics(cx: &mut TestAppContext) { // Wrapped around, first diagnostic cx.update_editor(|editor, window, cx| { - editor.go_to_diagnostic(&GoToDiagnostic, window, cx); + editor.go_to_diagnostic(&GoToDiagnostic::default(), window, cx); }); cx.assert_editor_state(indoc! {" fn func(abcˇ def: i32) -> u32 { @@ -1441,6 +1444,128 @@ async fn test_diagnostics_with_code(cx: &mut TestAppContext) { ); } +#[gpui::test] +async fn go_to_diagnostic_with_severity(cx: &mut TestAppContext) { + init_test(cx); + + let mut cx = EditorTestContext::new(cx).await; + let lsp_store = + cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store()); + + cx.set_state(indoc! {"error warning info hiˇnt"}); + + cx.update(|_, cx| { + lsp_store.update(cx, |lsp_store, cx| { + lsp_store + .update_diagnostics( + LanguageServerId(0), + lsp::PublishDiagnosticsParams { + uri: lsp::Url::from_file_path(path!("/root/file")).unwrap(), + version: None, + diagnostics: vec![ + lsp::Diagnostic { + range: lsp::Range::new( + lsp::Position::new(0, 0), + lsp::Position::new(0, 5), + ), + severity: Some(lsp::DiagnosticSeverity::ERROR), + ..Default::default() + }, + lsp::Diagnostic { + range: lsp::Range::new( + lsp::Position::new(0, 6), + lsp::Position::new(0, 13), + ), + severity: Some(lsp::DiagnosticSeverity::WARNING), + ..Default::default() + }, + lsp::Diagnostic { + range: lsp::Range::new( + lsp::Position::new(0, 14), + lsp::Position::new(0, 18), + ), + severity: Some(lsp::DiagnosticSeverity::INFORMATION), + ..Default::default() + }, + lsp::Diagnostic { + range: lsp::Range::new( + lsp::Position::new(0, 19), + lsp::Position::new(0, 23), + ), + severity: Some(lsp::DiagnosticSeverity::HINT), + ..Default::default() + }, + ], + }, + None, + DiagnosticSourceKind::Pushed, + &[], + cx, + ) + .unwrap() + }); + }); + cx.run_until_parked(); + + macro_rules! go { + ($severity:expr) => { + cx.update_editor(|editor, window, cx| { + editor.go_to_diagnostic( + &GoToDiagnostic { + severity: $severity, + }, + window, + cx, + ); + }); + }; + } + + // Default, should cycle through all diagnostics + go!(GoToDiagnosticSeverityFilter::default()); + cx.assert_editor_state(indoc! {"ˇerror warning info hint"}); + go!(GoToDiagnosticSeverityFilter::default()); + cx.assert_editor_state(indoc! {"error ˇwarning info hint"}); + go!(GoToDiagnosticSeverityFilter::default()); + cx.assert_editor_state(indoc! {"error warning ˇinfo hint"}); + go!(GoToDiagnosticSeverityFilter::default()); + cx.assert_editor_state(indoc! {"error warning info ˇhint"}); + go!(GoToDiagnosticSeverityFilter::default()); + cx.assert_editor_state(indoc! {"ˇerror warning info hint"}); + + let only_info = GoToDiagnosticSeverityFilter::Only(GoToDiagnosticSeverity::Information); + go!(only_info); + cx.assert_editor_state(indoc! {"error warning ˇinfo hint"}); + go!(only_info); + cx.assert_editor_state(indoc! {"error warning ˇinfo hint"}); + + let no_hints = GoToDiagnosticSeverityFilter::Range { + min: GoToDiagnosticSeverity::Information, + max: GoToDiagnosticSeverity::Error, + }; + + go!(no_hints); + cx.assert_editor_state(indoc! {"ˇerror warning info hint"}); + go!(no_hints); + cx.assert_editor_state(indoc! {"error ˇwarning info hint"}); + go!(no_hints); + cx.assert_editor_state(indoc! {"error warning ˇinfo hint"}); + go!(no_hints); + cx.assert_editor_state(indoc! {"ˇerror warning info hint"}); + + let warning_info = GoToDiagnosticSeverityFilter::Range { + min: GoToDiagnosticSeverity::Information, + max: GoToDiagnosticSeverity::Warning, + }; + + go!(warning_info); + cx.assert_editor_state(indoc! {"error ˇwarning info hint"}); + go!(warning_info); + cx.assert_editor_state(indoc! {"error warning ˇinfo hint"}); + go!(warning_info); + cx.assert_editor_state(indoc! {"error ˇwarning info hint"}); +} + fn init_test(cx: &mut TestAppContext) { cx.update(|cx| { zlog::init_test(); diff --git a/crates/diagnostics/src/items.rs b/crates/diagnostics/src/items.rs index b5f9e901bbc819414c93ed6300a41a1731699379..7ac6d101f315674cec4fd07f4ad2df0830284124 100644 --- a/crates/diagnostics/src/items.rs +++ b/crates/diagnostics/src/items.rs @@ -6,9 +6,10 @@ use gpui::{ WeakEntity, Window, }; use language::Diagnostic; -use project::project_settings::ProjectSettings; +use project::project_settings::{GoToDiagnosticSeverityFilter, ProjectSettings}; use settings::Settings; use ui::{Button, ButtonLike, Color, Icon, IconName, Label, Tooltip, h_flex, prelude::*}; +use util::ResultExt; use workspace::{StatusItemView, ToolbarItemEvent, Workspace, item::ItemHandle}; use crate::{Deploy, IncludeWarnings, ProjectDiagnosticsEditor}; @@ -20,6 +21,7 @@ pub struct DiagnosticIndicator { current_diagnostic: Option, _observe_active_editor: Option, diagnostics_update: Task<()>, + diagnostic_summary_update: Task<()>, } impl Render for DiagnosticIndicator { @@ -77,7 +79,7 @@ impl Render for DiagnosticIndicator { .tooltip(|window, cx| { Tooltip::for_action( "Next Diagnostic", - &editor::actions::GoToDiagnostic, + &editor::actions::GoToDiagnostic::default(), window, cx, ) @@ -135,8 +137,16 @@ impl DiagnosticIndicator { } project::Event::DiagnosticsUpdated { .. } => { - this.summary = project.read(cx).diagnostic_summary(false, cx); - cx.notify(); + this.diagnostic_summary_update = cx.spawn(async move |this, cx| { + cx.background_executor() + .timer(Duration::from_millis(30)) + .await; + this.update(cx, |this, cx| { + this.summary = project.read(cx).diagnostic_summary(false, cx); + cx.notify(); + }) + .log_err(); + }); } _ => {} @@ -150,13 +160,19 @@ impl DiagnosticIndicator { current_diagnostic: None, _observe_active_editor: None, diagnostics_update: Task::ready(()), + diagnostic_summary_update: Task::ready(()), } } fn go_to_next_diagnostic(&mut self, window: &mut Window, cx: &mut Context) { if let Some(editor) = self.active_editor.as_ref().and_then(|e| e.upgrade()) { editor.update(cx, |editor, cx| { - editor.go_to_diagnostic_impl(editor::Direction::Next, window, cx); + editor.go_to_diagnostic_impl( + editor::Direction::Next, + GoToDiagnosticSeverityFilter::default(), + window, + cx, + ); }) } } diff --git a/crates/docs_preprocessor/Cargo.toml b/crates/docs_preprocessor/Cargo.toml index a0df669abe6036859e2f6c73a26541ed1fc25767..e46ceb18db7e75f0f946da1d112509a18a68d4aa 100644 --- a/crates/docs_preprocessor/Cargo.toml +++ b/crates/docs_preprocessor/Cargo.toml @@ -7,17 +7,19 @@ license = "GPL-3.0-or-later" [dependencies] anyhow.workspace = true -clap.workspace = true -mdbook = "0.4.40" +command_palette.workspace = true +gpui.workspace = true +# We are specifically pinning this version of mdbook, as later versions introduce issues with double-nested subdirectories. +# Ask @maxdeviant about this before bumping. +mdbook = "= 0.4.40" +regex.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true -regex.workspace = true util.workspace = true workspace-hack.workspace = true zed.workspace = true -gpui.workspace = true -command_palette.workspace = true +zlog.workspace = true [lints] workspace = true diff --git a/crates/docs_preprocessor/src/main.rs b/crates/docs_preprocessor/src/main.rs index c8e945c7e83564d162e0b939b92169b905558393..1448f4cb52369b5142856faabd67c0f9c3271220 100644 --- a/crates/docs_preprocessor/src/main.rs +++ b/crates/docs_preprocessor/src/main.rs @@ -1,14 +1,15 @@ -use anyhow::Result; -use clap::{Arg, ArgMatches, Command}; +use anyhow::{Context, Result}; use mdbook::BookItem; use mdbook::book::{Book, Chapter}; use mdbook::preprocess::CmdPreprocessor; use regex::Regex; use settings::KeymapFile; -use std::collections::HashSet; +use std::borrow::Cow; +use std::collections::{HashMap, HashSet}; use std::io::{self, Read}; use std::process; use std::sync::LazyLock; +use util::paths::PathExt; static KEYMAP_MACOS: LazyLock = LazyLock::new(|| { load_keymap("keymaps/default-macos.json").expect("Failed to load MacOS keymap") @@ -20,60 +21,68 @@ static KEYMAP_LINUX: LazyLock = LazyLock::new(|| { static ALL_ACTIONS: LazyLock> = LazyLock::new(dump_all_gpui_actions); -pub fn make_app() -> Command { - Command::new("zed-docs-preprocessor") - .about("Preprocesses Zed Docs content to provide rich action & keybinding support and more") - .subcommand( - Command::new("supports") - .arg(Arg::new("renderer").required(true)) - .about("Check whether a renderer is supported by this preprocessor"), - ) -} +const FRONT_MATTER_COMMENT: &'static str = ""; fn main() -> Result<()> { - let matches = make_app().get_matches(); + zlog::init(); + zlog::init_output_stderr(); // call a zed:: function so everything in `zed` crate is linked and // all actions in the actual app are registered zed::stdout_is_a_pty(); - - if let Some(sub_args) = matches.subcommand_matches("supports") { - handle_supports(sub_args); - } else { - handle_preprocessing()?; + let args = std::env::args().skip(1).collect::>(); + + match args.get(0).map(String::as_str) { + Some("supports") => { + let renderer = args.get(1).expect("Required argument"); + let supported = renderer != "not-supported"; + if supported { + process::exit(0); + } else { + process::exit(1); + } + } + Some("postprocess") => handle_postprocessing()?, + _ => handle_preprocessing()?, } Ok(()) } #[derive(Debug, Clone, PartialEq, Eq, Hash)] -enum Error { +enum PreprocessorError { ActionNotFound { action_name: String }, DeprecatedActionUsed { used: String, should_be: String }, + InvalidFrontmatterLine(String), } -impl Error { +impl PreprocessorError { fn new_for_not_found_action(action_name: String) -> Self { for action in &*ALL_ACTIONS { for alias in action.deprecated_aliases { if alias == &action_name { - return Error::DeprecatedActionUsed { + return PreprocessorError::DeprecatedActionUsed { used: action_name.clone(), should_be: action.name.to_string(), }; } } } - Error::ActionNotFound { + PreprocessorError::ActionNotFound { action_name: action_name.to_string(), } } } -impl std::fmt::Display for Error { +impl std::fmt::Display for PreprocessorError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Error::ActionNotFound { action_name } => write!(f, "Action not found: {}", action_name), - Error::DeprecatedActionUsed { used, should_be } => write!( + PreprocessorError::InvalidFrontmatterLine(line) => { + write!(f, "Invalid frontmatter line: {}", line) + } + PreprocessorError::ActionNotFound { action_name } => { + write!(f, "Action not found: {}", action_name) + } + PreprocessorError::DeprecatedActionUsed { used, should_be } => write!( f, "Deprecated action used: {} should be {}", used, should_be @@ -89,8 +98,9 @@ fn handle_preprocessing() -> Result<()> { let (_ctx, mut book) = CmdPreprocessor::parse_input(input.as_bytes())?; - let mut errors = HashSet::::new(); + let mut errors = HashSet::::new(); + handle_frontmatter(&mut book, &mut errors); template_and_validate_keybindings(&mut book, &mut errors); template_and_validate_actions(&mut book, &mut errors); @@ -108,19 +118,41 @@ fn handle_preprocessing() -> Result<()> { Ok(()) } -fn handle_supports(sub_args: &ArgMatches) -> ! { - let renderer = sub_args - .get_one::("renderer") - .expect("Required argument"); - let supported = renderer != "not-supported"; - if supported { - process::exit(0); - } else { - process::exit(1); - } +fn handle_frontmatter(book: &mut Book, errors: &mut HashSet) { + let frontmatter_regex = Regex::new(r"(?s)^\s*---(.*?)---").unwrap(); + for_each_chapter_mut(book, |chapter| { + let new_content = frontmatter_regex.replace(&chapter.content, |caps: ®ex::Captures| { + let frontmatter = caps[1].trim(); + let frontmatter = frontmatter.trim_matches(&[' ', '-', '\n']); + let mut metadata = HashMap::::default(); + for line in frontmatter.lines() { + let Some((name, value)) = line.split_once(':') else { + errors.insert(PreprocessorError::InvalidFrontmatterLine(format!( + "{}: {}", + chapter_breadcrumbs(&chapter), + line + ))); + continue; + }; + let name = name.trim(); + let value = value.trim(); + metadata.insert(name.to_string(), value.to_string()); + } + FRONT_MATTER_COMMENT.replace( + "{}", + &serde_json::to_string(&metadata).expect("Failed to serialize metadata"), + ) + }); + match new_content { + Cow::Owned(content) => { + chapter.content = content; + } + Cow::Borrowed(_) => {} + } + }); } -fn template_and_validate_keybindings(book: &mut Book, errors: &mut HashSet) { +fn template_and_validate_keybindings(book: &mut Book, errors: &mut HashSet) { let regex = Regex::new(r"\{#kb (.*?)\}").unwrap(); for_each_chapter_mut(book, |chapter| { @@ -128,7 +160,9 @@ fn template_and_validate_keybindings(book: &mut Book, errors: &mut HashSet) { +fn template_and_validate_actions(book: &mut Book, errors: &mut HashSet) { let regex = Regex::new(r"\{#action (.*?)\}").unwrap(); for_each_chapter_mut(book, |chapter| { @@ -152,7 +186,9 @@ fn template_and_validate_actions(book: &mut Book, errors: &mut HashSet) { .replace_all(&chapter.content, |caps: ®ex::Captures| { let name = caps[1].trim(); let Some(action) = find_action_by_name(name) else { - errors.insert(Error::new_for_not_found_action(name.to_string())); + errors.insert(PreprocessorError::new_for_not_found_action( + name.to_string(), + )); return String::new(); }; format!("{}", &action.human_name) @@ -217,6 +253,13 @@ fn name_for_action(action_as_str: String) -> String { .unwrap_or(action_as_str) } +fn chapter_breadcrumbs(chapter: &Chapter) -> String { + let mut breadcrumbs = Vec::with_capacity(chapter.parent_names.len() + 1); + breadcrumbs.extend(chapter.parent_names.iter().map(String::as_str)); + breadcrumbs.push(chapter.name.as_str()); + format!("[{:?}] {}", chapter.source_path, breadcrumbs.join(" > ")) +} + fn load_keymap(asset_path: &str) -> Result { let content = util::asset_str::(asset_path); KeymapFile::parse(content.as_ref()) @@ -243,7 +286,6 @@ struct ActionDef { fn dump_all_gpui_actions() -> Vec { let mut actions = gpui::generate_list_of_all_registered_actions() - .into_iter() .map(|action| ActionDef { name: action.name, human_name: command_palette::humanize_action_name(action.name), @@ -255,3 +297,126 @@ fn dump_all_gpui_actions() -> Vec { return actions; } + +fn handle_postprocessing() -> Result<()> { + let logger = zlog::scoped!("render"); + let mut ctx = mdbook::renderer::RenderContext::from_json(io::stdin())?; + let output = ctx + .config + .get_mut("output") + .expect("has output") + .as_table_mut() + .expect("output is table"); + let zed_html = output.remove("zed-html").expect("zed-html output defined"); + let default_description = zed_html + .get("default-description") + .expect("Default description not found") + .as_str() + .expect("Default description not a string") + .to_string(); + let default_title = zed_html + .get("default-title") + .expect("Default title not found") + .as_str() + .expect("Default title not a string") + .to_string(); + + output.insert("html".to_string(), zed_html); + mdbook::Renderer::render(&mdbook::renderer::HtmlHandlebars::new(), &ctx)?; + let ignore_list = ["toc.html"]; + + let root_dir = ctx.destination.clone(); + let mut files = Vec::with_capacity(128); + let mut queue = Vec::with_capacity(64); + queue.push(root_dir.clone()); + while let Some(dir) = queue.pop() { + for entry in std::fs::read_dir(&dir).context(dir.to_sanitized_string())? { + let Ok(entry) = entry else { + continue; + }; + let file_type = entry.file_type().context("Failed to determine file type")?; + if file_type.is_dir() { + queue.push(entry.path()); + } + if file_type.is_file() + && matches!( + entry.path().extension().and_then(std::ffi::OsStr::to_str), + Some("html") + ) + { + if ignore_list.contains(&&*entry.file_name().to_string_lossy()) { + zlog::info!(logger => "Ignoring {}", entry.path().to_string_lossy()); + } else { + files.push(entry.path()); + } + } + } + } + + zlog::info!(logger => "Processing {} `.html` files", files.len()); + let meta_regex = Regex::new(&FRONT_MATTER_COMMENT.replace("{}", "(.*)")).unwrap(); + for file in files { + let contents = std::fs::read_to_string(&file)?; + let mut meta_description = None; + let mut meta_title = None; + let contents = meta_regex.replace(&contents, |caps: ®ex::Captures| { + let metadata: HashMap = serde_json::from_str(&caps[1]).with_context(|| format!("JSON Metadata: {:?}", &caps[1])).expect("Failed to deserialize metadata"); + for (kind, content) in metadata { + match kind.as_str() { + "description" => { + meta_description = Some(content); + } + "title" => { + meta_title = Some(content); + } + _ => { + zlog::warn!(logger => "Unrecognized frontmatter key: {} in {:?}", kind, pretty_path(&file, &root_dir)); + } + } + } + String::new() + }); + let meta_description = meta_description.as_ref().unwrap_or_else(|| { + zlog::warn!(logger => "No meta description found for {:?}", pretty_path(&file, &root_dir)); + &default_description + }); + let page_title = extract_title_from_page(&contents, pretty_path(&file, &root_dir)); + let meta_title = meta_title.as_ref().unwrap_or_else(|| { + zlog::debug!(logger => "No meta title found for {:?}", pretty_path(&file, &root_dir)); + &default_title + }); + let meta_title = format!("{} | {}", page_title, meta_title); + zlog::trace!(logger => "Updating {:?}", pretty_path(&file, &root_dir)); + let contents = contents.replace("#description#", meta_description); + let contents = TITLE_REGEX + .replace(&contents, |_: ®ex::Captures| { + format!("{}", meta_title) + }) + .to_string(); + // let contents = contents.replace("#title#", &meta_title); + std::fs::write(file, contents)?; + } + return Ok(()); + + fn pretty_path<'a>( + path: &'a std::path::PathBuf, + root: &'a std::path::PathBuf, + ) -> &'a std::path::Path { + &path.strip_prefix(&root).unwrap_or(&path) + } + const TITLE_REGEX: std::cell::LazyCell = + std::cell::LazyCell::new(|| Regex::new(r"\s*(.*?)\s*").unwrap()); + fn extract_title_from_page(contents: &str, pretty_path: &std::path::Path) -> String { + let title_tag_contents = &TITLE_REGEX + .captures(&contents) + .with_context(|| format!("Failed to find title in {:?}", pretty_path)) + .expect("Page has element")[1]; + let title = title_tag_contents + .trim() + .strip_suffix("- Zed") + .unwrap_or(title_tag_contents) + .trim() + .to_string(); + title + } +} diff --git a/crates/inline_completion/Cargo.toml b/crates/edit_prediction/Cargo.toml similarity index 82% rename from crates/inline_completion/Cargo.toml rename to crates/edit_prediction/Cargo.toml index 3a90875def1a8ce491765c24c18f432807292dc9..81c1e5dec20ce9032c4e1422f330b11da56fabe7 100644 --- a/crates/inline_completion/Cargo.toml +++ b/crates/edit_prediction/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "inline_completion" +name = "edit_prediction" version = "0.1.0" edition.workspace = true publish.workspace = true @@ -9,7 +9,7 @@ license = "GPL-3.0-or-later" workspace = true [lib] -path = "src/inline_completion.rs" +path = "src/edit_prediction.rs" [dependencies] client.workspace = true diff --git a/crates/edit_prediction/LICENSE-GPL b/crates/edit_prediction/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/edit_prediction/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/inline_completion/src/inline_completion.rs b/crates/edit_prediction/src/edit_prediction.rs similarity index 92% rename from crates/inline_completion/src/inline_completion.rs rename to crates/edit_prediction/src/edit_prediction.rs index c8f35bf16a116294edb5d1d2f5359733828e6995..c8502f75de5adac0a1bfdcb8cd8fe4444bb70f84 100644 --- a/crates/inline_completion/src/inline_completion.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -7,7 +7,7 @@ use project::Project; // TODO: Find a better home for `Direction`. // -// This should live in an ancestor crate of `editor` and `inline_completion`, +// This should live in an ancestor crate of `editor` and `edit_prediction`, // but at time of writing there isn't an obvious spot. #[derive(Copy, Clone, PartialEq, Eq)] pub enum Direction { @@ -16,7 +16,7 @@ pub enum Direction { } #[derive(Clone)] -pub struct InlineCompletion { +pub struct EditPrediction { /// The ID of the completion, if it has one. pub id: Option<SharedString>, pub edits: Vec<(Range<language::Anchor>, String)>, @@ -61,6 +61,10 @@ pub trait EditPredictionProvider: 'static + Sized { fn show_tab_accept_marker() -> bool { false } + fn supports_jump_to_edit() -> bool { + true + } + fn data_collection_state(&self, _cx: &App) -> DataCollectionState { DataCollectionState::Unsupported } @@ -102,10 +106,10 @@ pub trait EditPredictionProvider: 'static + Sized { buffer: &Entity<Buffer>, cursor_position: language::Anchor, cx: &mut Context<Self>, - ) -> Option<InlineCompletion>; + ) -> Option<EditPrediction>; } -pub trait InlineCompletionProviderHandle { +pub trait EditPredictionProviderHandle { fn name(&self) -> &'static str; fn display_name(&self) -> &'static str; fn is_enabled( @@ -116,6 +120,7 @@ pub trait InlineCompletionProviderHandle { ) -> bool; fn show_completions_in_menu(&self) -> bool; fn show_tab_accept_marker(&self) -> bool; + fn supports_jump_to_edit(&self) -> bool; fn data_collection_state(&self, cx: &App) -> DataCollectionState; fn usage(&self, cx: &App) -> Option<EditPredictionUsage>; fn toggle_data_collection(&self, cx: &mut App); @@ -143,10 +148,10 @@ pub trait InlineCompletionProviderHandle { buffer: &Entity<Buffer>, cursor_position: language::Anchor, cx: &mut App, - ) -> Option<InlineCompletion>; + ) -> Option<EditPrediction>; } -impl<T> InlineCompletionProviderHandle for Entity<T> +impl<T> EditPredictionProviderHandle for Entity<T> where T: EditPredictionProvider, { @@ -166,6 +171,10 @@ where T::show_tab_accept_marker() } + fn supports_jump_to_edit(&self) -> bool { + T::supports_jump_to_edit() + } + fn data_collection_state(&self, cx: &App) -> DataCollectionState { self.read(cx).data_collection_state(cx) } @@ -233,7 +242,7 @@ where buffer: &Entity<Buffer>, cursor_position: language::Anchor, cx: &mut App, - ) -> Option<InlineCompletion> { + ) -> Option<EditPrediction> { self.update(cx, |this, cx| this.suggest(buffer, cursor_position, cx)) } } diff --git a/crates/inline_completion_button/Cargo.toml b/crates/edit_prediction_button/Cargo.toml similarity index 86% rename from crates/inline_completion_button/Cargo.toml rename to crates/edit_prediction_button/Cargo.toml index c2a619d50075271be23aec9aa71dc554cf8075c0..07447280fa0d3b8041f1d35eba9c368288322c25 100644 --- a/crates/inline_completion_button/Cargo.toml +++ b/crates/edit_prediction_button/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "inline_completion_button" +name = "edit_prediction_button" version = "0.1.0" edition.workspace = true publish.workspace = true @@ -9,21 +9,23 @@ license = "GPL-3.0-or-later" workspace = true [lib] -path = "src/inline_completion_button.rs" +path = "src/edit_prediction_button.rs" doctest = false [dependencies] anyhow.workspace = true client.workspace = true +cloud_llm_client.workspace = true copilot.workspace = true editor.workspace = true feature_flags.workspace = true fs.workspace = true gpui.workspace = true indoc.workspace = true -inline_completion.workspace = true +edit_prediction.workspace = true language.workspace = true paths.workspace = true +project.workspace = true regex.workspace = true settings.workspace = true supermaven.workspace = true @@ -32,7 +34,6 @@ ui.workspace = true workspace-hack.workspace = true workspace.workspace = true zed_actions.workspace = true -zed_llm_client.workspace = true zeta.workspace = true [dev-dependencies] diff --git a/crates/edit_prediction_button/LICENSE-GPL b/crates/edit_prediction_button/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/edit_prediction_button/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/inline_completion_button/src/inline_completion_button.rs b/crates/edit_prediction_button/src/edit_prediction_button.rs similarity index 90% rename from crates/inline_completion_button/src/inline_completion_button.rs rename to crates/edit_prediction_button/src/edit_prediction_button.rs index 7e6b77b93deafbb971980d8b2d19f33f2fa348b4..3d3b43d71bc4a0914ed97dac24a278049f4c52f1 100644 --- a/crates/inline_completion_button/src/inline_completion_button.rs +++ b/crates/edit_prediction_button/src/edit_prediction_button.rs @@ -1,11 +1,8 @@ use anyhow::Result; use client::{UserStore, zed_urls}; +use cloud_llm_client::UsageLimit; use copilot::{Copilot, Status}; -use editor::{ - Editor, SelectionEffects, - actions::{ShowEditPrediction, ToggleEditPrediction}, - scroll::Autoscroll, -}; +use editor::{Editor, SelectionEffects, actions::ShowEditPrediction, scroll::Autoscroll}; use feature_flags::{FeatureFlagAppExt, PredictEditsRateCompletionsFeatureFlag}; use fs::Fs; use gpui::{ @@ -18,6 +15,7 @@ use language::{ EditPredictionsMode, File, Language, language_settings::{self, AllLanguageSettings, EditPredictionProvider, all_language_settings}, }; +use project::DisableAiSettings; use regex::Regex; use settings::{Settings, SettingsStore, update_settings_file}; use std::{ @@ -34,29 +32,29 @@ use workspace::{ notifications::NotificationId, }; use zed_actions::OpenBrowser; -use zed_llm_client::UsageLimit; use zeta::RateCompletions; actions!( edit_prediction, [ - /// Toggles the inline completion menu. + /// Toggles the edit prediction menu. ToggleMenu ] ); const COPILOT_SETTINGS_URL: &str = "https://github.com/settings/copilot"; +const PRIVACY_DOCS: &str = "https://zed.dev/docs/ai/privacy-and-security"; struct CopilotErrorToast; -pub struct InlineCompletionButton { +pub struct EditPredictionButton { editor_subscription: Option<(Subscription, usize)>, editor_enabled: Option<bool>, editor_show_predictions: bool, editor_focus_handle: Option<FocusHandle>, language: Option<Arc<Language>>, file: Option<Arc<dyn File>>, - edit_prediction_provider: Option<Arc<dyn inline_completion::InlineCompletionProviderHandle>>, + edit_prediction_provider: Option<Arc<dyn edit_prediction::EditPredictionProviderHandle>>, fs: Arc<dyn Fs>, user_store: Entity<UserStore>, popover_menu_handle: PopoverMenuHandle<ContextMenu>, @@ -69,8 +67,13 @@ enum SupermavenButtonStatus { Initializing, } -impl Render for InlineCompletionButton { +impl Render for EditPredictionButton { fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { + // Return empty div if AI is disabled + if DisableAiSettings::get_global(cx).disable_ai { + return div(); + } + let all_language_settings = all_language_settings(None, cx); match all_language_settings.edit_predictions.provider { @@ -193,13 +196,13 @@ impl Render for InlineCompletionButton { cx.open_url(activate_url.as_str()) }) .entry( - "Use Copilot", + "Use Zed AI", None, move |_, cx| { set_completion_provider( fs.clone(), cx, - EditPredictionProvider::Copilot, + EditPredictionProvider::Zed, ) }, ) @@ -239,21 +242,15 @@ impl Render for InlineCompletionButton { IconName::ZedPredictDisabled }; - let current_user_terms_accepted = - self.user_store.read(cx).current_user_has_accepted_terms(); - let has_subscription = self.user_store.read(cx).current_plan().is_some() - && self.user_store.read(cx).subscription_period().is_some(); - - if !has_subscription || !current_user_terms_accepted.unwrap_or(false) { - let signed_in = current_user_terms_accepted.is_some(); - let tooltip_meta = if signed_in { - if has_subscription { - "Read Terms of Service" - } else { + if zeta::should_show_upsell_modal(&self.user_store, cx) { + let tooltip_meta = if self.user_store.read(cx).current_user().is_some() { + if self.user_store.read(cx).has_accepted_terms_of_service() { "Choose a Plan" + } else { + "Accept the Terms of Service" } } else { - "Sign in to use" + "Sign In" }; return div().child( @@ -368,7 +365,7 @@ impl Render for InlineCompletionButton { } } -impl InlineCompletionButton { +impl EditPredictionButton { pub fn new( fs: Arc<dyn Fs>, user_store: Entity<UserStore>, @@ -390,9 +387,9 @@ impl InlineCompletionButton { language: None, file: None, edit_prediction_provider: None, + user_store, popover_menu_handle, fs, - user_store, } } @@ -403,15 +400,16 @@ impl InlineCompletionButton { ) -> Entity<ContextMenu> { let fs = self.fs.clone(); ContextMenu::build(window, cx, |menu, _, _| { - menu.entry("Sign In", None, copilot::initiate_sign_in) + menu.entry("Sign In to Copilot", None, copilot::initiate_sign_in) .entry("Disable Copilot", None, { let fs = fs.clone(); move |_window, cx| hide_copilot(fs.clone(), cx) }) - .entry("Use Supermaven", None, { + .separator() + .entry("Use Zed AI", None, { let fs = fs.clone(); move |_window, cx| { - set_completion_provider(fs.clone(), cx, EditPredictionProvider::Supermaven) + set_completion_provider(fs.clone(), cx, EditPredictionProvider::Zed) } }) }) @@ -439,9 +437,13 @@ impl InlineCompletionButton { if let Some(editor_focus_handle) = self.editor_focus_handle.clone() { let entry = ContextMenuEntry::new("This Buffer") .toggleable(IconPosition::Start, self.editor_show_predictions) - .action(Box::new(ToggleEditPrediction)) + .action(Box::new(editor::actions::ToggleEditPrediction)) .handler(move |window, cx| { - editor_focus_handle.dispatch_action(&ToggleEditPrediction, window, cx); + editor_focus_handle.dispatch_action( + &editor::actions::ToggleEditPrediction, + window, + cx, + ); }); match language_state.clone() { @@ -468,7 +470,7 @@ impl InlineCompletionButton { IconPosition::Start, None, move |_, cx| { - toggle_show_inline_completions_for_language(language.clone(), fs.clone(), cx) + toggle_show_edit_predictions_for_language(language.clone(), fs.clone(), cx) }, ); } @@ -476,17 +478,25 @@ impl InlineCompletionButton { let settings = AllLanguageSettings::get_global(cx); let globally_enabled = settings.show_edit_predictions(None, cx); - menu = menu.toggleable_entry("All Files", globally_enabled, IconPosition::Start, None, { - let fs = fs.clone(); - move |_, cx| toggle_inline_completions_globally(fs.clone(), cx) - }); + let entry = ContextMenuEntry::new("All Files") + .toggleable(IconPosition::Start, globally_enabled) + .action(workspace::ToggleEditPrediction.boxed_clone()) + .handler(|window, cx| { + window.dispatch_action(workspace::ToggleEditPrediction.boxed_clone(), cx) + }); + menu = menu.item(entry); let provider = settings.edit_predictions.provider; let current_mode = settings.edit_predictions_mode(); let subtle_mode = matches!(current_mode, EditPredictionsMode::Subtle); let eager_mode = matches!(current_mode, EditPredictionsMode::Eager); - if matches!(provider, EditPredictionProvider::Zed) { + if matches!( + provider, + EditPredictionProvider::Zed + | EditPredictionProvider::Copilot + | EditPredictionProvider::Supermaven + ) { menu = menu .separator() .header("Display Modes") @@ -518,7 +528,7 @@ impl InlineCompletionButton { ); } - menu = menu.separator().header("Privacy Settings"); + menu = menu.separator().header("Privacy"); if let Some(provider) = &self.edit_prediction_provider { let data_collection = provider.data_collection_state(cx); if data_collection.is_supported() { @@ -569,13 +579,15 @@ impl InlineCompletionButton { .child( Label::new(indoc!{ "Help us improve our open dataset model by sharing data from open source repositories. \ - Zed must detect a license file in your repo for this setting to take effect." + Zed must detect a license file in your repo for this setting to take effect. \ + Files with sensitive data and secrets are excluded by default." }) ) .child( h_flex() .items_start() .pt_2() + .pr_1() .flex_1() .gap_1p5() .border_t_1() @@ -635,6 +647,13 @@ impl InlineCompletionButton { .detach_and_log_err(cx); } }), + ).item( + ContextMenuEntry::new("View Documentation") + .icon(IconName::FileGeneric) + .icon_color(Color::Muted) + .handler(move |_, cx| { + cx.open_url(PRIVACY_DOCS); + }) ); if !self.editor_enabled.unwrap_or(true) { @@ -672,6 +691,13 @@ impl InlineCompletionButton { ) -> Entity<ContextMenu> { ContextMenu::build(window, cx, |menu, window, cx| { self.build_language_settings_menu(menu, window, cx) + .separator() + .entry("Use Zed AI instead", None, { + let fs = self.fs.clone(); + move |_window, cx| { + set_completion_provider(fs.clone(), cx, EditPredictionProvider::Zed) + } + }) .separator() .link( "Go to Copilot Settings", @@ -750,44 +776,24 @@ impl InlineCompletionButton { menu = menu .custom_entry( |_window, _cx| { - h_flex() - .gap_1() - .child( - Icon::new(IconName::Warning) - .size(IconSize::Small) - .color(Color::Warning), - ) - .child( - Label::new("Your GitHub account is less than 30 days old") - .size(LabelSize::Small) - .color(Color::Warning), - ) + Label::new("Your GitHub account is less than 30 days old.") + .size(LabelSize::Small) + .color(Color::Warning) .into_any_element() }, |_window, cx| cx.open_url(&zed_urls::account_url(cx)), ) - .entry( - "You need to upgrade to Zed Pro or contact us.", - None, - |_window, cx| cx.open_url(&zed_urls::account_url(cx)), - ) + .entry("Upgrade to Zed Pro or contact us.", None, |_window, cx| { + cx.open_url(&zed_urls::account_url(cx)) + }) .separator(); } else if self.user_store.read(cx).has_overdue_invoices() { menu = menu .custom_entry( |_window, _cx| { - h_flex() - .gap_1() - .child( - Icon::new(IconName::Warning) - .size(IconSize::Small) - .color(Color::Warning), - ) - .child( - Label::new("You have an outstanding invoice") - .size(LabelSize::Small) - .color(Color::Warning), - ) + Label::new("You have an outstanding invoice") + .size(LabelSize::Small) + .color(Color::Warning) .into_any_element() }, |_window, cx| { @@ -837,7 +843,7 @@ impl InlineCompletionButton { } } -impl StatusItemView for InlineCompletionButton { +impl StatusItemView for EditPredictionButton { fn set_active_pane_item( &mut self, item: Option<&dyn ItemHandle>, @@ -907,7 +913,7 @@ async fn open_disabled_globs_setting_in_editor( let settings = cx.global::<SettingsStore>(); - // Ensure that we always have "inline_completions { "disabled_globs": [] }" + // Ensure that we always have "edit_predictions { "disabled_globs": [] }" let edits = settings.edits_for_update::<AllLanguageSettings>(&text, |file| { file.edit_predictions .get_or_insert_with(Default::default) @@ -945,13 +951,6 @@ async fn open_disabled_globs_setting_in_editor( anyhow::Ok(()) } -fn toggle_inline_completions_globally(fs: Arc<dyn Fs>, cx: &mut App) { - let show_edit_predictions = all_language_settings(None, cx).show_edit_predictions(None, cx); - update_settings_file::<AllLanguageSettings>(fs, cx, move |file, _| { - file.defaults.show_edit_predictions = Some(!show_edit_predictions) - }); -} - fn set_completion_provider(fs: Arc<dyn Fs>, cx: &mut App, provider: EditPredictionProvider) { update_settings_file::<AllLanguageSettings>(fs, cx, move |file, _| { file.features @@ -960,7 +959,7 @@ fn set_completion_provider(fs: Arc<dyn Fs>, cx: &mut App, provider: EditPredicti }); } -fn toggle_show_inline_completions_for_language( +fn toggle_show_edit_predictions_for_language( language: Arc<Language>, fs: Arc<dyn Fs>, cx: &mut App, diff --git a/crates/editor/Cargo.toml b/crates/editor/Cargo.toml index 4d6939567eb8150883a4eb5e4e9e5b0949a421a0..339f98ae8bd88263f1fea12c535569864faae294 100644 --- a/crates/editor/Cargo.toml +++ b/crates/editor/Cargo.toml @@ -22,6 +22,7 @@ test-support = [ "theme/test-support", "util/test-support", "workspace/test-support", + "tree-sitter-c", "tree-sitter-rust", "tree-sitter-typescript", "tree-sitter-html", @@ -47,7 +48,7 @@ fs.workspace = true git.workspace = true gpui.workspace = true indoc.workspace = true -inline_completion.workspace = true +edit_prediction.workspace = true itertools.workspace = true language.workspace = true linkify.workspace = true @@ -76,6 +77,7 @@ telemetry.workspace = true text.workspace = true time.workspace = true theme.workspace = true +tree-sitter-c = { workspace = true, optional = true } tree-sitter-html = { workspace = true, optional = true } tree-sitter-rust = { workspace = true, optional = true } tree-sitter-typescript = { workspace = true, optional = true } @@ -106,10 +108,12 @@ settings = { workspace = true, features = ["test-support"] } tempfile.workspace = true text = { workspace = true, features = ["test-support"] } theme = { workspace = true, features = ["test-support"] } +tree-sitter-c.workspace = true tree-sitter-html.workspace = true tree-sitter-rust.workspace = true tree-sitter-typescript.workspace = true tree-sitter-yaml.workspace = true +tree-sitter-bash.workspace = true unindent.workspace = true util = { workspace = true, features = ["test-support"] } workspace = { workspace = true, features = ["test-support"] } diff --git a/crates/editor/src/actions.rs b/crates/editor/src/actions.rs index 70ec8ea00f52dccbab9e2a3ad4856599a8a94acf..39433b3c279e101f47ad4b2eed4d180f82a38997 100644 --- a/crates/editor/src/actions.rs +++ b/crates/editor/src/actions.rs @@ -1,6 +1,7 @@ //! This module contains all actions supported by [`Editor`]. use super::*; use gpui::{Action, actions}; +use project::project_settings::GoToDiagnosticSeverityFilter; use schemars::JsonSchema; use util::serde::default_true; @@ -258,6 +259,13 @@ pub struct SpawnNearestTask { pub reveal: task::RevealStrategy, } +#[derive(Clone, PartialEq, Action)] +#[action(no_json, no_register)] +pub struct DiffClipboardWithSelectionData { + pub clipboard_text: String, + pub editor: Entity<Editor>, +} + #[derive(Debug, PartialEq, Eq, Clone, Copy, Deserialize, Default)] pub enum UuidVersion { #[default] @@ -265,6 +273,24 @@ pub enum UuidVersion { V7, } +/// Goes to the next diagnostic in the file. +#[derive(PartialEq, Clone, Default, Debug, Deserialize, JsonSchema, Action)] +#[action(namespace = editor)] +#[serde(deny_unknown_fields)] +pub struct GoToDiagnostic { + #[serde(default)] + pub severity: GoToDiagnosticSeverityFilter, +} + +/// Goes to the previous diagnostic in the file. +#[derive(PartialEq, Clone, Default, Debug, Deserialize, JsonSchema, Action)] +#[action(namespace = editor)] +#[serde(deny_unknown_fields)] +pub struct GoToPreviousDiagnostic { + #[serde(default)] + pub severity: GoToDiagnosticSeverityFilter, +} + actions!( debugger, [ @@ -289,9 +315,8 @@ actions!( [ /// Accepts the full edit prediction. AcceptEditPrediction, - /// Accepts a partial Copilot suggestion. - AcceptPartialCopilotSuggestion, /// Accepts a partial edit prediction. + #[action(deprecated_aliases = ["editor::AcceptPartialCopilotSuggestion"])] AcceptPartialEditPrediction, /// Adds a cursor above the current selection. AddSelectionAbove, @@ -303,6 +328,8 @@ actions!( ApplyDiffHunk, /// Deletes the character before the cursor. Backspace, + /// Shows git blame information for the current line. + BlameHover, /// Cancels the current operation. Cancel, /// Cancels the running flycheck operation. @@ -337,6 +364,8 @@ actions!( ConvertToLowerCase, /// Toggles the case of selected text. ConvertToOppositeCase, + /// Converts selected text to sentence case. + ConvertToSentenceCase, /// Converts selected text to snake_case. ConvertToSnakeCase, /// Converts selected text to Title Case. @@ -377,6 +406,8 @@ actions!( DeleteToNextSubwordEnd, /// Deletes to the start of the previous subword. DeleteToPreviousSubwordStart, + /// Diffs the text stored in the clipboard against the current selection. + DiffClipboardWithSelection, /// Displays names of all active cursors. DisplayCursorNames, /// Duplicates the current line below. @@ -406,10 +437,14 @@ actions!( FoldRecursive, /// Folds the selected ranges. FoldSelectedRanges, + /// Toggles focus back to the last active buffer. + ToggleFocus, /// Toggles folding at the current position. ToggleFold, /// Toggles recursive folding at the current position. ToggleFoldRecursive, + /// Toggles all folds in a buffer or all excerpts in multibuffer. + ToggleFoldAll, /// Formats the entire document. Format, /// Formats only the selected text. @@ -422,8 +457,6 @@ actions!( GoToDefinition, /// Goes to definition in a split pane. GoToDefinitionSplit, - /// Goes to the next diagnostic in the file. - GoToDiagnostic, /// Goes to the next diff hunk. GoToHunk, /// Goes to the previous diff hunk. @@ -438,8 +471,6 @@ actions!( GoToParentModule, /// Goes to the previous change in the file. GoToPreviousChange, - /// Goes to the previous diagnostic in the file. - GoToPreviousDiagnostic, /// Goes to the type definition of the symbol at cursor. GoToTypeDefinition, /// Goes to type definition in a split pane. @@ -714,5 +745,6 @@ actions!( UniqueLinesCaseInsensitive, /// Removes duplicate lines (case-sensitive). UniqueLinesCaseSensitive, + UnwrapSyntaxNode ] ); diff --git a/crates/editor/src/clangd_ext.rs b/crates/editor/src/clangd_ext.rs index b745bf8c37c4a8b3562e6e6c7da123faa184368b..3239fdc653e0e2acdbdaa3396e30c0546ef259cf 100644 --- a/crates/editor/src/clangd_ext.rs +++ b/crates/editor/src/clangd_ext.rs @@ -29,16 +29,14 @@ pub fn switch_source_header( return; }; - let server_lookup = - find_specific_language_server_in_selection(editor, cx, is_c_language, CLANGD_SERVER_NAME); + let Some((_, _, server_to_query, buffer)) = + find_specific_language_server_in_selection(editor, cx, is_c_language, CLANGD_SERVER_NAME) + else { + return; + }; let project = project.clone(); let upstream_client = project.read(cx).lsp_store().read(cx).upstream_client(); cx.spawn_in(window, async move |_editor, cx| { - let Some((_, _, server_to_query, buffer)) = - server_lookup.await - else { - return Ok(()); - }; let source_file = buffer.read_with(cx, |buffer, _| { buffer.file().map(|file| file.path()).map(|path| path.to_string_lossy().to_string()).unwrap_or_else(|| "Unknown".to_string()) })?; diff --git a/crates/editor/src/code_completion_tests.rs b/crates/editor/src/code_completion_tests.rs index 4f9822b597eee13a8e3bad540b833e3158e2123e..fd8db29584d8eb6944ff674dd8bf5d860ce32428 100644 --- a/crates/editor/src/code_completion_tests.rs +++ b/crates/editor/src/code_completion_tests.rs @@ -94,7 +94,7 @@ async fn test_fuzzy_score(cx: &mut TestAppContext) { filter_and_sort_matches("set_text", &completions, SnippetSortOrder::Top, cx).await; assert_eq!(matches[0].string, "set_text"); assert_eq!(matches[1].string, "set_text_style_refinement"); - assert_eq!(matches[2].string, "set_context_menu_options"); + assert_eq!(matches[2].string, "set_placeholder_text"); } // fuzzy filter text over label, sort_text and sort_kind @@ -216,6 +216,28 @@ async fn test_sort_positions(cx: &mut TestAppContext) { assert_eq!(matches[0].string, "rounded-full"); } +#[gpui::test] +async fn test_fuzzy_over_sort_positions(cx: &mut TestAppContext) { + let completions = vec![ + CompletionBuilder::variable("lsp_document_colors", None, "7fffffff"), // 0.29 fuzzy score + CompletionBuilder::function( + "language_servers_running_disk_based_diagnostics", + None, + "7fffffff", + ), // 0.168 fuzzy score + CompletionBuilder::function("code_lens", None, "7fffffff"), // 3.2 fuzzy score + CompletionBuilder::variable("lsp_code_lens", None, "7fffffff"), // 3.2 fuzzy score + CompletionBuilder::function("fetch_code_lens", None, "7fffffff"), // 3.2 fuzzy score + ]; + + let matches = + filter_and_sort_matches("lens", &completions, SnippetSortOrder::default(), cx).await; + + assert_eq!(matches[0].string, "code_lens"); + assert_eq!(matches[1].string, "lsp_code_lens"); + assert_eq!(matches[2].string, "fetch_code_lens"); +} + async fn test_for_each_prefix<F>( target: &str, completions: &Vec<Completion>, diff --git a/crates/editor/src/code_context_menus.rs b/crates/editor/src/code_context_menus.rs index 8fbae8d6052d89299b10f3cd0c971af79abd3c90..4ae2a14ca730dafa7cfecd9e9b3bacbe3f7bc47b 100644 --- a/crates/editor/src/code_context_menus.rs +++ b/crates/editor/src/code_context_menus.rs @@ -1057,9 +1057,9 @@ impl CompletionsMenu { enum MatchTier<'a> { WordStartMatch { sort_exact: Reverse<i32>, - sort_positions: Vec<usize>, sort_snippet: Reverse<i32>, sort_score: Reverse<OrderedFloat<f64>>, + sort_positions: Vec<usize>, sort_text: Option<&'a str>, sort_kind: usize, sort_label: &'a str, @@ -1074,6 +1074,20 @@ impl CompletionsMenu { .and_then(|q| q.chars().next()) .and_then(|c| c.to_lowercase().next()); + if snippet_sort_order == SnippetSortOrder::None { + matches.retain(|string_match| { + let completion = &completions[string_match.candidate_id]; + + let is_snippet = matches!( + &completion.source, + CompletionSource::Lsp { lsp_completion, .. } + if lsp_completion.kind == Some(CompletionItemKind::SNIPPET) + ); + + !is_snippet + }); + } + matches.sort_unstable_by_key(|string_match| { let completion = &completions[string_match.candidate_id]; @@ -1112,6 +1126,7 @@ impl CompletionsMenu { SnippetSortOrder::Top => Reverse(if is_snippet { 1 } else { 0 }), SnippetSortOrder::Bottom => Reverse(if is_snippet { 0 } else { 1 }), SnippetSortOrder::Inline => Reverse(0), + SnippetSortOrder::None => Reverse(0), }; let sort_positions = string_match.positions.clone(); let sort_exact = Reverse(if Some(completion.label.filter_text()) == query { @@ -1122,9 +1137,9 @@ impl CompletionsMenu { MatchTier::WordStartMatch { sort_exact, - sort_positions, sort_snippet, sort_score, + sort_positions, sort_text, sort_kind, sort_label, @@ -1369,7 +1384,7 @@ impl CodeActionsMenu { } } - fn visible(&self) -> bool { + pub fn visible(&self) -> bool { !self.actions.is_empty() } diff --git a/crates/editor/src/display_map.rs b/crates/editor/src/display_map.rs index aa2408d6d9b616f2b1436d9bc66f42bd87506d19..a16e516a70c9638965585cc5d6a23d8a9f67b639 100644 --- a/crates/editor/src/display_map.rs +++ b/crates/editor/src/display_map.rs @@ -271,7 +271,6 @@ impl DisplayMap { height: Some(height), style, priority, - render_in_minimap: true, } }), ); @@ -636,7 +635,7 @@ pub(crate) struct Highlights<'a> { } #[derive(Clone, Copy, Debug)] -pub struct InlineCompletionStyles { +pub struct EditPredictionStyles { pub insertion: HighlightStyle, pub whitespace: HighlightStyle, } @@ -644,7 +643,7 @@ pub struct InlineCompletionStyles { #[derive(Default, Debug, Clone, Copy)] pub struct HighlightStyles { pub inlay_hint: Option<HighlightStyle>, - pub inline_completion: Option<InlineCompletionStyles>, + pub edit_prediction: Option<EditPredictionStyles>, } #[derive(Clone)] @@ -959,7 +958,7 @@ impl DisplaySnapshot { language_aware, HighlightStyles { inlay_hint: Some(editor_style.inlay_hints_style), - inline_completion: Some(editor_style.inline_completion_styles), + edit_prediction: Some(editor_style.edit_prediction_styles), }, ) .flat_map(|chunk| { @@ -1663,7 +1662,6 @@ pub mod tests { height: Some(height), render: Arc::new(|_| div().into_any()), priority, - render_in_minimap: true, } }) .collect::<Vec<_>>(); @@ -2029,7 +2027,6 @@ pub mod tests { style: BlockStyle::Sticky, render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }], cx, ); @@ -2039,7 +2036,7 @@ pub mod tests { map.update(cx, |map, cx| { map.splice_inlays( &[], - vec![Inlay::inline_completion( + vec![Inlay::edit_prediction( 0, buffer_snapshot.anchor_after(0), "\n", @@ -2227,7 +2224,6 @@ pub mod tests { style: BlockStyle::Sticky, render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, BlockProperties { placement: BlockPlacement::Below( @@ -2237,7 +2233,6 @@ pub mod tests { style: BlockStyle::Sticky, render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, ], cx, @@ -2344,7 +2339,6 @@ pub mod tests { style: BlockStyle::Sticky, render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }], cx, ) @@ -2420,7 +2414,6 @@ pub mod tests { style: BlockStyle::Fixed, render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }], cx, ); diff --git a/crates/editor/src/display_map/block_map.rs b/crates/editor/src/display_map/block_map.rs index ea754da03f70ff87e28bb73a614fad6b66d7e4c2..e25c02432d10c16969c963455188474e10ed04a6 100644 --- a/crates/editor/src/display_map/block_map.rs +++ b/crates/editor/src/display_map/block_map.rs @@ -22,7 +22,7 @@ use std::{ atomic::{AtomicUsize, Ordering::SeqCst}, }, }; -use sum_tree::{Bias, SumTree, Summary, TreeMap}; +use sum_tree::{Bias, Dimensions, SumTree, Summary, TreeMap}; use text::{BufferId, Edit}; use ui::ElementId; @@ -193,7 +193,6 @@ pub struct CustomBlock { style: BlockStyle, render: Arc<Mutex<RenderBlock>>, priority: usize, - pub(crate) render_in_minimap: bool, } #[derive(Clone)] @@ -205,7 +204,6 @@ pub struct BlockProperties<P> { pub style: BlockStyle, pub render: RenderBlock, pub priority: usize, - pub render_in_minimap: bool, } impl<P: Debug> Debug for BlockProperties<P> { @@ -418,7 +416,7 @@ struct TransformSummary { } pub struct BlockChunks<'a> { - transforms: sum_tree::Cursor<'a, Transform, (BlockRow, WrapRow)>, + transforms: sum_tree::Cursor<'a, Transform, Dimensions<BlockRow, WrapRow>>, input_chunks: wrap_map::WrapChunks<'a>, input_chunk: Chunk<'a>, output_row: u32, @@ -428,7 +426,7 @@ pub struct BlockChunks<'a> { #[derive(Clone)] pub struct BlockRows<'a> { - transforms: sum_tree::Cursor<'a, Transform, (BlockRow, WrapRow)>, + transforms: sum_tree::Cursor<'a, Transform, Dimensions<BlockRow, WrapRow>>, input_rows: wrap_map::WrapRows<'a>, output_row: BlockRow, started: bool, @@ -526,10 +524,10 @@ impl BlockMap { // * Isomorphic transforms that end *at* the start of the edit // * Below blocks that end at the start of the edit // However, if we hit a replace block that ends at the start of the edit we want to reconstruct it. - new_transforms.append(cursor.slice(&old_start, Bias::Left, &()), &()); + new_transforms.append(cursor.slice(&old_start, Bias::Left), &()); if let Some(transform) = cursor.item() { if transform.summary.input_rows > 0 - && cursor.end(&()) == old_start + && cursor.end() == old_start && transform .block .as_ref() @@ -537,13 +535,13 @@ impl BlockMap { { // Preserve the transform (push and next) new_transforms.push(transform.clone(), &()); - cursor.next(&()); + cursor.next(); // Preserve below blocks at end of edit while let Some(transform) = cursor.item() { if transform.block.as_ref().map_or(false, |b| b.place_below()) { new_transforms.push(transform.clone(), &()); - cursor.next(&()); + cursor.next(); } else { break; } @@ -581,8 +579,8 @@ impl BlockMap { let mut new_end = WrapRow(edit.new.end); loop { // Seek to the transform starting at or after the end of the edit - cursor.seek(&old_end, Bias::Left, &()); - cursor.next(&()); + cursor.seek(&old_end, Bias::Left); + cursor.next(); // Extend edit to the end of the discarded transform so it is reconstructed in full let transform_rows_after_edit = cursor.start().0 - old_end.0; @@ -594,8 +592,8 @@ impl BlockMap { if next_edit.old.start <= cursor.start().0 { old_end = WrapRow(next_edit.old.end); new_end = WrapRow(next_edit.new.end); - cursor.seek(&old_end, Bias::Left, &()); - cursor.next(&()); + cursor.seek(&old_end, Bias::Left); + cursor.next(); edits.next(); } else { break; @@ -610,7 +608,7 @@ impl BlockMap { // Discard below blocks at the end of the edit. They'll be reconstructed. while let Some(transform) = cursor.item() { if transform.block.as_ref().map_or(false, |b| b.place_below()) { - cursor.next(&()); + cursor.next(); } else { break; } @@ -722,7 +720,7 @@ impl BlockMap { push_isomorphic(&mut new_transforms, rows_after_last_block, wrap_snapshot); } - new_transforms.append(cursor.suffix(&()), &()); + new_transforms.append(cursor.suffix(), &()); debug_assert_eq!( new_transforms.summary().input_rows, wrap_snapshot.max_point().row() + 1 @@ -972,8 +970,8 @@ impl BlockMapReader<'_> { .unwrap_or(self.wrap_snapshot.max_point().row() + 1), ); - let mut cursor = self.transforms.cursor::<(WrapRow, BlockRow)>(&()); - cursor.seek(&start_wrap_row, Bias::Left, &()); + let mut cursor = self.transforms.cursor::<Dimensions<WrapRow, BlockRow>>(&()); + cursor.seek(&start_wrap_row, Bias::Left); while let Some(transform) = cursor.item() { if cursor.start().0 > end_wrap_row { break; @@ -984,7 +982,7 @@ impl BlockMapReader<'_> { return Some(cursor.start().1); } } - cursor.next(&()); + cursor.next(); } None @@ -1044,7 +1042,6 @@ impl BlockMapWriter<'_> { render: Arc::new(Mutex::new(block.render)), style: block.style, priority: block.priority, - render_in_minimap: block.render_in_minimap, }); self.0.custom_blocks.insert(block_ix, new_block.clone()); self.0.custom_blocks_by_id.insert(id, new_block); @@ -1079,7 +1076,6 @@ impl BlockMapWriter<'_> { style: block.style, render: block.render.clone(), priority: block.priority, - render_in_minimap: block.render_in_minimap, }; let new_block = Arc::new(new_block); *block = new_block.clone(); @@ -1296,8 +1292,8 @@ impl BlockSnapshot { ) -> BlockChunks<'a> { let max_output_row = cmp::min(rows.end, self.transforms.summary().output_rows); - let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); - cursor.seek(&BlockRow(rows.start), Bias::Right, &()); + let mut cursor = self.transforms.cursor::<Dimensions<BlockRow, WrapRow>>(&()); + cursor.seek(&BlockRow(rows.start), Bias::Right); let transform_output_start = cursor.start().0.0; let transform_input_start = cursor.start().1.0; @@ -1328,9 +1324,9 @@ impl BlockSnapshot { } pub(super) fn row_infos(&self, start_row: BlockRow) -> BlockRows<'_> { - let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); - cursor.seek(&start_row, Bias::Right, &()); - let (output_start, input_start) = cursor.start(); + let mut cursor = self.transforms.cursor::<Dimensions<BlockRow, WrapRow>>(&()); + cursor.seek(&start_row, Bias::Right); + let Dimensions(output_start, input_start, _) = cursor.start(); let overshoot = if cursor .item() .map_or(false, |transform| transform.block.is_none()) @@ -1350,9 +1346,9 @@ impl BlockSnapshot { pub fn blocks_in_range(&self, rows: Range<u32>) -> impl Iterator<Item = (u32, &Block)> { let mut cursor = self.transforms.cursor::<BlockRow>(&()); - cursor.seek(&BlockRow(rows.start), Bias::Left, &()); - while cursor.start().0 < rows.start && cursor.end(&()).0 <= rows.start { - cursor.next(&()); + cursor.seek(&BlockRow(rows.start), Bias::Left); + while cursor.start().0 < rows.start && cursor.end().0 <= rows.start { + cursor.next(); } std::iter::from_fn(move || { @@ -1368,10 +1364,10 @@ impl BlockSnapshot { break; } if let Some(block) = &transform.block { - cursor.next(&()); + cursor.next(); return Some((start_row, block)); } else { - cursor.next(&()); + cursor.next(); } } None @@ -1381,7 +1377,7 @@ impl BlockSnapshot { pub fn sticky_header_excerpt(&self, position: f32) -> Option<StickyHeaderExcerpt<'_>> { let top_row = position as u32; let mut cursor = self.transforms.cursor::<BlockRow>(&()); - cursor.seek(&BlockRow(top_row), Bias::Right, &()); + cursor.seek(&BlockRow(top_row), Bias::Right); while let Some(transform) = cursor.item() { match &transform.block { @@ -1390,7 +1386,7 @@ impl BlockSnapshot { } Some(block) if block.is_buffer_header() => return None, _ => { - cursor.prev(&()); + cursor.prev(); continue; } } @@ -1418,7 +1414,7 @@ impl BlockSnapshot { let wrap_row = WrapRow(wrap_point.row()); let mut cursor = self.transforms.cursor::<WrapRow>(&()); - cursor.seek(&wrap_row, Bias::Left, &()); + cursor.seek(&wrap_row, Bias::Left); while let Some(transform) = cursor.item() { if let Some(block) = transform.block.as_ref() { @@ -1429,7 +1425,7 @@ impl BlockSnapshot { break; } - cursor.next(&()); + cursor.next(); } None @@ -1445,19 +1441,19 @@ impl BlockSnapshot { } pub fn longest_row_in_range(&self, range: Range<BlockRow>) -> BlockRow { - let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); - cursor.seek(&range.start, Bias::Right, &()); + let mut cursor = self.transforms.cursor::<Dimensions<BlockRow, WrapRow>>(&()); + cursor.seek(&range.start, Bias::Right); let mut longest_row = range.start; let mut longest_row_chars = 0; if let Some(transform) = cursor.item() { if transform.block.is_none() { - let (output_start, input_start) = cursor.start(); + let Dimensions(output_start, input_start, _) = cursor.start(); let overshoot = range.start.0 - output_start.0; let wrap_start_row = input_start.0 + overshoot; let wrap_end_row = cmp::min( input_start.0 + (range.end.0 - output_start.0), - cursor.end(&()).1.0, + cursor.end().1.0, ); let summary = self .wrap_snapshot @@ -1465,12 +1461,12 @@ impl BlockSnapshot { longest_row = BlockRow(range.start.0 + summary.longest_row); longest_row_chars = summary.longest_row_chars; } - cursor.next(&()); + cursor.next(); } let cursor_start_row = cursor.start().0; if range.end > cursor_start_row { - let summary = cursor.summary::<_, TransformSummary>(&range.end, Bias::Right, &()); + let summary = cursor.summary::<_, TransformSummary>(&range.end, Bias::Right); if summary.longest_row_chars > longest_row_chars { longest_row = BlockRow(cursor_start_row.0 + summary.longest_row); longest_row_chars = summary.longest_row_chars; @@ -1478,7 +1474,7 @@ impl BlockSnapshot { if let Some(transform) = cursor.item() { if transform.block.is_none() { - let (output_start, input_start) = cursor.start(); + let Dimensions(output_start, input_start, _) = cursor.start(); let overshoot = range.end.0 - output_start.0; let wrap_start_row = input_start.0; let wrap_end_row = input_start.0 + overshoot; @@ -1496,10 +1492,10 @@ impl BlockSnapshot { } pub(super) fn line_len(&self, row: BlockRow) -> u32 { - let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); - cursor.seek(&BlockRow(row.0), Bias::Right, &()); + let mut cursor = self.transforms.cursor::<Dimensions<BlockRow, WrapRow>>(&()); + cursor.seek(&BlockRow(row.0), Bias::Right); if let Some(transform) = cursor.item() { - let (output_start, input_start) = cursor.start(); + let Dimensions(output_start, input_start, _) = cursor.start(); let overshoot = row.0 - output_start.0; if transform.block.is_some() { 0 @@ -1514,14 +1510,14 @@ impl BlockSnapshot { } pub(super) fn is_block_line(&self, row: BlockRow) -> bool { - let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); - cursor.seek(&row, Bias::Right, &()); + let mut cursor = self.transforms.cursor::<Dimensions<BlockRow, WrapRow>>(&()); + cursor.seek(&row, Bias::Right); cursor.item().map_or(false, |t| t.block.is_some()) } pub(super) fn is_folded_buffer_header(&self, row: BlockRow) -> bool { - let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); - cursor.seek(&row, Bias::Right, &()); + let mut cursor = self.transforms.cursor::<Dimensions<BlockRow, WrapRow>>(&()); + cursor.seek(&row, Bias::Right); let Some(transform) = cursor.item() else { return false; }; @@ -1532,8 +1528,8 @@ impl BlockSnapshot { let wrap_point = self .wrap_snapshot .make_wrap_point(Point::new(row.0, 0), Bias::Left); - let mut cursor = self.transforms.cursor::<(WrapRow, BlockRow)>(&()); - cursor.seek(&WrapRow(wrap_point.row()), Bias::Right, &()); + let mut cursor = self.transforms.cursor::<Dimensions<WrapRow, BlockRow>>(&()); + cursor.seek(&WrapRow(wrap_point.row()), Bias::Right); cursor.item().map_or(false, |transform| { transform .block @@ -1543,18 +1539,18 @@ impl BlockSnapshot { } pub fn clip_point(&self, point: BlockPoint, bias: Bias) -> BlockPoint { - let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); - cursor.seek(&BlockRow(point.row), Bias::Right, &()); + let mut cursor = self.transforms.cursor::<Dimensions<BlockRow, WrapRow>>(&()); + cursor.seek(&BlockRow(point.row), Bias::Right); let max_input_row = WrapRow(self.transforms.summary().input_rows); let mut search_left = - (bias == Bias::Left && cursor.start().1.0 > 0) || cursor.end(&()).1 == max_input_row; + (bias == Bias::Left && cursor.start().1.0 > 0) || cursor.end().1 == max_input_row; let mut reversed = false; loop { if let Some(transform) = cursor.item() { - let (output_start_row, input_start_row) = cursor.start(); - let (output_end_row, input_end_row) = cursor.end(&()); + let Dimensions(output_start_row, input_start_row, _) = cursor.start(); + let Dimensions(output_end_row, input_end_row, _) = cursor.end(); let output_start = Point::new(output_start_row.0, 0); let input_start = Point::new(input_start_row.0, 0); let input_end = Point::new(input_end_row.0, 0); @@ -1588,28 +1584,28 @@ impl BlockSnapshot { } if search_left { - cursor.prev(&()); + cursor.prev(); } else { - cursor.next(&()); + cursor.next(); } } else if reversed { return self.max_point(); } else { reversed = true; search_left = !search_left; - cursor.seek(&BlockRow(point.row), Bias::Right, &()); + cursor.seek(&BlockRow(point.row), Bias::Right); } } } pub fn to_block_point(&self, wrap_point: WrapPoint) -> BlockPoint { - let mut cursor = self.transforms.cursor::<(WrapRow, BlockRow)>(&()); - cursor.seek(&WrapRow(wrap_point.row()), Bias::Right, &()); + let mut cursor = self.transforms.cursor::<Dimensions<WrapRow, BlockRow>>(&()); + cursor.seek(&WrapRow(wrap_point.row()), Bias::Right); if let Some(transform) = cursor.item() { if transform.block.is_some() { BlockPoint::new(cursor.start().1.0, 0) } else { - let (input_start_row, output_start_row) = cursor.start(); + let Dimensions(input_start_row, output_start_row, _) = cursor.start(); let input_start = Point::new(input_start_row.0, 0); let output_start = Point::new(output_start_row.0, 0); let input_overshoot = wrap_point.0 - input_start; @@ -1621,8 +1617,8 @@ impl BlockSnapshot { } pub fn to_wrap_point(&self, block_point: BlockPoint, bias: Bias) -> WrapPoint { - let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); - cursor.seek(&BlockRow(block_point.row), Bias::Right, &()); + let mut cursor = self.transforms.cursor::<Dimensions<BlockRow, WrapRow>>(&()); + cursor.seek(&BlockRow(block_point.row), Bias::Right); if let Some(transform) = cursor.item() { match transform.block.as_ref() { Some(block) => { @@ -1634,7 +1630,7 @@ impl BlockSnapshot { } else if bias == Bias::Left { WrapPoint::new(cursor.start().1.0, 0) } else { - let wrap_row = cursor.end(&()).1.0 - 1; + let wrap_row = cursor.end().1.0 - 1; WrapPoint::new(wrap_row, self.wrap_snapshot.line_len(wrap_row)) } } @@ -1654,14 +1650,14 @@ impl BlockChunks<'_> { /// Go to the next transform fn advance(&mut self) { self.input_chunk = Chunk::default(); - self.transforms.next(&()); + self.transforms.next(); while let Some(transform) = self.transforms.item() { if transform .block .as_ref() .map_or(false, |block| block.height() == 0) { - self.transforms.next(&()); + self.transforms.next(); } else { break; } @@ -1676,7 +1672,7 @@ impl BlockChunks<'_> { let start_output_row = self.transforms.start().0.0; if start_output_row < self.max_output_row { let end_input_row = cmp::min( - self.transforms.end(&()).1.0, + self.transforms.end().1.0, start_input_row + (self.max_output_row - start_output_row), ); self.input_chunks.seek(start_input_row..end_input_row); @@ -1700,7 +1696,7 @@ impl<'a> Iterator for BlockChunks<'a> { let transform = self.transforms.item()?; if transform.block.is_some() { let block_start = self.transforms.start().0.0; - let mut block_end = self.transforms.end(&()).0.0; + let mut block_end = self.transforms.end().0.0; self.advance(); if self.transforms.item().is_none() { block_end -= 1; @@ -1735,7 +1731,7 @@ impl<'a> Iterator for BlockChunks<'a> { } } - let transform_end = self.transforms.end(&()).0.0; + let transform_end = self.transforms.end().0.0; let (prefix_rows, prefix_bytes) = offset_for_row(self.input_chunk.text, transform_end - self.output_row); self.output_row += prefix_rows; @@ -1774,15 +1770,15 @@ impl Iterator for BlockRows<'_> { self.started = true; } - if self.output_row.0 >= self.transforms.end(&()).0.0 { - self.transforms.next(&()); + if self.output_row.0 >= self.transforms.end().0.0 { + self.transforms.next(); while let Some(transform) = self.transforms.item() { if transform .block .as_ref() .map_or(false, |block| block.height() == 0) { - self.transforms.next(&()); + self.transforms.next(); } else { break; } @@ -1976,7 +1972,6 @@ mod tests { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, BlockProperties { style: BlockStyle::Fixed, @@ -1984,7 +1979,6 @@ mod tests { height: Some(2), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, BlockProperties { style: BlockStyle::Fixed, @@ -1992,7 +1986,6 @@ mod tests { height: Some(3), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, ]); @@ -2217,7 +2210,6 @@ mod tests { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, BlockProperties { style: BlockStyle::Fixed, @@ -2225,7 +2217,6 @@ mod tests { height: Some(2), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, BlockProperties { style: BlockStyle::Fixed, @@ -2233,7 +2224,6 @@ mod tests { height: Some(3), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, ]); @@ -2322,7 +2312,6 @@ mod tests { render: Arc::new(|_| div().into_any()), height: Some(1), priority: 0, - render_in_minimap: true, }, BlockProperties { style: BlockStyle::Fixed, @@ -2330,7 +2319,6 @@ mod tests { render: Arc::new(|_| div().into_any()), height: Some(1), priority: 0, - render_in_minimap: true, }, ]); @@ -2370,7 +2358,6 @@ mod tests { height: Some(4), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }])[0]; let blocks_snapshot = block_map.read(wraps_snapshot, Default::default()); @@ -2424,7 +2411,6 @@ mod tests { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, BlockProperties { style: BlockStyle::Fixed, @@ -2432,7 +2418,6 @@ mod tests { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, BlockProperties { style: BlockStyle::Fixed, @@ -2440,7 +2425,6 @@ mod tests { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, ]); let blocks_snapshot = block_map.read(wraps_snapshot.clone(), Default::default()); @@ -2455,7 +2439,6 @@ mod tests { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, BlockProperties { style: BlockStyle::Fixed, @@ -2463,7 +2446,6 @@ mod tests { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, BlockProperties { style: BlockStyle::Fixed, @@ -2471,7 +2453,6 @@ mod tests { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, ]); let blocks_snapshot = block_map.read(wraps_snapshot.clone(), Default::default()); @@ -2571,7 +2552,6 @@ mod tests { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, BlockProperties { style: BlockStyle::Fixed, @@ -2579,7 +2559,6 @@ mod tests { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, BlockProperties { style: BlockStyle::Fixed, @@ -2587,7 +2566,6 @@ mod tests { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, ]); let excerpt_blocks_3 = writer.insert(vec![ @@ -2597,7 +2575,6 @@ mod tests { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, BlockProperties { style: BlockStyle::Fixed, @@ -2605,7 +2582,6 @@ mod tests { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, ]); @@ -2653,7 +2629,6 @@ mod tests { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }]); let blocks_snapshot = block_map.read(wrap_snapshot.clone(), Patch::default()); let blocks = blocks_snapshot @@ -3011,7 +2986,6 @@ mod tests { height: Some(height), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, } }) .collect::<Vec<_>>(); @@ -3032,7 +3006,6 @@ mod tests { style: props.style, render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, })); for (block_properties, block_id) in block_properties.iter().zip(block_ids) { @@ -3557,7 +3530,6 @@ mod tests { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }])[0]; let blocks_snapshot = block_map.read(wraps_snapshot.clone(), Default::default()); diff --git a/crates/editor/src/display_map/crease_map.rs b/crates/editor/src/display_map/crease_map.rs index e6fe4270eccf86dacf38fa1e06a2c5a4f6081546..bdac982fa785e7b6628352572ab143fd978938b2 100644 --- a/crates/editor/src/display_map/crease_map.rs +++ b/crates/editor/src/display_map/crease_map.rs @@ -52,15 +52,15 @@ impl CreaseSnapshot { ) -> Option<&'a Crease<Anchor>> { let start = snapshot.anchor_before(Point::new(row.0, 0)); let mut cursor = self.creases.cursor::<ItemSummary>(snapshot); - cursor.seek(&start, Bias::Left, snapshot); + cursor.seek(&start, Bias::Left); while let Some(item) = cursor.item() { match Ord::cmp(&item.crease.range().start.to_point(snapshot).row, &row.0) { - Ordering::Less => cursor.next(snapshot), + Ordering::Less => cursor.next(), Ordering::Equal => { if item.crease.range().start.is_valid(snapshot) { return Some(&item.crease); } else { - cursor.next(snapshot); + cursor.next(); } } Ordering::Greater => break, @@ -76,11 +76,11 @@ impl CreaseSnapshot { ) -> impl 'a + Iterator<Item = &'a Crease<Anchor>> { let start = snapshot.anchor_before(Point::new(range.start.0, 0)); let mut cursor = self.creases.cursor::<ItemSummary>(snapshot); - cursor.seek(&start, Bias::Left, snapshot); + cursor.seek(&start, Bias::Left); std::iter::from_fn(move || { while let Some(item) = cursor.item() { - cursor.next(snapshot); + cursor.next(); let crease_range = item.crease.range(); let crease_start = crease_range.start.to_point(snapshot); let crease_end = crease_range.end.to_point(snapshot); @@ -102,13 +102,13 @@ impl CreaseSnapshot { let mut cursor = self.creases.cursor::<ItemSummary>(snapshot); let mut results = Vec::new(); - cursor.next(snapshot); + cursor.next(); while let Some(item) = cursor.item() { let crease_range = item.crease.range(); let start_point = crease_range.start.to_point(snapshot); let end_point = crease_range.end.to_point(snapshot); results.push((item.id, start_point..end_point)); - cursor.next(snapshot); + cursor.next(); } results @@ -298,7 +298,7 @@ impl CreaseMap { let mut cursor = self.snapshot.creases.cursor::<ItemSummary>(snapshot); for crease in creases { let crease_range = crease.range().clone(); - new_creases.append(cursor.slice(&crease_range, Bias::Left, snapshot), snapshot); + new_creases.append(cursor.slice(&crease_range, Bias::Left), snapshot); let id = self.next_id; self.next_id.0 += 1; @@ -306,7 +306,7 @@ impl CreaseMap { new_creases.push(CreaseItem { crease, id }, snapshot); new_ids.push(id); } - new_creases.append(cursor.suffix(snapshot), snapshot); + new_creases.append(cursor.suffix(), snapshot); new_creases }; new_ids @@ -332,9 +332,9 @@ impl CreaseMap { let mut cursor = self.snapshot.creases.cursor::<ItemSummary>(snapshot); for (id, range) in &removals { - new_creases.append(cursor.slice(range, Bias::Left, snapshot), snapshot); + new_creases.append(cursor.slice(range, Bias::Left), snapshot); while let Some(item) = cursor.item() { - cursor.next(snapshot); + cursor.next(); if item.id == *id { break; } else { @@ -343,7 +343,7 @@ impl CreaseMap { } } - new_creases.append(cursor.suffix(snapshot), snapshot); + new_creases.append(cursor.suffix(), snapshot); new_creases }; diff --git a/crates/editor/src/display_map/fold_map.rs b/crates/editor/src/display_map/fold_map.rs index f37e7063e7228176b0f5455c278f331ed31d6ba0..c4e53a0f4361d83429158f106bd81326c8ddb573 100644 --- a/crates/editor/src/display_map/fold_map.rs +++ b/crates/editor/src/display_map/fold_map.rs @@ -17,7 +17,7 @@ use std::{ sync::Arc, usize, }; -use sum_tree::{Bias, Cursor, FilterCursor, SumTree, Summary, TreeMap}; +use sum_tree::{Bias, Cursor, Dimensions, FilterCursor, SumTree, Summary, TreeMap}; use ui::IntoElement as _; use util::post_inc; @@ -98,8 +98,10 @@ impl FoldPoint { } pub fn to_inlay_point(self, snapshot: &FoldSnapshot) -> InlayPoint { - let mut cursor = snapshot.transforms.cursor::<(FoldPoint, InlayPoint)>(&()); - cursor.seek(&self, Bias::Right, &()); + let mut cursor = snapshot + .transforms + .cursor::<Dimensions<FoldPoint, InlayPoint>>(&()); + cursor.seek(&self, Bias::Right); let overshoot = self.0 - cursor.start().0.0; InlayPoint(cursor.start().1.0 + overshoot) } @@ -107,8 +109,8 @@ impl FoldPoint { pub fn to_offset(self, snapshot: &FoldSnapshot) -> FoldOffset { let mut cursor = snapshot .transforms - .cursor::<(FoldPoint, TransformSummary)>(&()); - cursor.seek(&self, Bias::Right, &()); + .cursor::<Dimensions<FoldPoint, TransformSummary>>(&()); + cursor.seek(&self, Bias::Right); let overshoot = self.0 - cursor.start().1.output.lines; let mut offset = cursor.start().1.output.len; if !overshoot.is_zero() { @@ -187,10 +189,10 @@ impl FoldMapWriter<'_> { width: None, }, ); - new_tree.append(cursor.slice(&fold.range, Bias::Right, buffer), buffer); + new_tree.append(cursor.slice(&fold.range, Bias::Right), buffer); new_tree.push(fold, buffer); } - new_tree.append(cursor.suffix(buffer), buffer); + new_tree.append(cursor.suffix(), buffer); new_tree }; @@ -252,7 +254,7 @@ impl FoldMapWriter<'_> { fold_ixs_to_delete.push(*folds_cursor.start()); self.0.snapshot.fold_metadata_by_id.remove(&fold.id); } - folds_cursor.next(buffer); + folds_cursor.next(); } } @@ -263,10 +265,10 @@ impl FoldMapWriter<'_> { let mut cursor = self.0.snapshot.folds.cursor::<usize>(buffer); let mut folds = SumTree::new(buffer); for fold_ix in fold_ixs_to_delete { - folds.append(cursor.slice(&fold_ix, Bias::Right, buffer), buffer); - cursor.next(buffer); + folds.append(cursor.slice(&fold_ix, Bias::Right), buffer); + cursor.next(); } - folds.append(cursor.suffix(buffer), buffer); + folds.append(cursor.suffix(), buffer); folds }; @@ -412,7 +414,7 @@ impl FoldMap { let mut new_transforms = SumTree::<Transform>::default(); let mut cursor = self.snapshot.transforms.cursor::<InlayOffset>(&()); - cursor.seek(&InlayOffset(0), Bias::Right, &()); + cursor.seek(&InlayOffset(0), Bias::Right); while let Some(mut edit) = inlay_edits_iter.next() { if let Some(item) = cursor.item() { @@ -421,19 +423,19 @@ impl FoldMap { |transform| { if !transform.is_fold() { transform.summary.add_summary(&item.summary, &()); - cursor.next(&()); + cursor.next(); } }, &(), ); } } - new_transforms.append(cursor.slice(&edit.old.start, Bias::Left, &()), &()); + new_transforms.append(cursor.slice(&edit.old.start, Bias::Left), &()); edit.new.start -= edit.old.start - *cursor.start(); edit.old.start = *cursor.start(); - cursor.seek(&edit.old.end, Bias::Right, &()); - cursor.next(&()); + cursor.seek(&edit.old.end, Bias::Right); + cursor.next(); let mut delta = edit.new_len().0 as isize - edit.old_len().0 as isize; loop { @@ -449,8 +451,8 @@ impl FoldMap { if next_edit.old.end >= edit.old.end { edit.old.end = next_edit.old.end; - cursor.seek(&edit.old.end, Bias::Right, &()); - cursor.next(&()); + cursor.seek(&edit.old.end, Bias::Right); + cursor.next(); } } else { break; @@ -467,11 +469,7 @@ impl FoldMap { .snapshot .folds .cursor::<FoldRange>(&inlay_snapshot.buffer); - folds_cursor.seek( - &FoldRange(anchor..Anchor::max()), - Bias::Left, - &inlay_snapshot.buffer, - ); + folds_cursor.seek(&FoldRange(anchor..Anchor::max()), Bias::Left); let mut folds = iter::from_fn({ let inlay_snapshot = &inlay_snapshot; @@ -485,7 +483,7 @@ impl FoldMap { ..inlay_snapshot.to_inlay_offset(buffer_end), ) }); - folds_cursor.next(&inlay_snapshot.buffer); + folds_cursor.next(); item } }) @@ -558,7 +556,7 @@ impl FoldMap { } } - new_transforms.append(cursor.suffix(&()), &()); + new_transforms.append(cursor.suffix(), &()); if new_transforms.is_empty() { let text_summary = inlay_snapshot.text_summary(); push_isomorphic(&mut new_transforms, text_summary); @@ -571,35 +569,36 @@ impl FoldMap { let mut old_transforms = self .snapshot .transforms - .cursor::<(InlayOffset, FoldOffset)>(&()); - let mut new_transforms = new_transforms.cursor::<(InlayOffset, FoldOffset)>(&()); + .cursor::<Dimensions<InlayOffset, FoldOffset>>(&()); + let mut new_transforms = + new_transforms.cursor::<Dimensions<InlayOffset, FoldOffset>>(&()); for mut edit in inlay_edits { - old_transforms.seek(&edit.old.start, Bias::Left, &()); + old_transforms.seek(&edit.old.start, Bias::Left); if old_transforms.item().map_or(false, |t| t.is_fold()) { edit.old.start = old_transforms.start().0; } let old_start = old_transforms.start().1.0 + (edit.old.start - old_transforms.start().0).0; - old_transforms.seek_forward(&edit.old.end, Bias::Right, &()); + old_transforms.seek_forward(&edit.old.end, Bias::Right); if old_transforms.item().map_or(false, |t| t.is_fold()) { - old_transforms.next(&()); + old_transforms.next(); edit.old.end = old_transforms.start().0; } let old_end = old_transforms.start().1.0 + (edit.old.end - old_transforms.start().0).0; - new_transforms.seek(&edit.new.start, Bias::Left, &()); + new_transforms.seek(&edit.new.start, Bias::Left); if new_transforms.item().map_or(false, |t| t.is_fold()) { edit.new.start = new_transforms.start().0; } let new_start = new_transforms.start().1.0 + (edit.new.start - new_transforms.start().0).0; - new_transforms.seek_forward(&edit.new.end, Bias::Right, &()); + new_transforms.seek_forward(&edit.new.end, Bias::Right); if new_transforms.item().map_or(false, |t| t.is_fold()) { - new_transforms.next(&()); + new_transforms.next(); edit.new.end = new_transforms.start().0; } let new_end = @@ -655,11 +654,13 @@ impl FoldSnapshot { pub fn text_summary_for_range(&self, range: Range<FoldPoint>) -> TextSummary { let mut summary = TextSummary::default(); - let mut cursor = self.transforms.cursor::<(FoldPoint, InlayPoint)>(&()); - cursor.seek(&range.start, Bias::Right, &()); + let mut cursor = self + .transforms + .cursor::<Dimensions<FoldPoint, InlayPoint>>(&()); + cursor.seek(&range.start, Bias::Right); if let Some(transform) = cursor.item() { let start_in_transform = range.start.0 - cursor.start().0.0; - let end_in_transform = cmp::min(range.end, cursor.end(&()).0).0 - cursor.start().0.0; + let end_in_transform = cmp::min(range.end, cursor.end().0).0 - cursor.start().0.0; if let Some(placeholder) = transform.placeholder.as_ref() { summary = TextSummary::from( &placeholder.text @@ -678,10 +679,10 @@ impl FoldSnapshot { } } - if range.end > cursor.end(&()).0 { - cursor.next(&()); + if range.end > cursor.end().0 { + cursor.next(); summary += &cursor - .summary::<_, TransformSummary>(&range.end, Bias::Right, &()) + .summary::<_, TransformSummary>(&range.end, Bias::Right) .output; if let Some(transform) = cursor.item() { let end_in_transform = range.end.0 - cursor.start().0.0; @@ -704,20 +705,19 @@ impl FoldSnapshot { } pub fn to_fold_point(&self, point: InlayPoint, bias: Bias) -> FoldPoint { - let mut cursor = self.transforms.cursor::<(InlayPoint, FoldPoint)>(&()); - cursor.seek(&point, Bias::Right, &()); + let mut cursor = self + .transforms + .cursor::<Dimensions<InlayPoint, FoldPoint>>(&()); + cursor.seek(&point, Bias::Right); if cursor.item().map_or(false, |t| t.is_fold()) { if bias == Bias::Left || point == cursor.start().0 { cursor.start().1 } else { - cursor.end(&()).1 + cursor.end().1 } } else { let overshoot = point.0 - cursor.start().0.0; - FoldPoint(cmp::min( - cursor.start().1.0 + overshoot, - cursor.end(&()).1.0, - )) + FoldPoint(cmp::min(cursor.start().1.0 + overshoot, cursor.end().1.0)) } } @@ -741,8 +741,10 @@ impl FoldSnapshot { } let fold_point = FoldPoint::new(start_row, 0); - let mut cursor = self.transforms.cursor::<(FoldPoint, InlayPoint)>(&()); - cursor.seek(&fold_point, Bias::Left, &()); + let mut cursor = self + .transforms + .cursor::<Dimensions<FoldPoint, InlayPoint>>(&()); + cursor.seek(&fold_point, Bias::Left); let overshoot = fold_point.0 - cursor.start().0.0; let inlay_point = InlayPoint(cursor.start().1.0 + overshoot); @@ -773,7 +775,7 @@ impl FoldSnapshot { let mut folds = intersecting_folds(&self.inlay_snapshot, &self.folds, range, false); iter::from_fn(move || { let item = folds.item(); - folds.next(&self.inlay_snapshot.buffer); + folds.next(); item }) } @@ -785,7 +787,7 @@ impl FoldSnapshot { let buffer_offset = offset.to_offset(&self.inlay_snapshot.buffer); let inlay_offset = self.inlay_snapshot.to_inlay_offset(buffer_offset); let mut cursor = self.transforms.cursor::<InlayOffset>(&()); - cursor.seek(&inlay_offset, Bias::Right, &()); + cursor.seek(&inlay_offset, Bias::Right); cursor.item().map_or(false, |t| t.placeholder.is_some()) } @@ -794,7 +796,7 @@ impl FoldSnapshot { .inlay_snapshot .to_inlay_point(Point::new(buffer_row.0, 0)); let mut cursor = self.transforms.cursor::<InlayPoint>(&()); - cursor.seek(&inlay_point, Bias::Right, &()); + cursor.seek(&inlay_point, Bias::Right); loop { match cursor.item() { Some(transform) => { @@ -808,11 +810,11 @@ impl FoldSnapshot { None => return false, } - if cursor.end(&()).row() == inlay_point.row() { - cursor.next(&()); + if cursor.end().row() == inlay_point.row() { + cursor.next(); } else { inlay_point.0 += Point::new(1, 0); - cursor.seek(&inlay_point, Bias::Right, &()); + cursor.seek(&inlay_point, Bias::Right); } } } @@ -823,15 +825,17 @@ impl FoldSnapshot { language_aware: bool, highlights: Highlights<'a>, ) -> FoldChunks<'a> { - let mut transform_cursor = self.transforms.cursor::<(FoldOffset, InlayOffset)>(&()); - transform_cursor.seek(&range.start, Bias::Right, &()); + let mut transform_cursor = self + .transforms + .cursor::<Dimensions<FoldOffset, InlayOffset>>(&()); + transform_cursor.seek(&range.start, Bias::Right); let inlay_start = { let overshoot = range.start.0 - transform_cursor.start().0.0; transform_cursor.start().1 + InlayOffset(overshoot) }; - let transform_end = transform_cursor.end(&()); + let transform_end = transform_cursor.end(); let inlay_end = if transform_cursor .item() @@ -878,15 +882,17 @@ impl FoldSnapshot { } pub fn clip_point(&self, point: FoldPoint, bias: Bias) -> FoldPoint { - let mut cursor = self.transforms.cursor::<(FoldPoint, InlayPoint)>(&()); - cursor.seek(&point, Bias::Right, &()); + let mut cursor = self + .transforms + .cursor::<Dimensions<FoldPoint, InlayPoint>>(&()); + cursor.seek(&point, Bias::Right); if let Some(transform) = cursor.item() { let transform_start = cursor.start().0.0; if transform.placeholder.is_some() { if point.0 == transform_start || matches!(bias, Bias::Left) { FoldPoint(transform_start) } else { - FoldPoint(cursor.end(&()).0.0) + FoldPoint(cursor.end().0.0) } } else { let overshoot = InlayPoint(point.0 - transform_start); @@ -945,7 +951,7 @@ fn intersecting_folds<'a>( start_cmp == Ordering::Less && end_cmp == Ordering::Greater } }); - cursor.next(buffer); + cursor.next(); cursor } @@ -1203,7 +1209,7 @@ impl<'a> sum_tree::Dimension<'a, FoldSummary> for usize { #[derive(Clone)] pub struct FoldRows<'a> { - cursor: Cursor<'a, Transform, (FoldPoint, InlayPoint)>, + cursor: Cursor<'a, Transform, Dimensions<FoldPoint, InlayPoint>>, input_rows: InlayBufferRows<'a>, fold_point: FoldPoint, } @@ -1211,7 +1217,7 @@ pub struct FoldRows<'a> { impl FoldRows<'_> { pub(crate) fn seek(&mut self, row: u32) { let fold_point = FoldPoint::new(row, 0); - self.cursor.seek(&fold_point, Bias::Left, &()); + self.cursor.seek(&fold_point, Bias::Left); let overshoot = fold_point.0 - self.cursor.start().0.0; let inlay_point = InlayPoint(self.cursor.start().1.0 + overshoot); self.input_rows.seek(inlay_point.row()); @@ -1224,8 +1230,8 @@ impl Iterator for FoldRows<'_> { fn next(&mut self) -> Option<Self::Item> { let mut traversed_fold = false; - while self.fold_point > self.cursor.end(&()).0 { - self.cursor.next(&()); + while self.fold_point > self.cursor.end().0 { + self.cursor.next(); traversed_fold = true; if self.cursor.item().is_none() { break; @@ -1320,7 +1326,7 @@ impl DerefMut for ChunkRendererContext<'_, '_> { } pub struct FoldChunks<'a> { - transform_cursor: Cursor<'a, Transform, (FoldOffset, InlayOffset)>, + transform_cursor: Cursor<'a, Transform, Dimensions<FoldOffset, InlayOffset>>, inlay_chunks: InlayChunks<'a>, inlay_chunk: Option<(InlayOffset, InlayChunk<'a>)>, inlay_offset: InlayOffset, @@ -1330,14 +1336,14 @@ pub struct FoldChunks<'a> { impl FoldChunks<'_> { pub(crate) fn seek(&mut self, range: Range<FoldOffset>) { - self.transform_cursor.seek(&range.start, Bias::Right, &()); + self.transform_cursor.seek(&range.start, Bias::Right); let inlay_start = { let overshoot = range.start.0 - self.transform_cursor.start().0.0; self.transform_cursor.start().1 + InlayOffset(overshoot) }; - let transform_end = self.transform_cursor.end(&()); + let transform_end = self.transform_cursor.end(); let inlay_end = if self .transform_cursor @@ -1376,10 +1382,10 @@ impl<'a> Iterator for FoldChunks<'a> { self.inlay_chunk.take(); self.inlay_offset += InlayOffset(transform.summary.input.len); - while self.inlay_offset >= self.transform_cursor.end(&()).1 + while self.inlay_offset >= self.transform_cursor.end().1 && self.transform_cursor.item().is_some() { - self.transform_cursor.next(&()); + self.transform_cursor.next(); } self.output_offset.0 += placeholder.text.len(); @@ -1396,7 +1402,7 @@ impl<'a> Iterator for FoldChunks<'a> { && self.inlay_chunks.offset() != self.inlay_offset { let transform_start = self.transform_cursor.start(); - let transform_end = self.transform_cursor.end(&()); + let transform_end = self.transform_cursor.end(); let inlay_end = if self.max_output_offset < transform_end.0 { let overshoot = self.max_output_offset.0 - transform_start.0.0; transform_start.1 + InlayOffset(overshoot) @@ -1417,14 +1423,14 @@ impl<'a> Iterator for FoldChunks<'a> { if let Some((buffer_chunk_start, mut inlay_chunk)) = self.inlay_chunk.clone() { let chunk = &mut inlay_chunk.chunk; let buffer_chunk_end = buffer_chunk_start + InlayOffset(chunk.text.len()); - let transform_end = self.transform_cursor.end(&()).1; + let transform_end = self.transform_cursor.end().1; let chunk_end = buffer_chunk_end.min(transform_end); chunk.text = &chunk.text [(self.inlay_offset - buffer_chunk_start).0..(chunk_end - buffer_chunk_start).0]; if chunk_end == transform_end { - self.transform_cursor.next(&()); + self.transform_cursor.next(); } else if chunk_end == buffer_chunk_end { self.inlay_chunk.take(); } @@ -1455,8 +1461,8 @@ impl FoldOffset { pub fn to_point(self, snapshot: &FoldSnapshot) -> FoldPoint { let mut cursor = snapshot .transforms - .cursor::<(FoldOffset, TransformSummary)>(&()); - cursor.seek(&self, Bias::Right, &()); + .cursor::<Dimensions<FoldOffset, TransformSummary>>(&()); + cursor.seek(&self, Bias::Right); let overshoot = if cursor.item().map_or(true, |t| t.is_fold()) { Point::new(0, (self.0 - cursor.start().0.0) as u32) } else { @@ -1469,8 +1475,10 @@ impl FoldOffset { #[cfg(test)] pub fn to_inlay_offset(self, snapshot: &FoldSnapshot) -> InlayOffset { - let mut cursor = snapshot.transforms.cursor::<(FoldOffset, InlayOffset)>(&()); - cursor.seek(&self, Bias::Right, &()); + let mut cursor = snapshot + .transforms + .cursor::<Dimensions<FoldOffset, InlayOffset>>(&()); + cursor.seek(&self, Bias::Right); let overshoot = self.0 - cursor.start().0.0; InlayOffset(cursor.start().1.0 + overshoot) } diff --git a/crates/editor/src/display_map/inlay_map.rs b/crates/editor/src/display_map/inlay_map.rs index f7a696860a1c85d6955fe9e6f5aa00c0fa32a156..b296b3e62a39aa2ec8671676e051e94f5f9622cf 100644 --- a/crates/editor/src/display_map/inlay_map.rs +++ b/crates/editor/src/display_map/inlay_map.rs @@ -10,7 +10,7 @@ use std::{ ops::{Add, AddAssign, Range, Sub, SubAssign}, sync::Arc, }; -use sum_tree::{Bias, Cursor, SumTree}; +use sum_tree::{Bias, Cursor, Dimensions, SumTree}; use text::{Patch, Rope}; use ui::{ActiveTheme, IntoElement as _, ParentElement as _, Styled as _, div}; @@ -48,16 +48,16 @@ pub struct Inlay { impl Inlay { pub fn hint(id: usize, position: Anchor, hint: &project::InlayHint) -> Self { let mut text = hint.text(); - if hint.padding_right && !text.ends_with(' ') { - text.push(' '); + if hint.padding_right && text.chars_at(text.len().saturating_sub(1)).next() != Some(' ') { + text.push(" "); } - if hint.padding_left && !text.starts_with(' ') { - text.insert(0, ' '); + if hint.padding_left && text.chars_at(0).next() != Some(' ') { + text.push_front(" "); } Self { id: InlayId::Hint(id), position, - text: text.into(), + text, color: None, } } @@ -81,9 +81,9 @@ impl Inlay { } } - pub fn inline_completion<T: Into<Rope>>(id: usize, position: Anchor, text: T) -> Self { + pub fn edit_prediction<T: Into<Rope>>(id: usize, position: Anchor, text: T) -> Self { Self { - id: InlayId::InlineCompletion(id), + id: InlayId::EditPrediction(id), position, text: text.into(), color: None, @@ -235,14 +235,14 @@ impl<'a> sum_tree::Dimension<'a, TransformSummary> for Point { #[derive(Clone)] pub struct InlayBufferRows<'a> { - transforms: Cursor<'a, Transform, (InlayPoint, Point)>, + transforms: Cursor<'a, Transform, Dimensions<InlayPoint, Point>>, buffer_rows: MultiBufferRows<'a>, inlay_row: u32, max_buffer_row: MultiBufferRow, } pub struct InlayChunks<'a> { - transforms: Cursor<'a, Transform, (InlayOffset, usize)>, + transforms: Cursor<'a, Transform, Dimensions<InlayOffset, usize>>, buffer_chunks: CustomHighlightsChunks<'a>, buffer_chunk: Option<Chunk<'a>>, inlay_chunks: Option<text::Chunks<'a>>, @@ -263,7 +263,7 @@ pub struct InlayChunk<'a> { impl InlayChunks<'_> { pub fn seek(&mut self, new_range: Range<InlayOffset>) { - self.transforms.seek(&new_range.start, Bias::Right, &()); + self.transforms.seek(&new_range.start, Bias::Right); let buffer_range = self.snapshot.to_buffer_offset(new_range.start) ..self.snapshot.to_buffer_offset(new_range.end); @@ -296,12 +296,12 @@ impl<'a> Iterator for InlayChunks<'a> { *chunk = self.buffer_chunks.next().unwrap(); } - let desired_bytes = self.transforms.end(&()).0.0 - self.output_offset.0; + let desired_bytes = self.transforms.end().0.0 - self.output_offset.0; // If we're already at the transform boundary, skip to the next transform if desired_bytes == 0 { self.inlay_chunks = None; - self.transforms.next(&()); + self.transforms.next(); return self.next(); } @@ -340,15 +340,13 @@ impl<'a> Iterator for InlayChunks<'a> { let mut renderer = None; let mut highlight_style = match inlay.id { - InlayId::InlineCompletion(_) => { - self.highlight_styles.inline_completion.map(|s| { - if inlay.text.chars().all(|c| c.is_whitespace()) { - s.whitespace - } else { - s.insertion - } - }) - } + InlayId::EditPrediction(_) => self.highlight_styles.edit_prediction.map(|s| { + if inlay.text.chars().all(|c| c.is_whitespace()) { + s.whitespace + } else { + s.insertion + } + }), InlayId::Hint(_) => self.highlight_styles.inlay_hint, InlayId::DebuggerValue(_) => self.highlight_styles.inlay_hint, InlayId::Color(_) => { @@ -397,7 +395,7 @@ impl<'a> Iterator for InlayChunks<'a> { let inlay_chunks = self.inlay_chunks.get_or_insert_with(|| { let start = offset_in_inlay; - let end = cmp::min(self.max_output_offset, self.transforms.end(&()).0) + let end = cmp::min(self.max_output_offset, self.transforms.end().0) - self.transforms.start().0; inlay.text.chunks_in_range(start.0..end.0) }); @@ -441,9 +439,9 @@ impl<'a> Iterator for InlayChunks<'a> { } }; - if self.output_offset >= self.transforms.end(&()).0 { + if self.output_offset >= self.transforms.end().0 { self.inlay_chunks = None; - self.transforms.next(&()); + self.transforms.next(); } Some(chunk) @@ -453,7 +451,7 @@ impl<'a> Iterator for InlayChunks<'a> { impl InlayBufferRows<'_> { pub fn seek(&mut self, row: u32) { let inlay_point = InlayPoint::new(row, 0); - self.transforms.seek(&inlay_point, Bias::Left, &()); + self.transforms.seek(&inlay_point, Bias::Left); let mut buffer_point = self.transforms.start().1; let buffer_row = MultiBufferRow(if row == 0 { @@ -487,7 +485,7 @@ impl Iterator for InlayBufferRows<'_> { self.inlay_row += 1; self.transforms - .seek_forward(&InlayPoint::new(self.inlay_row, 0), Bias::Left, &()); + .seek_forward(&InlayPoint::new(self.inlay_row, 0), Bias::Left); Some(buffer_row) } @@ -553,21 +551,23 @@ impl InlayMap { } else { let mut inlay_edits = Patch::default(); let mut new_transforms = SumTree::default(); - let mut cursor = snapshot.transforms.cursor::<(usize, InlayOffset)>(&()); + let mut cursor = snapshot + .transforms + .cursor::<Dimensions<usize, InlayOffset>>(&()); let mut buffer_edits_iter = buffer_edits.iter().peekable(); while let Some(buffer_edit) = buffer_edits_iter.next() { - new_transforms.append(cursor.slice(&buffer_edit.old.start, Bias::Left, &()), &()); + new_transforms.append(cursor.slice(&buffer_edit.old.start, Bias::Left), &()); if let Some(Transform::Isomorphic(transform)) = cursor.item() { - if cursor.end(&()).0 == buffer_edit.old.start { + if cursor.end().0 == buffer_edit.old.start { push_isomorphic(&mut new_transforms, *transform); - cursor.next(&()); + cursor.next(); } } // Remove all the inlays and transforms contained by the edit. let old_start = cursor.start().1 + InlayOffset(buffer_edit.old.start - cursor.start().0); - cursor.seek(&buffer_edit.old.end, Bias::Right, &()); + cursor.seek(&buffer_edit.old.end, Bias::Right); let old_end = cursor.start().1 + InlayOffset(buffer_edit.old.end - cursor.start().0); @@ -625,20 +625,20 @@ impl InlayMap { // we can push its remainder. if buffer_edits_iter .peek() - .map_or(true, |edit| edit.old.start >= cursor.end(&()).0) + .map_or(true, |edit| edit.old.start >= cursor.end().0) { let transform_start = new_transforms.summary().input.len; let transform_end = - buffer_edit.new.end + (cursor.end(&()).0 - buffer_edit.old.end); + buffer_edit.new.end + (cursor.end().0 - buffer_edit.old.end); push_isomorphic( &mut new_transforms, buffer_snapshot.text_summary_for_range(transform_start..transform_end), ); - cursor.next(&()); + cursor.next(); } } - new_transforms.append(cursor.suffix(&()), &()); + new_transforms.append(cursor.suffix(), &()); if new_transforms.is_empty() { new_transforms.push(Transform::Isomorphic(Default::default()), &()); } @@ -737,13 +737,13 @@ impl InlayMap { Inlay::mock_hint( post_inc(next_inlay_id), snapshot.buffer.anchor_at(position, bias), - text.clone(), + &text, ) } else { - Inlay::inline_completion( + Inlay::edit_prediction( post_inc(next_inlay_id), snapshot.buffer.anchor_at(position, bias), - text.clone(), + &text, ) }; let inlay_id = next_inlay.id; @@ -772,20 +772,20 @@ impl InlaySnapshot { pub fn to_point(&self, offset: InlayOffset) -> InlayPoint { let mut cursor = self .transforms - .cursor::<(InlayOffset, (InlayPoint, usize))>(&()); - cursor.seek(&offset, Bias::Right, &()); + .cursor::<Dimensions<InlayOffset, InlayPoint, usize>>(&()); + cursor.seek(&offset, Bias::Right); let overshoot = offset.0 - cursor.start().0.0; match cursor.item() { Some(Transform::Isomorphic(_)) => { - let buffer_offset_start = cursor.start().1.1; + let buffer_offset_start = cursor.start().2; let buffer_offset_end = buffer_offset_start + overshoot; let buffer_start = self.buffer.offset_to_point(buffer_offset_start); let buffer_end = self.buffer.offset_to_point(buffer_offset_end); - InlayPoint(cursor.start().1.0.0 + (buffer_end - buffer_start)) + InlayPoint(cursor.start().1.0 + (buffer_end - buffer_start)) } Some(Transform::Inlay(inlay)) => { let overshoot = inlay.text.offset_to_point(overshoot); - InlayPoint(cursor.start().1.0.0 + overshoot) + InlayPoint(cursor.start().1.0 + overshoot) } None => self.max_point(), } @@ -802,27 +802,27 @@ impl InlaySnapshot { pub fn to_offset(&self, point: InlayPoint) -> InlayOffset { let mut cursor = self .transforms - .cursor::<(InlayPoint, (InlayOffset, Point))>(&()); - cursor.seek(&point, Bias::Right, &()); + .cursor::<Dimensions<InlayPoint, InlayOffset, Point>>(&()); + cursor.seek(&point, Bias::Right); let overshoot = point.0 - cursor.start().0.0; match cursor.item() { Some(Transform::Isomorphic(_)) => { - let buffer_point_start = cursor.start().1.1; + let buffer_point_start = cursor.start().2; let buffer_point_end = buffer_point_start + overshoot; let buffer_offset_start = self.buffer.point_to_offset(buffer_point_start); let buffer_offset_end = self.buffer.point_to_offset(buffer_point_end); - InlayOffset(cursor.start().1.0.0 + (buffer_offset_end - buffer_offset_start)) + InlayOffset(cursor.start().1.0 + (buffer_offset_end - buffer_offset_start)) } Some(Transform::Inlay(inlay)) => { let overshoot = inlay.text.point_to_offset(overshoot); - InlayOffset(cursor.start().1.0.0 + overshoot) + InlayOffset(cursor.start().1.0 + overshoot) } None => self.len(), } } pub fn to_buffer_point(&self, point: InlayPoint) -> Point { - let mut cursor = self.transforms.cursor::<(InlayPoint, Point)>(&()); - cursor.seek(&point, Bias::Right, &()); + let mut cursor = self.transforms.cursor::<Dimensions<InlayPoint, Point>>(&()); + cursor.seek(&point, Bias::Right); match cursor.item() { Some(Transform::Isomorphic(_)) => { let overshoot = point.0 - cursor.start().0.0; @@ -833,8 +833,10 @@ impl InlaySnapshot { } } pub fn to_buffer_offset(&self, offset: InlayOffset) -> usize { - let mut cursor = self.transforms.cursor::<(InlayOffset, usize)>(&()); - cursor.seek(&offset, Bias::Right, &()); + let mut cursor = self + .transforms + .cursor::<Dimensions<InlayOffset, usize>>(&()); + cursor.seek(&offset, Bias::Right); match cursor.item() { Some(Transform::Isomorphic(_)) => { let overshoot = offset - cursor.start().0; @@ -846,20 +848,22 @@ impl InlaySnapshot { } pub fn to_inlay_offset(&self, offset: usize) -> InlayOffset { - let mut cursor = self.transforms.cursor::<(usize, InlayOffset)>(&()); - cursor.seek(&offset, Bias::Left, &()); + let mut cursor = self + .transforms + .cursor::<Dimensions<usize, InlayOffset>>(&()); + cursor.seek(&offset, Bias::Left); loop { match cursor.item() { Some(Transform::Isomorphic(_)) => { - if offset == cursor.end(&()).0 { + if offset == cursor.end().0 { while let Some(Transform::Inlay(inlay)) = cursor.next_item() { if inlay.position.bias() == Bias::Right { break; } else { - cursor.next(&()); + cursor.next(); } } - return cursor.end(&()).1; + return cursor.end().1; } else { let overshoot = offset - cursor.start().0; return InlayOffset(cursor.start().1.0 + overshoot); @@ -867,7 +871,7 @@ impl InlaySnapshot { } Some(Transform::Inlay(inlay)) => { if inlay.position.bias() == Bias::Left { - cursor.next(&()); + cursor.next(); } else { return cursor.start().1; } @@ -879,20 +883,20 @@ impl InlaySnapshot { } } pub fn to_inlay_point(&self, point: Point) -> InlayPoint { - let mut cursor = self.transforms.cursor::<(Point, InlayPoint)>(&()); - cursor.seek(&point, Bias::Left, &()); + let mut cursor = self.transforms.cursor::<Dimensions<Point, InlayPoint>>(&()); + cursor.seek(&point, Bias::Left); loop { match cursor.item() { Some(Transform::Isomorphic(_)) => { - if point == cursor.end(&()).0 { + if point == cursor.end().0 { while let Some(Transform::Inlay(inlay)) = cursor.next_item() { if inlay.position.bias() == Bias::Right { break; } else { - cursor.next(&()); + cursor.next(); } } - return cursor.end(&()).1; + return cursor.end().1; } else { let overshoot = point - cursor.start().0; return InlayPoint(cursor.start().1.0 + overshoot); @@ -900,7 +904,7 @@ impl InlaySnapshot { } Some(Transform::Inlay(inlay)) => { if inlay.position.bias() == Bias::Left { - cursor.next(&()); + cursor.next(); } else { return cursor.start().1; } @@ -913,8 +917,8 @@ impl InlaySnapshot { } pub fn clip_point(&self, mut point: InlayPoint, mut bias: Bias) -> InlayPoint { - let mut cursor = self.transforms.cursor::<(InlayPoint, Point)>(&()); - cursor.seek(&point, Bias::Left, &()); + let mut cursor = self.transforms.cursor::<Dimensions<InlayPoint, Point>>(&()); + cursor.seek(&point, Bias::Left); loop { match cursor.item() { Some(Transform::Isomorphic(transform)) => { @@ -923,7 +927,7 @@ impl InlaySnapshot { if inlay.position.bias() == Bias::Left { return point; } else if bias == Bias::Left { - cursor.prev(&()); + cursor.prev(); } else if transform.first_line_chars == 0 { point.0 += Point::new(1, 0); } else { @@ -932,12 +936,12 @@ impl InlaySnapshot { } else { return point; } - } else if cursor.end(&()).0 == point { + } else if cursor.end().0 == point { if let Some(Transform::Inlay(inlay)) = cursor.next_item() { if inlay.position.bias() == Bias::Right { return point; } else if bias == Bias::Right { - cursor.next(&()); + cursor.next(); } else if point.0.column == 0 { point.0.row -= 1; point.0.column = self.line_len(point.0.row); @@ -970,7 +974,7 @@ impl InlaySnapshot { } _ => return point, } - } else if point == cursor.end(&()).0 && inlay.position.bias() == Bias::Left { + } else if point == cursor.end().0 && inlay.position.bias() == Bias::Left { match cursor.next_item() { Some(Transform::Inlay(inlay)) => { if inlay.position.bias() == Bias::Right { @@ -983,9 +987,9 @@ impl InlaySnapshot { if bias == Bias::Left { point = cursor.start().0; - cursor.prev(&()); + cursor.prev(); } else { - cursor.next(&()); + cursor.next(); point = cursor.start().0; } } @@ -993,9 +997,9 @@ impl InlaySnapshot { bias = bias.invert(); if bias == Bias::Left { point = cursor.start().0; - cursor.prev(&()); + cursor.prev(); } else { - cursor.next(&()); + cursor.next(); point = cursor.start().0; } } @@ -1010,8 +1014,10 @@ impl InlaySnapshot { pub fn text_summary_for_range(&self, range: Range<InlayOffset>) -> TextSummary { let mut summary = TextSummary::default(); - let mut cursor = self.transforms.cursor::<(InlayOffset, usize)>(&()); - cursor.seek(&range.start, Bias::Right, &()); + let mut cursor = self + .transforms + .cursor::<Dimensions<InlayOffset, usize>>(&()); + cursor.seek(&range.start, Bias::Right); let overshoot = range.start.0 - cursor.start().0.0; match cursor.item() { @@ -1019,22 +1025,22 @@ impl InlaySnapshot { let buffer_start = cursor.start().1; let suffix_start = buffer_start + overshoot; let suffix_end = - buffer_start + (cmp::min(cursor.end(&()).0, range.end).0 - cursor.start().0.0); + buffer_start + (cmp::min(cursor.end().0, range.end).0 - cursor.start().0.0); summary = self.buffer.text_summary_for_range(suffix_start..suffix_end); - cursor.next(&()); + cursor.next(); } Some(Transform::Inlay(inlay)) => { let suffix_start = overshoot; - let suffix_end = cmp::min(cursor.end(&()).0, range.end).0 - cursor.start().0.0; + let suffix_end = cmp::min(cursor.end().0, range.end).0 - cursor.start().0.0; summary = inlay.text.cursor(suffix_start).summary(suffix_end); - cursor.next(&()); + cursor.next(); } None => {} } if range.end > cursor.start().0 { summary += cursor - .summary::<_, TransformSummary>(&range.end, Bias::Right, &()) + .summary::<_, TransformSummary>(&range.end, Bias::Right) .output; let overshoot = range.end.0 - cursor.start().0.0; @@ -1058,9 +1064,9 @@ impl InlaySnapshot { } pub fn row_infos(&self, row: u32) -> InlayBufferRows<'_> { - let mut cursor = self.transforms.cursor::<(InlayPoint, Point)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<InlayPoint, Point>>(&()); let inlay_point = InlayPoint::new(row, 0); - cursor.seek(&inlay_point, Bias::Left, &()); + cursor.seek(&inlay_point, Bias::Left); let max_buffer_row = self.buffer.max_row(); let mut buffer_point = cursor.start().1; @@ -1100,8 +1106,10 @@ impl InlaySnapshot { language_aware: bool, highlights: Highlights<'a>, ) -> InlayChunks<'a> { - let mut cursor = self.transforms.cursor::<(InlayOffset, usize)>(&()); - cursor.seek(&range.start, Bias::Right, &()); + let mut cursor = self + .transforms + .cursor::<Dimensions<InlayOffset, usize>>(&()); + cursor.seek(&range.start, Bias::Right); let buffer_range = self.to_buffer_offset(range.start)..self.to_buffer_offset(range.end); let buffer_chunks = CustomHighlightsChunks::new( @@ -1389,7 +1397,7 @@ mod tests { buffer.read(cx).snapshot(cx).anchor_before(3), "|123|", ), - Inlay::inline_completion( + Inlay::edit_prediction( post_inc(&mut next_inlay_id), buffer.read(cx).snapshot(cx).anchor_after(3), "|456|", @@ -1609,7 +1617,7 @@ mod tests { buffer.read(cx).snapshot(cx).anchor_before(4), "|456|", ), - Inlay::inline_completion( + Inlay::edit_prediction( post_inc(&mut next_inlay_id), buffer.read(cx).snapshot(cx).anchor_before(7), "\n|567|\n", @@ -1686,7 +1694,7 @@ mod tests { (offset, inlay.clone()) }) .collect::<Vec<_>>(); - let mut expected_text = Rope::from(buffer_snapshot.text()); + let mut expected_text = Rope::from(&buffer_snapshot.text()); for (offset, inlay) in inlays.iter().rev() { expected_text.replace(*offset..*offset, &inlay.text.to_string()); } diff --git a/crates/editor/src/display_map/wrap_map.rs b/crates/editor/src/display_map/wrap_map.rs index a29bf5388271e422bea6aa890d5617ab0cc3f5ee..269f8f0c409cee7fa87b64cccb5310e4beb1edd2 100644 --- a/crates/editor/src/display_map/wrap_map.rs +++ b/crates/editor/src/display_map/wrap_map.rs @@ -9,7 +9,7 @@ use multi_buffer::{MultiBufferSnapshot, RowInfo}; use smol::future::yield_now; use std::sync::LazyLock; use std::{cmp, collections::VecDeque, mem, ops::Range, time::Duration}; -use sum_tree::{Bias, Cursor, SumTree}; +use sum_tree::{Bias, Cursor, Dimensions, SumTree}; use text::Patch; pub use super::tab_map::TextSummary; @@ -55,7 +55,7 @@ pub struct WrapChunks<'a> { input_chunk: Chunk<'a>, output_position: WrapPoint, max_output_row: u32, - transforms: Cursor<'a, Transform, (WrapPoint, TabPoint)>, + transforms: Cursor<'a, Transform, Dimensions<WrapPoint, TabPoint>>, snapshot: &'a WrapSnapshot, } @@ -66,13 +66,13 @@ pub struct WrapRows<'a> { output_row: u32, soft_wrapped: bool, max_output_row: u32, - transforms: Cursor<'a, Transform, (WrapPoint, TabPoint)>, + transforms: Cursor<'a, Transform, Dimensions<WrapPoint, TabPoint>>, } impl WrapRows<'_> { pub(crate) fn seek(&mut self, start_row: u32) { self.transforms - .seek(&WrapPoint::new(start_row, 0), Bias::Left, &()); + .seek(&WrapPoint::new(start_row, 0), Bias::Left); let mut input_row = self.transforms.start().1.row(); if self.transforms.item().map_or(false, |t| t.is_isomorphic()) { input_row += start_row - self.transforms.start().0.row(); @@ -340,7 +340,7 @@ impl WrapSnapshot { let mut tab_edits_iter = tab_edits.iter().peekable(); new_transforms = - old_cursor.slice(&tab_edits_iter.peek().unwrap().old.start, Bias::Right, &()); + old_cursor.slice(&tab_edits_iter.peek().unwrap().old.start, Bias::Right); while let Some(edit) = tab_edits_iter.next() { if edit.new.start > TabPoint::from(new_transforms.summary().input.lines) { @@ -356,31 +356,29 @@ impl WrapSnapshot { )); } - old_cursor.seek_forward(&edit.old.end, Bias::Right, &()); + old_cursor.seek_forward(&edit.old.end, Bias::Right); if let Some(next_edit) = tab_edits_iter.peek() { - if next_edit.old.start > old_cursor.end(&()) { - if old_cursor.end(&()) > edit.old.end { + if next_edit.old.start > old_cursor.end() { + if old_cursor.end() > edit.old.end { let summary = self .tab_snapshot - .text_summary_for_range(edit.old.end..old_cursor.end(&())); + .text_summary_for_range(edit.old.end..old_cursor.end()); new_transforms.push_or_extend(Transform::isomorphic(summary)); } - old_cursor.next(&()); - new_transforms.append( - old_cursor.slice(&next_edit.old.start, Bias::Right, &()), - &(), - ); + old_cursor.next(); + new_transforms + .append(old_cursor.slice(&next_edit.old.start, Bias::Right), &()); } } else { - if old_cursor.end(&()) > edit.old.end { + if old_cursor.end() > edit.old.end { let summary = self .tab_snapshot - .text_summary_for_range(edit.old.end..old_cursor.end(&())); + .text_summary_for_range(edit.old.end..old_cursor.end()); new_transforms.push_or_extend(Transform::isomorphic(summary)); } - old_cursor.next(&()); - new_transforms.append(old_cursor.suffix(&()), &()); + old_cursor.next(); + new_transforms.append(old_cursor.suffix(), &()); } } } @@ -441,7 +439,6 @@ impl WrapSnapshot { new_transforms = old_cursor.slice( &TabPoint::new(row_edits.peek().unwrap().old_rows.start, 0), Bias::Right, - &(), ); while let Some(edit) = row_edits.next() { @@ -516,34 +513,31 @@ impl WrapSnapshot { } new_transforms.extend(edit_transforms, &()); - old_cursor.seek_forward(&TabPoint::new(edit.old_rows.end, 0), Bias::Right, &()); + old_cursor.seek_forward(&TabPoint::new(edit.old_rows.end, 0), Bias::Right); if let Some(next_edit) = row_edits.peek() { - if next_edit.old_rows.start > old_cursor.end(&()).row() { - if old_cursor.end(&()) > TabPoint::new(edit.old_rows.end, 0) { + if next_edit.old_rows.start > old_cursor.end().row() { + if old_cursor.end() > TabPoint::new(edit.old_rows.end, 0) { let summary = self.tab_snapshot.text_summary_for_range( - TabPoint::new(edit.old_rows.end, 0)..old_cursor.end(&()), + TabPoint::new(edit.old_rows.end, 0)..old_cursor.end(), ); new_transforms.push_or_extend(Transform::isomorphic(summary)); } - old_cursor.next(&()); + old_cursor.next(); new_transforms.append( - old_cursor.slice( - &TabPoint::new(next_edit.old_rows.start, 0), - Bias::Right, - &(), - ), + old_cursor + .slice(&TabPoint::new(next_edit.old_rows.start, 0), Bias::Right), &(), ); } } else { - if old_cursor.end(&()) > TabPoint::new(edit.old_rows.end, 0) { + if old_cursor.end() > TabPoint::new(edit.old_rows.end, 0) { let summary = self.tab_snapshot.text_summary_for_range( - TabPoint::new(edit.old_rows.end, 0)..old_cursor.end(&()), + TabPoint::new(edit.old_rows.end, 0)..old_cursor.end(), ); new_transforms.push_or_extend(Transform::isomorphic(summary)); } - old_cursor.next(&()); - new_transforms.append(old_cursor.suffix(&()), &()); + old_cursor.next(); + new_transforms.append(old_cursor.suffix(), &()); } } } @@ -570,19 +564,19 @@ impl WrapSnapshot { tab_edit.new.start.0.column = 0; tab_edit.new.end.0 += Point::new(1, 0); - old_cursor.seek(&tab_edit.old.start, Bias::Right, &()); + old_cursor.seek(&tab_edit.old.start, Bias::Right); let mut old_start = old_cursor.start().output.lines; old_start += tab_edit.old.start.0 - old_cursor.start().input.lines; - old_cursor.seek(&tab_edit.old.end, Bias::Right, &()); + old_cursor.seek(&tab_edit.old.end, Bias::Right); let mut old_end = old_cursor.start().output.lines; old_end += tab_edit.old.end.0 - old_cursor.start().input.lines; - new_cursor.seek(&tab_edit.new.start, Bias::Right, &()); + new_cursor.seek(&tab_edit.new.start, Bias::Right); let mut new_start = new_cursor.start().output.lines; new_start += tab_edit.new.start.0 - new_cursor.start().input.lines; - new_cursor.seek(&tab_edit.new.end, Bias::Right, &()); + new_cursor.seek(&tab_edit.new.end, Bias::Right); let mut new_end = new_cursor.start().output.lines; new_end += tab_edit.new.end.0 - new_cursor.start().input.lines; @@ -604,8 +598,10 @@ impl WrapSnapshot { ) -> WrapChunks<'a> { let output_start = WrapPoint::new(rows.start, 0); let output_end = WrapPoint::new(rows.end, 0); - let mut transforms = self.transforms.cursor::<(WrapPoint, TabPoint)>(&()); - transforms.seek(&output_start, Bias::Right, &()); + let mut transforms = self + .transforms + .cursor::<Dimensions<WrapPoint, TabPoint>>(&()); + transforms.seek(&output_start, Bias::Right); let mut input_start = TabPoint(transforms.start().1.0); if transforms.item().map_or(false, |t| t.is_isomorphic()) { input_start.0 += output_start.0 - transforms.start().0.0; @@ -632,8 +628,10 @@ impl WrapSnapshot { } pub fn line_len(&self, row: u32) -> u32 { - let mut cursor = self.transforms.cursor::<(WrapPoint, TabPoint)>(&()); - cursor.seek(&WrapPoint::new(row + 1, 0), Bias::Left, &()); + let mut cursor = self + .transforms + .cursor::<Dimensions<WrapPoint, TabPoint>>(&()); + cursor.seek(&WrapPoint::new(row + 1, 0), Bias::Left); if cursor .item() .map_or(false, |transform| transform.is_isomorphic()) @@ -657,11 +655,13 @@ impl WrapSnapshot { let start = WrapPoint::new(rows.start, 0); let end = WrapPoint::new(rows.end, 0); - let mut cursor = self.transforms.cursor::<(WrapPoint, TabPoint)>(&()); - cursor.seek(&start, Bias::Right, &()); + let mut cursor = self + .transforms + .cursor::<Dimensions<WrapPoint, TabPoint>>(&()); + cursor.seek(&start, Bias::Right); if let Some(transform) = cursor.item() { let start_in_transform = start.0 - cursor.start().0.0; - let end_in_transform = cmp::min(end, cursor.end(&()).0).0 - cursor.start().0.0; + let end_in_transform = cmp::min(end, cursor.end().0).0 - cursor.start().0.0; if transform.is_isomorphic() { let tab_start = TabPoint(cursor.start().1.0 + start_in_transform); let tab_end = TabPoint(cursor.start().1.0 + end_in_transform); @@ -678,12 +678,12 @@ impl WrapSnapshot { }; } - cursor.next(&()); + cursor.next(); } if rows.end > cursor.start().0.row() { summary += &cursor - .summary::<_, TransformSummary>(&WrapPoint::new(rows.end, 0), Bias::Right, &()) + .summary::<_, TransformSummary>(&WrapPoint::new(rows.end, 0), Bias::Right) .output; if let Some(transform) = cursor.item() { @@ -712,7 +712,7 @@ impl WrapSnapshot { pub fn soft_wrap_indent(&self, row: u32) -> Option<u32> { let mut cursor = self.transforms.cursor::<WrapPoint>(&()); - cursor.seek(&WrapPoint::new(row + 1, 0), Bias::Right, &()); + cursor.seek(&WrapPoint::new(row + 1, 0), Bias::Right); cursor.item().and_then(|transform| { if transform.is_isomorphic() { None @@ -727,8 +727,10 @@ impl WrapSnapshot { } pub fn row_infos(&self, start_row: u32) -> WrapRows<'_> { - let mut transforms = self.transforms.cursor::<(WrapPoint, TabPoint)>(&()); - transforms.seek(&WrapPoint::new(start_row, 0), Bias::Left, &()); + let mut transforms = self + .transforms + .cursor::<Dimensions<WrapPoint, TabPoint>>(&()); + transforms.seek(&WrapPoint::new(start_row, 0), Bias::Left); let mut input_row = transforms.start().1.row(); if transforms.item().map_or(false, |t| t.is_isomorphic()) { input_row += start_row - transforms.start().0.row(); @@ -747,8 +749,10 @@ impl WrapSnapshot { } pub fn to_tab_point(&self, point: WrapPoint) -> TabPoint { - let mut cursor = self.transforms.cursor::<(WrapPoint, TabPoint)>(&()); - cursor.seek(&point, Bias::Right, &()); + let mut cursor = self + .transforms + .cursor::<Dimensions<WrapPoint, TabPoint>>(&()); + cursor.seek(&point, Bias::Right); let mut tab_point = cursor.start().1.0; if cursor.item().map_or(false, |t| t.is_isomorphic()) { tab_point += point.0 - cursor.start().0.0; @@ -765,15 +769,17 @@ impl WrapSnapshot { } pub fn tab_point_to_wrap_point(&self, point: TabPoint) -> WrapPoint { - let mut cursor = self.transforms.cursor::<(TabPoint, WrapPoint)>(&()); - cursor.seek(&point, Bias::Right, &()); + let mut cursor = self + .transforms + .cursor::<Dimensions<TabPoint, WrapPoint>>(&()); + cursor.seek(&point, Bias::Right); WrapPoint(cursor.start().1.0 + (point.0 - cursor.start().0.0)) } pub fn clip_point(&self, mut point: WrapPoint, bias: Bias) -> WrapPoint { if bias == Bias::Left { let mut cursor = self.transforms.cursor::<WrapPoint>(&()); - cursor.seek(&point, Bias::Right, &()); + cursor.seek(&point, Bias::Right); if cursor.item().map_or(false, |t| !t.is_isomorphic()) { point = *cursor.start(); *point.column_mut() -= 1; @@ -790,17 +796,19 @@ impl WrapSnapshot { *point.column_mut() = 0; - let mut cursor = self.transforms.cursor::<(WrapPoint, TabPoint)>(&()); - cursor.seek(&point, Bias::Right, &()); + let mut cursor = self + .transforms + .cursor::<Dimensions<WrapPoint, TabPoint>>(&()); + cursor.seek(&point, Bias::Right); if cursor.item().is_none() { - cursor.prev(&()); + cursor.prev(); } while let Some(transform) = cursor.item() { if transform.is_isomorphic() && cursor.start().1.column() == 0 { - return cmp::min(cursor.end(&()).0.row(), point.row()); + return cmp::min(cursor.end().0.row(), point.row()); } else { - cursor.prev(&()); + cursor.prev(); } } @@ -810,13 +818,15 @@ impl WrapSnapshot { pub fn next_row_boundary(&self, mut point: WrapPoint) -> Option<u32> { point.0 += Point::new(1, 0); - let mut cursor = self.transforms.cursor::<(WrapPoint, TabPoint)>(&()); - cursor.seek(&point, Bias::Right, &()); + let mut cursor = self + .transforms + .cursor::<Dimensions<WrapPoint, TabPoint>>(&()); + cursor.seek(&point, Bias::Right); while let Some(transform) = cursor.item() { if transform.is_isomorphic() && cursor.start().1.column() == 0 { return Some(cmp::max(cursor.start().0.row(), point.row())); } else { - cursor.next(&()); + cursor.next(); } } @@ -889,7 +899,7 @@ impl WrapChunks<'_> { pub(crate) fn seek(&mut self, rows: Range<u32>) { let output_start = WrapPoint::new(rows.start, 0); let output_end = WrapPoint::new(rows.end, 0); - self.transforms.seek(&output_start, Bias::Right, &()); + self.transforms.seek(&output_start, Bias::Right); let mut input_start = TabPoint(self.transforms.start().1.0); if self.transforms.item().map_or(false, |t| t.is_isomorphic()) { input_start.0 += output_start.0 - self.transforms.start().0.0; @@ -930,7 +940,7 @@ impl<'a> Iterator for WrapChunks<'a> { } self.output_position.0 += summary; - self.transforms.next(&()); + self.transforms.next(); return Some(Chunk { text: &display_text[start_ix..end_ix], ..Default::default() @@ -942,7 +952,7 @@ impl<'a> Iterator for WrapChunks<'a> { } let mut input_len = 0; - let transform_end = self.transforms.end(&()).0; + let transform_end = self.transforms.end().0; for c in self.input_chunk.text.chars() { let char_len = c.len_utf8(); input_len += char_len; @@ -954,7 +964,7 @@ impl<'a> Iterator for WrapChunks<'a> { } if self.output_position >= transform_end { - self.transforms.next(&()); + self.transforms.next(); break; } } @@ -982,7 +992,7 @@ impl Iterator for WrapRows<'_> { self.output_row += 1; self.transforms - .seek_forward(&WrapPoint::new(self.output_row, 0), Bias::Left, &()); + .seek_forward(&WrapPoint::new(self.output_row, 0), Bias::Left); if self.transforms.item().map_or(false, |t| t.is_isomorphic()) { self.input_buffer_row = self.input_buffer_rows.next().unwrap(); self.soft_wrapped = false; diff --git a/crates/editor/src/inline_completion_tests.rs b/crates/editor/src/edit_prediction_tests.rs similarity index 58% rename from crates/editor/src/inline_completion_tests.rs rename to crates/editor/src/edit_prediction_tests.rs index 5ac34c94f52820b4326da59ecaf4afc5253d6525..7bf51e45d72f383b4af34cf6ad493792f8e9d351 100644 --- a/crates/editor/src/inline_completion_tests.rs +++ b/crates/editor/src/edit_prediction_tests.rs @@ -1,26 +1,26 @@ +use edit_prediction::EditPredictionProvider; use gpui::{Entity, prelude::*}; use indoc::indoc; -use inline_completion::EditPredictionProvider; use multi_buffer::{Anchor, MultiBufferSnapshot, ToPoint}; use project::Project; use std::ops::Range; use text::{Point, ToOffset}; use crate::{ - InlineCompletion, editor_tests::init_test, test::editor_test_context::EditorTestContext, + EditPrediction, editor_tests::init_test, test::editor_test_context::EditorTestContext, }; #[gpui::test] -async fn test_inline_completion_insert(cx: &mut gpui::TestAppContext) { +async fn test_edit_prediction_insert(cx: &mut gpui::TestAppContext) { init_test(cx, |_| {}); let mut cx = EditorTestContext::new(cx).await; - let provider = cx.new(|_| FakeInlineCompletionProvider::default()); + let provider = cx.new(|_| FakeEditPredictionProvider::default()); assign_editor_completion_provider(provider.clone(), &mut cx); cx.set_state("let absolute_zero_celsius = ˇ;"); propose_edits(&provider, vec![(28..28, "-273.15")], &mut cx); - cx.update_editor(|editor, window, cx| editor.update_visible_inline_completion(window, cx)); + cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx)); assert_editor_active_edit_completion(&mut cx, |_, edits| { assert_eq!(edits.len(), 1); @@ -33,16 +33,16 @@ async fn test_inline_completion_insert(cx: &mut gpui::TestAppContext) { } #[gpui::test] -async fn test_inline_completion_modification(cx: &mut gpui::TestAppContext) { +async fn test_edit_prediction_modification(cx: &mut gpui::TestAppContext) { init_test(cx, |_| {}); let mut cx = EditorTestContext::new(cx).await; - let provider = cx.new(|_| FakeInlineCompletionProvider::default()); + let provider = cx.new(|_| FakeEditPredictionProvider::default()); assign_editor_completion_provider(provider.clone(), &mut cx); cx.set_state("let pi = ˇ\"foo\";"); propose_edits(&provider, vec![(9..14, "3.14159")], &mut cx); - cx.update_editor(|editor, window, cx| editor.update_visible_inline_completion(window, cx)); + cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx)); assert_editor_active_edit_completion(&mut cx, |_, edits| { assert_eq!(edits.len(), 1); @@ -55,11 +55,11 @@ async fn test_inline_completion_modification(cx: &mut gpui::TestAppContext) { } #[gpui::test] -async fn test_inline_completion_jump_button(cx: &mut gpui::TestAppContext) { +async fn test_edit_prediction_jump_button(cx: &mut gpui::TestAppContext) { init_test(cx, |_| {}); let mut cx = EditorTestContext::new(cx).await; - let provider = cx.new(|_| FakeInlineCompletionProvider::default()); + let provider = cx.new(|_| FakeEditPredictionProvider::default()); assign_editor_completion_provider(provider.clone(), &mut cx); // Cursor is 2+ lines above the proposed edit @@ -77,7 +77,7 @@ async fn test_inline_completion_jump_button(cx: &mut gpui::TestAppContext) { &mut cx, ); - cx.update_editor(|editor, window, cx| editor.update_visible_inline_completion(window, cx)); + cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx)); assert_editor_active_move_completion(&mut cx, |snapshot, move_target| { assert_eq!(move_target.to_point(&snapshot), Point::new(4, 3)); }); @@ -107,7 +107,7 @@ async fn test_inline_completion_jump_button(cx: &mut gpui::TestAppContext) { &mut cx, ); - cx.update_editor(|editor, window, cx| editor.update_visible_inline_completion(window, cx)); + cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx)); assert_editor_active_move_completion(&mut cx, |snapshot, move_target| { assert_eq!(move_target.to_point(&snapshot), Point::new(1, 3)); }); @@ -124,11 +124,11 @@ async fn test_inline_completion_jump_button(cx: &mut gpui::TestAppContext) { } #[gpui::test] -async fn test_inline_completion_invalidation_range(cx: &mut gpui::TestAppContext) { +async fn test_edit_prediction_invalidation_range(cx: &mut gpui::TestAppContext) { init_test(cx, |_| {}); let mut cx = EditorTestContext::new(cx).await; - let provider = cx.new(|_| FakeInlineCompletionProvider::default()); + let provider = cx.new(|_| FakeEditPredictionProvider::default()); assign_editor_completion_provider(provider.clone(), &mut cx); // Cursor is 3+ lines above the proposed edit @@ -148,7 +148,7 @@ async fn test_inline_completion_invalidation_range(cx: &mut gpui::TestAppContext &mut cx, ); - cx.update_editor(|editor, window, cx| editor.update_visible_inline_completion(window, cx)); + cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx)); assert_editor_active_move_completion(&mut cx, |snapshot, move_target| { assert_eq!(move_target.to_point(&snapshot), edit_location); }); @@ -176,7 +176,7 @@ async fn test_inline_completion_invalidation_range(cx: &mut gpui::TestAppContext line "}); cx.editor(|editor, _, _| { - assert!(editor.active_inline_completion.is_none()); + assert!(editor.active_edit_prediction.is_none()); }); // Cursor is 3+ lines below the proposed edit @@ -196,7 +196,7 @@ async fn test_inline_completion_invalidation_range(cx: &mut gpui::TestAppContext &mut cx, ); - cx.update_editor(|editor, window, cx| editor.update_visible_inline_completion(window, cx)); + cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx)); assert_editor_active_move_completion(&mut cx, |snapshot, move_target| { assert_eq!(move_target.to_point(&snapshot), edit_location); }); @@ -224,7 +224,50 @@ async fn test_inline_completion_invalidation_range(cx: &mut gpui::TestAppContext line ˇ5 "}); cx.editor(|editor, _, _| { - assert!(editor.active_inline_completion.is_none()); + assert!(editor.active_edit_prediction.is_none()); + }); +} + +#[gpui::test] +async fn test_edit_prediction_jump_disabled_for_non_zed_providers(cx: &mut gpui::TestAppContext) { + init_test(cx, |_| {}); + + let mut cx = EditorTestContext::new(cx).await; + let provider = cx.new(|_| FakeNonZedEditPredictionProvider::default()); + assign_editor_completion_provider_non_zed(provider.clone(), &mut cx); + + // Cursor is 2+ lines above the proposed edit + cx.set_state(indoc! {" + line 0 + line ˇ1 + line 2 + line 3 + line + "}); + + propose_edits_non_zed( + &provider, + vec![(Point::new(4, 3)..Point::new(4, 3), " 4")], + &mut cx, + ); + + cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx)); + + // For non-Zed providers, there should be no move completion (jump functionality disabled) + cx.editor(|editor, _, _| { + if let Some(completion_state) = &editor.active_edit_prediction { + // Should be an Edit prediction, not a Move prediction + match &completion_state.completion { + EditPrediction::Edit { .. } => { + // This is expected for non-Zed providers + } + EditPrediction::Move { .. } => { + panic!( + "Non-Zed providers should not show Move predictions (jump functionality)" + ); + } + } + } }); } @@ -234,11 +277,11 @@ fn assert_editor_active_edit_completion( ) { cx.editor(|editor, _, cx| { let completion_state = editor - .active_inline_completion + .active_edit_prediction .as_ref() .expect("editor has no active completion"); - if let InlineCompletion::Edit { edits, .. } = &completion_state.completion { + if let EditPrediction::Edit { edits, .. } = &completion_state.completion { assert(editor.buffer().read(cx).snapshot(cx), edits); } else { panic!("expected edit completion"); @@ -252,11 +295,11 @@ fn assert_editor_active_move_completion( ) { cx.editor(|editor, _, cx| { let completion_state = editor - .active_inline_completion + .active_edit_prediction .as_ref() .expect("editor has no active completion"); - if let InlineCompletion::Move { target, .. } = &completion_state.completion { + if let EditPrediction::Move { target, .. } = &completion_state.completion { assert(editor.buffer().read(cx).snapshot(cx), *target); } else { panic!("expected move completion"); @@ -271,7 +314,7 @@ fn accept_completion(cx: &mut EditorTestContext) { } fn propose_edits<T: ToOffset>( - provider: &Entity<FakeInlineCompletionProvider>, + provider: &Entity<FakeEditPredictionProvider>, edits: Vec<(Range<T>, &str)>, cx: &mut EditorTestContext, ) { @@ -283,7 +326,7 @@ fn propose_edits<T: ToOffset>( cx.update(|_, cx| { provider.update(cx, |provider, _| { - provider.set_inline_completion(Some(inline_completion::InlineCompletion { + provider.set_edit_prediction(Some(edit_prediction::EditPrediction { id: None, edits: edits.collect(), edit_preview: None, @@ -293,7 +336,38 @@ fn propose_edits<T: ToOffset>( } fn assign_editor_completion_provider( - provider: Entity<FakeInlineCompletionProvider>, + provider: Entity<FakeEditPredictionProvider>, + cx: &mut EditorTestContext, +) { + cx.update_editor(|editor, window, cx| { + editor.set_edit_prediction_provider(Some(provider), window, cx); + }) +} + +fn propose_edits_non_zed<T: ToOffset>( + provider: &Entity<FakeNonZedEditPredictionProvider>, + edits: Vec<(Range<T>, &str)>, + cx: &mut EditorTestContext, +) { + let snapshot = cx.buffer_snapshot(); + let edits = edits.into_iter().map(|(range, text)| { + let range = snapshot.anchor_after(range.start)..snapshot.anchor_before(range.end); + (range, text.into()) + }); + + cx.update(|_, cx| { + provider.update(cx, |provider, _| { + provider.set_edit_prediction(Some(edit_prediction::EditPrediction { + id: None, + edits: edits.collect(), + edit_preview: None, + })) + }) + }); +} + +fn assign_editor_completion_provider_non_zed( + provider: Entity<FakeNonZedEditPredictionProvider>, cx: &mut EditorTestContext, ) { cx.update_editor(|editor, window, cx| { @@ -302,20 +376,17 @@ fn assign_editor_completion_provider( } #[derive(Default, Clone)] -pub struct FakeInlineCompletionProvider { - pub completion: Option<inline_completion::InlineCompletion>, +pub struct FakeEditPredictionProvider { + pub completion: Option<edit_prediction::EditPrediction>, } -impl FakeInlineCompletionProvider { - pub fn set_inline_completion( - &mut self, - completion: Option<inline_completion::InlineCompletion>, - ) { +impl FakeEditPredictionProvider { + pub fn set_edit_prediction(&mut self, completion: Option<edit_prediction::EditPrediction>) { self.completion = completion; } } -impl EditPredictionProvider for FakeInlineCompletionProvider { +impl EditPredictionProvider for FakeEditPredictionProvider { fn name() -> &'static str { "fake-completion-provider" } @@ -328,6 +399,84 @@ impl EditPredictionProvider for FakeInlineCompletionProvider { false } + fn supports_jump_to_edit() -> bool { + true + } + + fn is_enabled( + &self, + _buffer: &gpui::Entity<language::Buffer>, + _cursor_position: language::Anchor, + _cx: &gpui::App, + ) -> bool { + true + } + + fn is_refreshing(&self) -> bool { + false + } + + fn refresh( + &mut self, + _project: Option<Entity<Project>>, + _buffer: gpui::Entity<language::Buffer>, + _cursor_position: language::Anchor, + _debounce: bool, + _cx: &mut gpui::Context<Self>, + ) { + } + + fn cycle( + &mut self, + _buffer: gpui::Entity<language::Buffer>, + _cursor_position: language::Anchor, + _direction: edit_prediction::Direction, + _cx: &mut gpui::Context<Self>, + ) { + } + + fn accept(&mut self, _cx: &mut gpui::Context<Self>) {} + + fn discard(&mut self, _cx: &mut gpui::Context<Self>) {} + + fn suggest<'a>( + &mut self, + _buffer: &gpui::Entity<language::Buffer>, + _cursor_position: language::Anchor, + _cx: &mut gpui::Context<Self>, + ) -> Option<edit_prediction::EditPrediction> { + self.completion.clone() + } +} + +#[derive(Default, Clone)] +pub struct FakeNonZedEditPredictionProvider { + pub completion: Option<edit_prediction::EditPrediction>, +} + +impl FakeNonZedEditPredictionProvider { + pub fn set_edit_prediction(&mut self, completion: Option<edit_prediction::EditPrediction>) { + self.completion = completion; + } +} + +impl EditPredictionProvider for FakeNonZedEditPredictionProvider { + fn name() -> &'static str { + "fake-non-zed-provider" + } + + fn display_name() -> &'static str { + "Fake Non-Zed Provider" + } + + fn show_completions_in_menu() -> bool { + false + } + + fn supports_jump_to_edit() -> bool { + false + } + fn is_enabled( &self, _buffer: &gpui::Entity<language::Buffer>, @@ -355,7 +504,7 @@ impl EditPredictionProvider for FakeInlineCompletionProvider { &mut self, _buffer: gpui::Entity<language::Buffer>, _cursor_position: language::Anchor, - _direction: inline_completion::Direction, + _direction: edit_prediction::Direction, _cx: &mut gpui::Context<Self>, ) { } @@ -369,7 +518,7 @@ impl EditPredictionProvider for FakeInlineCompletionProvider { _buffer: &gpui::Entity<language::Buffer>, _cursor_position: language::Anchor, _cx: &mut gpui::Context<Self>, - ) -> Option<inline_completion::InlineCompletion> { + ) -> Option<edit_prediction::EditPrediction> { self.completion.clone() } } diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index d7e6e42659dd98f32bbe4bad17ca9411ee8d453b..bd7963a2e2ee448e876a59ba8948182cf8a2ca48 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -43,50 +43,65 @@ pub mod tasks; #[cfg(test)] mod code_completion_tests; #[cfg(test)] -mod editor_tests; +mod edit_prediction_tests; #[cfg(test)] -mod inline_completion_tests; +mod editor_tests; mod signature_help; #[cfg(any(test, feature = "test-support"))] pub mod test; pub(crate) use actions::*; -pub use actions::{AcceptEditPrediction, OpenExcerpts, OpenExcerptsSplit}; +pub use display_map::{ChunkRenderer, ChunkRendererContext, DisplayPoint, FoldPlaceholder}; +pub use edit_prediction::Direction; +pub use editor_settings::{ + CurrentLineHighlight, DocumentColorsRenderMode, EditorSettings, HideMouseMode, + ScrollBeyondLastLine, ScrollbarAxes, SearchSettings, ShowMinimap, ShowScrollbar, +}; +pub use editor_settings_controls::*; +pub use element::{ + CursorLayout, EditorElement, HighlightedRange, HighlightedRangeLine, PointForPosition, +}; +pub use git::blame::BlameRenderer; +pub use hover_popover::hover_markdown_style; +pub use items::MAX_TAB_TITLE_LEN; +pub use lsp::CompletionContext; +pub use lsp_ext::lsp_tasks; +pub use multi_buffer::{ + Anchor, AnchorRangeExt, ExcerptId, ExcerptRange, MultiBuffer, MultiBufferSnapshot, PathKey, + RowInfo, ToOffset, ToPoint, +}; +pub use proposed_changes_editor::{ + ProposedChangeLocation, ProposedChangesEditor, ProposedChangesEditorToolbar, +}; +pub use text::Bias; + +use ::git::{ + Restore, + blame::{BlameEntry, ParsedCommitMessage}, +}; use aho_corasick::AhoCorasick; use anyhow::{Context as _, Result, anyhow}; use blink_manager::BlinkManager; use buffer_diff::DiffHunkStatus; use client::{Collaborator, ParticipantIndex}; use clock::{AGENT_REPLICA_ID, ReplicaId}; +use code_context_menus::{ + AvailableCodeAction, CodeActionContents, CodeActionsItem, CodeActionsMenu, CodeContextMenu, + CompletionsMenu, ContextMenuOrigin, +}; use collections::{BTreeMap, HashMap, HashSet, VecDeque}; use convert_case::{Case, Casing}; use dap::TelemetrySpawnLocation; use display_map::*; -pub use display_map::{ChunkRenderer, ChunkRendererContext, DisplayPoint, FoldPlaceholder}; -pub use editor_settings::{ - CurrentLineHighlight, DocumentColorsRenderMode, EditorSettings, HideMouseMode, - ScrollBeyondLastLine, ScrollbarAxes, SearchSettings, ShowScrollbar, -}; +use edit_prediction::{EditPredictionProvider, EditPredictionProviderHandle}; use editor_settings::{GoToDefinitionFallback, Minimap as MinimapSettings}; -pub use editor_settings_controls::*; use element::{AcceptEditPredictionBinding, LineWithInvisibles, PositionMap, layout_line}; -pub use element::{ - CursorLayout, EditorElement, HighlightedRange, HighlightedRangeLine, PointForPosition, -}; use futures::{ FutureExt, StreamExt as _, future::{self, Shared, join}, stream::FuturesUnordered, }; use fuzzy::{StringMatch, StringMatchCandidate}; -use lsp_colors::LspColorData; - -use ::git::blame::BlameEntry; -use ::git::{Restore, blame::ParsedCommitMessage}; -use code_context_menus::{ - AvailableCodeAction, CodeActionContents, CodeActionsItem, CodeActionsMenu, CodeContextMenu, - CompletionsMenu, ContextMenuOrigin, -}; use git::blame::{GitBlame, GlobalBlameRenderer}; use gpui::{ Action, Animation, AnimationExt, AnyElement, App, AppContext, AsyncWindowContext, @@ -100,32 +115,42 @@ use gpui::{ }; use highlight_matching_bracket::refresh_matching_bracket_highlights; use hover_links::{HoverLink, HoveredLinkState, InlayHighlight, find_file}; -pub use hover_popover::hover_markdown_style; use hover_popover::{HoverState, hide_hover}; use indent_guides::ActiveIndentGuidesState; use inlay_hint_cache::{InlayHintCache, InlaySplice, InvalidationStrategy}; -pub use inline_completion::Direction; -use inline_completion::{EditPredictionProvider, InlineCompletionProviderHandle}; -pub use items::MAX_TAB_TITLE_LEN; use itertools::Itertools; use language::{ - AutoindentMode, BracketMatch, BracketPair, Buffer, Capability, CharKind, CodeLabel, - CursorShape, DiagnosticEntry, DiffOptions, DocumentationConfig, EditPredictionsMode, - EditPreview, HighlightedText, IndentKind, IndentSize, Language, OffsetRangeExt, Point, - Selection, SelectionGoal, TextObject, TransactionId, TreeSitterOptions, WordsQuery, + AutoindentMode, BlockCommentConfig, BracketMatch, BracketPair, Buffer, BufferRow, + BufferSnapshot, Capability, CharClassifier, CharKind, CodeLabel, CursorShape, DiagnosticEntry, + DiffOptions, EditPredictionsMode, EditPreview, HighlightedText, IndentKind, IndentSize, + Language, OffsetRangeExt, Point, Runnable, RunnableRange, Selection, SelectionGoal, TextObject, + TransactionId, TreeSitterOptions, WordsQuery, language_settings::{ self, InlayHintSettings, LspInsertMode, RewrapBehavior, WordsCompletionMode, all_language_settings, language_settings, }, - point_from_lsp, text_diff_with_options, + point_from_lsp, point_to_lsp, text_diff_with_options, }; -use language::{BufferRow, CharClassifier, Runnable, RunnableRange, point_to_lsp}; use linked_editing_ranges::refresh_linked_ranges; +use lsp::{ + CodeActionKind, CompletionItemKind, CompletionTriggerKind, InsertTextFormat, InsertTextMode, + LanguageServerId, +}; +use lsp_colors::LspColorData; use markdown::Markdown; use mouse_context_menu::MouseContextMenu; +use movement::TextLayoutDetails; +use multi_buffer::{ + ExcerptInfo, ExpandExcerptDirection, MultiBufferDiffHunk, MultiBufferPoint, MultiBufferRow, + MultiOrSingleBufferOffsetRange, ToOffsetUtf16, +}; +use parking_lot::Mutex; use persistence::DB; use project::{ - BreakpointWithPosition, CompletionResponse, ProjectPath, + BreakpointWithPosition, CodeAction, Completion, CompletionIntent, CompletionResponse, + CompletionSource, DisableAiSettings, DocumentHighlight, InlayHint, Location, LocationLink, + PrepareRenameResponse, Project, ProjectItem, ProjectPath, ProjectTransaction, TaskSourceKind, + debugger::breakpoint_store::Breakpoint, debugger::{ breakpoint_store::{ BreakpointEditAction, BreakpointSessionState, BreakpointState, BreakpointStore, @@ -134,44 +159,12 @@ use project::{ session::{Session, SessionEvent}, }, git_store::{GitStoreEvent, RepositoryEvent}, - project_settings::DiagnosticSeverity, -}; - -pub use git::blame::BlameRenderer; -pub use proposed_changes_editor::{ - ProposedChangeLocation, ProposedChangesEditor, ProposedChangesEditorToolbar, -}; -use std::{cell::OnceCell, iter::Peekable, ops::Not}; -use task::{ResolvedTask, RunnableTag, TaskTemplate, TaskVariables}; - -pub use lsp::CompletionContext; -use lsp::{ - CodeActionKind, CompletionItemKind, CompletionTriggerKind, InsertTextFormat, InsertTextMode, - LanguageServerId, LanguageServerName, -}; - -use language::BufferSnapshot; -pub use lsp_ext::lsp_tasks; -use movement::TextLayoutDetails; -pub use multi_buffer::{ - Anchor, AnchorRangeExt, ExcerptId, ExcerptRange, MultiBuffer, MultiBufferSnapshot, PathKey, - RowInfo, ToOffset, ToPoint, -}; -use multi_buffer::{ - ExcerptInfo, ExpandExcerptDirection, MultiBufferDiffHunk, MultiBufferPoint, MultiBufferRow, - MultiOrSingleBufferOffsetRange, ToOffsetUtf16, -}; -use parking_lot::Mutex; -use project::{ - CodeAction, Completion, CompletionIntent, CompletionSource, DocumentHighlight, InlayHint, - Location, LocationLink, PrepareRenameResponse, Project, ProjectItem, ProjectTransaction, - TaskSourceKind, - debugger::breakpoint_store::Breakpoint, lsp_store::{CompletionDocumentation, FormatTrigger, LspFormatTarget, OpenLspBufferHandle}, + project_settings::{DiagnosticSeverity, GoToDiagnosticSeverityFilter}, project_settings::{GitGutterSetting, ProjectSettings}, }; -use rand::prelude::*; -use rpc::{ErrorExt, proto::*}; +use rand::{seq::SliceRandom, thread_rng}; +use rpc::{ErrorCode, ErrorExt, proto::PeerId}; use scroll::{Autoscroll, OngoingScroll, ScrollAnchor, ScrollManager, ScrollbarAutoHide}; use selections_collection::{ MutableSelectionsCollection, SelectionsCollection, resolve_selections, @@ -180,21 +173,24 @@ use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsLocation, SettingsStore, update_settings_file}; use smallvec::{SmallVec, smallvec}; use snippet::Snippet; -use std::sync::Arc; use std::{ any::TypeId, borrow::Cow, + cell::OnceCell, cell::RefCell, cmp::{self, Ordering, Reverse}, + iter::Peekable, mem, num::NonZeroU32, + ops::Not, ops::{ControlFlow, Deref, DerefMut, Range, RangeInclusive}, path::{Path, PathBuf}, rc::Rc, + sync::Arc, time::{Duration, Instant}, }; -pub use sum_tree::Bias; use sum_tree::TreeMap; +use task::{ResolvedTask, RunnableTag, TaskTemplate, TaskVariables}; use text::{BufferId, FromAnchor, OffsetUtf16, Rope}; use theme::{ ActiveTheme, PlayerColor, StatusColors, SyntaxTheme, Theme, ThemeSettings, @@ -216,10 +212,8 @@ use workspace::{ use crate::{ code_context_menus::CompletionsMenuSource, - hover_links::{find_url, find_url_from_range}, -}; -use crate::{ editor_settings::MultiCursorModifier, + hover_links::{find_url, find_url_from_range}, signature_help::{SignatureHelpHiddenBy, SignatureHelpState}, }; @@ -274,7 +268,7 @@ impl InlineValueCache { #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum InlayId { - InlineCompletion(usize), + EditPrediction(usize), DebuggerValue(usize), // LSP Hint(usize), @@ -284,7 +278,7 @@ pub enum InlayId { impl InlayId { fn id(&self) -> usize { match self { - Self::InlineCompletion(id) => *id, + Self::EditPrediction(id) => *id, Self::DebuggerValue(id) => *id, Self::Hint(id) => *id, Self::Color(id) => *id, @@ -356,6 +350,7 @@ pub fn init(cx: &mut App) { workspace.register_action(Editor::new_file_vertical); workspace.register_action(Editor::new_file_horizontal); workspace.register_action(Editor::cancel_language_server_work); + workspace.register_action(Editor::toggle_focus); }, ) .detach(); @@ -482,9 +477,7 @@ pub enum SelectMode { #[derive(Clone, PartialEq, Eq, Debug)] pub enum EditorMode { - SingleLine { - auto_width: bool, - }, + SingleLine, AutoHeight { min_lines: usize, max_lines: Option<usize>, @@ -554,7 +547,7 @@ pub struct EditorStyle { pub syntax: Arc<SyntaxTheme>, pub status: StatusColors, pub inlay_hints_style: HighlightStyle, - pub inline_completion_styles: InlineCompletionStyles, + pub edit_prediction_styles: EditPredictionStyles, pub unnecessary_code_fade: f32, pub show_underlines: bool, } @@ -573,7 +566,7 @@ impl Default for EditorStyle { // style and retrieve them directly from the theme. status: StatusColors::dark(), inlay_hints_style: HighlightStyle::default(), - inline_completion_styles: InlineCompletionStyles { + edit_prediction_styles: EditPredictionStyles { insertion: HighlightStyle::default(), whitespace: HighlightStyle::default(), }, @@ -595,8 +588,8 @@ pub fn make_inlay_hints_style(cx: &mut App) -> HighlightStyle { } } -pub fn make_suggestion_styles(cx: &mut App) -> InlineCompletionStyles { - InlineCompletionStyles { +pub fn make_suggestion_styles(cx: &mut App) -> EditPredictionStyles { + EditPredictionStyles { insertion: HighlightStyle { color: Some(cx.theme().status().predictive), ..HighlightStyle::default() @@ -616,7 +609,7 @@ pub(crate) enum EditDisplayMode { Inline, } -enum InlineCompletion { +enum EditPrediction { Edit { edits: Vec<(Range<Anchor>, String)>, edit_preview: Option<EditPreview>, @@ -629,9 +622,9 @@ enum InlineCompletion { }, } -struct InlineCompletionState { +struct EditPredictionState { inlay_ids: Vec<InlayId>, - completion: InlineCompletion, + completion: EditPrediction, completion_id: Option<SharedString>, invalidation_range: Range<Anchor>, } @@ -644,7 +637,7 @@ enum EditPredictionSettings { }, } -enum InlineCompletionHighlight {} +enum EditPredictionHighlight {} #[derive(Debug, Clone)] struct InlineDiagnostic { @@ -655,7 +648,7 @@ struct InlineDiagnostic { severity: lsp::DiagnosticSeverity, } -pub enum MenuInlineCompletionsPolicy { +pub enum MenuEditPredictionsPolicy { Never, ByProvider, } @@ -951,6 +944,7 @@ struct InlineBlamePopover { hide_task: Option<Task<()>>, popover_bounds: Option<Bounds<Pixels>>, popover_state: InlineBlamePopoverState, + keyboard_grace: bool, } enum SelectionDragState { @@ -1093,15 +1087,15 @@ pub struct Editor { pending_mouse_down: Option<Rc<RefCell<Option<MouseDownEvent>>>>, gutter_hovered: bool, hovered_link_state: Option<HoveredLinkState>, - edit_prediction_provider: Option<RegisteredInlineCompletionProvider>, + edit_prediction_provider: Option<RegisteredEditPredictionProvider>, code_action_providers: Vec<Rc<dyn CodeActionProvider>>, - active_inline_completion: Option<InlineCompletionState>, + active_edit_prediction: Option<EditPredictionState>, /// Used to prevent flickering as the user types while the menu is open - stale_inline_completion_in_menu: Option<InlineCompletionState>, + stale_edit_prediction_in_menu: Option<EditPredictionState>, edit_prediction_settings: EditPredictionSettings, - inline_completions_hidden_for_vim_mode: bool, - show_inline_completions_override: Option<bool>, - menu_inline_completions_policy: MenuInlineCompletionsPolicy, + edit_predictions_hidden_for_vim_mode: bool, + show_edit_predictions_override: Option<bool>, + menu_edit_predictions_policy: MenuEditPredictionsPolicy, edit_prediction_preview: EditPredictionPreview, edit_prediction_indent_conflict: bool, edit_prediction_requires_modifier_in_indent_conflict: bool, @@ -1304,6 +1298,7 @@ impl Default for SelectionHistoryMode { /// /// Similarly, you might want to disable scrolling if you don't want the viewport to /// move. +#[derive(Clone)] pub struct SelectionEffects { nav_history: Option<bool>, completions: bool, @@ -1515,8 +1510,8 @@ pub struct RenameState { struct InvalidationStack<T>(Vec<T>); -struct RegisteredInlineCompletionProvider { - provider: Arc<dyn InlineCompletionProviderHandle>, +struct RegisteredEditPredictionProvider { + provider: Arc<dyn EditPredictionProviderHandle>, _subscription: Subscription, } @@ -1662,13 +1657,7 @@ impl Editor { pub fn single_line(window: &mut Window, cx: &mut Context<Self>) -> Self { let buffer = cx.new(|cx| Buffer::local("", cx)); let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); - Self::new( - EditorMode::SingleLine { auto_width: false }, - buffer, - None, - window, - cx, - ) + Self::new(EditorMode::SingleLine, buffer, None, window, cx) } pub fn multi_line(window: &mut Window, cx: &mut Context<Self>) -> Self { @@ -1677,18 +1666,6 @@ impl Editor { Self::new(EditorMode::full(), buffer, None, window, cx) } - pub fn auto_width(window: &mut Window, cx: &mut Context<Self>) -> Self { - let buffer = cx.new(|cx| Buffer::local("", cx)); - let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); - Self::new( - EditorMode::SingleLine { auto_width: true }, - buffer, - None, - window, - cx, - ) - } - pub fn auto_height( min_lines: usize, max_lines: usize, @@ -1791,10 +1768,11 @@ impl Editor { ) -> Self { debug_assert!( display_map.is_none() || mode.is_minimap(), - "Providing a display map for a new editor is only intended for the minimap and might have unindended side effects otherwise!" + "Providing a display map for a new editor is only intended for the minimap and might have unintended side effects otherwise!" ); let full_mode = mode.is_full(); + let is_minimap = mode.is_minimap(); let diagnostics_max_severity = if full_mode { EditorSettings::get_global(cx) .diagnostics_max_severity @@ -1855,13 +1833,19 @@ impl Editor { let selections = SelectionsCollection::new(display_map.clone(), buffer.clone()); - let blink_manager = cx.new(|cx| BlinkManager::new(CURSOR_BLINK_INTERVAL, cx)); + let blink_manager = cx.new(|cx| { + let mut blink_manager = BlinkManager::new(CURSOR_BLINK_INTERVAL, cx); + if is_minimap { + blink_manager.disable(cx); + } + blink_manager + }); let soft_wrap_mode_override = matches!(mode, EditorMode::SingleLine { .. }) .then(|| language_settings::SoftWrap::None); let mut project_subscriptions = Vec::new(); - if mode.is_full() { + if full_mode { if let Some(project) = project.as_ref() { project_subscriptions.push(cx.subscribe_in( project, @@ -1880,7 +1864,6 @@ impl Editor { editor.tasks_update_task = Some(editor.refresh_runnables(window, cx)); } - editor.update_lsp_data(true, None, window, cx); } project::Event::SnippetEdit(id, snippet_edits) => { if let Some(buffer) = editor.buffer.read(cx).buffer(*id) { @@ -1902,6 +1885,11 @@ impl Editor { } } } + project::Event::LanguageServerBufferRegistered { buffer_id, .. } => { + if editor.buffer().read(cx).buffer(*buffer_id).is_some() { + editor.update_lsp_data(false, Some(*buffer_id), window, cx); + } + } _ => {} }, )); @@ -1972,18 +1960,23 @@ impl Editor { let inlay_hint_settings = inlay_hint_settings(selections.newest_anchor().head(), &buffer_snapshot, cx); let focus_handle = cx.focus_handle(); - cx.on_focus(&focus_handle, window, Self::handle_focus) - .detach(); - cx.on_focus_in(&focus_handle, window, Self::handle_focus_in) - .detach(); - cx.on_focus_out(&focus_handle, window, Self::handle_focus_out) - .detach(); - cx.on_blur(&focus_handle, window, Self::handle_blur) - .detach(); - cx.observe_pending_input(window, Self::observe_pending_input) - .detach(); - - let show_indent_guides = if matches!(mode, EditorMode::SingleLine { .. }) { + if !is_minimap { + cx.on_focus(&focus_handle, window, Self::handle_focus) + .detach(); + cx.on_focus_in(&focus_handle, window, Self::handle_focus_in) + .detach(); + cx.on_focus_out(&focus_handle, window, Self::handle_focus_out) + .detach(); + cx.on_blur(&focus_handle, window, Self::handle_blur) + .detach(); + cx.observe_pending_input(window, Self::observe_pending_input) + .detach(); + } + + let show_indent_guides = if matches!( + mode, + EditorMode::SingleLine { .. } | EditorMode::Minimap { .. } + ) { Some(false) } else { None @@ -2049,10 +2042,10 @@ impl Editor { minimap_visibility: MinimapVisibility::for_mode(&mode, cx), offset_content: !matches!(mode, EditorMode::SingleLine { .. }), show_breadcrumbs: EditorSettings::get_global(cx).toolbar.breadcrumbs, - show_gutter: mode.is_full(), - show_line_numbers: None, + show_gutter: full_mode, + show_line_numbers: (!full_mode).then_some(false), use_relative_line_numbers: None, - disable_expand_excerpt_buttons: false, + disable_expand_excerpt_buttons: !full_mode, show_git_diff_gutter: None, show_code_actions: None, show_runnables: None, @@ -2086,7 +2079,7 @@ impl Editor { document_highlights_task: None, linked_editing_range_task: None, pending_rename: None, - searchable: true, + searchable: !is_minimap, cursor_shape: EditorSettings::get_global(cx) .cursor_shape .unwrap_or_default(), @@ -2094,9 +2087,9 @@ impl Editor { autoindent_mode: Some(AutoindentMode::EachLine), collapse_matches: false, workspace: None, - input_enabled: true, - use_modal_editing: mode.is_full(), - read_only: mode.is_minimap(), + input_enabled: !is_minimap, + use_modal_editing: full_mode, + read_only: is_minimap, use_autoclose: true, use_auto_surround: true, auto_replace_emoji_shortcode: false, @@ -2107,16 +2100,15 @@ impl Editor { pending_mouse_down: None, hovered_link_state: None, edit_prediction_provider: None, - active_inline_completion: None, - stale_inline_completion_in_menu: None, + active_edit_prediction: None, + stale_edit_prediction_in_menu: None, edit_prediction_preview: EditPredictionPreview::Inactive { released_too_fast: false, }, - inline_diagnostics_enabled: mode.is_full(), - diagnostics_enabled: mode.is_full(), + inline_diagnostics_enabled: full_mode, + diagnostics_enabled: full_mode, inline_value_cache: InlineValueCache::new(inlay_hint_settings.show_value_hints), inlay_hint_cache: InlayHintCache::new(inlay_hint_settings), - gutter_hovered: false, pixel_position_of_newest_cursor: None, last_bounds: None, @@ -2128,9 +2120,9 @@ impl Editor { hovered_cursors: HashMap::default(), next_editor_action_id: EditorActionId::default(), editor_actions: Rc::default(), - inline_completions_hidden_for_vim_mode: false, - show_inline_completions_override: None, - menu_inline_completions_policy: MenuInlineCompletionsPolicy::ByProvider, + edit_predictions_hidden_for_vim_mode: false, + show_edit_predictions_override: None, + menu_edit_predictions_policy: MenuEditPredictionsPolicy::ByProvider, edit_prediction_settings: EditPredictionSettings::Disabled, edit_prediction_indent_conflict: false, edit_prediction_requires_modifier_in_indent_conflict: true, @@ -2139,9 +2131,10 @@ impl Editor { show_git_blame_inline: false, show_selection_menu: None, show_git_blame_inline_delay_task: None, - git_blame_inline_enabled: ProjectSettings::get_global(cx).git.inline_blame_enabled(), + git_blame_inline_enabled: full_mode + && ProjectSettings::get_global(cx).git.inline_blame_enabled(), render_diff_hunk_controls: Arc::new(render_diff_hunk_controls), - serialize_dirty_buffers: !mode.is_minimap() + serialize_dirty_buffers: !is_minimap && ProjectSettings::get_global(cx) .session .restore_unsaved_buffers, @@ -2152,27 +2145,31 @@ impl Editor { breakpoint_store, gutter_breakpoint_indicator: (None, None), hovered_diff_hunk_row: None, - _subscriptions: vec![ - cx.observe(&buffer, Self::on_buffer_changed), - cx.subscribe_in(&buffer, window, Self::on_buffer_event), - cx.observe_in(&display_map, window, Self::on_display_map_changed), - cx.observe(&blink_manager, |_, _, cx| cx.notify()), - cx.observe_global_in::<SettingsStore>(window, Self::settings_changed), - observe_buffer_font_size_adjustment(cx, |_, cx| cx.notify()), - cx.observe_window_activation(window, |editor, window, cx| { - let active = window.is_window_active(); - editor.blink_manager.update(cx, |blink_manager, cx| { - if active { - blink_manager.enable(cx); - } else { - blink_manager.disable(cx); - } - }); - if active { - editor.show_mouse_cursor(cx); - } - }), - ], + _subscriptions: (!is_minimap) + .then(|| { + vec![ + cx.observe(&buffer, Self::on_buffer_changed), + cx.subscribe_in(&buffer, window, Self::on_buffer_event), + cx.observe_in(&display_map, window, Self::on_display_map_changed), + cx.observe(&blink_manager, |_, _, cx| cx.notify()), + cx.observe_global_in::<SettingsStore>(window, Self::settings_changed), + observe_buffer_font_size_adjustment(cx, |_, cx| cx.notify()), + cx.observe_window_activation(window, |editor, window, cx| { + let active = window.is_window_active(); + editor.blink_manager.update(cx, |blink_manager, cx| { + if active { + blink_manager.enable(cx); + } else { + blink_manager.disable(cx); + } + }); + if active { + editor.show_mouse_cursor(cx); + } + }), + ] + }) + .unwrap_or_default(), tasks_update_task: None, pull_diagnostics_task: Task::ready(()), colors: None, @@ -2203,6 +2200,11 @@ impl Editor { selection_drag_state: SelectionDragState::None, folding_newlines: Task::ready(()), }; + + if is_minimap { + return editor; + } + if let Some(breakpoints) = editor.breakpoint_store.as_ref() { editor ._subscriptions @@ -2322,7 +2324,10 @@ impl Editor { editor.update_lsp_data(false, None, window, cx); } - editor.report_editor_event("Editor Opened", None, cx); + if editor.mode.is_full() { + editor.report_editor_event("Editor Opened", None, cx); + } + editor } @@ -2349,7 +2354,7 @@ impl Editor { } pub fn key_context(&self, window: &Window, cx: &App) -> KeyContext { - self.key_context_internal(self.has_active_inline_completion(), window, cx) + self.key_context_internal(self.has_active_edit_prediction(), window, cx) } fn key_context_internal( @@ -2377,13 +2382,17 @@ impl Editor { } match self.context_menu.borrow().as_ref() { - Some(CodeContextMenu::Completions(_)) => { - key_context.add("menu"); - key_context.add("showing_completions"); + Some(CodeContextMenu::Completions(menu)) => { + if menu.visible() { + key_context.add("menu"); + key_context.add("showing_completions"); + } } - Some(CodeContextMenu::CodeActions(_)) => { - key_context.add("menu"); - key_context.add("showing_code_actions") + Some(CodeContextMenu::CodeActions(menu)) => { + if menu.visible() { + key_context.add("menu"); + key_context.add("showing_code_actions") + } } None => {} } @@ -2696,6 +2705,11 @@ impl Editor { self.completion_provider = provider; } + #[cfg(any(test, feature = "test-support"))] + pub fn completion_provider(&self) -> Option<Rc<dyn CompletionProvider>> { + self.completion_provider.clone() + } + pub fn semantics_provider(&self) -> Option<Rc<dyn SemanticsProvider>> { self.semantics_provider.clone() } @@ -2712,17 +2726,16 @@ impl Editor { ) where T: EditPredictionProvider, { - self.edit_prediction_provider = - provider.map(|provider| RegisteredInlineCompletionProvider { - _subscription: cx.observe_in(&provider, window, |this, _, window, cx| { - if this.focus_handle.is_focused(window) { - this.update_visible_inline_completion(window, cx); - } - }), - provider: Arc::new(provider), - }); + self.edit_prediction_provider = provider.map(|provider| RegisteredEditPredictionProvider { + _subscription: cx.observe_in(&provider, window, |this, _, window, cx| { + if this.focus_handle.is_focused(window) { + this.update_visible_edit_prediction(window, cx); + } + }), + provider: Arc::new(provider), + }); self.update_edit_prediction_settings(cx); - self.refresh_inline_completion(false, false, window, cx); + self.refresh_edit_prediction(false, false, window, cx); } pub fn placeholder_text(&self) -> Option<&str> { @@ -2793,24 +2806,24 @@ impl Editor { self.input_enabled = input_enabled; } - pub fn set_inline_completions_hidden_for_vim_mode( + pub fn set_edit_predictions_hidden_for_vim_mode( &mut self, hidden: bool, window: &mut Window, cx: &mut Context<Self>, ) { - if hidden != self.inline_completions_hidden_for_vim_mode { - self.inline_completions_hidden_for_vim_mode = hidden; + if hidden != self.edit_predictions_hidden_for_vim_mode { + self.edit_predictions_hidden_for_vim_mode = hidden; if hidden { - self.update_visible_inline_completion(window, cx); + self.update_visible_edit_prediction(window, cx); } else { - self.refresh_inline_completion(true, false, window, cx); + self.refresh_edit_prediction(true, false, window, cx); } } } - pub fn set_menu_inline_completions_policy(&mut self, value: MenuInlineCompletionsPolicy) { - self.menu_inline_completions_policy = value; + pub fn set_menu_edit_predictions_policy(&mut self, value: MenuEditPredictionsPolicy) { + self.menu_edit_predictions_policy = value; } pub fn set_autoindent(&mut self, autoindent: bool) { @@ -2847,7 +2860,7 @@ impl Editor { window: &mut Window, cx: &mut Context<Self>, ) { - if self.show_inline_completions_override.is_some() { + if self.show_edit_predictions_override.is_some() { self.set_show_edit_predictions(None, window, cx); } else { let show_edit_predictions = !self.edit_predictions_enabled(); @@ -2861,17 +2874,17 @@ impl Editor { window: &mut Window, cx: &mut Context<Self>, ) { - self.show_inline_completions_override = show_edit_predictions; + self.show_edit_predictions_override = show_edit_predictions; self.update_edit_prediction_settings(cx); if let Some(false) = show_edit_predictions { - self.discard_inline_completion(false, cx); + self.discard_edit_prediction(false, cx); } else { - self.refresh_inline_completion(false, true, window, cx); + self.refresh_edit_prediction(false, true, window, cx); } } - fn inline_completions_disabled_in_scope( + fn edit_predictions_disabled_in_scope( &self, buffer: &Entity<Buffer>, buffer_position: language::Anchor, @@ -2933,10 +2946,12 @@ impl Editor { } } + let selection_anchors = self.selections.disjoint_anchors(); + if self.focus_handle.is_focused(window) && self.leader_id.is_none() { self.buffer.update(cx, |buffer, cx| { buffer.set_active_selections( - &self.selections.disjoint_anchors(), + &selection_anchors, self.selections.line_mode, self.cursor_shape, cx, @@ -2953,9 +2968,8 @@ impl Editor { self.select_next_state = None; self.select_prev_state = None; self.select_syntax_node_history.try_clear(); - self.invalidate_autoclose_regions(&self.selections.disjoint_anchors(), buffer); - self.snippet_stack - .invalidate(&self.selections.disjoint_anchors(), buffer); + self.invalidate_autoclose_regions(&selection_anchors, buffer); + self.snippet_stack.invalidate(&selection_anchors, buffer); self.take_rename(false, window, cx); let newest_selection = self.selections.newest_anchor(); @@ -3037,7 +3051,7 @@ impl Editor { self.refresh_document_highlights(cx); self.refresh_selected_text_highlights(false, window, cx); refresh_matching_bracket_highlights(self, window, cx); - self.update_visible_inline_completion(window, cx); + self.update_visible_edit_prediction(window, cx); self.edit_prediction_requires_modifier_in_indent_conflict = true; linked_editing_ranges::refresh_linked_ranges(self, window, cx); self.inline_blame_popover.take(); @@ -3827,7 +3841,7 @@ impl Editor { return true; } - if is_user_requested && self.discard_inline_completion(true, cx) { + if is_user_requested && self.discard_edit_prediction(true, cx) { return true; } @@ -4036,7 +4050,8 @@ impl Editor { // then don't insert that closing bracket again; just move the selection // past the closing bracket. let should_skip = selection.end == region.range.end.to_point(&snapshot) - && text.as_ref() == region.pair.end.as_str(); + && text.as_ref() == region.pair.end.as_str() + && snapshot.contains_str_at(region.range.end, text.as_ref()); if should_skip { let anchor = snapshot.anchor_after(selection.end); new_selections @@ -4232,7 +4247,7 @@ impl Editor { ); } - let had_active_inline_completion = this.has_active_inline_completion(); + let had_active_edit_prediction = this.has_active_edit_prediction(); this.change_selections( SelectionEffects::scroll(Autoscroll::fit()).completions(false), window, @@ -4257,7 +4272,7 @@ impl Editor { } let trigger_in_words = - this.show_edit_predictions_in_menu() || !had_active_inline_completion; + this.show_edit_predictions_in_menu() || !had_active_edit_prediction; if this.hard_wrap.is_some() { let latest: Range<Point> = this.selections.newest(cx).range(); if latest.is_empty() @@ -4279,7 +4294,7 @@ impl Editor { } this.trigger_completion_on_input(&text, trigger_in_words, window, cx); linked_editing_ranges::refresh_linked_ranges(this, window, cx); - this.refresh_inline_completion(true, false, window, cx); + this.refresh_edit_prediction(true, false, window, cx); jsx_tag_auto_close::handle_from(this, initial_buffer_versions, window, cx); }); } @@ -4397,7 +4412,9 @@ impl Editor { }) .max_by_key(|(_, len)| *len)?; - if let Some((block_start, _)) = language.block_comment_delimiters() + if let Some(BlockCommentConfig { + start: block_start, .. + }) = language.block_comment() { let block_start_trimmed = block_start.trim_end(); if block_start_trimmed.starts_with(delimiter.trim_end()) { @@ -4434,13 +4451,12 @@ impl Editor { return None; } - let DocumentationConfig { + let BlockCommentConfig { start: start_tag, end: end_tag, prefix: delimiter, tab_size: len, - } = language.documentation()?; - + } = language.documentation_comment()?; let is_within_block_comment = buffer .language_scope_at(start_point) .is_some_and(|scope| scope.override_name() == Some("comment")); @@ -4510,7 +4526,7 @@ impl Editor { let cursor_is_at_start_of_end_tag = column == end_tag_offset; if cursor_is_at_start_of_end_tag { - indent_on_extra_newline.len = (*len).into(); + indent_on_extra_newline.len = *len; } } cursor_is_before_end_tag @@ -4523,7 +4539,7 @@ impl Editor { && cursor_is_before_end_tag_if_exists { if cursor_is_after_start_tag { - indent_on_newline.len = (*len).into(); + indent_on_newline.len = *len; } Some(delimiter.clone()) } else { @@ -4613,7 +4629,7 @@ impl Editor { .collect(); this.change_selections(Default::default(), window, cx, |s| s.select(new_selections)); - this.refresh_inline_completion(true, false, window, cx); + this.refresh_edit_prediction(true, false, window, cx); }); } @@ -4961,13 +4977,17 @@ impl Editor { }) } - /// Remove any autoclose regions that no longer contain their selection. + /// Remove any autoclose regions that no longer contain their selection or have invalid anchors in ranges. fn invalidate_autoclose_regions( &mut self, mut selections: &[Selection<Anchor>], buffer: &MultiBufferSnapshot, ) { self.autoclose_regions.retain(|state| { + if !state.range.start.is_valid(buffer) || !state.range.end.is_valid(buffer) { + return false; + } + let mut i = 0; while let Some(selection) = selections.get(i) { if selection.end.cmp(&state.range.start, buffer).is_lt() { @@ -5442,7 +5462,7 @@ impl Editor { }; let (word_replace_range, word_to_exclude) = if let (word_range, Some(CharKind::Word)) = - buffer_snapshot.surrounding_word(buffer_position) + buffer_snapshot.surrounding_word(buffer_position, false) { let word_to_exclude = buffer_snapshot .text_for_range(word_range.clone()) @@ -5657,9 +5677,9 @@ impl Editor { crate::hover_popover::hide_hover(editor, cx); if editor.show_edit_predictions_in_menu() { - editor.update_visible_inline_completion(window, cx); + editor.update_visible_edit_prediction(window, cx); } else { - editor.discard_inline_completion(false, cx); + editor.discard_edit_prediction(false, cx); } cx.notify(); @@ -5670,10 +5690,10 @@ impl Editor { if editor.completion_tasks.len() <= 1 { // If there are no more completion tasks and the last menu was empty, we should hide it. let was_hidden = editor.hide_context_menu(window, cx).is_none(); - // If it was already hidden and we don't show inline completions in the menu, we should - // also show the inline-completion when available. + // If it was already hidden and we don't show edit predictions in the menu, + // we should also show the edit prediction when available. if was_hidden && editor.show_edit_predictions_in_menu() { - editor.update_visible_inline_completion(window, cx); + editor.update_visible_edit_prediction(window, cx); } } }) @@ -5767,7 +5787,7 @@ impl Editor { let entries = completions_menu.entries.borrow(); let mat = entries.get(item_ix.unwrap_or(completions_menu.selected_item))?; if self.show_edit_predictions_in_menu() { - self.discard_inline_completion(true, cx); + self.discard_edit_prediction(true, cx); } mat.candidate_id }; @@ -5879,18 +5899,20 @@ impl Editor { text: new_text[common_prefix_len..].into(), }); - self.transact(window, cx, |this, window, cx| { + self.transact(window, cx, |editor, window, cx| { if let Some(mut snippet) = snippet { snippet.text = new_text.to_string(); - this.insert_snippet(&ranges, snippet, window, cx).log_err(); + editor + .insert_snippet(&ranges, snippet, window, cx) + .log_err(); } else { - this.buffer.update(cx, |buffer, cx| { + editor.buffer.update(cx, |multi_buffer, cx| { let auto_indent = match completion.insert_text_mode { Some(InsertTextMode::AS_IS) => None, - _ => this.autoindent_mode.clone(), + _ => editor.autoindent_mode.clone(), }; let edits = ranges.into_iter().map(|range| (range, new_text.as_str())); - buffer.edit(edits, auto_indent, cx); + multi_buffer.edit(edits, auto_indent, cx); }); } for (buffer, edits) in linked_edits { @@ -5909,8 +5931,9 @@ impl Editor { }) } - this.refresh_inline_completion(true, false, window, cx); + editor.refresh_edit_prediction(true, false, window, cx); }); + self.invalidate_autoclose_regions(&self.selections.disjoint_anchors(), &snapshot); let show_new_completions_on_confirm = completion .confirm @@ -5968,7 +5991,7 @@ impl Editor { let deployed_from = action.deployed_from.clone(); let action = action.clone(); self.completion_tasks.clear(); - self.discard_inline_completion(false, cx); + self.discard_edit_prediction(false, cx); let multibuffer_point = match &action.deployed_from { Some(CodeActionSource::Indicator(row)) | Some(CodeActionSource::RunMenu(row)) => { @@ -6388,7 +6411,6 @@ impl Editor { IconButton::new("inline_code_actions", ui::IconName::BoltFilled) .icon_size(icon_size) .shape(ui::IconButtonShape::Square) - .style(ButtonStyle::Transparent) .icon_color(ui::Color::Hidden) .toggle_state(is_active) .when(show_tooltip, |this| { @@ -6508,21 +6530,55 @@ impl Editor { } } + pub fn blame_hover(&mut self, _: &BlameHover, window: &mut Window, cx: &mut Context<Self>) { + let snapshot = self.snapshot(window, cx); + let cursor = self.selections.newest::<Point>(cx).head(); + let Some((buffer, point, _)) = snapshot.buffer_snapshot.point_to_buffer_point(cursor) + else { + return; + }; + + let Some(blame) = self.blame.as_ref() else { + return; + }; + + let row_info = RowInfo { + buffer_id: Some(buffer.remote_id()), + buffer_row: Some(point.row), + ..Default::default() + }; + let Some(blame_entry) = blame + .update(cx, |blame, cx| blame.blame_for_rows(&[row_info], cx).next()) + .flatten() + else { + return; + }; + + let anchor = self.selections.newest_anchor().head(); + let position = self.to_pixel_point(anchor, &snapshot, window); + if let (Some(position), Some(last_bounds)) = (position, self.last_bounds) { + self.show_blame_popover(&blame_entry, position + last_bounds.origin, true, cx); + }; + } + fn show_blame_popover( &mut self, blame_entry: &BlameEntry, position: gpui::Point<Pixels>, + ignore_timeout: bool, cx: &mut Context<Self>, ) { if let Some(state) = &mut self.inline_blame_popover { state.hide_task.take(); } else { - let delay = EditorSettings::get_global(cx).hover_popover_delay; + let blame_popover_delay = EditorSettings::get_global(cx).hover_popover_delay; let blame_entry = blame_entry.clone(); let show_task = cx.spawn(async move |editor, cx| { - cx.background_executor() - .timer(std::time::Duration::from_millis(delay)) - .await; + if !ignore_timeout { + cx.background_executor() + .timer(std::time::Duration::from_millis(blame_popover_delay)) + .await; + } editor .update(cx, |editor, cx| { editor.inline_blame_popover_show_task.take(); @@ -6551,6 +6607,7 @@ impl Editor { commit_message: details, markdown, }, + keyboard_grace: ignore_timeout, }); cx.notify(); }) @@ -6596,8 +6653,8 @@ impl Editor { } let snapshot = cursor_buffer.read(cx).snapshot(); - let (start_word_range, _) = snapshot.surrounding_word(cursor_buffer_position); - let (end_word_range, _) = snapshot.surrounding_word(tail_buffer_position); + let (start_word_range, _) = snapshot.surrounding_word(cursor_buffer_position, false); + let (end_word_range, _) = snapshot.surrounding_word(tail_buffer_position, false); if start_word_range != end_word_range { self.document_highlights_task.take(); self.clear_background_highlights::<DocumentHighlightRead>(cx); @@ -6939,20 +6996,24 @@ impl Editor { } } - pub fn refresh_inline_completion( + pub fn refresh_edit_prediction( &mut self, debounce: bool, user_requested: bool, window: &mut Window, cx: &mut Context<Self>, ) -> Option<()> { + if DisableAiSettings::get_global(cx).disable_ai { + return None; + } + let provider = self.edit_prediction_provider()?; let cursor = self.selections.newest_anchor().head(); let (buffer, cursor_buffer_position) = self.buffer.read(cx).text_anchor_for_position(cursor, cx)?; if !self.edit_predictions_enabled_in_buffer(&buffer, cursor_buffer_position, cx) { - self.discard_inline_completion(false, cx); + self.discard_edit_prediction(false, cx); return None; } @@ -6961,11 +7022,11 @@ impl Editor { || !self.is_focused(window) || buffer.read(cx).is_empty()) { - self.discard_inline_completion(false, cx); + self.discard_edit_prediction(false, cx); return None; } - self.update_visible_inline_completion(window, cx); + self.update_visible_edit_prediction(window, cx); provider.refresh( self.project.clone(), buffer, @@ -7001,8 +7062,9 @@ impl Editor { } pub fn update_edit_prediction_settings(&mut self, cx: &mut Context<Self>) { - if self.edit_prediction_provider.is_none() { + if self.edit_prediction_provider.is_none() || DisableAiSettings::get_global(cx).disable_ai { self.edit_prediction_settings = EditPredictionSettings::Disabled; + self.discard_edit_prediction(false, cx); } else { let selection = self.selections.newest_anchor(); let cursor = selection.head(); @@ -7023,8 +7085,8 @@ impl Editor { cx: &App, ) -> EditPredictionSettings { if !self.mode.is_full() - || !self.show_inline_completions_override.unwrap_or(true) - || self.inline_completions_disabled_in_scope(buffer, buffer_position, cx) + || !self.show_edit_predictions_override.unwrap_or(true) + || self.edit_predictions_disabled_in_scope(buffer, buffer_position, cx) { return EditPredictionSettings::Disabled; } @@ -7038,8 +7100,8 @@ impl Editor { }; let by_provider = matches!( - self.menu_inline_completions_policy, - MenuInlineCompletionsPolicy::ByProvider + self.menu_edit_predictions_policy, + MenuEditPredictionsPolicy::ByProvider ); let show_in_menu = by_provider @@ -7109,7 +7171,7 @@ impl Editor { .unwrap_or(false) } - fn cycle_inline_completion( + fn cycle_edit_prediction( &mut self, direction: Direction, window: &mut Window, @@ -7119,28 +7181,28 @@ impl Editor { let cursor = self.selections.newest_anchor().head(); let (buffer, cursor_buffer_position) = self.buffer.read(cx).text_anchor_for_position(cursor, cx)?; - if self.inline_completions_hidden_for_vim_mode || !self.should_show_edit_predictions() { + if self.edit_predictions_hidden_for_vim_mode || !self.should_show_edit_predictions() { return None; } provider.cycle(buffer, cursor_buffer_position, direction, cx); - self.update_visible_inline_completion(window, cx); + self.update_visible_edit_prediction(window, cx); Some(()) } - pub fn show_inline_completion( + pub fn show_edit_prediction( &mut self, _: &ShowEditPrediction, window: &mut Window, cx: &mut Context<Self>, ) { - if !self.has_active_inline_completion() { - self.refresh_inline_completion(false, true, window, cx); + if !self.has_active_edit_prediction() { + self.refresh_edit_prediction(false, true, window, cx); return; } - self.update_visible_inline_completion(window, cx); + self.update_visible_edit_prediction(window, cx); } pub fn display_cursor_names( @@ -7172,11 +7234,11 @@ impl Editor { window: &mut Window, cx: &mut Context<Self>, ) { - if self.has_active_inline_completion() { - self.cycle_inline_completion(Direction::Next, window, cx); + if self.has_active_edit_prediction() { + self.cycle_edit_prediction(Direction::Next, window, cx); } else { let is_copilot_disabled = self - .refresh_inline_completion(false, true, window, cx) + .refresh_edit_prediction(false, true, window, cx) .is_none(); if is_copilot_disabled { cx.propagate(); @@ -7190,11 +7252,11 @@ impl Editor { window: &mut Window, cx: &mut Context<Self>, ) { - if self.has_active_inline_completion() { - self.cycle_inline_completion(Direction::Prev, window, cx); + if self.has_active_edit_prediction() { + self.cycle_edit_prediction(Direction::Prev, window, cx); } else { let is_copilot_disabled = self - .refresh_inline_completion(false, true, window, cx) + .refresh_edit_prediction(false, true, window, cx) .is_none(); if is_copilot_disabled { cx.propagate(); @@ -7212,18 +7274,14 @@ impl Editor { self.hide_context_menu(window, cx); } - let Some(active_inline_completion) = self.active_inline_completion.as_ref() else { + let Some(active_edit_prediction) = self.active_edit_prediction.as_ref() else { return; }; - self.report_inline_completion_event( - active_inline_completion.completion_id.clone(), - true, - cx, - ); + self.report_edit_prediction_event(active_edit_prediction.completion_id.clone(), true, cx); - match &active_inline_completion.completion { - InlineCompletion::Move { target, .. } => { + match &active_edit_prediction.completion { + EditPrediction::Move { target, .. } => { let target = *target; if let Some(position_map) = &self.last_position_map { @@ -7265,7 +7323,7 @@ impl Editor { } } } - InlineCompletion::Edit { edits, .. } => { + EditPrediction::Edit { edits, .. } => { if let Some(provider) = self.edit_prediction_provider() { provider.accept(cx); } @@ -7293,9 +7351,9 @@ impl Editor { } } - self.update_visible_inline_completion(window, cx); - if self.active_inline_completion.is_none() { - self.refresh_inline_completion(true, true, window, cx); + self.update_visible_edit_prediction(window, cx); + if self.active_edit_prediction.is_none() { + self.refresh_edit_prediction(true, true, window, cx); } cx.notify(); @@ -7305,27 +7363,23 @@ impl Editor { self.edit_prediction_requires_modifier_in_indent_conflict = false; } - pub fn accept_partial_inline_completion( + pub fn accept_partial_edit_prediction( &mut self, _: &AcceptPartialEditPrediction, window: &mut Window, cx: &mut Context<Self>, ) { - let Some(active_inline_completion) = self.active_inline_completion.as_ref() else { + let Some(active_edit_prediction) = self.active_edit_prediction.as_ref() else { return; }; if self.selections.count() != 1 { return; } - self.report_inline_completion_event( - active_inline_completion.completion_id.clone(), - true, - cx, - ); + self.report_edit_prediction_event(active_edit_prediction.completion_id.clone(), true, cx); - match &active_inline_completion.completion { - InlineCompletion::Move { target, .. } => { + match &active_edit_prediction.completion { + EditPrediction::Move { target, .. } => { let target = *target; self.change_selections( SelectionEffects::scroll(Autoscroll::newest()), @@ -7336,7 +7390,7 @@ impl Editor { }, ); } - InlineCompletion::Edit { edits, .. } => { + EditPrediction::Edit { edits, .. } => { // Find an insertion that starts at the cursor position. let snapshot = self.buffer.read(cx).snapshot(cx); let cursor_offset = self.selections.newest::<usize>(cx).head(); @@ -7370,7 +7424,7 @@ impl Editor { self.insert_with_autoindent_mode(&partial_completion, None, window, cx); - self.refresh_inline_completion(true, true, window, cx); + self.refresh_edit_prediction(true, true, window, cx); cx.notify(); } else { self.accept_edit_prediction(&Default::default(), window, cx); @@ -7379,28 +7433,28 @@ impl Editor { } } - fn discard_inline_completion( + fn discard_edit_prediction( &mut self, - should_report_inline_completion_event: bool, + should_report_edit_prediction_event: bool, cx: &mut Context<Self>, ) -> bool { - if should_report_inline_completion_event { + if should_report_edit_prediction_event { let completion_id = self - .active_inline_completion + .active_edit_prediction .as_ref() .and_then(|active_completion| active_completion.completion_id.clone()); - self.report_inline_completion_event(completion_id, false, cx); + self.report_edit_prediction_event(completion_id, false, cx); } if let Some(provider) = self.edit_prediction_provider() { provider.discard(cx); } - self.take_active_inline_completion(cx) + self.take_active_edit_prediction(cx) } - fn report_inline_completion_event(&self, id: Option<SharedString>, accepted: bool, cx: &App) { + fn report_edit_prediction_event(&self, id: Option<SharedString>, accepted: bool, cx: &App) { let Some(provider) = self.edit_prediction_provider() else { return; }; @@ -7431,18 +7485,18 @@ impl Editor { ); } - pub fn has_active_inline_completion(&self) -> bool { - self.active_inline_completion.is_some() + pub fn has_active_edit_prediction(&self) -> bool { + self.active_edit_prediction.is_some() } - fn take_active_inline_completion(&mut self, cx: &mut Context<Self>) -> bool { - let Some(active_inline_completion) = self.active_inline_completion.take() else { + fn take_active_edit_prediction(&mut self, cx: &mut Context<Self>) -> bool { + let Some(active_edit_prediction) = self.active_edit_prediction.take() else { return false; }; - self.splice_inlays(&active_inline_completion.inlay_ids, Default::default(), cx); - self.clear_highlights::<InlineCompletionHighlight>(cx); - self.stale_inline_completion_in_menu = Some(active_inline_completion); + self.splice_inlays(&active_edit_prediction.inlay_ids, Default::default(), cx); + self.clear_highlights::<EditPredictionHighlight>(cx); + self.stale_edit_prediction_in_menu = Some(active_edit_prediction); true } @@ -7587,7 +7641,7 @@ impl Editor { since: Instant::now(), }; - self.update_visible_inline_completion(window, cx); + self.update_visible_edit_prediction(window, cx); cx.notify(); } } else if let EditPredictionPreview::Active { @@ -7610,16 +7664,20 @@ impl Editor { released_too_fast: since.elapsed() < Duration::from_millis(200), }; self.clear_row_highlights::<EditPredictionPreview>(); - self.update_visible_inline_completion(window, cx); + self.update_visible_edit_prediction(window, cx); cx.notify(); } } - fn update_visible_inline_completion( + fn update_visible_edit_prediction( &mut self, _window: &mut Window, cx: &mut Context<Self>, ) -> Option<()> { + if DisableAiSettings::get_global(cx).disable_ai { + return None; + } + let selection = self.selections.newest_anchor(); let cursor = selection.head(); let multibuffer = self.buffer.read(cx).snapshot(cx); @@ -7629,12 +7687,12 @@ impl Editor { let show_in_menu = self.show_edit_predictions_in_menu(); let completions_menu_has_precedence = !show_in_menu && (self.context_menu.borrow().is_some() - || (!self.completion_tasks.is_empty() && !self.has_active_inline_completion())); + || (!self.completion_tasks.is_empty() && !self.has_active_edit_prediction())); if completions_menu_has_precedence || !offset_selection.is_empty() || self - .active_inline_completion + .active_edit_prediction .as_ref() .map_or(false, |completion| { let invalidation_range = completion.invalidation_range.to_offset(&multibuffer); @@ -7642,11 +7700,11 @@ impl Editor { !invalidation_range.contains(&offset_selection.head()) }) { - self.discard_inline_completion(false, cx); + self.discard_edit_prediction(false, cx); return None; } - self.take_active_inline_completion(cx); + self.take_active_edit_prediction(cx); let Some(provider) = self.edit_prediction_provider() else { self.edit_prediction_settings = EditPredictionSettings::Disabled; return None; @@ -7672,8 +7730,8 @@ impl Editor { } } - let inline_completion = provider.suggest(&buffer, cursor_buffer_position, cx)?; - let edits = inline_completion + let edit_prediction = provider.suggest(&buffer, cursor_buffer_position, cx)?; + let edits = edit_prediction .edits .into_iter() .flat_map(|(range, new_text)| { @@ -7707,16 +7765,22 @@ impl Editor { } else { None }; - let is_move = - move_invalidation_row_range.is_some() || self.inline_completions_hidden_for_vim_mode; + let supports_jump = self + .edit_prediction_provider + .as_ref() + .map(|provider| provider.provider.supports_jump_to_edit()) + .unwrap_or(true); + + let is_move = supports_jump + && (move_invalidation_row_range.is_some() || self.edit_predictions_hidden_for_vim_mode); let completion = if is_move { invalidation_row_range = move_invalidation_row_range.unwrap_or(edit_start_row..edit_end_row); let target = first_edit_start; - InlineCompletion::Move { target, snapshot } + EditPrediction::Move { target, snapshot } } else { let show_completions_in_buffer = !self.edit_prediction_visible_in_cursor_popover(true) - && !self.inline_completions_hidden_for_vim_mode; + && !self.edit_predictions_hidden_for_vim_mode; if show_completions_in_buffer { if edits @@ -7725,7 +7789,7 @@ impl Editor { { let mut inlays = Vec::new(); for (range, new_text) in &edits { - let inlay = Inlay::inline_completion( + let inlay = Inlay::edit_prediction( post_inc(&mut self.next_inlay_id), range.start, new_text.as_str(), @@ -7737,7 +7801,7 @@ impl Editor { self.splice_inlays(&[], inlays, cx); } else { let background_color = cx.theme().status().deleted_background; - self.highlight_text::<InlineCompletionHighlight>( + self.highlight_text::<EditPredictionHighlight>( edits.iter().map(|(range, _)| range.clone()).collect(), HighlightStyle { background_color: Some(background_color), @@ -7760,9 +7824,9 @@ impl Editor { EditDisplayMode::DiffPopover }; - InlineCompletion::Edit { + EditPrediction::Edit { edits, - edit_preview: inline_completion.edit_preview, + edit_preview: edit_prediction.edit_preview, display_mode, snapshot, } @@ -7775,11 +7839,11 @@ impl Editor { multibuffer.line_len(MultiBufferRow(invalidation_row_range.end)), )); - self.stale_inline_completion_in_menu = None; - self.active_inline_completion = Some(InlineCompletionState { + self.stale_edit_prediction_in_menu = None; + self.active_edit_prediction = Some(EditPredictionState { inlay_ids, completion, - completion_id: inline_completion.id, + completion_id: edit_prediction.id, invalidation_range, }); @@ -7788,7 +7852,7 @@ impl Editor { Some(()) } - pub fn edit_prediction_provider(&self) -> Option<Arc<dyn InlineCompletionProviderHandle>> { + pub fn edit_prediction_provider(&self) -> Option<Arc<dyn EditPredictionProviderHandle>> { Some(self.edit_prediction_provider.as_ref()?.provider.clone()) } @@ -8130,7 +8194,7 @@ impl Editor { editor.set_breakpoint_context_menu( row, Some(position), - event.down.position, + event.position(), window, cx, ); @@ -8188,8 +8252,7 @@ impl Editor { return; }; - // Try to find a closest, enclosing node using tree-sitter that has a - // task + // Try to find a closest, enclosing node using tree-sitter that has a task let Some((buffer, buffer_row, tasks)) = self .find_enclosing_node_task(cx) // Or find the task that's closest in row-distance. @@ -8289,26 +8352,33 @@ impl Editor { let color = Color::Muted; let position = breakpoint.as_ref().map(|(anchor, _, _)| *anchor); - IconButton::new(("run_indicator", row.0 as usize), ui::IconName::Play) - .shape(ui::IconButtonShape::Square) - .icon_size(IconSize::XSmall) - .icon_color(color) - .toggle_state(is_active) - .on_click(cx.listener(move |editor, e: &ClickEvent, window, cx| { - let quick_launch = e.down.button == MouseButton::Left; - window.focus(&editor.focus_handle(cx)); - editor.toggle_code_actions( - &ToggleCodeActions { - deployed_from: Some(CodeActionSource::RunMenu(row)), - quick_launch, - }, - window, - cx, - ); - })) - .on_right_click(cx.listener(move |editor, event: &ClickEvent, window, cx| { - editor.set_breakpoint_context_menu(row, position, event.down.position, window, cx); - })) + IconButton::new( + ("run_indicator", row.0 as usize), + ui::IconName::PlayOutlined, + ) + .shape(ui::IconButtonShape::Square) + .icon_size(IconSize::XSmall) + .icon_color(color) + .toggle_state(is_active) + .on_click(cx.listener(move |editor, e: &ClickEvent, window, cx| { + let quick_launch = match e { + ClickEvent::Keyboard(_) => true, + ClickEvent::Mouse(e) => e.down.button == MouseButton::Left, + }; + + window.focus(&editor.focus_handle(cx)); + editor.toggle_code_actions( + &ToggleCodeActions { + deployed_from: Some(CodeActionSource::RunMenu(row)), + quick_launch, + }, + window, + cx, + ); + })) + .on_right_click(cx.listener(move |editor, event: &ClickEvent, window, cx| { + editor.set_breakpoint_context_menu(row, position, event.position(), window, cx); + })) } pub fn context_menu_visible(&self) -> bool { @@ -8355,14 +8425,14 @@ impl Editor { if self.mode().is_minimap() { return None; } - let active_inline_completion = self.active_inline_completion.as_ref()?; + let active_edit_prediction = self.active_edit_prediction.as_ref()?; if self.edit_prediction_visible_in_cursor_popover(true) { return None; } - match &active_inline_completion.completion { - InlineCompletion::Move { target, .. } => { + match &active_edit_prediction.completion { + EditPrediction::Move { target, .. } => { let target_display_point = target.to_display_point(editor_snapshot); if self.edit_prediction_requires_modifier() { @@ -8399,11 +8469,11 @@ impl Editor { ) } } - InlineCompletion::Edit { + EditPrediction::Edit { display_mode: EditDisplayMode::Inline, .. } => None, - InlineCompletion::Edit { + EditPrediction::Edit { display_mode: EditDisplayMode::TabAccept, edits, .. @@ -8424,7 +8494,7 @@ impl Editor { cx, ) } - InlineCompletion::Edit { + EditPrediction::Edit { edits, edit_preview, display_mode: EditDisplayMode::DiffPopover, @@ -8740,8 +8810,12 @@ impl Editor { return None; } - let highlighted_edits = - crate::inline_completion_edit_text(&snapshot, edits, edit_preview.as_ref()?, false, cx); + let highlighted_edits = if let Some(edit_preview) = edit_preview.as_ref() { + crate::edit_prediction_edit_text(&snapshot, edits, edit_preview, false, cx) + } else { + // Fallback for providers without edit_preview + crate::edit_prediction_fallback_text(edits, cx) + }; let styled_text = highlighted_edits.to_styled_text(&style.text); let line_count = highlighted_edits.text.lines().count(); @@ -9009,6 +9083,18 @@ impl Editor { let editor_bg_color = cx.theme().colors().editor_background; editor_bg_color.blend(accent_color.opacity(0.6)) } + fn get_prediction_provider_icon_name( + provider: &Option<RegisteredEditPredictionProvider>, + ) -> IconName { + match provider { + Some(provider) => match provider.provider.name() { + "copilot" => IconName::Copilot, + "supermaven" => IconName::Supermaven, + _ => IconName::ZedPredict, + }, + None => IconName::ZedPredict, + } + } fn render_edit_prediction_cursor_popover( &self, @@ -9021,6 +9107,7 @@ impl Editor { cx: &mut Context<Editor>, ) -> Option<AnyElement> { let provider = self.edit_prediction_provider.as_ref()?; + let provider_icon = Self::get_prediction_provider_icon_name(&self.edit_prediction_provider); if provider.provider.needs_terms_acceptance(cx) { return Some( @@ -9047,7 +9134,7 @@ impl Editor { h_flex() .flex_1() .gap_2() - .child(Icon::new(IconName::ZedPredict)) + .child(Icon::new(provider_icon)) .child(Label::new("Accept Terms of Service")) .child(div().w_full()) .child( @@ -9063,15 +9150,11 @@ impl Editor { let is_refreshing = provider.provider.is_refreshing(cx); - fn pending_completion_container() -> Div { - h_flex() - .h_full() - .flex_1() - .gap_2() - .child(Icon::new(IconName::ZedPredict)) + fn pending_completion_container(icon: IconName) -> Div { + h_flex().h_full().flex_1().gap_2().child(Icon::new(icon)) } - let completion = match &self.active_inline_completion { + let completion = match &self.active_edit_prediction { Some(prediction) => { if !self.has_visible_completions_menu() { const RADIUS: Pixels = px(6.); @@ -9089,7 +9172,7 @@ impl Editor { .rounded_tl(px(0.)) .overflow_hidden() .child(div().px_1p5().child(match &prediction.completion { - InlineCompletion::Move { target, snapshot } => { + EditPrediction::Move { target, snapshot } => { use text::ToPoint as _; if target.text_anchor.to_point(&snapshot).row > cursor_point.row { @@ -9098,7 +9181,7 @@ impl Editor { Icon::new(IconName::ZedPredictUp) } } - InlineCompletion::Edit { .. } => Icon::new(IconName::ZedPredict), + EditPrediction::Edit { .. } => Icon::new(provider_icon), })) .child( h_flex() @@ -9157,7 +9240,7 @@ impl Editor { )? } - None if is_refreshing => match &self.stale_inline_completion_in_menu { + None if is_refreshing => match &self.stale_edit_prediction_in_menu { Some(stale_completion) => self.render_edit_prediction_cursor_popover_preview( stale_completion, cursor_point, @@ -9165,15 +9248,15 @@ impl Editor { cx, )?, - None => { - pending_completion_container().child(Label::new("...").size(LabelSize::Small)) - } + None => pending_completion_container(provider_icon) + .child(Label::new("...").size(LabelSize::Small)), }, - None => pending_completion_container().child(Label::new("No Prediction")), + None => pending_completion_container(provider_icon) + .child(Label::new("...").size(LabelSize::Small)), }; - let completion = if is_refreshing { + let completion = if is_refreshing || self.active_edit_prediction.is_none() { completion .with_animation( "loading-completion", @@ -9187,7 +9270,7 @@ impl Editor { completion.into_any_element() }; - let has_completion = self.active_inline_completion.is_some(); + let has_completion = self.active_edit_prediction.is_some(); let is_platform_style_mac = PlatformStyle::platform() == PlatformStyle::Mac; Some( @@ -9246,7 +9329,7 @@ impl Editor { fn render_edit_prediction_cursor_popover_preview( &self, - completion: &InlineCompletionState, + completion: &EditPredictionState, cursor_point: Point, style: &EditorStyle, cx: &mut Context<Editor>, @@ -9273,25 +9356,37 @@ impl Editor { .child(Icon::new(arrow).color(Color::Muted).size(IconSize::Small)) } + let supports_jump = self + .edit_prediction_provider + .as_ref() + .map(|provider| provider.provider.supports_jump_to_edit()) + .unwrap_or(true); + match &completion.completion { - InlineCompletion::Move { + EditPrediction::Move { target, snapshot, .. - } => Some( - h_flex() - .px_2() - .gap_2() - .flex_1() - .child( - if target.text_anchor.to_point(&snapshot).row > cursor_point.row { - Icon::new(IconName::ZedPredictDown) - } else { - Icon::new(IconName::ZedPredictUp) - }, - ) - .child(Label::new("Jump to Edit")), - ), + } => { + if !supports_jump { + return None; + } - InlineCompletion::Edit { + Some( + h_flex() + .px_2() + .gap_2() + .flex_1() + .child( + if target.text_anchor.to_point(&snapshot).row > cursor_point.row { + Icon::new(IconName::ZedPredictDown) + } else { + Icon::new(IconName::ZedPredictUp) + }, + ) + .child(Label::new("Jump to Edit")), + ) + } + + EditPrediction::Edit { edits, edit_preview, snapshot, @@ -9299,14 +9394,13 @@ impl Editor { } => { let first_edit_row = edits.first()?.0.start.text_anchor.to_point(&snapshot).row; - let (highlighted_edits, has_more_lines) = crate::inline_completion_edit_text( - &snapshot, - &edits, - edit_preview.as_ref()?, - true, - cx, - ) - .first_line_preview(); + let (highlighted_edits, has_more_lines) = + if let Some(edit_preview) = edit_preview.as_ref() { + crate::edit_prediction_edit_text(&snapshot, &edits, edit_preview, true, cx) + .first_line_preview() + } else { + crate::edit_prediction_fallback_text(&edits, cx).first_line_preview() + }; let styled_text = gpui::StyledText::new(highlighted_edits.text) .with_default_highlights(&style.text, highlighted_edits.highlights); @@ -9317,11 +9411,13 @@ impl Editor { .child(styled_text) .when(has_more_lines, |parent| parent.child("…")); - let left = if first_edit_row != cursor_point.row { + let left = if supports_jump && first_edit_row != cursor_point.row { render_relative_row_jump("", cursor_point.row, first_edit_row) .into_any_element() } else { - Icon::new(IconName::ZedPredict).into_any_element() + let icon_name = + Editor::get_prediction_provider_icon_name(&self.edit_prediction_provider); + Icon::new(icon_name).into_any_element() }; Some( @@ -9377,8 +9473,8 @@ impl Editor { cx.notify(); self.completion_tasks.clear(); let context_menu = self.context_menu.borrow_mut().take(); - self.stale_inline_completion_in_menu.take(); - self.update_visible_inline_completion(window, cx); + self.stale_edit_prediction_in_menu.take(); + self.update_visible_edit_prediction(window, cx); if let Some(CodeContextMenu::Completions(_)) = &context_menu { if let Some(completion_provider) = &self.completion_provider { completion_provider.selection_changed(None, window, cx); @@ -9516,27 +9612,46 @@ impl Editor { // Check whether the just-entered snippet ends with an auto-closable bracket. if self.autoclose_regions.is_empty() { let snapshot = self.buffer.read(cx).snapshot(cx); - for selection in &mut self.selections.all::<Point>(cx) { + let mut all_selections = self.selections.all::<Point>(cx); + for selection in &mut all_selections { let selection_head = selection.head(); let Some(scope) = snapshot.language_scope_at(selection_head) else { continue; }; let mut bracket_pair = None; - let next_chars = snapshot.chars_at(selection_head).collect::<String>(); - let prev_chars = snapshot - .reversed_chars_at(selection_head) - .collect::<String>(); - for (pair, enabled) in scope.brackets() { - if enabled - && pair.close - && prev_chars.starts_with(pair.start.as_str()) - && next_chars.starts_with(pair.end.as_str()) - { - bracket_pair = Some(pair.clone()); - break; + let max_lookup_length = scope + .brackets() + .map(|(pair, _)| { + pair.start + .as_str() + .chars() + .count() + .max(pair.end.as_str().chars().count()) + }) + .max(); + if let Some(max_lookup_length) = max_lookup_length { + let next_text = snapshot + .chars_at(selection_head) + .take(max_lookup_length) + .collect::<String>(); + let prev_text = snapshot + .reversed_chars_at(selection_head) + .take(max_lookup_length) + .collect::<String>(); + + for (pair, enabled) in scope.brackets() { + if enabled + && pair.close + && prev_text.starts_with(pair.start.as_str()) + && next_text.starts_with(pair.end.as_str()) + { + bracket_pair = Some(pair.clone()); + break; + } } } + if let Some(pair) = bracket_pair { let snapshot_settings = snapshot.language_settings_at(selection_head, cx); let autoclose_enabled = @@ -9717,7 +9832,7 @@ impl Editor { this.edit(edits, None, cx); }) } - this.refresh_inline_completion(true, false, window, cx); + this.refresh_edit_prediction(true, false, window, cx); linked_editing_ranges::refresh_linked_ranges(this, window, cx); }); } @@ -9736,7 +9851,7 @@ impl Editor { }) }); this.insert("", window, cx); - this.refresh_inline_completion(true, false, window, cx); + this.refresh_edit_prediction(true, false, window, cx); }); } @@ -9869,7 +9984,7 @@ impl Editor { self.transact(window, cx, |this, window, cx| { this.buffer.update(cx, |b, cx| b.edit(edits, None, cx)); this.change_selections(Default::default(), window, cx, |s| s.select(selections)); - this.refresh_inline_completion(true, false, window, cx); + this.refresh_edit_prediction(true, false, window, cx); }); } @@ -10442,7 +10557,6 @@ impl Editor { cloned_prompt.clone().into_any_element() }), priority: 0, - render_in_minimap: true, }]; let focus_handle = bp_prompt.focus_handle(cx); @@ -10832,17 +10946,6 @@ impl Editor { }); } - pub fn toggle_case(&mut self, _: &ToggleCase, window: &mut Window, cx: &mut Context<Self>) { - self.manipulate_text(window, cx, |text| { - let has_upper_case_characters = text.chars().any(|c| c.is_uppercase()); - if has_upper_case_characters { - text.to_lowercase() - } else { - text.to_uppercase() - } - }) - } - fn manipulate_immutable_lines<Fn>( &mut self, window: &mut Window, @@ -11098,6 +11201,26 @@ impl Editor { }) } + pub fn convert_to_sentence_case( + &mut self, + _: &ConvertToSentenceCase, + window: &mut Window, + cx: &mut Context<Self>, + ) { + self.manipulate_text(window, cx, |text| text.to_case(Case::Sentence)) + } + + pub fn toggle_case(&mut self, _: &ToggleCase, window: &mut Window, cx: &mut Context<Self>) { + self.manipulate_text(window, cx, |text| { + let has_upper_case_characters = text.chars().any(|c| c.is_uppercase()); + if has_upper_case_characters { + text.to_lowercase() + } else { + text.to_uppercase() + } + }) + } + pub fn convert_to_rot13( &mut self, _: &ConvertToRot13, @@ -12110,6 +12233,41 @@ impl Editor { }); } + pub fn diff_clipboard_with_selection( + &mut self, + _: &DiffClipboardWithSelection, + window: &mut Window, + cx: &mut Context<Self>, + ) { + let selections = self.selections.all::<usize>(cx); + + if selections.is_empty() { + log::warn!("There should always be at least one selection in Zed. This is a bug."); + return; + }; + + let clipboard_text = match cx.read_from_clipboard() { + Some(item) => match item.entries().first() { + Some(ClipboardEntry::String(text)) => Some(text.text().to_string()), + _ => None, + }, + None => None, + }; + + let Some(clipboard_text) = clipboard_text else { + log::warn!("Clipboard doesn't contain text."); + return; + }; + + window.dispatch_action( + Box::new(DiffClipboardWithSelectionData { + clipboard_text, + editor: cx.entity(), + }), + cx, + ); + } + pub fn paste(&mut self, _: &Paste, window: &mut Window, cx: &mut Context<Self>) { self.hide_mouse_cursor(HideMouseCursorOrigin::TypingAction, cx); if let Some(item) = cx.read_from_clipboard() { @@ -12155,7 +12313,7 @@ impl Editor { } self.request_autoscroll(Autoscroll::fit(), cx); self.unmark_text(window, cx); - self.refresh_inline_completion(true, false, window, cx); + self.refresh_edit_prediction(true, false, window, cx); cx.emit(EditorEvent::Edited { transaction_id }); cx.emit(EditorEvent::TransactionUndone { transaction_id }); } @@ -12185,7 +12343,7 @@ impl Editor { } self.request_autoscroll(Autoscroll::fit(), cx); self.unmark_text(window, cx); - self.refresh_inline_completion(true, false, window, cx); + self.refresh_edit_prediction(true, false, window, cx); cx.emit(EditorEvent::Edited { transaction_id }); } } @@ -14269,8 +14427,11 @@ impl Editor { (position..position, first_prefix.clone()) })); } - } else if let Some((full_comment_prefix, comment_suffix)) = - language.block_comment_delimiters() + } else if let Some(BlockCommentConfig { + start: full_comment_prefix, + end: comment_suffix, + .. + }) = language.block_comment() { let comment_prefix = full_comment_prefix.trim_end_matches(' '); let comment_prefix_whitespace = &full_comment_prefix[comment_prefix.len()..]; @@ -14587,6 +14748,81 @@ impl Editor { } } + pub fn unwrap_syntax_node( + &mut self, + _: &UnwrapSyntaxNode, + window: &mut Window, + cx: &mut Context<Self>, + ) { + self.hide_mouse_cursor(HideMouseCursorOrigin::MovementAction, cx); + + let buffer = self.buffer.read(cx).snapshot(cx); + let old_selections: Box<[_]> = self.selections.all::<usize>(cx).into(); + + let edits = old_selections + .iter() + // only consider the first selection for now + .take(1) + .map(|selection| { + // Only requires two branches once if-let-chains stabilize (#53667) + let selection_range = if !selection.is_empty() { + selection.range() + } else if let Some((_, ancestor_range)) = + buffer.syntax_ancestor(selection.start..selection.end) + { + match ancestor_range { + MultiOrSingleBufferOffsetRange::Single(range) => range, + MultiOrSingleBufferOffsetRange::Multi(range) => range, + } + } else { + selection.range() + }; + + let mut new_range = selection_range.clone(); + while let Some((_, ancestor_range)) = buffer.syntax_ancestor(new_range.clone()) { + new_range = match ancestor_range { + MultiOrSingleBufferOffsetRange::Single(range) => range, + MultiOrSingleBufferOffsetRange::Multi(range) => range, + }; + if new_range.start < selection_range.start + || new_range.end > selection_range.end + { + break; + } + } + + (selection, selection_range, new_range) + }) + .collect::<Vec<_>>(); + + self.transact(window, cx, |editor, window, cx| { + for (_, child, parent) in &edits { + let text = buffer.text_for_range(child.clone()).collect::<String>(); + editor.replace_text_in_range(Some(parent.clone()), &text, window, cx); + } + + editor.change_selections( + SelectionEffects::scroll(Autoscroll::fit()), + window, + cx, + |s| { + s.select( + edits + .iter() + .map(|(s, old, new)| Selection { + id: s.id, + start: new.start, + end: new.start + old.len(), + goal: SelectionGoal::None, + reversed: s.reversed, + }) + .collect(), + ); + }, + ); + }); + } + fn refresh_runnables(&mut self, window: &mut Window, cx: &mut Context<Self>) -> Task<()> { if !EditorSettings::get_global(cx).gutter.runnables { self.clear_tasks(); @@ -15063,7 +15299,7 @@ impl Editor { pub fn go_to_diagnostic( &mut self, - _: &GoToDiagnostic, + action: &GoToDiagnostic, window: &mut Window, cx: &mut Context<Self>, ) { @@ -15071,12 +15307,12 @@ impl Editor { return; } self.hide_mouse_cursor(HideMouseCursorOrigin::MovementAction, cx); - self.go_to_diagnostic_impl(Direction::Next, window, cx) + self.go_to_diagnostic_impl(Direction::Next, action.severity, window, cx) } pub fn go_to_prev_diagnostic( &mut self, - _: &GoToPreviousDiagnostic, + action: &GoToPreviousDiagnostic, window: &mut Window, cx: &mut Context<Self>, ) { @@ -15084,12 +15320,13 @@ impl Editor { return; } self.hide_mouse_cursor(HideMouseCursorOrigin::MovementAction, cx); - self.go_to_diagnostic_impl(Direction::Prev, window, cx) + self.go_to_diagnostic_impl(Direction::Prev, action.severity, window, cx) } pub fn go_to_diagnostic_impl( &mut self, direction: Direction, + severity: GoToDiagnosticSeverityFilter, window: &mut Window, cx: &mut Context<Self>, ) { @@ -15105,9 +15342,11 @@ impl Editor { fn filtered( snapshot: EditorSnapshot, + severity: GoToDiagnosticSeverityFilter, diagnostics: impl Iterator<Item = DiagnosticEntry<usize>>, ) -> impl Iterator<Item = DiagnosticEntry<usize>> { diagnostics + .filter(move |entry| severity.matches(entry.diagnostic.severity)) .filter(|entry| entry.range.start != entry.range.end) .filter(|entry| !entry.diagnostic.is_unnecessary) .filter(move |entry| !snapshot.intersects_fold(entry.range.start)) @@ -15116,12 +15355,14 @@ impl Editor { let snapshot = self.snapshot(window, cx); let before = filtered( snapshot.clone(), + severity, buffer .diagnostics_in_range(0..selection.start) .filter(|entry| entry.range.start <= selection.start), ); let after = filtered( snapshot, + severity, buffer .diagnostics_in_range(selection.start..buffer.len()) .filter(|entry| entry.range.start >= selection.start), @@ -15164,7 +15405,7 @@ impl Editor { ]) }); self.activate_diagnostics(buffer_id, next_diagnostic, window, cx); - self.refresh_inline_completion(false, true, window, cx); + self.refresh_edit_prediction(false, true, window, cx); } pub fn go_to_next_hunk(&mut self, _: &GoToHunk, window: &mut Window, cx: &mut Context<Self>) { @@ -15725,7 +15966,7 @@ impl Editor { let language_server_name = project .language_server_statuses(cx) .find(|(id, _)| server_id == *id) - .map(|(_, status)| LanguageServerName::from(status.name.as_str())); + .map(|(_, status)| status.name.clone()); language_server_name.map(|language_server_name| { project.open_local_buffer_via_lsp( lsp_location.uri.clone(), @@ -16128,7 +16369,7 @@ impl Editor { font_weight: Some(FontWeight::BOLD), ..make_inlay_hints_style(cx.app) }, - inline_completion_styles: make_suggestion_styles( + edit_prediction_styles: make_suggestion_styles( cx.app, ), ..EditorStyle::default() @@ -16138,7 +16379,6 @@ impl Editor { } }), priority: 0, - render_in_minimap: true, }], Some(Autoscroll::fit()), cx, @@ -16880,7 +17120,7 @@ impl Editor { now: Instant, window: &mut Window, cx: &mut Context<Self>, - ) { + ) -> Option<TransactionId> { self.end_selection(window, cx); if let Some(tx_id) = self .buffer @@ -16890,7 +17130,10 @@ impl Editor { .insert_transaction(tx_id, self.selections.disjoint_anchors()); cx.emit(EditorEvent::TransactionBegun { transaction_id: tx_id, - }) + }); + Some(tx_id) + } else { + None } } @@ -16918,6 +17161,17 @@ impl Editor { } } + pub fn modify_transaction_selection_history( + &mut self, + transaction_id: TransactionId, + modify: impl FnOnce(&mut (Arc<[Selection<Anchor>]>, Option<Arc<[Selection<Anchor>]>>)), + ) -> bool { + self.selection_history + .transaction_mut(transaction_id) + .map(modify) + .is_some() + } + pub fn set_mark(&mut self, _: &actions::SetMark, window: &mut Window, cx: &mut Context<Self>) { if self.selection_mark_mode { self.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { @@ -16947,6 +17201,18 @@ impl Editor { cx.notify(); } + pub fn toggle_focus( + workspace: &mut Workspace, + _: &actions::ToggleFocus, + window: &mut Window, + cx: &mut Context<Workspace>, + ) { + let Some(item) = workspace.recent_active_item_by_type::<Self>(cx) else { + return; + }; + workspace.activate_item(&item, true, true, window, cx); + } + pub fn toggle_fold( &mut self, _: &actions::ToggleFold, @@ -17072,6 +17338,46 @@ impl Editor { } } + pub fn toggle_fold_all( + &mut self, + _: &actions::ToggleFoldAll, + window: &mut Window, + cx: &mut Context<Self>, + ) { + if self.buffer.read(cx).is_singleton() { + let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx)); + let has_folds = display_map + .folds_in_range(0..display_map.buffer_snapshot.len()) + .next() + .is_some(); + + if has_folds { + self.unfold_all(&actions::UnfoldAll, window, cx); + } else { + self.fold_all(&actions::FoldAll, window, cx); + } + } else { + let buffer_ids = self.buffer.read(cx).excerpt_buffer_ids(); + let should_unfold = buffer_ids + .iter() + .any(|buffer_id| self.is_buffer_folded(*buffer_id, cx)); + + self.toggle_fold_multiple_buffers = cx.spawn_in(window, async move |editor, cx| { + editor + .update_in(cx, |editor, _, cx| { + for buffer_id in buffer_ids { + if should_unfold { + editor.unfold_buffer(buffer_id, cx); + } else { + editor.fold_buffer(buffer_id, cx); + } + } + }) + .ok(); + }); + } + } + fn fold_at_level( &mut self, fold_at: &FoldAtLevel, @@ -18001,7 +18307,7 @@ impl Editor { parent: cx.weak_entity(), }, self.buffer.clone(), - self.project.clone(), + None, Some(self.display_map.clone()), window, cx, @@ -18837,7 +19143,7 @@ impl Editor { (selection.range(), uuid.to_string()) }); this.edit(edits, cx); - this.refresh_inline_completion(true, false, window, cx); + this.refresh_edit_prediction(true, false, window, cx); }); } @@ -19655,8 +19961,9 @@ impl Editor { Anchor::in_buffer(excerpt_id, buffer_id, hint.position), hint.text(), ); - - new_inlays.push(inlay); + if !inlay.text.chars().contains(&'\n') { + new_inlays.push(inlay); + } }); } @@ -19689,8 +19996,8 @@ impl Editor { self.refresh_selected_text_highlights(true, window, cx); self.refresh_single_line_folds(window, cx); refresh_matching_bracket_highlights(self, window, cx); - if self.has_active_inline_completion() { - self.update_visible_inline_completion(window, cx); + if self.has_active_edit_prediction() { + self.update_visible_edit_prediction(window, cx); } if let Some(project) = self.project.as_ref() { if let Some(edited_buffer) = edited_buffer { @@ -19884,17 +20191,15 @@ impl Editor { } fn settings_changed(&mut self, window: &mut Window, cx: &mut Context<Self>) { - let new_severity = if self.diagnostics_enabled() { - EditorSettings::get_global(cx) + if self.diagnostics_enabled() { + let new_severity = EditorSettings::get_global(cx) .diagnostics_max_severity - .unwrap_or(DiagnosticSeverity::Hint) - } else { - DiagnosticSeverity::Off - }; - self.set_max_diagnostics_severity(new_severity, cx); + .unwrap_or(DiagnosticSeverity::Hint); + self.set_max_diagnostics_severity(new_severity, cx); + } self.tasks_update_task = Some(self.refresh_runnables(window, cx)); self.update_edit_prediction_settings(cx); - self.refresh_inline_completion(true, false, window, cx); + self.refresh_edit_prediction(true, false, window, cx); self.refresh_inline_values(cx); self.refresh_inlay_hints( InlayHintRefreshReason::SettingsChange(inlay_hint_settings( @@ -20503,6 +20808,7 @@ impl Editor { if event.blurred != self.focus_handle { self.last_focused_descendant = Some(event.blurred); } + self.selection_drag_state = SelectionDragState::None; self.refresh_inlay_hints(InlayHintRefreshReason::ModifiersChanged(false), cx); } @@ -20525,7 +20831,7 @@ impl Editor { { self.hide_context_menu(window, cx); } - self.discard_inline_completion(false, cx); + self.discard_edit_prediction(false, cx); cx.emit(EditorEvent::Blurred); cx.notify(); } @@ -20938,13 +21244,6 @@ fn process_completion_for_edit( .is_le(), "replace_range should start before or at cursor position" ); - debug_assert!( - insert_range - .end - .cmp(&cursor_position, &buffer_snapshot) - .is_le(), - "insert_range should end before or at cursor position" - ); let should_replace = match intent { CompletionIntent::CompleteWithInsert => false, @@ -21672,11 +21971,11 @@ impl CodeActionProvider for Entity<Project> { cx: &mut App, ) -> Task<Result<Vec<CodeAction>>> { self.update(cx, |project, cx| { - let code_lens = project.code_lens(buffer, range.clone(), cx); + let code_lens_actions = project.code_lens_actions(buffer, range.clone(), cx); let code_actions = project.code_actions(buffer, range, None, cx); cx.background_spawn(async move { - let (code_lens, code_actions) = join(code_lens, code_actions).await; - Ok(code_lens + let (code_lens_actions, code_actions) = join(code_lens_actions, code_actions).await; + Ok(code_lens_actions .context("code lens fetch")? .into_iter() .chain(code_actions.context("code action fetch")?) @@ -22005,7 +22304,6 @@ impl SemanticsProvider for Entity<Project> { } fn supports_inlay_hints(&self, buffer: &Entity<Buffer>, cx: &mut App) -> bool { - // TODO: make this work for remote projects self.update(cx, |project, cx| { if project .active_debug_session(cx) @@ -22073,7 +22371,7 @@ impl SemanticsProvider for Entity<Project> { // Fallback on using TreeSitter info to determine identifier range buffer.read_with(cx, |buffer, _| { let snapshot = buffer.snapshot(); - let (range, kind) = snapshot.surrounding_word(position); + let (range, kind) = snapshot.surrounding_word(position, false); if kind != Some(CharKind::Word) { return None; } @@ -22118,7 +22416,7 @@ fn consume_contiguous_rows( selections: &mut Peekable<std::slice::Iter<Selection<Point>>>, ) -> (MultiBufferRow, MultiBufferRow) { contiguous_row_selections.push(selection.clone()); - let start_row = MultiBufferRow(selection.start.row); + let start_row = starting_row(selection, display_map); let mut end_row = ending_row(selection, display_map); while let Some(next_selection) = selections.peek() { @@ -22132,6 +22430,14 @@ fn consume_contiguous_rows( (start_row, end_row) } +fn starting_row(selection: &Selection<Point>, display_map: &DisplaySnapshot) -> MultiBufferRow { + if selection.start.column > 0 { + MultiBufferRow(display_map.prev_line_boundary(selection.start).0.row) + } else { + MultiBufferRow(selection.start.row) + } +} + fn ending_row(next_selection: &Selection<Point>, display_map: &DisplaySnapshot) -> MultiBufferRow { if next_selection.end.column > 0 || next_selection.is_empty() { MultiBufferRow(display_map.next_line_boundary(next_selection.end).0.row + 1) @@ -22586,7 +22892,7 @@ impl Render for Editor { syntax: cx.theme().syntax().clone(), status: cx.theme().status().clone(), inlay_hints_style: make_inlay_hints_style(cx), - inline_completion_styles: make_suggestion_styles(cx), + edit_prediction_styles: make_suggestion_styles(cx), unnecessary_code_fade: ThemeSettings::get_global(cx).unnecessary_code_fade, show_underlines: self.diagnostics_enabled(), }, @@ -22981,7 +23287,7 @@ impl InvalidationRegion for SnippetState { } } -fn inline_completion_edit_text( +fn edit_prediction_edit_text( current_snapshot: &BufferSnapshot, edits: &[(Range<Anchor>, String)], edit_preview: &EditPreview, @@ -23001,6 +23307,33 @@ fn inline_completion_edit_text( edit_preview.highlight_edits(current_snapshot, &edits, include_deletions, cx) } +fn edit_prediction_fallback_text(edits: &[(Range<Anchor>, String)], cx: &App) -> HighlightedText { + // Fallback for providers that don't provide edit_preview (like Copilot/Supermaven) + // Just show the raw edit text with basic styling + let mut text = String::new(); + let mut highlights = Vec::new(); + + let insertion_highlight_style = HighlightStyle { + color: Some(cx.theme().colors().text), + ..Default::default() + }; + + for (_, edit_text) in edits { + let start_offset = text.len(); + text.push_str(edit_text); + let end_offset = text.len(); + + if start_offset < end_offset { + highlights.push((start_offset..end_offset, insertion_highlight_style)); + } + } + + HighlightedText { + text: text.into(), + highlights, + } +} + pub fn diagnostic_style(severity: lsp::DiagnosticSeverity, colors: &StatusColors) -> Hsla { match severity { lsp::DiagnosticSeverity::ERROR => colors.error, diff --git a/crates/editor/src/editor_settings.rs b/crates/editor/src/editor_settings.rs index 5d8379ddfb87600f7cd56d10f5684ed333589e78..14f46c0e60dfc3487430b22ea83913984bae3c24 100644 --- a/crates/editor/src/editor_settings.rs +++ b/crates/editor/src/editor_settings.rs @@ -395,6 +395,8 @@ pub enum SnippetSortOrder { Inline, /// Place snippets at the bottom of the completion list Bottom, + /// Do not show snippets in the completion list + None, } #[derive(Clone, Default, Serialize, Deserialize, JsonSchema)] diff --git a/crates/editor/src/editor_tests.rs b/crates/editor/src/editor_tests.rs index b9078301c771372dccf596c107cb96f9352f0251..b31963c9c8c694acebf072e05f693f36a81185af 100644 --- a/crates/editor/src/editor_tests.rs +++ b/crates/editor/src/editor_tests.rs @@ -2,7 +2,7 @@ use super::*; use crate::{ JoinLines, code_context_menus::CodeContextMenu, - inline_completion_tests::FakeInlineCompletionProvider, + edit_prediction_tests::FakeEditPredictionProvider, linked_editing_ranges::LinkedEditingRanges, scroll::scroll_amount::ScrollAmount, test::{ @@ -55,7 +55,7 @@ use util::{ uri, }; use workspace::{ - CloseActiveItem, CloseAllItems, CloseInactiveItems, MoveItemToPaneInDirection, NavigationEntry, + CloseActiveItem, CloseAllItems, CloseOtherItems, MoveItemToPaneInDirection, NavigationEntry, OpenOptions, ViewId, item::{FollowEvent, FollowableItem, Item, ItemHandle, SaveOptions}, }; @@ -2875,11 +2875,11 @@ async fn test_newline_documentation_comments(cx: &mut TestAppContext) { let language = Arc::new( Language::new( LanguageConfig { - documentation: Some(language::DocumentationConfig { + documentation_comment: Some(language::BlockCommentConfig { start: "/**".into(), end: "*/".into(), prefix: "* ".into(), - tab_size: NonZeroU32::new(1).unwrap(), + tab_size: 1, }), ..LanguageConfig::default() @@ -3089,7 +3089,12 @@ async fn test_newline_comments_with_block_comment(cx: &mut TestAppContext) { let lua_language = Arc::new(Language::new( LanguageConfig { line_comments: vec!["--".into()], - block_comment: Some(("--[[".into(), "]]".into())), + block_comment: Some(language::BlockCommentConfig { + start: "--[[".into(), + prefix: "".into(), + end: "]]".into(), + tab_size: 0, + }), ..LanguageConfig::default() }, None, @@ -4719,6 +4724,23 @@ async fn test_toggle_case(cx: &mut TestAppContext) { "}); } +#[gpui::test] +async fn test_convert_to_sentence_case(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + + let mut cx = EditorTestContext::new(cx).await; + + cx.set_state(indoc! {" + «implement-windows-supportˇ» + "}); + cx.update_editor(|e, window, cx| { + e.convert_to_sentence_case(&ConvertToSentenceCase, window, cx) + }); + cx.assert_editor_state(indoc! {" + «Implement windows supportˇ» + "}); +} + #[gpui::test] async fn test_manipulate_text(cx: &mut TestAppContext) { init_test(cx, |_| {}); @@ -5064,6 +5086,33 @@ fn test_move_line_up_down(cx: &mut TestAppContext) { }); } +#[gpui::test] +fn test_move_line_up_selection_at_end_of_fold(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + let editor = cx.add_window(|window, cx| { + let buffer = MultiBuffer::build_simple("\n\n\n\n\n\naaaa\nbbbb\ncccc", cx); + build_editor(buffer, window, cx) + }); + _ = editor.update(cx, |editor, window, cx| { + editor.fold_creases( + vec![Crease::simple( + Point::new(6, 4)..Point::new(7, 4), + FoldPlaceholder::test(), + )], + true, + window, + cx, + ); + editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { + s.select_ranges([Point::new(7, 4)..Point::new(7, 4)]) + }); + assert_eq!(editor.display_text(cx), "\n\n\n\n\n\naaaa⋯\ncccc"); + editor.move_line_up(&MoveLineUp, window, cx); + let buffer_text = editor.buffer.read(cx).snapshot(cx).text(); + assert_eq!(buffer_text, "\n\n\n\n\naaaa\nbbbb\n\ncccc"); + }); +} + #[gpui::test] fn test_move_line_up_down_with_blocks(cx: &mut TestAppContext) { init_test(cx, |_| {}); @@ -5081,7 +5130,6 @@ fn test_move_line_up_down_with_blocks(cx: &mut TestAppContext) { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }], Some(Autoscroll::fit()), cx, @@ -5124,7 +5172,6 @@ async fn test_selections_and_replace_blocks(cx: &mut TestAppContext) { style: BlockStyle::Sticky, render: Arc::new(|_| gpui::div().into_any_element()), priority: 0, - render_in_minimap: true, }], None, cx, @@ -7204,12 +7251,12 @@ async fn test_undo_format_scrolls_to_last_edit_pos(cx: &mut TestAppContext) { } #[gpui::test] -async fn test_undo_inline_completion_scrolls_to_edit_pos(cx: &mut TestAppContext) { +async fn test_undo_edit_prediction_scrolls_to_edit_pos(cx: &mut TestAppContext) { init_test(cx, |_| {}); let mut cx = EditorTestContext::new(cx).await; - let provider = cx.new(|_| FakeInlineCompletionProvider::default()); + let provider = cx.new(|_| FakeEditPredictionProvider::default()); cx.update_editor(|editor, window, cx| { editor.set_edit_prediction_provider(Some(provider.clone()), window, cx); }); @@ -7232,7 +7279,7 @@ async fn test_undo_inline_completion_scrolls_to_edit_pos(cx: &mut TestAppContext cx.update(|_, cx| { provider.update(cx, |provider, _| { - provider.set_inline_completion(Some(inline_completion::InlineCompletion { + provider.set_edit_prediction(Some(edit_prediction::EditPrediction { id: None, edits: vec![(edit_position..edit_position, "X".into())], edit_preview: None, @@ -7240,7 +7287,7 @@ async fn test_undo_inline_completion_scrolls_to_edit_pos(cx: &mut TestAppContext }) }); - cx.update_editor(|editor, window, cx| editor.update_visible_inline_completion(window, cx)); + cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx)); cx.update_editor(|editor, window, cx| { editor.accept_edit_prediction(&crate::AcceptEditPrediction, window, cx) }); @@ -7922,6 +7969,38 @@ async fn test_select_larger_smaller_syntax_node_for_string(cx: &mut TestAppConte }); } +#[gpui::test] +async fn test_unwrap_syntax_node(cx: &mut gpui::TestAppContext) { + init_test(cx, |_| {}); + + let mut cx = EditorTestContext::new(cx).await; + + let language = Arc::new(Language::new( + LanguageConfig::default(), + Some(tree_sitter_rust::LANGUAGE.into()), + )); + + cx.update_buffer(|buffer, cx| { + buffer.set_language(Some(language), cx); + }); + + cx.set_state( + &r#" + use mod1::mod2::{«mod3ˇ», mod4}; + "# + .unindent(), + ); + cx.update_editor(|editor, window, cx| { + editor.unwrap_syntax_node(&UnwrapSyntaxNode, window, cx); + }); + cx.assert_editor_state( + &r#" + use mod1::mod2::«mod3ˇ»; + "# + .unindent(), + ); +} + #[gpui::test] async fn test_fold_function_bodies(cx: &mut TestAppContext) { init_test(cx, |_| {}); @@ -8565,6 +8644,7 @@ async fn test_autoclose_with_embedded_language(cx: &mut TestAppContext) { cx.language_registry().add(html_language.clone()); cx.language_registry().add(javascript_language.clone()); + cx.executor().run_until_parked(); cx.update_buffer(|buffer, cx| { buffer.set_language(Some(html_language), cx); @@ -9572,6 +9652,74 @@ async fn test_document_format_during_save(cx: &mut TestAppContext) { } } +#[gpui::test] +async fn test_redo_after_noop_format(cx: &mut TestAppContext) { + init_test(cx, |settings| { + settings.defaults.ensure_final_newline_on_save = Some(false); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_file(path!("/file.txt"), "foo".into()).await; + + let project = Project::test(fs, [path!("/file.txt").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/file.txt"), cx) + }) + .await + .unwrap(); + + let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); + let (editor, cx) = cx.add_window_view(|window, cx| { + build_editor_with_project(project.clone(), buffer, window, cx) + }); + editor.update_in(cx, |editor, window, cx| { + editor.change_selections(SelectionEffects::default(), window, cx, |s| { + s.select_ranges([0..0]) + }); + }); + assert!(!cx.read(|cx| editor.is_dirty(cx))); + + editor.update_in(cx, |editor, window, cx| { + editor.handle_input("\n", window, cx) + }); + cx.run_until_parked(); + save(&editor, &project, cx).await; + assert_eq!("\nfoo", editor.read_with(cx, |editor, cx| editor.text(cx))); + + editor.update_in(cx, |editor, window, cx| { + editor.undo(&Default::default(), window, cx); + }); + save(&editor, &project, cx).await; + assert_eq!("foo", editor.read_with(cx, |editor, cx| editor.text(cx))); + + editor.update_in(cx, |editor, window, cx| { + editor.redo(&Default::default(), window, cx); + }); + cx.run_until_parked(); + assert_eq!("\nfoo", editor.read_with(cx, |editor, cx| editor.text(cx))); + + async fn save(editor: &Entity<Editor>, project: &Entity<Project>, cx: &mut VisualTestContext) { + let save = editor + .update_in(cx, |editor, window, cx| { + editor.save( + SaveOptions { + format: true, + autosave: false, + }, + project.clone(), + window, + cx, + ) + }) + .unwrap(); + cx.executor().start_waiting(); + save.await; + assert!(!cx.read(|cx| editor.is_dirty(cx))); + } +} + #[gpui::test] async fn test_multibuffer_format_during_save(cx: &mut TestAppContext) { init_test(cx, |_| {}); @@ -9957,8 +10105,14 @@ async fn test_autosave_with_dirty_buffers(cx: &mut TestAppContext) { ); } -#[gpui::test] -async fn test_range_format_during_save(cx: &mut TestAppContext) { +async fn setup_range_format_test( + cx: &mut TestAppContext, +) -> ( + Entity<Project>, + Entity<Editor>, + &mut gpui::VisualTestContext, + lsp::FakeLanguageServer, +) { init_test(cx, |_| {}); let fs = FakeFs::new(cx.executor()); @@ -9973,9 +10127,9 @@ async fn test_range_format_during_save(cx: &mut TestAppContext) { FakeLspAdapter { capabilities: lsp::ServerCapabilities { document_range_formatting_provider: Some(lsp::OneOf::Left(true)), - ..Default::default() + ..lsp::ServerCapabilities::default() }, - ..Default::default() + ..FakeLspAdapter::default() }, ); @@ -9990,14 +10144,22 @@ async fn test_range_format_during_save(cx: &mut TestAppContext) { let (editor, cx) = cx.add_window_view(|window, cx| { build_editor_with_project(project.clone(), buffer, window, cx) }); + + cx.executor().start_waiting(); + let fake_server = fake_servers.next().await.unwrap(); + + (project, editor, cx, fake_server) +} + +#[gpui::test] +async fn test_range_format_on_save_success(cx: &mut TestAppContext) { + let (project, editor, cx, fake_server) = setup_range_format_test(cx).await; + editor.update_in(cx, |editor, window, cx| { editor.set_text("one\ntwo\nthree\n", window, cx) }); assert!(cx.read(|cx| editor.is_dirty(cx))); - cx.executor().start_waiting(); - let fake_server = fake_servers.next().await.unwrap(); - let save = editor .update_in(cx, |editor, window, cx| { editor.save( @@ -10032,13 +10194,18 @@ async fn test_range_format_during_save(cx: &mut TestAppContext) { "one, two\nthree\n" ); assert!(!cx.read(|cx| editor.is_dirty(cx))); +} + +#[gpui::test] +async fn test_range_format_on_save_timeout(cx: &mut TestAppContext) { + let (project, editor, cx, fake_server) = setup_range_format_test(cx).await; editor.update_in(cx, |editor, window, cx| { editor.set_text("one\ntwo\nthree\n", window, cx) }); assert!(cx.read(|cx| editor.is_dirty(cx))); - // Ensure we can still save even if formatting hangs. + // Test that save still works when formatting hangs fake_server.set_request_handler::<lsp::request::RangeFormatting, _, _>( move |params, _| async move { assert_eq!( @@ -10070,8 +10237,13 @@ async fn test_range_format_during_save(cx: &mut TestAppContext) { "one\ntwo\nthree\n" ); assert!(!cx.read(|cx| editor.is_dirty(cx))); +} + +#[gpui::test] +async fn test_range_format_not_called_for_clean_buffer(cx: &mut TestAppContext) { + let (project, editor, cx, fake_server) = setup_range_format_test(cx).await; - // For non-dirty buffer, no formatting request should be sent + // Buffer starts clean, no formatting should be requested let save = editor .update_in(cx, |editor, window, cx| { editor.save( @@ -10092,6 +10264,12 @@ async fn test_range_format_during_save(cx: &mut TestAppContext) { .next(); cx.executor().start_waiting(); save.await; + cx.run_until_parked(); +} + +#[gpui::test] +async fn test_range_format_respects_language_tab_size_override(cx: &mut TestAppContext) { + let (project, editor, cx, fake_server) = setup_range_format_test(cx).await; // Set Rust language override and assert overridden tabsize is sent to language server update_test_language_settings(cx, |settings| { @@ -10105,7 +10283,7 @@ async fn test_range_format_during_save(cx: &mut TestAppContext) { }); editor.update_in(cx, |editor, window, cx| { - editor.set_text("somehting_new\n", window, cx) + editor.set_text("something_new\n", window, cx) }); assert!(cx.read(|cx| editor.is_dirty(cx))); let save = editor @@ -13255,6 +13433,178 @@ async fn test_as_is_completions(cx: &mut TestAppContext) { cx.assert_editor_state("fn a() {}\n unsafeˇ"); } +#[gpui::test] +async fn test_panic_during_c_completions(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + let language = + Arc::try_unwrap(languages::language("c", tree_sitter_c::LANGUAGE.into())).unwrap(); + let mut cx = EditorLspTestContext::new( + language, + lsp::ServerCapabilities { + completion_provider: Some(lsp::CompletionOptions { + ..lsp::CompletionOptions::default() + }), + ..lsp::ServerCapabilities::default() + }, + cx, + ) + .await; + + cx.set_state( + "#ifndef BAR_H +#define BAR_H + +#include <stdbool.h> + +int fn_branch(bool do_branch1, bool do_branch2); + +#endif // BAR_H +ˇ", + ); + cx.executor().run_until_parked(); + cx.update_editor(|editor, window, cx| { + editor.handle_input("#", window, cx); + }); + cx.executor().run_until_parked(); + cx.update_editor(|editor, window, cx| { + editor.handle_input("i", window, cx); + }); + cx.executor().run_until_parked(); + cx.update_editor(|editor, window, cx| { + editor.handle_input("n", window, cx); + }); + cx.executor().run_until_parked(); + cx.assert_editor_state( + "#ifndef BAR_H +#define BAR_H + +#include <stdbool.h> + +int fn_branch(bool do_branch1, bool do_branch2); + +#endif // BAR_H +#inˇ", + ); + + cx.lsp + .set_request_handler::<lsp::request::Completion, _, _>(move |_, _| async move { + Ok(Some(lsp::CompletionResponse::List(lsp::CompletionList { + is_incomplete: false, + item_defaults: None, + items: vec![lsp::CompletionItem { + kind: Some(lsp::CompletionItemKind::SNIPPET), + label_details: Some(lsp::CompletionItemLabelDetails { + detail: Some("header".to_string()), + description: None, + }), + label: " include".to_string(), + text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit { + range: lsp::Range { + start: lsp::Position { + line: 8, + character: 1, + }, + end: lsp::Position { + line: 8, + character: 1, + }, + }, + new_text: "include \"$0\"".to_string(), + })), + sort_text: Some("40b67681include".to_string()), + insert_text_format: Some(lsp::InsertTextFormat::SNIPPET), + filter_text: Some("include".to_string()), + insert_text: Some("include \"$0\"".to_string()), + ..lsp::CompletionItem::default() + }], + }))) + }); + cx.update_editor(|editor, window, cx| { + editor.show_completions(&ShowCompletions { trigger: None }, window, cx); + }); + cx.executor().run_until_parked(); + cx.update_editor(|editor, window, cx| { + editor.confirm_completion(&ConfirmCompletion::default(), window, cx) + }); + cx.executor().run_until_parked(); + cx.assert_editor_state( + "#ifndef BAR_H +#define BAR_H + +#include <stdbool.h> + +int fn_branch(bool do_branch1, bool do_branch2); + +#endif // BAR_H +#include \"ˇ\"", + ); + + cx.lsp + .set_request_handler::<lsp::request::Completion, _, _>(move |_, _| async move { + Ok(Some(lsp::CompletionResponse::List(lsp::CompletionList { + is_incomplete: true, + item_defaults: None, + items: vec![lsp::CompletionItem { + kind: Some(lsp::CompletionItemKind::FILE), + label: "AGL/".to_string(), + text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit { + range: lsp::Range { + start: lsp::Position { + line: 8, + character: 10, + }, + end: lsp::Position { + line: 8, + character: 11, + }, + }, + new_text: "AGL/".to_string(), + })), + sort_text: Some("40b67681AGL/".to_string()), + insert_text_format: Some(lsp::InsertTextFormat::PLAIN_TEXT), + filter_text: Some("AGL/".to_string()), + insert_text: Some("AGL/".to_string()), + ..lsp::CompletionItem::default() + }], + }))) + }); + cx.update_editor(|editor, window, cx| { + editor.show_completions(&ShowCompletions { trigger: None }, window, cx); + }); + cx.executor().run_until_parked(); + cx.update_editor(|editor, window, cx| { + editor.confirm_completion(&ConfirmCompletion::default(), window, cx) + }); + cx.executor().run_until_parked(); + cx.assert_editor_state( + r##"#ifndef BAR_H +#define BAR_H + +#include <stdbool.h> + +int fn_branch(bool do_branch1, bool do_branch2); + +#endif // BAR_H +#include "AGL/ˇ"##, + ); + + cx.update_editor(|editor, window, cx| { + editor.handle_input("\"", window, cx); + }); + cx.executor().run_until_parked(); + cx.assert_editor_state( + r##"#ifndef BAR_H +#define BAR_H + +#include <stdbool.h> + +int fn_branch(bool do_branch1, bool do_branch2); + +#endif // BAR_H +#include "AGL/"ˇ"##, + ); +} + #[gpui::test] async fn test_no_duplicated_completion_requests(cx: &mut TestAppContext) { init_test(cx, |_| {}); @@ -13740,7 +14090,12 @@ async fn test_toggle_block_comment(cx: &mut TestAppContext) { Language::new( LanguageConfig { name: "HTML".into(), - block_comment: Some(("<!-- ".into(), " -->".into())), + block_comment: Some(BlockCommentConfig { + start: "<!-- ".into(), + prefix: "".into(), + end: " -->".into(), + tab_size: 0, + }), ..Default::default() }, Some(tree_sitter_html::LANGUAGE.into()), @@ -14736,7 +15091,7 @@ async fn go_to_prev_overlapping_diagnostic(executor: BackgroundExecutor, cx: &mu executor.run_until_parked(); cx.update_editor(|editor, window, cx| { - editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic, window, cx); + editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic::default(), window, cx); }); cx.assert_editor_state(indoc! {" @@ -14745,7 +15100,7 @@ async fn go_to_prev_overlapping_diagnostic(executor: BackgroundExecutor, cx: &mu "}); cx.update_editor(|editor, window, cx| { - editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic, window, cx); + editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic::default(), window, cx); }); cx.assert_editor_state(indoc! {" @@ -14754,7 +15109,7 @@ async fn go_to_prev_overlapping_diagnostic(executor: BackgroundExecutor, cx: &mu "}); cx.update_editor(|editor, window, cx| { - editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic, window, cx); + editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic::default(), window, cx); }); cx.assert_editor_state(indoc! {" @@ -14763,7 +15118,7 @@ async fn go_to_prev_overlapping_diagnostic(executor: BackgroundExecutor, cx: &mu "}); cx.update_editor(|editor, window, cx| { - editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic, window, cx); + editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic::default(), window, cx); }); cx.assert_editor_state(indoc! {" @@ -16761,7 +17116,7 @@ async fn test_multibuffer_reverts(cx: &mut TestAppContext) { } #[gpui::test] -async fn test_mutlibuffer_in_navigation_history(cx: &mut TestAppContext) { +async fn test_multibuffer_in_navigation_history(cx: &mut TestAppContext) { init_test(cx, |_| {}); let cols = 4; @@ -20229,7 +20584,7 @@ async fn test_multi_buffer_navigation_with_folded_buffers(cx: &mut TestAppContex } #[gpui::test] -async fn test_inline_completion_text(cx: &mut TestAppContext) { +async fn test_edit_prediction_text(cx: &mut TestAppContext) { init_test(cx, |_| {}); // Simple insertion @@ -20328,7 +20683,7 @@ async fn test_inline_completion_text(cx: &mut TestAppContext) { } #[gpui::test] -async fn test_inline_completion_text_with_deletions(cx: &mut TestAppContext) { +async fn test_edit_prediction_text_with_deletions(cx: &mut TestAppContext) { init_test(cx, |_| {}); // Deletion @@ -20418,7 +20773,7 @@ async fn assert_highlighted_edits( .await; cx.update(|_window, cx| { - let highlighted_edits = inline_completion_edit_text( + let highlighted_edits = edit_prediction_edit_text( &snapshot.as_singleton().unwrap().2, &edits, &edit_preview, @@ -21190,16 +21545,32 @@ async fn test_apply_code_lens_actions_with_commands(cx: &mut gpui::TestAppContex }, ); - let (buffer, _handle) = project - .update(cx, |p, cx| { - p.open_local_buffer_with_lsp(path!("/dir/a.ts"), cx) + let editor = workspace + .update(cx, |workspace, window, cx| { + workspace.open_abs_path( + PathBuf::from(path!("/dir/a.ts")), + OpenOptions::default(), + window, + cx, + ) }) + .unwrap() .await + .unwrap() + .downcast::<Editor>() .unwrap(); cx.executor().run_until_parked(); let fake_server = fake_language_servers.next().await.unwrap(); + let buffer = editor.update(cx, |editor, cx| { + editor + .buffer() + .read(cx) + .as_singleton() + .expect("have opened a single file by path") + }); + let buffer_snapshot = buffer.update(cx, |buffer, _| buffer.snapshot()); let anchor = buffer_snapshot.anchor_at(0, text::Bias::Left); drop(buffer_snapshot); @@ -21257,7 +21628,7 @@ async fn test_apply_code_lens_actions_with_commands(cx: &mut gpui::TestAppContex assert_eq!( actions.len(), 1, - "Should have only one valid action for the 0..0 range" + "Should have only one valid action for the 0..0 range, got: {actions:#?}" ); let action = actions[0].clone(); let apply = project.update(cx, |project, cx| { @@ -21303,7 +21674,7 @@ async fn test_apply_code_lens_actions_with_commands(cx: &mut gpui::TestAppContex .into_iter() .collect(), ), - ..Default::default() + ..lsp::WorkspaceEdit::default() }, }, ) @@ -21326,6 +21697,38 @@ async fn test_apply_code_lens_actions_with_commands(cx: &mut gpui::TestAppContex buffer.undo(cx); assert_eq!(buffer.text(), "a"); }); + + let actions_after_edits = cx + .update_window(*workspace, |_, window, cx| { + project.code_actions(&buffer, anchor..anchor, window, cx) + }) + .unwrap() + .await + .unwrap(); + assert_eq!( + actions, actions_after_edits, + "For the same selection, same code lens actions should be returned" + ); + + let _responses = + fake_server.set_request_handler::<lsp::request::CodeLensRequest, _, _>(|_, _| async move { + panic!("No more code lens requests are expected"); + }); + editor.update_in(cx, |editor, window, cx| { + editor.select_all(&SelectAll, window, cx); + }); + cx.executor().run_until_parked(); + let new_actions = cx + .update_window(*workspace, |_, window, cx| { + project.code_actions(&buffer, anchor..anchor, window, cx) + }) + .unwrap() + .await + .unwrap(); + assert_eq!( + actions, new_actions, + "Code lens are queried for the same range and should get the same set back, but without additional LSP queries now" + ); } #[gpui::test] @@ -21465,7 +21868,7 @@ println!("5"); .unwrap(); pane_1 .update_in(cx, |pane, window, cx| { - pane.close_inactive_items(&CloseInactiveItems::default(), window, cx) + pane.close_other_items(&CloseOtherItems::default(), None, window, cx) }) .await .unwrap(); @@ -21501,7 +21904,7 @@ println!("5"); .unwrap(); pane_2 .update_in(cx, |pane, window, cx| { - pane.close_inactive_items(&CloseInactiveItems::default(), window, cx) + pane.close_other_items(&CloseOtherItems::default(), None, window, cx) }) .await .unwrap(); @@ -22465,6 +22868,435 @@ async fn test_indent_on_newline_for_python(cx: &mut TestAppContext) { "}); } +#[gpui::test] +async fn test_tab_in_leading_whitespace_auto_indents_for_bash(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + + let mut cx = EditorTestContext::new(cx).await; + let language = languages::language("bash", tree_sitter_bash::LANGUAGE.into()); + cx.update_buffer(|buffer, cx| buffer.set_language(Some(language), cx)); + + // test cursor move to start of each line on tab + // for `if`, `elif`, `else`, `while`, `for`, `case` and `function` + cx.set_state(indoc! {" + function main() { + ˇ for item in $items; do + ˇ while [ -n \"$item\" ]; do + ˇ if [ \"$value\" -gt 10 ]; then + ˇ continue + ˇ elif [ \"$value\" -lt 0 ]; then + ˇ break + ˇ else + ˇ echo \"$item\" + ˇ fi + ˇ done + ˇ done + ˇ} + "}); + cx.update_editor(|e, window, cx| e.tab(&Tab, window, cx)); + cx.assert_editor_state(indoc! {" + function main() { + ˇfor item in $items; do + ˇwhile [ -n \"$item\" ]; do + ˇif [ \"$value\" -gt 10 ]; then + ˇcontinue + ˇelif [ \"$value\" -lt 0 ]; then + ˇbreak + ˇelse + ˇecho \"$item\" + ˇfi + ˇdone + ˇdone + ˇ} + "}); + // test relative indent is preserved when tab + cx.update_editor(|e, window, cx| e.tab(&Tab, window, cx)); + cx.assert_editor_state(indoc! {" + function main() { + ˇfor item in $items; do + ˇwhile [ -n \"$item\" ]; do + ˇif [ \"$value\" -gt 10 ]; then + ˇcontinue + ˇelif [ \"$value\" -lt 0 ]; then + ˇbreak + ˇelse + ˇecho \"$item\" + ˇfi + ˇdone + ˇdone + ˇ} + "}); + + // test cursor move to start of each line on tab + // for `case` statement with patterns + cx.set_state(indoc! {" + function handle() { + ˇ case \"$1\" in + ˇ start) + ˇ echo \"a\" + ˇ ;; + ˇ stop) + ˇ echo \"b\" + ˇ ;; + ˇ *) + ˇ echo \"c\" + ˇ ;; + ˇ esac + ˇ} + "}); + cx.update_editor(|e, window, cx| e.tab(&Tab, window, cx)); + cx.assert_editor_state(indoc! {" + function handle() { + ˇcase \"$1\" in + ˇstart) + ˇecho \"a\" + ˇ;; + ˇstop) + ˇecho \"b\" + ˇ;; + ˇ*) + ˇecho \"c\" + ˇ;; + ˇesac + ˇ} + "}); +} + +#[gpui::test] +async fn test_indent_after_input_for_bash(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + + let mut cx = EditorTestContext::new(cx).await; + let language = languages::language("bash", tree_sitter_bash::LANGUAGE.into()); + cx.update_buffer(|buffer, cx| buffer.set_language(Some(language), cx)); + + // test indents on comment insert + cx.set_state(indoc! {" + function main() { + ˇ for item in $items; do + ˇ while [ -n \"$item\" ]; do + ˇ if [ \"$value\" -gt 10 ]; then + ˇ continue + ˇ elif [ \"$value\" -lt 0 ]; then + ˇ break + ˇ else + ˇ echo \"$item\" + ˇ fi + ˇ done + ˇ done + ˇ} + "}); + cx.update_editor(|e, window, cx| e.handle_input("#", window, cx)); + cx.assert_editor_state(indoc! {" + function main() { + #ˇ for item in $items; do + #ˇ while [ -n \"$item\" ]; do + #ˇ if [ \"$value\" -gt 10 ]; then + #ˇ continue + #ˇ elif [ \"$value\" -lt 0 ]; then + #ˇ break + #ˇ else + #ˇ echo \"$item\" + #ˇ fi + #ˇ done + #ˇ done + #ˇ} + "}); +} + +#[gpui::test] +async fn test_outdent_after_input_for_bash(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + + let mut cx = EditorTestContext::new(cx).await; + let language = languages::language("bash", tree_sitter_bash::LANGUAGE.into()); + cx.update_buffer(|buffer, cx| buffer.set_language(Some(language), cx)); + + // test `else` auto outdents when typed inside `if` block + cx.set_state(indoc! {" + if [ \"$1\" = \"test\" ]; then + echo \"foo bar\" + ˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.handle_input("else", window, cx); + }); + cx.assert_editor_state(indoc! {" + if [ \"$1\" = \"test\" ]; then + echo \"foo bar\" + elseˇ + "}); + + // test `elif` auto outdents when typed inside `if` block + cx.set_state(indoc! {" + if [ \"$1\" = \"test\" ]; then + echo \"foo bar\" + ˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.handle_input("elif", window, cx); + }); + cx.assert_editor_state(indoc! {" + if [ \"$1\" = \"test\" ]; then + echo \"foo bar\" + elifˇ + "}); + + // test `fi` auto outdents when typed inside `else` block + cx.set_state(indoc! {" + if [ \"$1\" = \"test\" ]; then + echo \"foo bar\" + else + echo \"bar baz\" + ˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.handle_input("fi", window, cx); + }); + cx.assert_editor_state(indoc! {" + if [ \"$1\" = \"test\" ]; then + echo \"foo bar\" + else + echo \"bar baz\" + fiˇ + "}); + + // test `done` auto outdents when typed inside `while` block + cx.set_state(indoc! {" + while read line; do + echo \"$line\" + ˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.handle_input("done", window, cx); + }); + cx.assert_editor_state(indoc! {" + while read line; do + echo \"$line\" + doneˇ + "}); + + // test `done` auto outdents when typed inside `for` block + cx.set_state(indoc! {" + for file in *.txt; do + cat \"$file\" + ˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.handle_input("done", window, cx); + }); + cx.assert_editor_state(indoc! {" + for file in *.txt; do + cat \"$file\" + doneˇ + "}); + + // test `esac` auto outdents when typed inside `case` block + cx.set_state(indoc! {" + case \"$1\" in + start) + echo \"foo bar\" + ;; + stop) + echo \"bar baz\" + ;; + ˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.handle_input("esac", window, cx); + }); + cx.assert_editor_state(indoc! {" + case \"$1\" in + start) + echo \"foo bar\" + ;; + stop) + echo \"bar baz\" + ;; + esacˇ + "}); + + // test `*)` auto outdents when typed inside `case` block + cx.set_state(indoc! {" + case \"$1\" in + start) + echo \"foo bar\" + ;; + ˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.handle_input("*)", window, cx); + }); + cx.assert_editor_state(indoc! {" + case \"$1\" in + start) + echo \"foo bar\" + ;; + *)ˇ + "}); + + // test `fi` outdents to correct level with nested if blocks + cx.set_state(indoc! {" + if [ \"$1\" = \"test\" ]; then + echo \"outer if\" + if [ \"$2\" = \"debug\" ]; then + echo \"inner if\" + ˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.handle_input("fi", window, cx); + }); + cx.assert_editor_state(indoc! {" + if [ \"$1\" = \"test\" ]; then + echo \"outer if\" + if [ \"$2\" = \"debug\" ]; then + echo \"inner if\" + fiˇ + "}); +} + +#[gpui::test] +async fn test_indent_on_newline_for_bash(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + update_test_language_settings(cx, |settings| { + settings.defaults.extend_comment_on_newline = Some(false); + }); + let mut cx = EditorTestContext::new(cx).await; + let language = languages::language("bash", tree_sitter_bash::LANGUAGE.into()); + cx.update_buffer(|buffer, cx| buffer.set_language(Some(language), cx)); + + // test correct indent after newline on comment + cx.set_state(indoc! {" + # COMMENT:ˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.newline(&Newline, window, cx); + }); + cx.assert_editor_state(indoc! {" + # COMMENT: + ˇ + "}); + + // test correct indent after newline after `then` + cx.set_state(indoc! {" + + if [ \"$1\" = \"test\" ]; thenˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.newline(&Newline, window, cx); + }); + cx.run_until_parked(); + cx.assert_editor_state(indoc! {" + + if [ \"$1\" = \"test\" ]; then + ˇ + "}); + + // test correct indent after newline after `else` + cx.set_state(indoc! {" + if [ \"$1\" = \"test\" ]; then + elseˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.newline(&Newline, window, cx); + }); + cx.run_until_parked(); + cx.assert_editor_state(indoc! {" + if [ \"$1\" = \"test\" ]; then + else + ˇ + "}); + + // test correct indent after newline after `elif` + cx.set_state(indoc! {" + if [ \"$1\" = \"test\" ]; then + elifˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.newline(&Newline, window, cx); + }); + cx.run_until_parked(); + cx.assert_editor_state(indoc! {" + if [ \"$1\" = \"test\" ]; then + elif + ˇ + "}); + + // test correct indent after newline after `do` + cx.set_state(indoc! {" + for file in *.txt; doˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.newline(&Newline, window, cx); + }); + cx.run_until_parked(); + cx.assert_editor_state(indoc! {" + for file in *.txt; do + ˇ + "}); + + // test correct indent after newline after case pattern + cx.set_state(indoc! {" + case \"$1\" in + start)ˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.newline(&Newline, window, cx); + }); + cx.run_until_parked(); + cx.assert_editor_state(indoc! {" + case \"$1\" in + start) + ˇ + "}); + + // test correct indent after newline after case pattern + cx.set_state(indoc! {" + case \"$1\" in + start) + ;; + *)ˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.newline(&Newline, window, cx); + }); + cx.run_until_parked(); + cx.assert_editor_state(indoc! {" + case \"$1\" in + start) + ;; + *) + ˇ + "}); + + // test correct indent after newline after function opening brace + cx.set_state(indoc! {" + function test() {ˇ} + "}); + cx.update_editor(|editor, window, cx| { + editor.newline(&Newline, window, cx); + }); + cx.run_until_parked(); + cx.assert_editor_state(indoc! {" + function test() { + ˇ + } + "}); + + // test no extra indent after semicolon on same line + cx.set_state(indoc! {" + echo \"test\";ˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.newline(&Newline, window, cx); + }); + cx.run_until_parked(); + cx.assert_editor_state(indoc! {" + echo \"test\"; + ˇ + "}); +} + fn empty_range(row: usize, column: usize) -> Range<DisplayPoint> { let point = DisplayPoint::new(DisplayRow(row as u32), column as u32); point..point @@ -22710,7 +23542,7 @@ pub(crate) fn init_test(cx: &mut TestAppContext, f: fn(&mut AllLanguageSettingsC workspace::init_settings(cx); crate::init(cx); }); - + zlog::init_test(); update_test_language_settings(cx, f); } diff --git a/crates/editor/src/element.rs b/crates/editor/src/element.rs index 014583912fd3233d3bc1872188a15679d766ff2b..034fff970d5c7d251ba8e246637274f6897a7b26 100644 --- a/crates/editor/src/element.rs +++ b/crates/editor/src/element.rs @@ -3,13 +3,13 @@ use crate::{ CodeActionSource, ColumnarMode, ConflictsOurs, ConflictsOursMarker, ConflictsOuter, ConflictsTheirs, ConflictsTheirsMarker, ContextMenuPlacement, CursorShape, CustomBlockId, DisplayDiffHunk, DisplayPoint, DisplayRow, DocumentHighlightRead, DocumentHighlightWrite, - EditDisplayMode, Editor, EditorMode, EditorSettings, EditorSnapshot, EditorStyle, - FILE_HEADER_HEIGHT, FocusedBlock, GutterDimensions, HalfPageDown, HalfPageUp, HandleInput, - HoveredCursor, InlayHintRefreshReason, InlineCompletion, JumpData, LineDown, LineHighlight, - LineUp, MAX_LINE_LEN, MINIMAP_FONT_SIZE, MULTI_BUFFER_EXCERPT_HEADER_HEIGHT, OpenExcerpts, - PageDown, PageUp, PhantomBreakpointIndicator, Point, RowExt, RowRangeExt, SelectPhase, + EditDisplayMode, EditPrediction, Editor, EditorMode, EditorSettings, EditorSnapshot, + EditorStyle, FILE_HEADER_HEIGHT, FocusedBlock, GutterDimensions, HalfPageDown, HalfPageUp, + HandleInput, HoveredCursor, InlayHintRefreshReason, JumpData, LineDown, LineHighlight, LineUp, + MAX_LINE_LEN, MINIMAP_FONT_SIZE, MULTI_BUFFER_EXCERPT_HEADER_HEIGHT, OpenExcerpts, PageDown, + PageUp, PhantomBreakpointIndicator, Point, RowExt, RowRangeExt, SelectPhase, SelectedTextHighlight, Selection, SelectionDragState, SoftWrap, StickyHeaderExcerpt, ToPoint, - ToggleFold, + ToggleFold, ToggleFoldAll, code_context_menus::{CodeActionsMenu, MENU_ASIDE_MAX_WIDTH, MENU_ASIDE_MIN_WIDTH, MENU_GAP}, display_map::{ Block, BlockContext, BlockStyle, ChunkRendererId, DisplaySnapshot, EditorMargins, @@ -43,11 +43,11 @@ use gpui::{ Bounds, ClickEvent, ContentMask, Context, Corner, Corners, CursorStyle, DispatchPhase, Edges, Element, ElementInputHandler, Entity, Focusable as _, FontId, GlobalElementId, Hitbox, HitboxBehavior, Hsla, InteractiveElement, IntoElement, IsZero, Keystroke, Length, - ModifiersChangedEvent, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, PaintQuad, - ParentElement, Pixels, ScrollDelta, ScrollHandle, ScrollWheelEvent, ShapedLine, SharedString, - Size, StatefulInteractiveElement, Style, Styled, TextRun, TextStyleRefinement, WeakEntity, - Window, anchored, deferred, div, fill, linear_color_stop, linear_gradient, outline, point, px, - quad, relative, size, solid_background, transparent_black, + ModifiersChangedEvent, MouseButton, MouseClickEvent, MouseDownEvent, MouseMoveEvent, + MouseUpEvent, PaintQuad, ParentElement, Pixels, ScrollDelta, ScrollHandle, ScrollWheelEvent, + ShapedLine, SharedString, Size, StatefulInteractiveElement, Style, Styled, TextRun, + TextStyleRefinement, WeakEntity, Window, anchored, deferred, div, fill, linear_color_stop, + linear_gradient, outline, point, px, quad, relative, size, solid_background, transparent_black, }; use itertools::Itertools; use language::language_settings::{ @@ -216,6 +216,7 @@ impl EditorElement { register_action(editor, window, Editor::newline_above); register_action(editor, window, Editor::newline_below); register_action(editor, window, Editor::backspace); + register_action(editor, window, Editor::blame_hover); register_action(editor, window, Editor::delete); register_action(editor, window, Editor::tab); register_action(editor, window, Editor::backtab); @@ -229,7 +230,6 @@ impl EditorElement { register_action(editor, window, Editor::sort_lines_case_insensitive); register_action(editor, window, Editor::reverse_lines); register_action(editor, window, Editor::shuffle_lines); - register_action(editor, window, Editor::toggle_case); register_action(editor, window, Editor::convert_indentation_to_spaces); register_action(editor, window, Editor::convert_indentation_to_tabs); register_action(editor, window, Editor::convert_to_upper_case); @@ -240,6 +240,8 @@ impl EditorElement { register_action(editor, window, Editor::convert_to_upper_camel_case); register_action(editor, window, Editor::convert_to_lower_camel_case); register_action(editor, window, Editor::convert_to_opposite_case); + register_action(editor, window, Editor::convert_to_sentence_case); + register_action(editor, window, Editor::toggle_case); register_action(editor, window, Editor::convert_to_rot13); register_action(editor, window, Editor::convert_to_rot47); register_action(editor, window, Editor::delete_to_previous_word_start); @@ -261,6 +263,7 @@ impl EditorElement { register_action(editor, window, Editor::kill_ring_yank); register_action(editor, window, Editor::copy); register_action(editor, window, Editor::copy_and_trim); + register_action(editor, window, Editor::diff_clipboard_with_selection); register_action(editor, window, Editor::paste); register_action(editor, window, Editor::undo); register_action(editor, window, Editor::redo); @@ -354,6 +357,7 @@ impl EditorElement { register_action(editor, window, Editor::toggle_comments); register_action(editor, window, Editor::select_larger_syntax_node); register_action(editor, window, Editor::select_smaller_syntax_node); + register_action(editor, window, Editor::unwrap_syntax_node); register_action(editor, window, Editor::select_enclosing_symbol); register_action(editor, window, Editor::move_to_enclosing_bracket); register_action(editor, window, Editor::undo_selection); @@ -416,6 +420,7 @@ impl EditorElement { register_action(editor, window, Editor::fold_recursive); register_action(editor, window, Editor::toggle_fold); register_action(editor, window, Editor::toggle_fold_recursive); + register_action(editor, window, Editor::toggle_fold_all); register_action(editor, window, Editor::unfold_lines); register_action(editor, window, Editor::unfold_recursive); register_action(editor, window, Editor::unfold_all); @@ -550,7 +555,7 @@ impl EditorElement { register_action(editor, window, Editor::signature_help_next); register_action(editor, window, Editor::next_edit_prediction); register_action(editor, window, Editor::previous_edit_prediction); - register_action(editor, window, Editor::show_inline_completion); + register_action(editor, window, Editor::show_edit_prediction); register_action(editor, window, Editor::context_menu_first); register_action(editor, window, Editor::context_menu_prev); register_action(editor, window, Editor::context_menu_next); @@ -558,7 +563,7 @@ impl EditorElement { register_action(editor, window, Editor::display_cursor_names); register_action(editor, window, Editor::unique_lines_case_insensitive); register_action(editor, window, Editor::unique_lines_case_sensitive); - register_action(editor, window, Editor::accept_partial_inline_completion); + register_action(editor, window, Editor::accept_partial_edit_prediction); register_action(editor, window, Editor::accept_edit_prediction); register_action(editor, window, Editor::restore_file); register_action(editor, window, Editor::git_restore); @@ -945,9 +950,14 @@ impl EditorElement { let hovered_link_modifier = Editor::multi_cursor_modifier(false, &event.modifiers(), cx); - if !pending_nonempty_selections && hovered_link_modifier && text_hitbox.is_hovered(window) { - let point = position_map.point_for_position(event.up.position); + if let Some(mouse_position) = event.mouse_position() + && !pending_nonempty_selections + && hovered_link_modifier + && text_hitbox.is_hovered(window) + { + let point = position_map.point_for_position(mouse_position); editor.handle_click_hovered_link(point, event.modifiers(), window, cx); + editor.selection_drag_state = SelectionDragState::None; cx.stop_propagation(); } @@ -1141,10 +1151,14 @@ impl EditorElement { .as_ref() .and_then(|state| state.popover_bounds) .map_or(false, |bounds| bounds.contains(&event.position)); + let keyboard_grace = editor + .inline_blame_popover + .as_ref() + .map_or(false, |state| state.keyboard_grace); if mouse_over_inline_blame || mouse_over_popover { - editor.show_blame_popover(&blame_entry, event.position, cx); - } else { + editor.show_blame_popover(&blame_entry, event.position, false, cx); + } else if !keyboard_grace { editor.hide_blame_popover(cx); } } else { @@ -2084,7 +2098,7 @@ impl EditorElement { row_block_types: &HashMap<DisplayRow, bool>, content_origin: gpui::Point<Pixels>, scroll_pixel_position: gpui::Point<Pixels>, - inline_completion_popover_origin: Option<gpui::Point<Pixels>>, + edit_prediction_popover_origin: Option<gpui::Point<Pixels>>, start_row: DisplayRow, end_row: DisplayRow, line_height: Pixels, @@ -2093,16 +2107,19 @@ impl EditorElement { window: &mut Window, cx: &mut App, ) -> HashMap<DisplayRow, AnyElement> { - if self.editor.read(cx).mode().is_minimap() { - return HashMap::default(); - } - - let max_severity = match ProjectSettings::get_global(cx) - .diagnostics - .inline - .max_severity - .unwrap_or_else(|| self.editor.read(cx).diagnostics_max_severity) - .into_lsp() + let max_severity = match self + .editor + .read(cx) + .inline_diagnostics_enabled() + .then(|| { + ProjectSettings::get_global(cx) + .diagnostics + .inline + .max_severity + .unwrap_or_else(|| self.editor.read(cx).diagnostics_max_severity) + .into_lsp() + }) + .flatten() { Some(max_severity) => max_severity, None => return HashMap::default(), @@ -2198,12 +2215,13 @@ impl EditorElement { cmp::max(padded_line, min_start) }; - let behind_inline_completion_popover = inline_completion_popover_origin - .as_ref() - .map_or(false, |inline_completion_popover_origin| { - (pos_y..pos_y + line_height).contains(&inline_completion_popover_origin.y) - }); - let opacity = if behind_inline_completion_popover { + let behind_edit_prediction_popover = edit_prediction_popover_origin.as_ref().map_or( + false, + |edit_prediction_popover_origin| { + (pos_y..pos_y + line_height).contains(&edit_prediction_popover_origin.y) + }, + ); + let opacity = if behind_edit_prediction_popover { 0.5 } else { 1.0 @@ -2415,9 +2433,9 @@ impl EditorElement { let mut padding = INLINE_BLAME_PADDING_EM_WIDTHS; - if let Some(inline_completion) = editor.active_inline_completion.as_ref() { - match &inline_completion.completion { - InlineCompletion::Edit { + if let Some(edit_prediction) = editor.active_edit_prediction.as_ref() { + match &edit_prediction.completion { + EditPrediction::Edit { display_mode: EditDisplayMode::TabAccept, .. } => padding += INLINE_ACCEPT_SUGGESTION_EM_WIDTHS, @@ -2618,9 +2636,6 @@ impl EditorElement { window: &mut Window, cx: &mut App, ) -> Option<Vec<IndentGuideLayout>> { - if self.editor.read(cx).mode().is_minimap() { - return None; - } let indent_guides = self.editor.update(cx, |editor, cx| { editor.indent_guides(visible_buffer_range, snapshot, cx) })?; @@ -3084,9 +3099,9 @@ impl EditorElement { window: &mut Window, cx: &mut App, ) -> Arc<HashMap<MultiBufferRow, LineNumberLayout>> { - let include_line_numbers = snapshot.show_line_numbers.unwrap_or_else(|| { - EditorSettings::get_global(cx).gutter.line_numbers && snapshot.mode.is_full() - }); + let include_line_numbers = snapshot + .show_line_numbers + .unwrap_or_else(|| EditorSettings::get_global(cx).gutter.line_numbers); if !include_line_numbers { return Arc::default(); } @@ -3399,22 +3414,18 @@ impl EditorElement { div() .size_full() - .children( - (!snapshot.mode.is_minimap() || custom.render_in_minimap).then(|| { - custom.render(&mut BlockContext { - window, - app: cx, - anchor_x, - margins: editor_margins, - line_height, - em_width, - block_id, - selected, - max_width: text_hitbox.size.width.max(*scroll_width), - editor_style: &self.style, - }) - }), - ) + .child(custom.render(&mut BlockContext { + window, + app: cx, + anchor_x, + margins: editor_margins, + line_height, + em_width, + block_id, + selected, + max_width: text_hitbox.size.width.max(*scroll_width), + editor_style: &self.style, + })) .into_any() } @@ -3620,24 +3631,37 @@ impl EditorElement { .tooltip({ let focus_handle = focus_handle.clone(); move |window, cx| { - Tooltip::for_action_in( + Tooltip::with_meta_in( "Toggle Excerpt Fold", - &ToggleFold, + Some(&ToggleFold), + "Alt+click to toggle all", &focus_handle, window, cx, ) } }) - .on_click(move |_, _, cx| { - if is_folded { + .on_click(move |event, window, cx| { + if event.modifiers().alt { + // Alt+click toggles all buffers editor.update(cx, |editor, cx| { - editor.unfold_buffer(buffer_id, cx); + editor.toggle_fold_all( + &ToggleFoldAll, + window, + cx, + ); }); } else { - editor.update(cx, |editor, cx| { - editor.fold_buffer(buffer_id, cx); - }); + // Regular click toggles single buffer + if is_folded { + editor.update(cx, |editor, cx| { + editor.unfold_buffer(buffer_id, cx); + }); + } else { + editor.update(cx, |editor, cx| { + editor.fold_buffer(buffer_id, cx); + }); + } } }), ), @@ -3658,6 +3682,7 @@ impl EditorElement { .id("path header block") .size_full() .justify_between() + .overflow_hidden() .child( h_flex() .gap_2() @@ -3716,7 +3741,7 @@ impl EditorElement { move |editor, e: &ClickEvent, window, cx| { editor.open_excerpts_common( Some(jump_data.clone()), - e.down.modifiers.secondary(), + e.modifiers().secondary(), window, cx, ); @@ -3993,6 +4018,7 @@ impl EditorElement { let available_width = hitbox.bounds.size.width - right_margin; let mut header = v_flex() + .w_full() .relative() .child( div() @@ -4067,8 +4093,7 @@ impl EditorElement { { let editor = self.editor.read(cx); - if editor - .edit_prediction_visible_in_cursor_popover(editor.has_active_inline_completion()) + if editor.edit_prediction_visible_in_cursor_popover(editor.has_active_edit_prediction()) { height_above_menu += editor.edit_prediction_cursor_popover_height() + POPOVER_Y_PADDING; @@ -6657,14 +6682,14 @@ impl EditorElement { } } - fn paint_inline_completion_popover( + fn paint_edit_prediction_popover( &mut self, layout: &mut EditorLayout, window: &mut Window, cx: &mut App, ) { - if let Some(inline_completion_popover) = layout.inline_completion_popover.as_mut() { - inline_completion_popover.paint(window, cx); + if let Some(edit_prediction_popover) = layout.edit_prediction_popover.as_mut() { + edit_prediction_popover.paint(window, cx); } } @@ -6762,7 +6787,7 @@ impl EditorElement { } fn paint_mouse_listeners(&mut self, layout: &EditorLayout, window: &mut Window, cx: &mut App) { - if self.editor.read(cx).mode.is_minimap() { + if layout.mode.is_minimap() { return; } @@ -6863,10 +6888,10 @@ impl EditorElement { // Fire click handlers during the bubble phase. DispatchPhase::Bubble => editor.update(cx, |editor, cx| { if let Some(mouse_down) = captured_mouse_down.take() { - let event = ClickEvent { + let event = ClickEvent::Mouse(MouseClickEvent { down: mouse_down, up: event.clone(), - }; + }); Self::click(editor, &event, &position_map, window, cx); } }), @@ -7777,46 +7802,13 @@ impl Element for EditorElement { editor.set_style(self.style.clone(), window, cx); let layout_id = match editor.mode { - EditorMode::SingleLine { auto_width } => { + EditorMode::SingleLine => { let rem_size = window.rem_size(); - let height = self.style.text.line_height_in_pixels(rem_size); - if auto_width { - let editor_handle = cx.entity().clone(); - let style = self.style.clone(); - window.request_measured_layout( - Style::default(), - move |_, _, window, cx| { - let editor_snapshot = editor_handle - .update(cx, |editor, cx| editor.snapshot(window, cx)); - let line = Self::layout_lines( - DisplayRow(0)..DisplayRow(1), - &editor_snapshot, - &style, - px(f32::MAX), - |_| false, // Single lines never soft wrap - window, - cx, - ) - .pop() - .unwrap(); - - let font_id = - window.text_system().resolve_font(&style.text.font()); - let font_size = - style.text.font_size.to_pixels(window.rem_size()); - let em_width = - window.text_system().em_width(font_id, font_size).unwrap(); - - size(line.width + em_width, height) - }, - ) - } else { - let mut style = Style::default(); - style.size.height = height.into(); - style.size.width = relative(1.).into(); - window.request_layout(style, None, cx) - } + let mut style = Style::default(); + style.size.height = height.into(); + style.size.width = relative(1.).into(); + window.request_layout(style, None, cx) } EditorMode::AutoHeight { min_lines, @@ -7889,9 +7881,14 @@ impl Element for EditorElement { line_height: Some(self.style.text.line_height), ..Default::default() }; - let focus_handle = self.editor.focus_handle(cx); - window.set_view_id(self.editor.entity_id()); - window.set_focus_handle(&focus_handle, cx); + + let is_minimap = self.editor.read(cx).mode.is_minimap(); + + if !is_minimap { + let focus_handle = self.editor.focus_handle(cx); + window.set_view_id(self.editor.entity_id()); + window.set_focus_handle(&focus_handle, cx); + } let rem_size = self.rem_size(cx); window.with_rem_size(rem_size, |window| { @@ -7953,17 +7950,11 @@ impl Element for EditorElement { right: right_margin, }; - // Offset the content_bounds from the text_bounds by the gutter margin (which - // is roughly half a character wide) to make hit testing work more like how we want. - let content_offset = point(editor_margins.gutter.margin, Pixels::ZERO); - - let editor_content_width = editor_width - content_offset.x; - snapshot = self.editor.update(cx, |editor, cx| { editor.last_bounds = Some(bounds); editor.gutter_dimensions = gutter_dimensions; editor.set_visible_line_count(bounds.size.height / line_height, window, cx); - editor.set_visible_column_count(editor_content_width / em_advance); + editor.set_visible_column_count(editor_width / em_advance); if matches!( editor.mode, @@ -7975,10 +7966,10 @@ impl Element for EditorElement { let wrap_width = match editor.soft_wrap_mode(cx) { SoftWrap::GitDiff => None, SoftWrap::None => Some(wrap_width_for(MAX_LINE_LEN as u32 / 2)), - SoftWrap::EditorWidth => Some(editor_content_width), + SoftWrap::EditorWidth => Some(editor_width), SoftWrap::Column(column) => Some(wrap_width_for(column)), SoftWrap::Bounded(column) => { - Some(editor_content_width.min(wrap_width_for(column))) + Some(editor_width.min(wrap_width_for(column))) } }; @@ -8003,13 +7994,12 @@ impl Element for EditorElement { HitboxBehavior::Normal, ); + // Offset the content_bounds from the text_bounds by the gutter margin (which + // is roughly half a character wide) to make hit testing work more like how we want. + let content_offset = point(editor_margins.gutter.margin, Pixels::ZERO); let content_origin = text_hitbox.origin + content_offset; - let editor_text_bounds = - Bounds::from_corners(content_origin, bounds.bottom_right()); - - let height_in_lines = editor_text_bounds.size.height / line_height; - + let height_in_lines = bounds.size.height / line_height; let max_row = snapshot.max_point().row().as_f32(); // The max scroll position for the top of the window @@ -8035,23 +8025,25 @@ impl Element for EditorElement { } }; - // TODO: Autoscrolling for both axes - let mut autoscroll_request = None; - let mut autoscroll_containing_element = false; - let mut autoscroll_horizontally = false; - self.editor.update(cx, |editor, cx| { - autoscroll_request = editor.autoscroll_request(); - autoscroll_containing_element = + let ( + autoscroll_request, + autoscroll_containing_element, + needs_horizontal_autoscroll, + ) = self.editor.update(cx, |editor, cx| { + let autoscroll_request = editor.autoscroll_request(); + let autoscroll_containing_element = autoscroll_request.is_some() || editor.has_pending_selection(); - // TODO: Is this horizontal or vertical?! - autoscroll_horizontally = editor.autoscroll_vertically( - bounds, - line_height, - max_scroll_top, - window, - cx, - ); - snapshot = editor.snapshot(window, cx); + + let (needs_horizontal_autoscroll, was_scrolled) = editor + .autoscroll_vertically(bounds, line_height, max_scroll_top, window, cx); + if was_scrolled.0 { + snapshot = editor.snapshot(window, cx); + } + ( + autoscroll_request, + autoscroll_containing_element, + needs_horizontal_autoscroll, + ) }); let mut scroll_position = snapshot.scroll_position(); @@ -8327,18 +8319,22 @@ impl Element for EditorElement { window, cx, ); - let new_renrerer_widths = line_layouts - .iter() - .flat_map(|layout| &layout.fragments) - .filter_map(|fragment| { - if let LineFragment::Element { id, size, .. } = fragment { - Some((*id, size.width)) - } else { - None - } - }); - if self.editor.update(cx, |editor, cx| { - editor.update_renderer_widths(new_renrerer_widths, cx) + let new_renderer_widths = (!is_minimap).then(|| { + line_layouts + .iter() + .flat_map(|layout| &layout.fragments) + .filter_map(|fragment| { + if let LineFragment::Element { id, size, .. } = fragment { + Some((*id, size.width)) + } else { + None + } + }) + }); + if new_renderer_widths.is_some_and(|new_renderer_widths| { + self.editor.update(cx, |editor, cx| { + editor.update_renderer_widths(new_renderer_widths, cx) + }) }) { // If the fold widths have changed, we need to prepaint // the element again to account for any changes in @@ -8387,7 +8383,6 @@ impl Element for EditorElement { glyph_grid_cell, size(longest_line_width, max_row.as_f32() * line_height), longest_line_blame_width, - editor_width, EditorSettings::get_global(cx), ); @@ -8401,27 +8396,31 @@ impl Element for EditorElement { let sticky_header_excerpt_id = sticky_header_excerpt.as_ref().map(|top| top.excerpt.id); - let blocks = window.with_element_namespace("blocks", |window| { - self.render_blocks( - start_row..end_row, - &snapshot, - &hitbox, - &text_hitbox, - editor_width, - &mut scroll_width, - &editor_margins, - em_width, - gutter_dimensions.full_width(), - line_height, - &mut line_layouts, - &local_selections, - &selected_buffer_ids, - is_row_soft_wrapped, - sticky_header_excerpt_id, - window, - cx, - ) - }); + let blocks = (!is_minimap) + .then(|| { + window.with_element_namespace("blocks", |window| { + self.render_blocks( + start_row..end_row, + &snapshot, + &hitbox, + &text_hitbox, + editor_width, + &mut scroll_width, + &editor_margins, + em_width, + gutter_dimensions.full_width(), + line_height, + &mut line_layouts, + &local_selections, + &selected_buffer_ids, + is_row_soft_wrapped, + sticky_header_excerpt_id, + window, + cx, + ) + }) + }) + .unwrap_or_else(|| Ok((Vec::default(), HashMap::default()))); let (mut blocks, row_block_types) = match blocks { Ok(blocks) => blocks, Err(resized_blocks) => { @@ -8455,30 +8454,27 @@ impl Element for EditorElement { MultiBufferRow(end_anchor.to_point(&snapshot.buffer_snapshot).row); let scroll_max = point( - ((scroll_width - editor_content_width) / em_advance).max(0.0), + ((scroll_width - editor_width) / em_advance).max(0.0), max_scroll_top, ); self.editor.update(cx, |editor, cx| { - let clamped = editor.scroll_manager.clamp_scroll_left(scroll_max.x); + if editor.scroll_manager.clamp_scroll_left(scroll_max.x) { + scroll_position.x = scroll_position.x.min(scroll_max.x); + } - let autoscrolled = if autoscroll_horizontally { - editor.autoscroll_horizontally( + if needs_horizontal_autoscroll.0 + && let Some(new_scroll_position) = editor.autoscroll_horizontally( start_row, - editor_content_width, + editor_width, scroll_width, em_advance, &line_layouts, window, cx, ) - } else { - false - }; - - if clamped || autoscrolled { - snapshot = editor.snapshot(window, cx); - scroll_position = snapshot.scroll_position(); + { + scroll_position = new_scroll_position; } }); @@ -8511,7 +8507,7 @@ impl Element for EditorElement { ) }); - let (inline_completion_popover, inline_completion_popover_origin) = self + let (edit_prediction_popover, edit_prediction_popover_origin) = self .editor .update(cx, |editor, cx| { editor.render_edit_prediction_popover( @@ -8540,7 +8536,7 @@ impl Element for EditorElement { &row_block_types, content_origin, scroll_pixel_position, - inline_completion_popover_origin, + edit_prediction_popover_origin, start_row, end_row, line_height, @@ -8593,7 +8589,9 @@ impl Element for EditorElement { } } else { log::error!( - "bug: line_ix {} is out of bounds - row_infos.len(): {}, line_layouts.len(): {}, crease_trailers.len(): {}", + "bug: line_ix {} is out of bounds - row_infos.len(): {}, \ + line_layouts.len(): {}, \ + crease_trailers.len(): {}", line_ix, row_infos.len(), line_layouts.len(), @@ -8839,7 +8837,7 @@ impl Element for EditorElement { underline: None, strikethrough: None, }], - None + None, ); let space_invisible = window.text_system().shape_line( "•".into(), @@ -8852,7 +8850,7 @@ impl Element for EditorElement { underline: None, strikethrough: None, }], - None + None, ); let mode = snapshot.mode.clone(); @@ -8927,7 +8925,7 @@ impl Element for EditorElement { cursors, visible_cursors, selections, - inline_completion_popover, + edit_prediction_popover, diff_hunk_controls, mouse_context_menu, test_indicators, @@ -8954,19 +8952,21 @@ impl Element for EditorElement { window: &mut Window, cx: &mut App, ) { - let focus_handle = self.editor.focus_handle(cx); - let key_context = self - .editor - .update(cx, |editor, cx| editor.key_context(window, cx)); - - window.set_key_context(key_context); - window.handle_input( - &focus_handle, - ElementInputHandler::new(bounds, self.editor.clone()), - cx, - ); - self.register_actions(window, cx); - self.register_key_listeners(window, cx, layout); + if !layout.mode.is_minimap() { + let focus_handle = self.editor.focus_handle(cx); + let key_context = self + .editor + .update(cx, |editor, cx| editor.key_context(window, cx)); + + window.set_key_context(key_context); + window.handle_input( + &focus_handle, + ElementInputHandler::new(bounds, self.editor.clone()), + cx, + ); + self.register_actions(window, cx); + self.register_key_listeners(window, cx, layout); + } let text_style = TextStyleRefinement { font_size: Some(self.style.text.font_size), @@ -9007,7 +9007,7 @@ impl Element for EditorElement { self.paint_minimap(layout, window, cx); self.paint_scrollbars(layout, window, cx); - self.paint_inline_completion_popover(layout, window, cx); + self.paint_edit_prediction_popover(layout, window, cx); self.paint_mouse_context_menu(layout, window, cx); }); }) @@ -9047,7 +9047,6 @@ impl ScrollbarLayoutInformation { glyph_grid_cell: Size<Pixels>, document_size: Size<Pixels>, longest_line_blame_width: Pixels, - editor_width: Pixels, settings: &EditorSettings, ) -> Self { let vertical_overscroll = match settings.scroll_beyond_last_line { @@ -9058,19 +9057,11 @@ impl ScrollbarLayoutInformation { } }; - let right_margin = if document_size.width + longest_line_blame_width >= editor_width { - glyph_grid_cell.width - } else { - px(0.0) - }; - - let overscroll = size(right_margin + longest_line_blame_width, vertical_overscroll); - - let scroll_range = document_size + overscroll; + let overscroll = size(longest_line_blame_width, vertical_overscroll); ScrollbarLayoutInformation { editor_bounds, - scroll_range, + scroll_range: document_size + overscroll, glyph_grid_cell, } } @@ -9117,7 +9108,7 @@ pub struct EditorLayout { expand_toggles: Vec<Option<(AnyElement, gpui::Point<Pixels>)>>, diff_hunk_controls: Vec<AnyElement>, crease_trailers: Vec<Option<CreaseTrailerLayout>>, - inline_completion_popover: Option<AnyElement>, + edit_prediction_popover: Option<AnyElement>, mouse_context_menu: Option<AnyElement>, tab_invisible: ShapedLine, space_invisible: ShapedLine, @@ -9175,7 +9166,7 @@ struct EditorScrollbars { impl EditorScrollbars { pub fn from_scrollbar_axes( - settings_visibility: ScrollbarAxes, + show_scrollbar: ScrollbarAxes, layout_information: &ScrollbarLayoutInformation, content_offset: gpui::Point<Pixels>, scroll_position: gpui::Point<f32>, @@ -9213,22 +9204,13 @@ impl EditorScrollbars { }; let mut create_scrollbar_layout = |axis| { - settings_visibility - .along(axis) + let viewport_size = viewport_size.along(axis); + let scroll_range = scroll_range.along(axis); + + // We always want a vertical scrollbar track for scrollbar diagnostic visibility. + (show_scrollbar.along(axis) + && (axis == ScrollbarAxis::Vertical || scroll_range > viewport_size)) .then(|| { - ( - viewport_size.along(axis) - content_offset.along(axis), - scroll_range.along(axis), - ) - }) - .filter(|(viewport_size, scroll_range)| { - // The scrollbar should only be rendered if the content does - // not entirely fit into the editor - // However, this only applies to the horizontal scrollbar, as information about the - // vertical scrollbar layout is always needed for scrollbar diagnostics. - axis != ScrollbarAxis::Horizontal || viewport_size < scroll_range - }) - .map(|(viewport_size, scroll_range)| { ScrollbarLayout::new( window.insert_hitbox(scrollbar_bounds_for(axis), HitboxBehavior::Normal), viewport_size, @@ -10275,7 +10257,6 @@ mod tests { height: Some(3), render: Arc::new(|cx| div().h(3. * cx.window.line_height()).into_any()), priority: 0, - render_in_minimap: true, }], None, cx, @@ -10365,7 +10346,7 @@ mod tests { }); for editor_mode_without_invisibles in [ - EditorMode::SingleLine { auto_width: false }, + EditorMode::SingleLine, EditorMode::AutoHeight { min_lines: 1, max_lines: Some(100), diff --git a/crates/editor/src/git/blame.rs b/crates/editor/src/git/blame.rs index d4c9e37895444aba8045aec94f2c10af9df72a55..fc350a5a15b4f7b105872e61e5a2401d183c1a6d 100644 --- a/crates/editor/src/git/blame.rs +++ b/crates/editor/src/git/blame.rs @@ -296,7 +296,7 @@ impl GitBlame { let row = info .buffer_row .filter(|_| info.buffer_id == Some(buffer_id))?; - cursor.seek_forward(&row, Bias::Right, &()); + cursor.seek_forward(&row, Bias::Right); cursor.item()?.blame.clone() }) } @@ -389,7 +389,7 @@ impl GitBlame { } } - new_entries.append(cursor.slice(&edit.old.start, Bias::Right, &()), &()); + new_entries.append(cursor.slice(&edit.old.start, Bias::Right), &()); if edit.new.start > new_entries.summary().rows { new_entries.push( @@ -401,7 +401,7 @@ impl GitBlame { ); } - cursor.seek(&edit.old.end, Bias::Right, &()); + cursor.seek(&edit.old.end, Bias::Right); if !edit.new.is_empty() { new_entries.push( GitBlameEntry { @@ -412,7 +412,7 @@ impl GitBlame { ); } - let old_end = cursor.end(&()); + let old_end = cursor.end(); if row_edits .peek() .map_or(true, |next_edit| next_edit.old.start >= old_end) @@ -421,18 +421,18 @@ impl GitBlame { if old_end > edit.old.end { new_entries.push( GitBlameEntry { - rows: cursor.end(&()) - edit.old.end, + rows: cursor.end() - edit.old.end, blame: entry.blame.clone(), }, &(), ); } - cursor.next(&()); + cursor.next(); } } } - new_entries.append(cursor.suffix(&()), &()); + new_entries.append(cursor.suffix(), &()); drop(cursor); self.buffer_snapshot = new_snapshot; diff --git a/crates/editor/src/inlay_hint_cache.rs b/crates/editor/src/inlay_hint_cache.rs index db01cc7ad1d668520f9650c7d396156814c50ba1..60ad0e5bf6c5672a3ce651793b8f76a82ab4c0ff 100644 --- a/crates/editor/src/inlay_hint_cache.rs +++ b/crates/editor/src/inlay_hint_cache.rs @@ -3546,7 +3546,7 @@ pub mod tests { let excerpt_hints = excerpt_hints.read(); for id in &excerpt_hints.ordered_hints { let hint = &excerpt_hints.hints_by_id[id]; - let mut label = hint.text(); + let mut label = hint.text().to_string(); if hint.padding_left { label.insert(0, ' '); } diff --git a/crates/editor/src/items.rs b/crates/editor/src/items.rs index 2e4631a62b16db51476c5ce5918bdc973806381e..ca635a2132790e809258c8bd63fbd3a1c3edcdb3 100644 --- a/crates/editor/src/items.rs +++ b/crates/editor/src/items.rs @@ -813,7 +813,13 @@ impl Item for Editor { window: &mut Window, cx: &mut Context<Self>, ) -> Task<Result<()>> { - self.report_editor_event("Editor Saved", None, cx); + // Add meta data tracking # of auto saves + if options.autosave { + self.report_editor_event("Editor Autosaved", None, cx); + } else { + self.report_editor_event("Editor Saved", None, cx); + } + let buffers = self.buffer().clone().read(cx).all_buffers(); let buffers = buffers .into_iter() @@ -1220,7 +1226,20 @@ impl SerializableItem for Editor { abs_path: None, contents: None, .. - } => Task::ready(Err(anyhow!("No path or contents found for buffer"))), + } => window.spawn(cx, async move |cx| { + let buffer = project + .update(cx, |project, cx| project.create_buffer(cx))? + .await?; + + cx.update(|window, cx| { + cx.new(|cx| { + let mut editor = Editor::for_buffer(buffer, Some(project), window, cx); + + editor.read_metadata_from_db(item_id, workspace_id, window, cx); + editor + }) + }) + }), } } @@ -2092,5 +2111,38 @@ mod tests { assert!(editor.has_conflict(cx)); // The editor should have a conflict }); } + + // Test case 5: Deserialize with no path, no content, no language, and no old mtime (new, empty, unsaved buffer) + { + let project = Project::test(fs.clone(), [path!("/file.rs").as_ref()], cx).await; + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + + let workspace_id = workspace::WORKSPACE_DB.next_id().await.unwrap(); + + let item_id = 10000 as ItemId; + let serialized_editor = SerializedEditor { + abs_path: None, + contents: None, + language: None, + mtime: None, + }; + + DB.save_serialized_editor(item_id, workspace_id, serialized_editor) + .await + .unwrap(); + + let deserialized = + deserialize_editor(item_id, workspace_id, workspace, project, cx).await; + + deserialized.update(cx, |editor, cx| { + assert_eq!(editor.text(cx), ""); + assert!(!editor.is_dirty(cx)); + assert!(!editor.has_conflict(cx)); + + let buffer = editor.buffer().read(cx).as_singleton().unwrap().read(cx); + assert!(buffer.file().is_none()); + }); + } } } diff --git a/crates/editor/src/linked_editing_ranges.rs b/crates/editor/src/linked_editing_ranges.rs index 7c2672fc0da01e966a3c134402c3246a1cd1a46e..a185de33ca3c9522245de50e98ebf6d983acb4e0 100644 --- a/crates/editor/src/linked_editing_ranges.rs +++ b/crates/editor/src/linked_editing_ranges.rs @@ -95,7 +95,7 @@ pub(super) fn refresh_linked_ranges( let snapshot = buffer.read(cx).snapshot(); let buffer_id = buffer.read(cx).remote_id(); - let linked_edits_task = project.linked_edit(buffer, *start, cx); + let linked_edits_task = project.linked_edits(buffer, *start, cx); let highlights = move || async move { let edits = linked_edits_task.await.log_err()?; // Find the range containing our current selection. diff --git a/crates/editor/src/lsp_colors.rs b/crates/editor/src/lsp_colors.rs index ce07dd43fe8ffc2a3705eefec96e2382312301a7..08cf9078f2301e84ec96b49cbc1abb16eb611d68 100644 --- a/crates/editor/src/lsp_colors.rs +++ b/crates/editor/src/lsp_colors.rs @@ -6,7 +6,7 @@ use gpui::{Hsla, Rgba}; use itertools::Itertools; use language::point_from_lsp; use multi_buffer::Anchor; -use project::{DocumentColor, lsp_store::ColorFetchStrategy}; +use project::{DocumentColor, lsp_store::LspFetchStrategy}; use settings::Settings as _; use text::{Bias, BufferId, OffsetRangeExt as _}; use ui::{App, Context, Window}; @@ -180,9 +180,9 @@ impl Editor { .filter_map(|buffer| { let buffer_id = buffer.read(cx).remote_id(); let fetch_strategy = if ignore_cache { - ColorFetchStrategy::IgnoreCache + LspFetchStrategy::IgnoreCache } else { - ColorFetchStrategy::UseCache { + LspFetchStrategy::UseCache { known_cache_version: self.colors.as_ref().and_then(|colors| { Some(colors.buffer_colors.get(&buffer_id)?.cache_version_used) }), diff --git a/crates/editor/src/lsp_ext.rs b/crates/editor/src/lsp_ext.rs index 8d078f304ca9fdc2a3d9371762adb7dc72a65ca1..6161afbbc0377d377e352f357b5a0ea6b0606770 100644 --- a/crates/editor/src/lsp_ext.rs +++ b/crates/editor/src/lsp_ext.rs @@ -3,9 +3,8 @@ use std::time::Duration; use crate::Editor; use collections::HashMap; -use futures::stream::FuturesUnordered; use gpui::AsyncApp; -use gpui::{App, AppContext as _, Entity, Task}; +use gpui::{App, Entity, Task}; use itertools::Itertools; use language::Buffer; use language::Language; @@ -18,7 +17,6 @@ use project::Project; use project::TaskSourceKind; use project::lsp_store::lsp_ext_command::GetLspRunnables; use smol::future::FutureExt as _; -use smol::stream::StreamExt; use task::ResolvedTask; use task::TaskContext; use text::BufferId; @@ -29,52 +27,32 @@ pub(crate) fn find_specific_language_server_in_selection<F>( editor: &Editor, cx: &mut App, filter_language: F, - language_server_name: &str, -) -> Task<Option<(Anchor, Arc<Language>, LanguageServerId, Entity<Buffer>)>> + language_server_name: LanguageServerName, +) -> Option<(Anchor, Arc<Language>, LanguageServerId, Entity<Buffer>)> where F: Fn(&Language) -> bool, { - let Some(project) = &editor.project else { - return Task::ready(None); - }; - - let applicable_buffers = editor + let project = editor.project.clone()?; + editor .selections .disjoint_anchors() .iter() .filter_map(|selection| Some((selection.head(), selection.head().buffer_id?))) .unique_by(|(_, buffer_id)| *buffer_id) - .filter_map(|(trigger_anchor, buffer_id)| { + .find_map(|(trigger_anchor, buffer_id)| { let buffer = editor.buffer().read(cx).buffer(buffer_id)?; let language = buffer.read(cx).language_at(trigger_anchor.text_anchor)?; if filter_language(&language) { - Some((trigger_anchor, buffer, language)) + let server_id = buffer.update(cx, |buffer, cx| { + project + .read(cx) + .language_server_id_for_name(buffer, &language_server_name, cx) + })?; + Some((trigger_anchor, language, server_id, buffer)) } else { None } }) - .collect::<Vec<_>>(); - - let applicable_buffer_tasks = applicable_buffers - .into_iter() - .map(|(trigger_anchor, buffer, language)| { - let task = buffer.update(cx, |buffer, cx| { - project.update(cx, |project, cx| { - project.language_server_id_for_name(buffer, language_server_name, cx) - }) - }); - (trigger_anchor, buffer, language, task) - }) - .collect::<Vec<_>>(); - cx.background_spawn(async move { - for (trigger_anchor, buffer, language, task) in applicable_buffer_tasks { - if let Some(server_id) = task.await { - return Some((trigger_anchor, language, server_id, buffer)); - } - } - - None - }) } async fn lsp_task_context( @@ -116,9 +94,9 @@ pub fn lsp_tasks( for_position: Option<text::Anchor>, cx: &mut App, ) -> Task<Vec<(TaskSourceKind, Vec<(Option<LocationLink>, ResolvedTask)>)>> { - let mut lsp_task_sources = task_sources + let lsp_task_sources = task_sources .iter() - .map(|(name, buffer_ids)| { + .filter_map(|(name, buffer_ids)| { let buffers = buffer_ids .iter() .filter(|&&buffer_id| match for_position { @@ -127,61 +105,63 @@ pub fn lsp_tasks( }) .filter_map(|&buffer_id| project.read(cx).buffer_for_id(buffer_id, cx)) .collect::<Vec<_>>(); - language_server_for_buffers(project.clone(), name.clone(), buffers, cx) + + let server_id = buffers.iter().find_map(|buffer| { + project.read_with(cx, |project, cx| { + project.language_server_id_for_name(buffer.read(cx), name, cx) + }) + }); + server_id.zip(Some(buffers)) }) - .collect::<FuturesUnordered<_>>(); + .collect::<Vec<_>>(); cx.spawn(async move |cx| { cx.spawn(async move |cx| { let mut lsp_tasks = HashMap::default(); - while let Some(server_to_query) = lsp_task_sources.next().await { - if let Some((server_id, buffers)) = server_to_query { - let mut new_lsp_tasks = Vec::new(); - for buffer in buffers { - let source_kind = match buffer.update(cx, |buffer, _| { - buffer.language().map(|language| language.name()) - }) { - Ok(Some(language_name)) => TaskSourceKind::Lsp { - server: server_id, - language_name: SharedString::from(language_name), - }, - Ok(None) => continue, - Err(_) => return Vec::new(), - }; - let id_base = source_kind.to_id_base(); - let lsp_buffer_context = lsp_task_context(&project, &buffer, cx) - .await - .unwrap_or_default(); + for (server_id, buffers) in lsp_task_sources { + let mut new_lsp_tasks = Vec::new(); + for buffer in buffers { + let source_kind = match buffer.update(cx, |buffer, _| { + buffer.language().map(|language| language.name()) + }) { + Ok(Some(language_name)) => TaskSourceKind::Lsp { + server: server_id, + language_name: SharedString::from(language_name), + }, + Ok(None) => continue, + Err(_) => return Vec::new(), + }; + let id_base = source_kind.to_id_base(); + let lsp_buffer_context = lsp_task_context(&project, &buffer, cx) + .await + .unwrap_or_default(); - if let Ok(runnables_task) = project.update(cx, |project, cx| { - let buffer_id = buffer.read(cx).remote_id(); - project.request_lsp( - buffer, - LanguageServerToQuery::Other(server_id), - GetLspRunnables { - buffer_id, - position: for_position, + if let Ok(runnables_task) = project.update(cx, |project, cx| { + let buffer_id = buffer.read(cx).remote_id(); + project.request_lsp( + buffer, + LanguageServerToQuery::Other(server_id), + GetLspRunnables { + buffer_id, + position: for_position, + }, + cx, + ) + }) { + if let Some(new_runnables) = runnables_task.await.log_err() { + new_lsp_tasks.extend(new_runnables.runnables.into_iter().filter_map( + |(location, runnable)| { + let resolved_task = + runnable.resolve_task(&id_base, &lsp_buffer_context)?; + Some((location, resolved_task)) }, - cx, - ) - }) { - if let Some(new_runnables) = runnables_task.await.log_err() { - new_lsp_tasks.extend( - new_runnables.runnables.into_iter().filter_map( - |(location, runnable)| { - let resolved_task = runnable - .resolve_task(&id_base, &lsp_buffer_context)?; - Some((location, resolved_task)) - }, - ), - ); - } + )); } - lsp_tasks - .entry(source_kind) - .or_insert_with(Vec::new) - .append(&mut new_lsp_tasks); } + lsp_tasks + .entry(source_kind) + .or_insert_with(Vec::new) + .append(&mut new_lsp_tasks); } } lsp_tasks.into_iter().collect() @@ -198,27 +178,3 @@ pub fn lsp_tasks( .await }) } - -fn language_server_for_buffers( - project: Entity<Project>, - name: LanguageServerName, - candidates: Vec<Entity<Buffer>>, - cx: &mut App, -) -> Task<Option<(LanguageServerId, Vec<Entity<Buffer>>)>> { - cx.spawn(async move |cx| { - for buffer in &candidates { - let server_id = buffer - .update(cx, |buffer, cx| { - project.update(cx, |project, cx| { - project.language_server_id_for_name(buffer, &name.0, cx) - }) - }) - .ok()? - .await; - if let Some(server_id) = server_id { - return Some((server_id, candidates)); - } - } - None - }) -} diff --git a/crates/editor/src/mouse_context_menu.rs b/crates/editor/src/mouse_context_menu.rs index cbb6791a2f0c7bba9fa0da9774d71eedd78f2c55..9d5145dec1f380013fbf76776efd077d7b466a37 100644 --- a/crates/editor/src/mouse_context_menu.rs +++ b/crates/editor/src/mouse_context_menu.rs @@ -1,8 +1,8 @@ use crate::{ Copy, CopyAndTrim, CopyPermalinkToLine, Cut, DisplayPoint, DisplaySnapshot, Editor, EvaluateSelectedText, FindAllReferences, GoToDeclaration, GoToDefinition, GoToImplementation, - GoToTypeDefinition, Paste, Rename, RevealInFileManager, SelectMode, SelectionEffects, - SelectionExt, ToDisplayPoint, ToggleCodeActions, + GoToTypeDefinition, Paste, Rename, RevealInFileManager, RunToCursor, SelectMode, + SelectionEffects, SelectionExt, ToDisplayPoint, ToggleCodeActions, actions::{Format, FormatSelections}, selections_collection::SelectionsCollection, }; @@ -200,15 +200,21 @@ pub fn deploy_context_menu( }); let evaluate_selection = window.is_action_available(&EvaluateSelectedText, cx); + let run_to_cursor = window.is_action_available(&RunToCursor, cx); ui::ContextMenu::build(window, cx, |menu, _window, _cx| { let builder = menu .on_blur_subscription(Subscription::new(|| {})) + .when(run_to_cursor, |builder| { + builder.action("Run to Cursor", Box::new(RunToCursor)) + }) .when(evaluate_selection && has_selections, |builder| { - builder - .action("Evaluate Selection", Box::new(EvaluateSelectedText)) - .separator() + builder.action("Evaluate Selection", Box::new(EvaluateSelectedText)) }) + .when( + run_to_cursor || (evaluate_selection && has_selections), + |builder| builder.separator(), + ) .action("Go to Definition", Box::new(GoToDefinition)) .action("Go to Declaration", Box::new(GoToDeclaration)) .action("Go to Type Definition", Box::new(GoToTypeDefinition)) diff --git a/crates/editor/src/movement.rs b/crates/editor/src/movement.rs index b9b7cb2e58c56cb3b1e14e1c52aa7b8b38f510b6..a8850984a191be89400097d20c0992e3664aff44 100644 --- a/crates/editor/src/movement.rs +++ b/crates/editor/src/movement.rs @@ -907,12 +907,12 @@ mod tests { let inlays = (0..buffer_snapshot.len()) .flat_map(|offset| { [ - Inlay::inline_completion( + Inlay::edit_prediction( post_inc(&mut id), buffer_snapshot.anchor_at(offset, Bias::Left), "test", ), - Inlay::inline_completion( + Inlay::edit_prediction( post_inc(&mut id), buffer_snapshot.anchor_at(offset, Bias::Right), "test", diff --git a/crates/editor/src/rust_analyzer_ext.rs b/crates/editor/src/rust_analyzer_ext.rs index da0f11036ff683a59a658b0b22139809d393d7ed..2b8150de67050ccced22100bfedd02be44f63907 100644 --- a/crates/editor/src/rust_analyzer_ext.rs +++ b/crates/editor/src/rust_analyzer_ext.rs @@ -57,21 +57,21 @@ pub fn go_to_parent_module( return; }; - let server_lookup = find_specific_language_server_in_selection( - editor, - cx, - is_rust_language, - RUST_ANALYZER_NAME, - ); + let Some((trigger_anchor, _, server_to_query, buffer)) = + find_specific_language_server_in_selection( + editor, + cx, + is_rust_language, + RUST_ANALYZER_NAME, + ) + else { + return; + }; let project = project.clone(); let lsp_store = project.read(cx).lsp_store(); let upstream_client = lsp_store.read(cx).upstream_client(); cx.spawn_in(window, async move |editor, cx| { - let Some((trigger_anchor, _, server_to_query, buffer)) = server_lookup.await else { - return anyhow::Ok(()); - }; - let location_links = if let Some((client, project_id)) = upstream_client { let buffer_id = buffer.read_with(cx, |buffer, _| buffer.remote_id())?; @@ -121,7 +121,7 @@ pub fn go_to_parent_module( ) })? .await?; - Ok(()) + anyhow::Ok(()) }) .detach_and_log_err(cx); } @@ -139,21 +139,19 @@ pub fn expand_macro_recursively( return; }; - let server_lookup = find_specific_language_server_in_selection( - editor, - cx, - is_rust_language, - RUST_ANALYZER_NAME, - ); - + let Some((trigger_anchor, rust_language, server_to_query, buffer)) = + find_specific_language_server_in_selection( + editor, + cx, + is_rust_language, + RUST_ANALYZER_NAME, + ) + else { + return; + }; let project = project.clone(); let upstream_client = project.read(cx).lsp_store().read(cx).upstream_client(); cx.spawn_in(window, async move |_editor, cx| { - let Some((trigger_anchor, rust_language, server_to_query, buffer)) = server_lookup.await - else { - return Ok(()); - }; - let macro_expansion = if let Some((client, project_id)) = upstream_client { let buffer_id = buffer.update(cx, |buffer, _| buffer.remote_id())?; let request = proto::LspExtExpandMacro { @@ -231,20 +229,20 @@ pub fn open_docs(editor: &mut Editor, _: &OpenDocs, window: &mut Window, cx: &mu return; }; - let server_lookup = find_specific_language_server_in_selection( - editor, - cx, - is_rust_language, - RUST_ANALYZER_NAME, - ); + let Some((trigger_anchor, _, server_to_query, buffer)) = + find_specific_language_server_in_selection( + editor, + cx, + is_rust_language, + RUST_ANALYZER_NAME, + ) + else { + return; + }; let project = project.clone(); let upstream_client = project.read(cx).lsp_store().read(cx).upstream_client(); cx.spawn_in(window, async move |_editor, cx| { - let Some((trigger_anchor, _, server_to_query, buffer)) = server_lookup.await else { - return Ok(()); - }; - let docs_urls = if let Some((client, project_id)) = upstream_client { let buffer_id = buffer.read_with(cx, |buffer, _| buffer.remote_id())?; let request = proto::LspExtOpenDocs { diff --git a/crates/editor/src/scroll.rs b/crates/editor/src/scroll.rs index b3007d3091d79074b99b3fe6f2d7b00003f72015..ecaf7c11e41373c96547e13f3d4b83757e2501a8 100644 --- a/crates/editor/src/scroll.rs +++ b/crates/editor/src/scroll.rs @@ -12,7 +12,7 @@ use crate::{ }; pub use autoscroll::{Autoscroll, AutoscrollStrategy}; use core::fmt::Debug; -use gpui::{App, Axis, Context, Global, Pixels, Task, Window, point, px}; +use gpui::{Along, App, Axis, Context, Global, Pixels, Task, Window, point, px}; use language::language_settings::{AllLanguageSettings, SoftWrap}; use language::{Bias, Point}; pub use scroll_amount::ScrollAmount; @@ -27,6 +27,8 @@ use workspace::{ItemId, WorkspaceId}; pub const SCROLL_EVENT_SEPARATION: Duration = Duration::from_millis(28); const SCROLLBAR_SHOW_INTERVAL: Duration = Duration::from_secs(1); +pub struct WasScrolled(pub(crate) bool); + #[derive(Default)] pub struct ScrollbarAutoHide(pub bool); @@ -47,14 +49,14 @@ impl ScrollAnchor { } pub fn scroll_position(&self, snapshot: &DisplaySnapshot) -> gpui::Point<f32> { - let mut scroll_position = self.offset; - if self.anchor == Anchor::min() { - scroll_position.y = 0.; - } else { - let scroll_top = self.anchor.to_display_point(snapshot).row().as_f32(); - scroll_position.y += scroll_top; - } - scroll_position + self.offset.apply_along(Axis::Vertical, |offset| { + if self.anchor == Anchor::min() { + 0. + } else { + let scroll_top = self.anchor.to_display_point(snapshot).row().as_f32(); + (offset + scroll_top).max(0.) + } + }) } pub fn top_row(&self, buffer: &MultiBufferSnapshot) -> u32 { @@ -215,87 +217,56 @@ impl ScrollManager { workspace_id: Option<WorkspaceId>, window: &mut Window, cx: &mut Context<Editor>, - ) { - let (new_anchor, top_row) = if scroll_position.y <= 0. && scroll_position.x <= 0. { - ( - ScrollAnchor { - anchor: Anchor::min(), - offset: scroll_position.max(&gpui::Point::default()), - }, - 0, - ) - } else if scroll_position.y <= 0. { - let buffer_point = map - .clip_point( - DisplayPoint::new(DisplayRow(0), scroll_position.x as u32), - Bias::Left, - ) - .to_point(map); - let anchor = map.buffer_snapshot.anchor_at(buffer_point, Bias::Right); - - ( - ScrollAnchor { - anchor: anchor, - offset: scroll_position.max(&gpui::Point::default()), - }, - 0, - ) - } else { - let scroll_top = scroll_position.y; - let scroll_top = match EditorSettings::get_global(cx).scroll_beyond_last_line { - ScrollBeyondLastLine::OnePage => scroll_top, - ScrollBeyondLastLine::Off => { - if let Some(height_in_lines) = self.visible_line_count { - let max_row = map.max_point().row().0 as f32; - scroll_top.min(max_row - height_in_lines + 1.).max(0.) - } else { - scroll_top - } + ) -> WasScrolled { + let scroll_top = scroll_position.y.max(0.); + let scroll_top = match EditorSettings::get_global(cx).scroll_beyond_last_line { + ScrollBeyondLastLine::OnePage => scroll_top, + ScrollBeyondLastLine::Off => { + if let Some(height_in_lines) = self.visible_line_count { + let max_row = map.max_point().row().0 as f32; + scroll_top.min(max_row - height_in_lines + 1.).max(0.) + } else { + scroll_top } - ScrollBeyondLastLine::VerticalScrollMargin => { - if let Some(height_in_lines) = self.visible_line_count { - let max_row = map.max_point().row().0 as f32; - scroll_top - .min(max_row - height_in_lines + 1. + self.vertical_scroll_margin) - .max(0.) - } else { - scroll_top - } + } + ScrollBeyondLastLine::VerticalScrollMargin => { + if let Some(height_in_lines) = self.visible_line_count { + let max_row = map.max_point().row().0 as f32; + scroll_top + .min(max_row - height_in_lines + 1. + self.vertical_scroll_margin) + .max(0.) + } else { + scroll_top } - }; + } + }; - let scroll_top_row = DisplayRow(scroll_top as u32); - let scroll_top_buffer_point = map - .clip_point( - DisplayPoint::new(scroll_top_row, scroll_position.x as u32), - Bias::Left, - ) - .to_point(map); - let top_anchor = map - .buffer_snapshot - .anchor_at(scroll_top_buffer_point, Bias::Right); - - ( - ScrollAnchor { - anchor: top_anchor, - offset: point( - scroll_position.x.max(0.), - scroll_top - top_anchor.to_display_point(map).row().as_f32(), - ), - }, - scroll_top_buffer_point.row, + let scroll_top_row = DisplayRow(scroll_top as u32); + let scroll_top_buffer_point = map + .clip_point( + DisplayPoint::new(scroll_top_row, scroll_position.x as u32), + Bias::Left, ) - }; + .to_point(map); + let top_anchor = map + .buffer_snapshot + .anchor_at(scroll_top_buffer_point, Bias::Right); self.set_anchor( - new_anchor, - top_row, + ScrollAnchor { + anchor: top_anchor, + offset: point( + scroll_position.x.max(0.), + scroll_top - top_anchor.to_display_point(map).row().as_f32(), + ), + }, + scroll_top_buffer_point.row, local, autoscroll, workspace_id, window, cx, - ); + ) } fn set_anchor( @@ -307,7 +278,7 @@ impl ScrollManager { workspace_id: Option<WorkspaceId>, window: &mut Window, cx: &mut Context<Editor>, - ) { + ) -> WasScrolled { let adjusted_anchor = if self.forbid_vertical_scroll { ScrollAnchor { offset: gpui::Point::new(anchor.offset.x, self.anchor.offset.y), @@ -317,10 +288,14 @@ impl ScrollManager { anchor }; + self.autoscroll_request.take(); + if self.anchor == adjusted_anchor { + return WasScrolled(false); + } + self.anchor = adjusted_anchor; cx.emit(EditorEvent::ScrollPositionChanged { local, autoscroll }); self.show_scrollbars(window, cx); - self.autoscroll_request.take(); if let Some(workspace_id) = workspace_id { let item_id = cx.entity().entity_id().as_u64() as ItemId; @@ -342,6 +317,8 @@ impl ScrollManager { .detach() } cx.notify(); + + WasScrolled(true) } pub fn show_scrollbars(&mut self, window: &mut Window, cx: &mut Context<Editor>) { @@ -552,13 +529,13 @@ impl Editor { scroll_position: gpui::Point<f32>, window: &mut Window, cx: &mut Context<Self>, - ) { + ) -> WasScrolled { let mut position = scroll_position; if self.scroll_manager.forbid_vertical_scroll { let current_position = self.scroll_position(cx); position.y = current_position.y; } - self.set_scroll_position_internal(position, true, false, window, cx); + self.set_scroll_position_internal(position, true, false, window, cx) } /// Scrolls so that `row` is at the top of the editor view. @@ -590,7 +567,7 @@ impl Editor { autoscroll: bool, window: &mut Window, cx: &mut Context<Self>, - ) { + ) -> WasScrolled { let map = self.display_map.update(cx, |map, cx| map.snapshot(cx)); self.set_scroll_position_taking_display_map( scroll_position, @@ -599,7 +576,7 @@ impl Editor { map, window, cx, - ); + ) } fn set_scroll_position_taking_display_map( @@ -610,7 +587,7 @@ impl Editor { display_map: DisplaySnapshot, window: &mut Window, cx: &mut Context<Self>, - ) { + ) -> WasScrolled { hide_hover(self, cx); let workspace_id = self.workspace.as_ref().and_then(|workspace| workspace.1); @@ -624,7 +601,7 @@ impl Editor { scroll_position }; - self.scroll_manager.set_scroll_position( + let editor_was_scrolled = self.scroll_manager.set_scroll_position( adjusted_position, &display_map, local, @@ -636,6 +613,7 @@ impl Editor { self.refresh_inlay_hints(InlayHintRefreshReason::NewLinesShown, cx); self.refresh_colors(false, None, window, cx); + editor_was_scrolled } pub fn scroll_position(&self, cx: &mut Context<Self>) -> gpui::Point<f32> { diff --git a/crates/editor/src/scroll/autoscroll.rs b/crates/editor/src/scroll/autoscroll.rs index 340277633a2c63131997f9eca76316ccf6c3ad39..e8a1f8da734685f85091b3bd28a2fb1a0be89208 100644 --- a/crates/editor/src/scroll/autoscroll.rs +++ b/crates/editor/src/scroll/autoscroll.rs @@ -1,6 +1,6 @@ use crate::{ DisplayRow, Editor, EditorMode, LineWithInvisibles, RowExt, SelectionEffects, - display_map::ToDisplayPoint, + display_map::ToDisplayPoint, scroll::WasScrolled, }; use gpui::{Bounds, Context, Pixels, Window, px}; use language::Point; @@ -99,19 +99,21 @@ impl AutoscrollStrategy { } } +pub(crate) struct NeedsHorizontalAutoscroll(pub(crate) bool); + impl Editor { pub fn autoscroll_request(&self) -> Option<Autoscroll> { self.scroll_manager.autoscroll_request() } - pub fn autoscroll_vertically( + pub(crate) fn autoscroll_vertically( &mut self, bounds: Bounds<Pixels>, line_height: Pixels, max_scroll_top: f32, window: &mut Window, cx: &mut Context<Editor>, - ) -> bool { + ) -> (NeedsHorizontalAutoscroll, WasScrolled) { let viewport_height = bounds.size.height; let visible_lines = viewport_height / line_height; let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx)); @@ -129,12 +131,14 @@ impl Editor { scroll_position.y = max_scroll_top; } - if original_y != scroll_position.y { - self.set_scroll_position(scroll_position, window, cx); - } + let editor_was_scrolled = if original_y != scroll_position.y { + self.set_scroll_position(scroll_position, window, cx) + } else { + WasScrolled(false) + }; let Some((autoscroll, local)) = self.scroll_manager.autoscroll_request.take() else { - return false; + return (NeedsHorizontalAutoscroll(false), editor_was_scrolled); }; let mut target_top; @@ -212,7 +216,7 @@ impl Editor { target_bottom = target_top + 1.; } - match strategy { + let was_autoscrolled = match strategy { AutoscrollStrategy::Fit | AutoscrollStrategy::Newest => { let margin = margin.min(self.scroll_manager.vertical_scroll_margin); let target_top = (target_top - margin).max(0.0); @@ -225,39 +229,42 @@ impl Editor { if needs_scroll_up && !needs_scroll_down { scroll_position.y = target_top; - self.set_scroll_position_internal(scroll_position, local, true, window, cx); - } - if !needs_scroll_up && needs_scroll_down { + } else if !needs_scroll_up && needs_scroll_down { scroll_position.y = target_bottom - visible_lines; - self.set_scroll_position_internal(scroll_position, local, true, window, cx); + } + + if needs_scroll_up ^ needs_scroll_down { + self.set_scroll_position_internal(scroll_position, local, true, window, cx) + } else { + WasScrolled(false) } } AutoscrollStrategy::Center => { scroll_position.y = (target_top - margin).max(0.0); - self.set_scroll_position_internal(scroll_position, local, true, window, cx); + self.set_scroll_position_internal(scroll_position, local, true, window, cx) } AutoscrollStrategy::Focused => { let margin = margin.min(self.scroll_manager.vertical_scroll_margin); scroll_position.y = (target_top - margin).max(0.0); - self.set_scroll_position_internal(scroll_position, local, true, window, cx); + self.set_scroll_position_internal(scroll_position, local, true, window, cx) } AutoscrollStrategy::Top => { scroll_position.y = (target_top).max(0.0); - self.set_scroll_position_internal(scroll_position, local, true, window, cx); + self.set_scroll_position_internal(scroll_position, local, true, window, cx) } AutoscrollStrategy::Bottom => { scroll_position.y = (target_bottom - visible_lines).max(0.0); - self.set_scroll_position_internal(scroll_position, local, true, window, cx); + self.set_scroll_position_internal(scroll_position, local, true, window, cx) } AutoscrollStrategy::TopRelative(lines) => { scroll_position.y = target_top - lines as f32; - self.set_scroll_position_internal(scroll_position, local, true, window, cx); + self.set_scroll_position_internal(scroll_position, local, true, window, cx) } AutoscrollStrategy::BottomRelative(lines) => { scroll_position.y = target_bottom + lines as f32; - self.set_scroll_position_internal(scroll_position, local, true, window, cx); + self.set_scroll_position_internal(scroll_position, local, true, window, cx) } - } + }; self.scroll_manager.last_autoscroll = Some(( self.scroll_manager.anchor.offset, @@ -266,7 +273,8 @@ impl Editor { strategy, )); - true + let was_scrolled = WasScrolled(editor_was_scrolled.0 || was_autoscrolled.0); + (NeedsHorizontalAutoscroll(true), was_scrolled) } pub(crate) fn autoscroll_horizontally( @@ -278,7 +286,7 @@ impl Editor { layouts: &[LineWithInvisibles], window: &mut Window, cx: &mut Context<Self>, - ) -> bool { + ) -> Option<gpui::Point<f32>> { let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx)); let selections = self.selections.all::<Point>(cx); let mut scroll_position = self.scroll_manager.scroll_position(&display_map); @@ -319,22 +327,26 @@ impl Editor { target_right = target_right.min(scroll_width); if target_right - target_left > viewport_width { - return false; + return None; } let scroll_left = self.scroll_manager.anchor.offset.x * em_advance; let scroll_right = scroll_left + viewport_width; - if target_left < scroll_left { + let was_scrolled = if target_left < scroll_left { scroll_position.x = target_left / em_advance; - self.set_scroll_position_internal(scroll_position, true, true, window, cx); - true + self.set_scroll_position_internal(scroll_position, true, true, window, cx) } else if target_right > scroll_right { scroll_position.x = (target_right - viewport_width) / em_advance; - self.set_scroll_position_internal(scroll_position, true, true, window, cx); - true + self.set_scroll_position_internal(scroll_position, true, true, window, cx) + } else { + WasScrolled(false) + }; + + if was_scrolled.0 { + Some(scroll_position) } else { - false + None } } diff --git a/crates/editor/src/signature_help.rs b/crates/editor/src/signature_help.rs index 3447e66ccdb1ac235aa1688f658096ff26f69193..e9f8d2dbd33f71e224ae1c868dab80a7c4bb467a 100644 --- a/crates/editor/src/signature_help.rs +++ b/crates/editor/src/signature_help.rs @@ -191,7 +191,7 @@ impl Editor { if let Some(language) = language { for signature in &mut signature_help.signatures { - let text = Rope::from(signature.label.to_string()); + let text = Rope::from(signature.label.as_ref()); let highlights = language .highlight_text(&text, 0..signature.label.len()) .into_iter() diff --git a/crates/editor/src/test/editor_lsp_test_context.rs b/crates/editor/src/test/editor_lsp_test_context.rs index f7f34135f3ccd5432b088351029632acef420cc9..c59786b1eb387835a21e2c155efaf6acefd4ff4a 100644 --- a/crates/editor/src/test/editor_lsp_test_context.rs +++ b/crates/editor/src/test/editor_lsp_test_context.rs @@ -14,7 +14,8 @@ use futures::Future; use gpui::{Context, Entity, Focusable as _, VisualTestContext, Window}; use indoc::indoc; use language::{ - FakeLspAdapter, Language, LanguageConfig, LanguageMatcher, LanguageQueries, point_to_lsp, + BlockCommentConfig, FakeLspAdapter, Language, LanguageConfig, LanguageMatcher, LanguageQueries, + point_to_lsp, }; use lsp::{notification, request}; use multi_buffer::ToPointUtf16; @@ -269,7 +270,12 @@ impl EditorLspTestContext { path_suffixes: vec!["html".into()], ..Default::default() }, - block_comment: Some(("<!-- ".into(), " -->".into())), + block_comment: Some(BlockCommentConfig { + start: "<!--".into(), + prefix: "".into(), + end: "-->".into(), + tab_size: 0, + }), completion_query_characters: ['-'].into_iter().collect(), ..Default::default() }, diff --git a/crates/eval/Cargo.toml b/crates/eval/Cargo.toml index d5db7f71a4593a66ee8218c053109041035428ab..a0214c76a1c7230e071cbc65c1eadbc44c7d6ca8 100644 --- a/crates/eval/Cargo.toml +++ b/crates/eval/Cargo.toml @@ -19,8 +19,8 @@ path = "src/explorer.rs" [dependencies] agent.workspace = true -agent_ui.workspace = true agent_settings.workspace = true +agent_ui.workspace = true anyhow.workspace = true assistant_tool.workspace = true assistant_tools.workspace = true @@ -29,6 +29,7 @@ buffer_diff.workspace = true chrono.workspace = true clap.workspace = true client.workspace = true +cloud_llm_client.workspace = true collections.workspace = true debug_adapter_extension.workspace = true dirs.workspace = true @@ -68,4 +69,3 @@ util.workspace = true uuid.workspace = true watch.workspace = true workspace-hack.workspace = true -zed_llm_client.workspace = true diff --git a/crates/eval/src/eval.rs b/crates/eval/src/eval.rs index a02b4a7f0bb7f306b7c5389336c6113a7d15d096..d638ac171feafd8be72925cc26beff6726a3ab8d 100644 --- a/crates/eval/src/eval.rs +++ b/crates/eval/src/eval.rs @@ -18,7 +18,7 @@ use collections::{HashMap, HashSet}; use extension::ExtensionHostProxy; use futures::future; use gpui::http_client::read_proxy_from_env; -use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal}; +use gpui::{App, AppContext, Application, AsyncApp, Entity, UpdateGlobal}; use gpui_tokio::Tokio; use language::LanguageRegistry; use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry, SelectedModel}; @@ -337,7 +337,8 @@ pub struct AgentAppState { } pub fn init(cx: &mut App) -> Arc<AgentAppState> { - release_channel::init(SemanticVersion::default(), cx); + let app_version = AppVersion::global(cx); + release_channel::init(app_version, cx); gpui_tokio::init(cx); let mut settings_store = SettingsStore::new(cx); @@ -350,7 +351,7 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> { // Set User-Agent so we can download language servers from GitHub let user_agent = format!( "Zed/{} ({}; {})", - AppVersion::global(cx), + app_version, std::env::consts::OS, std::env::consts::ARCH ); diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index 904eca83e609dc8766fb3a5a69ed9040c82f0168..23c8814916da2df4016c4196d7767b748da54280 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -15,11 +15,11 @@ use agent_settings::AgentProfileId; use anyhow::{Result, anyhow}; use async_trait::async_trait; use buffer_diff::DiffHunkStatus; +use cloud_llm_client::CompletionIntent; use collections::HashMap; use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased}; use gpui::{App, AppContext, AsyncApp, Entity}; use language_model::{LanguageModel, Role, StopReason}; -use zed_llm_client::CompletionIntent; pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2); @@ -221,9 +221,6 @@ impl ExampleContext { ThreadEvent::ShowError(thread_error) => { tx.try_send(Err(anyhow!(thread_error.clone()))).ok(); } - ThreadEvent::RetriesFailed { .. } => { - // Ignore retries failed events - } ThreadEvent::Stopped(reason) => match reason { Ok(StopReason::EndTurn) => { tx.close_channel(); @@ -425,6 +422,13 @@ impl AppContext for ExampleContext { self.app.update_entity(handle, update) } + fn as_mut<'a, T>(&'a mut self, handle: &Entity<T>) -> Self::Result<gpui::GpuiBorrow<'a, T>> + where + T: 'static, + { + self.app.as_mut(handle) + } + fn read_entity<T, R>( &self, handle: &Entity<T>, diff --git a/crates/extension/Cargo.toml b/crates/extension/Cargo.toml index 4fc7da2dcaa6e30ac7cbcd8a16d95b1485beb27d..42189f20b3477b4581103807445a397e65dd89eb 100644 --- a/crates/extension/Cargo.toml +++ b/crates/extension/Cargo.toml @@ -32,7 +32,11 @@ serde.workspace = true serde_json.workspace = true task.workspace = true toml.workspace = true +url.workspace = true util.workspace = true wasm-encoder.workspace = true wasmparser.workspace = true workspace-hack.workspace = true + +[dev-dependencies] +pretty_assertions.workspace = true diff --git a/crates/extension/src/capabilities.rs b/crates/extension/src/capabilities.rs new file mode 100644 index 0000000000000000000000000000000000000000..b8afc4ec0694181179c5bd59d98534ac5002c137 --- /dev/null +++ b/crates/extension/src/capabilities.rs @@ -0,0 +1,20 @@ +mod download_file_capability; +mod npm_install_package_capability; +mod process_exec_capability; + +pub use download_file_capability::*; +pub use npm_install_package_capability::*; +pub use process_exec_capability::*; + +use serde::{Deserialize, Serialize}; + +/// A capability for an extension. +#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] +#[serde(tag = "kind", rename_all = "snake_case")] +pub enum ExtensionCapability { + #[serde(rename = "process:exec")] + ProcessExec(ProcessExecCapability), + DownloadFile(DownloadFileCapability), + #[serde(rename = "npm:install")] + NpmInstallPackage(NpmInstallPackageCapability), +} diff --git a/crates/extension/src/capabilities/download_file_capability.rs b/crates/extension/src/capabilities/download_file_capability.rs new file mode 100644 index 0000000000000000000000000000000000000000..a76755b593a2c42d8bbf22e8926b7409fca9061e --- /dev/null +++ b/crates/extension/src/capabilities/download_file_capability.rs @@ -0,0 +1,121 @@ +use serde::{Deserialize, Serialize}; +use url::Url; + +#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct DownloadFileCapability { + pub host: String, + pub path: Vec<String>, +} + +impl DownloadFileCapability { + /// Returns whether the capability allows downloading a file from the given URL. + pub fn allows(&self, url: &Url) -> bool { + let Some(desired_host) = url.host_str() else { + return false; + }; + + let Some(desired_path) = url.path_segments() else { + return false; + }; + let desired_path = desired_path.collect::<Vec<_>>(); + + if self.host != desired_host && self.host != "*" { + return false; + } + + for (ix, path_segment) in self.path.iter().enumerate() { + if path_segment == "**" { + return true; + } + + if ix >= desired_path.len() { + return false; + } + + if path_segment != "*" && path_segment != desired_path[ix] { + return false; + } + } + + if self.path.len() < desired_path.len() { + return false; + } + + true + } +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + + use super::*; + + #[test] + fn test_allows() { + let capability = DownloadFileCapability { + host: "*".to_string(), + path: vec!["**".to_string()], + }; + assert_eq!( + capability.allows(&"https://example.com/some/path".parse().unwrap()), + true + ); + + let capability = DownloadFileCapability { + host: "github.com".to_string(), + path: vec!["**".to_string()], + }; + assert_eq!( + capability.allows(&"https://github.com/some-owner/some-repo".parse().unwrap()), + true + ); + assert_eq!( + capability.allows( + &"https://fake-github.com/some-owner/some-repo" + .parse() + .unwrap() + ), + false + ); + + let capability = DownloadFileCapability { + host: "github.com".to_string(), + path: vec!["specific-owner".to_string(), "*".to_string()], + }; + assert_eq!( + capability.allows(&"https://github.com/some-owner/some-repo".parse().unwrap()), + false + ); + assert_eq!( + capability.allows( + &"https://github.com/specific-owner/some-repo" + .parse() + .unwrap() + ), + true + ); + + let capability = DownloadFileCapability { + host: "github.com".to_string(), + path: vec!["specific-owner".to_string(), "*".to_string()], + }; + assert_eq!( + capability.allows( + &"https://github.com/some-owner/some-repo/extra" + .parse() + .unwrap() + ), + false + ); + assert_eq!( + capability.allows( + &"https://github.com/specific-owner/some-repo/extra" + .parse() + .unwrap() + ), + false + ); + } +} diff --git a/crates/extension/src/capabilities/npm_install_package_capability.rs b/crates/extension/src/capabilities/npm_install_package_capability.rs new file mode 100644 index 0000000000000000000000000000000000000000..287645fc7506d95f66440d439777a67facfe1e35 --- /dev/null +++ b/crates/extension/src/capabilities/npm_install_package_capability.rs @@ -0,0 +1,39 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct NpmInstallPackageCapability { + pub package: String, +} + +impl NpmInstallPackageCapability { + /// Returns whether the capability allows installing the given NPM package. + pub fn allows(&self, package: &str) -> bool { + self.package == "*" || self.package == package + } +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + + use super::*; + + #[test] + fn test_allows() { + let capability = NpmInstallPackageCapability { + package: "*".to_string(), + }; + assert_eq!(capability.allows("package"), true); + + let capability = NpmInstallPackageCapability { + package: "react".to_string(), + }; + assert_eq!(capability.allows("react"), true); + + let capability = NpmInstallPackageCapability { + package: "react".to_string(), + }; + assert_eq!(capability.allows("malicious-package"), false); + } +} diff --git a/crates/extension/src/capabilities/process_exec_capability.rs b/crates/extension/src/capabilities/process_exec_capability.rs new file mode 100644 index 0000000000000000000000000000000000000000..053a7b212b9dc747270c5d3b011e6b27a3b37049 --- /dev/null +++ b/crates/extension/src/capabilities/process_exec_capability.rs @@ -0,0 +1,116 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct ProcessExecCapability { + /// The command to execute. + pub command: String, + /// The arguments to pass to the command. Use `*` for a single wildcard argument. + /// If the last element is `**`, then any trailing arguments are allowed. + pub args: Vec<String>, +} + +impl ProcessExecCapability { + /// Returns whether the capability allows the given command and arguments. + pub fn allows( + &self, + desired_command: &str, + desired_args: &[impl AsRef<str> + std::fmt::Debug], + ) -> bool { + if self.command != desired_command && self.command != "*" { + return false; + } + + for (ix, arg) in self.args.iter().enumerate() { + if arg == "**" { + return true; + } + + if ix >= desired_args.len() { + return false; + } + + if arg != "*" && arg != desired_args[ix].as_ref() { + return false; + } + } + + if self.args.len() < desired_args.len() { + return false; + } + + true + } +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + + use super::*; + + #[test] + fn test_allows_with_exact_match() { + let capability = ProcessExecCapability { + command: "ls".to_string(), + args: vec!["-la".to_string()], + }; + + assert_eq!(capability.allows("ls", &["-la"]), true); + assert_eq!(capability.allows("ls", &["-l"]), false); + assert_eq!(capability.allows("pwd", &[] as &[&str]), false); + } + + #[test] + fn test_allows_with_wildcard_arg() { + let capability = ProcessExecCapability { + command: "git".to_string(), + args: vec!["*".to_string()], + }; + + assert_eq!(capability.allows("git", &["status"]), true); + assert_eq!(capability.allows("git", &["commit"]), true); + // Too many args. + assert_eq!(capability.allows("git", &["status", "-s"]), false); + // Wrong command. + assert_eq!(capability.allows("npm", &["install"]), false); + } + + #[test] + fn test_allows_with_double_wildcard() { + let capability = ProcessExecCapability { + command: "cargo".to_string(), + args: vec!["test".to_string(), "**".to_string()], + }; + + assert_eq!(capability.allows("cargo", &["test"]), true); + assert_eq!(capability.allows("cargo", &["test", "--all"]), true); + assert_eq!( + capability.allows("cargo", &["test", "--all", "--no-fail-fast"]), + true + ); + // Wrong first arg. + assert_eq!(capability.allows("cargo", &["build"]), false); + } + + #[test] + fn test_allows_with_mixed_wildcards() { + let capability = ProcessExecCapability { + command: "docker".to_string(), + args: vec!["run".to_string(), "*".to_string(), "**".to_string()], + }; + + assert_eq!(capability.allows("docker", &["run", "nginx"]), true); + assert_eq!(capability.allows("docker", &["run"]), false); + assert_eq!( + capability.allows("docker", &["run", "ubuntu", "bash"]), + true + ); + assert_eq!( + capability.allows("docker", &["run", "alpine", "sh", "-c", "echo hello"]), + true + ); + // Wrong first arg. + assert_eq!(capability.allows("docker", &["ps"]), false); + } +} diff --git a/crates/extension/src/extension.rs b/crates/extension/src/extension.rs index 8b150e19b9a802e3b2115043ff7ae46e037f9c60..35f7f419383cb9f3c6cc518663ad818735eab80e 100644 --- a/crates/extension/src/extension.rs +++ b/crates/extension/src/extension.rs @@ -1,3 +1,4 @@ +mod capabilities; pub mod extension_builder; mod extension_events; mod extension_host_proxy; @@ -16,6 +17,7 @@ use language::LanguageName; use semantic_version::SemanticVersion; use task::{SpawnInTerminal, ZedDebugConfig}; +pub use crate::capabilities::*; pub use crate::extension_events::*; pub use crate::extension_host_proxy::*; pub use crate::extension_manifest::*; diff --git a/crates/extension/src/extension_manifest.rs b/crates/extension/src/extension_manifest.rs index 0a14923c0c1a4ccfb153d9fa7f602d36805799fe..5852b3e3fc32601e8d9527e02d593e02cd66f3c6 100644 --- a/crates/extension/src/extension_manifest.rs +++ b/crates/extension/src/extension_manifest.rs @@ -12,6 +12,8 @@ use std::{ sync::Arc, }; +use crate::ExtensionCapability; + /// This is the old version of the extension manifest, from when it was `extension.json`. #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] pub struct OldExtensionManifest { @@ -100,24 +102,8 @@ impl ExtensionManifest { desired_args: &[impl AsRef<str> + std::fmt::Debug], ) -> Result<()> { let is_allowed = self.capabilities.iter().any(|capability| match capability { - ExtensionCapability::ProcessExec { command, args } if command == desired_command => { - for (ix, arg) in args.iter().enumerate() { - if arg == "**" { - return true; - } - - if ix >= desired_args.len() { - return false; - } - - if arg != "*" && arg != desired_args[ix].as_ref() { - return false; - } - } - if args.len() < desired_args.len() { - return false; - } - true + ExtensionCapability::ProcessExec(capability) => { + capability.allows(desired_command, desired_args) } _ => false, }); @@ -148,20 +134,6 @@ pub fn build_debug_adapter_schema_path( }) } -/// A capability for an extension. -#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] -#[serde(tag = "kind")] -pub enum ExtensionCapability { - #[serde(rename = "process:exec")] - ProcessExec { - /// The command to execute. - command: String, - /// The arguments to pass to the command. Use `*` for a single wildcard argument. - /// If the last element is `**`, then any trailing arguments are allowed. - args: Vec<String>, - }, -} - #[derive(Clone, Default, PartialEq, Eq, Debug, Deserialize, Serialize)] pub struct LibManifestEntry { pub kind: Option<ExtensionLibraryKind>, @@ -191,7 +163,7 @@ pub struct LanguageServerManifestEntry { #[serde(default)] languages: Vec<LanguageName>, #[serde(default)] - pub language_ids: HashMap<String, String>, + pub language_ids: HashMap<LanguageName, String>, #[serde(default)] pub code_action_kinds: Option<Vec<lsp::CodeActionKind>>, } @@ -309,6 +281,10 @@ fn manifest_from_old_manifest( #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; + + use crate::ProcessExecCapability; + use super::*; fn extension_manifest() -> ExtensionManifest { @@ -360,12 +336,12 @@ mod tests { } #[test] - fn test_allow_exact_match() { + fn test_allow_exec_exact_match() { let manifest = ExtensionManifest { - capabilities: vec![ExtensionCapability::ProcessExec { + capabilities: vec![ExtensionCapability::ProcessExec(ProcessExecCapability { command: "ls".to_string(), args: vec!["-la".to_string()], - }], + })], ..extension_manifest() }; @@ -375,12 +351,12 @@ mod tests { } #[test] - fn test_allow_wildcard_arg() { + fn test_allow_exec_wildcard_arg() { let manifest = ExtensionManifest { - capabilities: vec![ExtensionCapability::ProcessExec { + capabilities: vec![ExtensionCapability::ProcessExec(ProcessExecCapability { command: "git".to_string(), args: vec!["*".to_string()], - }], + })], ..extension_manifest() }; @@ -391,12 +367,12 @@ mod tests { } #[test] - fn test_allow_double_wildcard() { + fn test_allow_exec_double_wildcard() { let manifest = ExtensionManifest { - capabilities: vec![ExtensionCapability::ProcessExec { + capabilities: vec![ExtensionCapability::ProcessExec(ProcessExecCapability { command: "cargo".to_string(), args: vec!["test".to_string(), "**".to_string()], - }], + })], ..extension_manifest() }; @@ -411,12 +387,12 @@ mod tests { } #[test] - fn test_allow_mixed_wildcards() { + fn test_allow_exec_mixed_wildcards() { let manifest = ExtensionManifest { - capabilities: vec![ExtensionCapability::ProcessExec { + capabilities: vec![ExtensionCapability::ProcessExec(ProcessExecCapability { command: "docker".to_string(), args: vec!["run".to_string(), "*".to_string(), "**".to_string()], - }], + })], ..extension_manifest() }; diff --git a/crates/extension/src/types.rs b/crates/extension/src/types.rs index cb24e5077b839a0c5ded24c084fbdd7c1cbeab7c..ed9eb2ec2fb96a3b19125355be90e6ba7a5a6e90 100644 --- a/crates/extension/src/types.rs +++ b/crates/extension/src/types.rs @@ -3,7 +3,7 @@ mod dap; mod lsp; mod slash_command; -use std::ops::Range; +use std::{ops::Range, path::PathBuf}; use util::redact::should_redact; @@ -18,7 +18,7 @@ pub type EnvVars = Vec<(String, String)>; /// A command. pub struct Command { /// The command to execute. - pub command: String, + pub command: PathBuf, /// The arguments to pass to the command. pub args: Vec<String>, /// The environment variables to set for the command. diff --git a/crates/extension_cli/src/main.rs b/crates/extension_cli/src/main.rs index 45a7e3b6412ea9fba02a6394394b3ca9fc5bc58f..ab4a9cddb0fa13421677772d1c07c1a8d9234d76 100644 --- a/crates/extension_cli/src/main.rs +++ b/crates/extension_cli/src/main.rs @@ -289,6 +289,24 @@ async fn copy_extension_resources( } } + if let Some(snippets_path) = manifest.snippets.as_ref() { + let parent = snippets_path.parent(); + if let Some(parent) = parent.filter(|p| p.components().next().is_some()) { + fs::create_dir_all(output_dir.join(parent))?; + } + copy_recursive( + fs.as_ref(), + &extension_path.join(&snippets_path), + &output_dir.join(&snippets_path), + CopyOptions { + overwrite: true, + ignore_if_exists: false, + }, + ) + .await + .with_context(|| format!("failed to copy snippets from '{}'", snippets_path.display()))?; + } + Ok(()) } diff --git a/crates/extension_host/benches/extension_compilation_benchmark.rs b/crates/extension_host/benches/extension_compilation_benchmark.rs index 9d867af0417d1b4429de59ad4c1672eda8b1b676..a4fa9bfeff7472e65103e2b4bd91b2240fa8fb32 100644 --- a/crates/extension_host/benches/extension_compilation_benchmark.rs +++ b/crates/extension_host/benches/extension_compilation_benchmark.rs @@ -134,10 +134,12 @@ fn manifest() -> ExtensionManifest { slash_commands: BTreeMap::default(), indexed_docs_providers: BTreeMap::default(), snippets: None, - capabilities: vec![ExtensionCapability::ProcessExec { - command: "echo".into(), - args: vec!["hello!".into()], - }], + capabilities: vec![ExtensionCapability::ProcessExec( + extension::ProcessExecCapability { + command: "echo".into(), + args: vec!["hello!".into()], + }, + )], debug_adapters: Default::default(), debug_locators: Default::default(), } diff --git a/crates/extension_host/src/capability_granter.rs b/crates/extension_host/src/capability_granter.rs new file mode 100644 index 0000000000000000000000000000000000000000..c77e5ecba15b5e10caa331d3b6ee3976b899ed21 --- /dev/null +++ b/crates/extension_host/src/capability_granter.rs @@ -0,0 +1,153 @@ +use std::sync::Arc; + +use anyhow::{Result, bail}; +use extension::{ExtensionCapability, ExtensionManifest}; +use url::Url; + +pub struct CapabilityGranter { + granted_capabilities: Vec<ExtensionCapability>, + manifest: Arc<ExtensionManifest>, +} + +impl CapabilityGranter { + pub fn new( + granted_capabilities: Vec<ExtensionCapability>, + manifest: Arc<ExtensionManifest>, + ) -> Self { + Self { + granted_capabilities, + manifest, + } + } + + pub fn grant_exec( + &self, + desired_command: &str, + desired_args: &[impl AsRef<str> + std::fmt::Debug], + ) -> Result<()> { + self.manifest.allow_exec(desired_command, desired_args)?; + + let is_allowed = self + .granted_capabilities + .iter() + .any(|capability| match capability { + ExtensionCapability::ProcessExec(capability) => { + capability.allows(desired_command, desired_args) + } + _ => false, + }); + + if !is_allowed { + bail!( + "capability for process:exec {desired_command} {desired_args:?} is not granted by the extension host", + ); + } + + Ok(()) + } + + pub fn grant_download_file(&self, desired_url: &Url) -> Result<()> { + let is_allowed = self + .granted_capabilities + .iter() + .any(|capability| match capability { + ExtensionCapability::DownloadFile(capability) => capability.allows(desired_url), + _ => false, + }); + + if !is_allowed { + bail!( + "capability for download_file {desired_url} is not granted by the extension host", + ); + } + + Ok(()) + } + + pub fn grant_npm_install_package(&self, package_name: &str) -> Result<()> { + let is_allowed = self + .granted_capabilities + .iter() + .any(|capability| match capability { + ExtensionCapability::NpmInstallPackage(capability) => { + capability.allows(package_name) + } + _ => false, + }); + + if !is_allowed { + bail!("capability for npm:install {package_name} is not granted by the extension host",); + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeMap; + + use extension::{ProcessExecCapability, SchemaVersion}; + + use super::*; + + fn extension_manifest() -> ExtensionManifest { + ExtensionManifest { + id: "test".into(), + name: "Test".to_string(), + version: "1.0.0".into(), + schema_version: SchemaVersion::ZERO, + description: None, + repository: None, + authors: vec![], + lib: Default::default(), + themes: vec![], + icon_themes: vec![], + languages: vec![], + grammars: BTreeMap::default(), + language_servers: BTreeMap::default(), + context_servers: BTreeMap::default(), + slash_commands: BTreeMap::default(), + indexed_docs_providers: BTreeMap::default(), + snippets: None, + capabilities: vec![], + debug_adapters: Default::default(), + debug_locators: Default::default(), + } + } + + #[test] + fn test_grant_exec() { + let manifest = Arc::new(ExtensionManifest { + capabilities: vec![ExtensionCapability::ProcessExec(ProcessExecCapability { + command: "ls".to_string(), + args: vec!["-la".to_string()], + })], + ..extension_manifest() + }); + + // It returns an error when the extension host has no granted capabilities. + let granter = CapabilityGranter::new(Vec::new(), manifest.clone()); + assert!(granter.grant_exec("ls", &["-la"]).is_err()); + + // It succeeds when the extension host has the exact capability. + let granter = CapabilityGranter::new( + vec![ExtensionCapability::ProcessExec(ProcessExecCapability { + command: "ls".to_string(), + args: vec!["-la".to_string()], + })], + manifest.clone(), + ); + assert!(granter.grant_exec("ls", &["-la"]).is_ok()); + + // It succeeds when the extension host has a wildcard capability. + let granter = CapabilityGranter::new( + vec![ExtensionCapability::ProcessExec(ProcessExecCapability { + command: "*".to_string(), + args: vec!["**".to_string()], + })], + manifest.clone(), + ); + assert!(granter.grant_exec("ls", &["-la"]).is_ok()); + } +} diff --git a/crates/extension_host/src/extension_host.rs b/crates/extension_host/src/extension_host.rs index 075c68d51a1afe6bc67ae17ca88dd35b34191761..dc38c244f1f94b6aeb266d52c7933394412ec269 100644 --- a/crates/extension_host/src/extension_host.rs +++ b/crates/extension_host/src/extension_host.rs @@ -1,3 +1,4 @@ +mod capability_granter; pub mod extension_settings; pub mod headless_host; pub mod wasm_host; @@ -1313,10 +1314,17 @@ impl ExtensionStore { } for snippets_path in &snippets_to_add { - if let Some(snippets_contents) = fs.load(snippets_path).await.log_err() { - proxy - .register_snippet(snippets_path, &snippets_contents) - .log_err(); + match fs + .load(snippets_path) + .await + .with_context(|| format!("Loading snippets from {snippets_path:?}")) + { + Ok(snippets_contents) => { + proxy + .register_snippet(snippets_path, &snippets_contents) + .log_err(); + } + Err(e) => log::error!("Cannot load snippets: {e:#}"), } } } @@ -1331,20 +1339,25 @@ impl ExtensionStore { let extension_path = root_dir.join(extension.manifest.id.as_ref()); let wasm_extension = WasmExtension::load( - extension_path, + &extension_path, &extension.manifest, wasm_host.clone(), &cx, ) - .await; + .await + .with_context(|| format!("Loading extension from {extension_path:?}")); - if let Some(wasm_extension) = wasm_extension.log_err() { - wasm_extensions.push((extension.manifest.clone(), wasm_extension)); - } else { - this.update(cx, |_, cx| { - cx.emit(Event::ExtensionFailedToLoad(extension.manifest.id.clone())) - }) - .ok(); + match wasm_extension { + Ok(wasm_extension) => { + wasm_extensions.push((extension.manifest.clone(), wasm_extension)) + } + Err(e) => { + log::error!("Failed to load extension: {e:#}"); + this.update(cx, |_, cx| { + cx.emit(Event::ExtensionFailedToLoad(extension.manifest.id.clone())) + }) + .ok(); + } } } diff --git a/crates/extension_host/src/extension_store_test.rs b/crates/extension_host/src/extension_store_test.rs index 891ab91852982cd46064292d56f11cc6407ea72c..c31774c20d3e94f829e8de5d6ca822228735ca18 100644 --- a/crates/extension_host/src/extension_store_test.rs +++ b/crates/extension_host/src/extension_store_test.rs @@ -10,7 +10,7 @@ use fs::{FakeFs, Fs, RealFs}; use futures::{AsyncReadExt, StreamExt, io::BufReader}; use gpui::{AppContext as _, SemanticVersion, TestAppContext}; use http_client::{FakeHttpClient, Response}; -use language::{BinaryStatus, LanguageMatcher, LanguageRegistry}; +use language::{BinaryStatus, LanguageMatcher, LanguageName, LanguageRegistry}; use language_extension::LspAccess; use lsp::LanguageServerName; use node_runtime::NodeRuntime; @@ -306,7 +306,11 @@ async fn test_extension_store(cx: &mut TestAppContext) { assert_eq!( language_registry.language_names(), - ["ERB", "Plain Text", "Ruby"] + [ + LanguageName::new("ERB"), + LanguageName::new("Plain Text"), + LanguageName::new("Ruby"), + ] ); assert_eq!( theme_registry.list_names(), @@ -458,7 +462,11 @@ async fn test_extension_store(cx: &mut TestAppContext) { assert_eq!( language_registry.language_names(), - ["ERB", "Plain Text", "Ruby"] + [ + LanguageName::new("ERB"), + LanguageName::new("Plain Text"), + LanguageName::new("Ruby"), + ] ); assert_eq!( language_registry.grammar_names(), @@ -513,7 +521,10 @@ async fn test_extension_store(cx: &mut TestAppContext) { assert_eq!(actual_language.hidden, expected_language.hidden); } - assert_eq!(language_registry.language_names(), ["Plain Text"]); + assert_eq!( + language_registry.language_names(), + [LanguageName::new("Plain Text")] + ); assert_eq!(language_registry.grammar_names(), []); }); } diff --git a/crates/extension_host/src/headless_host.rs b/crates/extension_host/src/headless_host.rs index dbc9bbfe1379a1766be5fc27b55f633d5b004e51..adc9638c2998eb1f122df5137577ca7e0cf4c975 100644 --- a/crates/extension_host/src/headless_host.rs +++ b/crates/extension_host/src/headless_host.rs @@ -173,9 +173,8 @@ impl HeadlessExtensionStore { return Ok(()); } - let wasm_extension: Arc<dyn Extension> = Arc::new( - WasmExtension::load(extension_dir.clone(), &manifest, wasm_host.clone(), &cx).await?, - ); + let wasm_extension: Arc<dyn Extension> = + Arc::new(WasmExtension::load(&extension_dir, &manifest, wasm_host.clone(), &cx).await?); for (language_server_id, language_server_config) in &manifest.language_servers { for language in language_server_config.languages() { diff --git a/crates/extension_host/src/wasm_host.rs b/crates/extension_host/src/wasm_host.rs index 3971fa426306a6f746bfe72de0ff96934205d4db..d990b670f49221aca2f0af901293c70d341cf029 100644 --- a/crates/extension_host/src/wasm_host.rs +++ b/crates/extension_host/src/wasm_host.rs @@ -1,13 +1,15 @@ pub mod wit; use crate::ExtensionManifest; +use crate::capability_granter::CapabilityGranter; use anyhow::{Context as _, Result, anyhow, bail}; use async_trait::async_trait; use dap::{DebugRequest, StartDebuggingRequestArgumentsRequest}; use extension::{ CodeLabel, Command, Completion, ContextServerConfiguration, DebugAdapterBinary, - DebugTaskDefinition, ExtensionHostProxy, KeyValueStoreDelegate, ProjectDelegate, SlashCommand, - SlashCommandArgumentCompletion, SlashCommandOutput, Symbol, WorktreeDelegate, + DebugTaskDefinition, DownloadFileCapability, ExtensionCapability, ExtensionHostProxy, + KeyValueStoreDelegate, NpmInstallPackageCapability, ProcessExecCapability, ProjectDelegate, + SlashCommand, SlashCommandArgumentCompletion, SlashCommandOutput, Symbol, WorktreeDelegate, }; use fs::{Fs, normalize_path}; use futures::future::LocalBoxFuture; @@ -50,6 +52,8 @@ pub struct WasmHost { pub(crate) proxy: Arc<ExtensionHostProxy>, fs: Arc<dyn Fs>, pub work_dir: PathBuf, + /// The capabilities granted to extensions running on the host. + pub(crate) granted_capabilities: Vec<ExtensionCapability>, _main_thread_message_task: Task<()>, main_thread_message_tx: mpsc::UnboundedSender<MainThreadCall>, } @@ -102,7 +106,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn language_server_initialization_options( @@ -127,7 +131,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn language_server_workspace_configuration( @@ -150,7 +154,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn language_server_additional_initialization_options( @@ -175,7 +179,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn language_server_additional_workspace_configuration( @@ -200,7 +204,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn labels_for_completions( @@ -226,7 +230,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn labels_for_symbols( @@ -252,7 +256,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn complete_slash_command_argument( @@ -271,7 +275,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn run_slash_command( @@ -297,7 +301,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn context_server_command( @@ -316,7 +320,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn context_server_configuration( @@ -343,7 +347,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn suggest_docs_packages(&self, provider: Arc<str>) -> Result<Vec<String>> { @@ -358,7 +362,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn index_docs( @@ -384,7 +388,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn get_dap_binary( @@ -406,7 +410,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn dap_request_kind( &self, @@ -423,7 +427,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn dap_config_to_scenario(&self, config: ZedDebugConfig) -> Result<DebugScenario> { @@ -437,7 +441,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn dap_locator_create_scenario( @@ -461,7 +465,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn run_dap_locator( &self, @@ -477,7 +481,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } } @@ -486,6 +490,7 @@ pub struct WasmState { pub table: ResourceTable, ctx: wasi::WasiCtx, pub host: Arc<WasmHost>, + pub(crate) capability_granter: CapabilityGranter, } type MainThreadCall = Box<dyn Send + for<'a> FnOnce(&'a mut AsyncApp) -> LocalBoxFuture<'a, ()>>; @@ -571,6 +576,19 @@ impl WasmHost { node_runtime, proxy, release_channel: ReleaseChannel::global(cx), + granted_capabilities: vec![ + ExtensionCapability::ProcessExec(ProcessExecCapability { + command: "*".to_string(), + args: vec!["**".to_string()], + }), + ExtensionCapability::DownloadFile(DownloadFileCapability { + host: "*".to_string(), + path: vec!["**".to_string()], + }), + ExtensionCapability::NpmInstallPackage(NpmInstallPackageCapability { + package: "*".to_string(), + }), + ], _main_thread_message_task: task, main_thread_message_tx: tx, }) @@ -597,6 +615,10 @@ impl WasmHost { manifest: manifest.clone(), table: ResourceTable::new(), host: this.clone(), + capability_granter: CapabilityGranter::new( + this.granted_capabilities.clone(), + manifest.clone(), + ), }, ); // Store will yield after 1 tick, and get a new deadline of 1 tick after each yield. @@ -715,7 +737,7 @@ fn parse_wasm_extension_version_custom_section(data: &[u8]) -> Option<SemanticVe impl WasmExtension { pub async fn load( - extension_dir: PathBuf, + extension_dir: &Path, manifest: &Arc<ExtensionManifest>, wasm_host: Arc<WasmHost>, cx: &AsyncApp, @@ -739,7 +761,7 @@ impl WasmExtension { .with_context(|| format!("failed to load wasm extension {}", manifest.id)) } - pub async fn call<T, Fn>(&self, f: Fn) -> T + pub async fn call<T, Fn>(&self, f: Fn) -> Result<T> where T: 'static + Send, Fn: 'static @@ -755,8 +777,19 @@ impl WasmExtension { } .boxed() })) - .expect("wasm extension channel should not be closed yet"); - return_rx.await.expect("wasm extension channel") + .map_err(|_| { + anyhow!( + "wasm extension channel should not be closed yet, extension {} (id {})", + self.manifest.name, + self.manifest.id, + ) + })?; + return_rx.await.with_context(|| { + format!( + "wasm extension channel, extension {} (id {})", + self.manifest.name, self.manifest.id, + ) + }) } } @@ -777,8 +810,19 @@ impl WasmState { } .boxed_local() })) - .expect("main thread message channel should not be closed yet"); - async move { return_rx.await.expect("main thread message channel") } + .unwrap_or_else(|_| { + panic!( + "main thread message channel should not be closed yet, extension {} (id {})", + self.manifest.name, self.manifest.id, + ) + }); + let name = self.manifest.name.clone(); + let id = self.manifest.id.clone(); + async move { + return_rx.await.unwrap_or_else(|_| { + panic!("main thread message channel, extension {name} (id {id})") + }) + } } fn work_dir(&self) -> PathBuf { diff --git a/crates/extension_host/src/wasm_host/wit/since_v0_6_0.rs b/crates/extension_host/src/wasm_host/wit/since_v0_6_0.rs index ced2ea9c677022e95f106ac6ba0543303fe5a372..767b9033ade3c81c6ac149363676513c72996b7e 100644 --- a/crates/extension_host/src/wasm_host/wit/since_v0_6_0.rs +++ b/crates/extension_host/src/wasm_host/wit/since_v0_6_0.rs @@ -30,6 +30,7 @@ use std::{ sync::{Arc, OnceLock}, }; use task::{SpawnInTerminal, ZedDebugConfig}; +use url::Url; use util::{archive::extract_zip, fs::make_file_executable, maybe}; use wasmtime::component::{Linker, Resource}; @@ -75,7 +76,7 @@ impl From<Range> for std::ops::Range<usize> { impl From<Command> for extension::Command { fn from(value: Command) -> Self { Self { - command: value.command, + command: value.command.into(), args: value.args, env: value.env, } @@ -744,6 +745,9 @@ impl nodejs::Host for WasmState { package_name: String, version: String, ) -> wasmtime::Result<Result<(), String>> { + self.capability_granter + .grant_npm_install_package(&package_name)?; + self.host .node_runtime .npm_install_packages(&self.work_dir(), &[(&package_name, &version)]) @@ -847,7 +851,8 @@ impl process::Host for WasmState { command: process::Command, ) -> wasmtime::Result<Result<process::Output, String>> { maybe!(async { - self.manifest.allow_exec(&command.command, &command.args)?; + self.capability_granter + .grant_exec(&command.command, &command.args)?; let output = util::command::new_smol_command(command.command.as_str()) .args(&command.args) @@ -958,7 +963,7 @@ impl ExtensionImports for WasmState { command, } => Ok(serde_json::to_string(&settings::ContextServerSettings { command: Some(settings::CommandSettings { - path: Some(command.path), + path: command.path.to_str().map(|path| path.to_string()), arguments: Some(command.args), env: command.env.map(|env| env.into_iter().collect()), }), @@ -1010,6 +1015,9 @@ impl ExtensionImports for WasmState { file_type: DownloadedFileType, ) -> wasmtime::Result<Result<(), String>> { maybe!(async { + let parsed_url = Url::parse(&url)?; + self.capability_granter.grant_download_file(&parsed_url)?; + let path = PathBuf::from(path); let extension_work_dir = self.host.work_dir.join(self.manifest.id.as_ref()); diff --git a/crates/extensions_ui/src/extensions_ui.rs b/crates/extensions_ui/src/extensions_ui.rs index 0d00deb10e64ec72e3bf64b1c8ce0929d944104a..fe3e94f5c20dc1a78ae01defc24e290c18a1a3e6 100644 --- a/crates/extensions_ui/src/extensions_ui.rs +++ b/crates/extensions_ui/src/extensions_ui.rs @@ -6,6 +6,7 @@ use std::sync::OnceLock; use std::time::Duration; use std::{ops::Range, sync::Arc}; +use anyhow::Context as _; use client::{ExtensionMetadata, ExtensionProvides}; use collections::{BTreeMap, BTreeSet}; use editor::{Editor, EditorElement, EditorStyle}; @@ -23,7 +24,7 @@ use settings::Settings; use strum::IntoEnumIterator as _; use theme::ThemeSettings; use ui::{ - CheckboxWithLabel, ContextMenu, PopoverMenu, ScrollableHandle, Scrollbar, ScrollbarState, + CheckboxWithLabel, Chip, ContextMenu, PopoverMenu, ScrollableHandle, Scrollbar, ScrollbarState, ToggleButton, Tooltip, prelude::*, }; use vim_mode_setting::VimModeSetting; @@ -80,16 +81,24 @@ pub fn init(cx: &mut App) { .find_map(|item| item.downcast::<ExtensionsPage>()); if let Some(existing) = existing { - if provides_filter.is_some() { - existing.update(cx, |extensions_page, cx| { + existing.update(cx, |extensions_page, cx| { + if provides_filter.is_some() { extensions_page.change_provides_filter(provides_filter, cx); - }); - } + } + if let Some(id) = action.id.as_ref() { + extensions_page.focus_extension(id, window, cx); + } + }); workspace.activate_item(&existing, true, true, window, cx); } else { - let extensions_page = - ExtensionsPage::new(workspace, provides_filter, window, cx); + let extensions_page = ExtensionsPage::new( + workspace, + provides_filter, + action.id.as_deref(), + window, + cx, + ); workspace.add_item_to_active_pane( Box::new(extensions_page), None, @@ -287,6 +296,7 @@ impl ExtensionsPage { pub fn new( workspace: &Workspace, provides_filter: Option<ExtensionProvides>, + focus_extension_id: Option<&str>, window: &mut Window, cx: &mut Context<Workspace>, ) -> Entity<Self> { @@ -317,6 +327,9 @@ impl ExtensionsPage { let query_editor = cx.new(|cx| { let mut input = Editor::single_line(window, cx); input.set_placeholder_text("Search extensions...", cx); + if let Some(id) = focus_extension_id { + input.set_text(format!("id:{id}"), window, cx); + } input }); cx.subscribe(&query_editor, Self::on_query_change).detach(); @@ -340,7 +353,7 @@ impl ExtensionsPage { scrollbar_state: ScrollbarState::new(scroll_handle), }; this.fetch_extensions( - None, + this.search_query(cx), Some(BTreeSet::from_iter(this.provides_filter)), None, cx, @@ -464,9 +477,23 @@ impl ExtensionsPage { .cloned() .collect::<Vec<_>>(); - let remote_extensions = extension_store.update(cx, |store, cx| { - store.fetch_extensions(search.as_deref(), provides_filter.as_ref(), cx) - }); + let remote_extensions = + if let Some(id) = search.as_ref().and_then(|s| s.strip_prefix("id:")) { + let versions = + extension_store.update(cx, |store, cx| store.fetch_extension_versions(id, cx)); + cx.foreground_executor().spawn(async move { + let versions = versions.await?; + let latest = versions + .into_iter() + .max_by_key(|v| v.published_at) + .context("no extension found")?; + Ok(vec![latest]) + }) + } else { + extension_store.update(cx, |store, cx| { + store.fetch_extensions(search.as_deref(), provides_filter.as_ref(), cx) + }) + }; cx.spawn(async move |this, cx| { let dev_extensions = if let Some(search) = search { @@ -732,20 +759,7 @@ impl ExtensionsPage { _ => {} } - Some( - div() - .px_1() - .border_1() - .rounded_sm() - .border_color(cx.theme().colors().border) - .bg(cx.theme().colors().element_background) - .child( - Label::new(extension_provides_label( - *provides, - )) - .size(LabelSize::XSmall), - ), - ) + Some(Chip::new(extension_provides_label(*provides))) }) .collect::<Vec<_>>(), ), @@ -1165,6 +1179,13 @@ impl ExtensionsPage { self.refresh_feature_upsells(cx); } + pub fn focus_extension(&mut self, id: &str, window: &mut Window, cx: &mut Context<Self>) { + self.query_editor.update(cx, |editor, cx| { + editor.set_text(format!("id:{id}"), window, cx) + }); + self.refresh_search(cx); + } + pub fn change_provides_filter( &mut self, provides_filter: Option<ExtensionProvides>, diff --git a/crates/feature_flags/src/feature_flags.rs b/crates/feature_flags/src/feature_flags.rs index 9252977f7539ee4ac674e37f1b7c73d4651423d7..ef357adf35997bfb7560f1e1849ef69e780cd1f9 100644 --- a/crates/feature_flags/src/feature_flags.rs +++ b/crates/feature_flags/src/feature_flags.rs @@ -85,6 +85,11 @@ impl FeatureFlag for ThreadAutoCaptureFeatureFlag { false } } +pub struct PanicFeatureFlag; + +impl FeatureFlag for PanicFeatureFlag { + const NAME: &'static str = "panic"; +} pub struct JjUiFeatureFlag {} @@ -98,17 +103,6 @@ impl FeatureFlag for AcpFeatureFlag { const NAME: &'static str = "acp"; } -pub struct ZedCloudFeatureFlag {} - -impl FeatureFlag for ZedCloudFeatureFlag { - const NAME: &'static str = "zed-cloud"; - - fn enabled_for_staff() -> bool { - // Require individual opt-in, for now. - false - } -} - pub trait FeatureFlagViewExt<V: 'static> { fn observe_flag<T: FeatureFlag, F>(&mut self, window: &Window, callback: F) -> Subscription where @@ -164,6 +158,11 @@ where } } +#[derive(Debug)] +pub struct OnFlagsReady { + pub is_staff: bool, +} + pub trait FeatureFlagAppExt { fn wait_for_flag<T: FeatureFlag>(&mut self) -> WaitForFlag; @@ -175,6 +174,10 @@ pub trait FeatureFlagAppExt { fn has_flag<T: FeatureFlag>(&self) -> bool; fn is_staff(&self) -> bool; + fn on_flags_ready<F>(&mut self, callback: F) -> Subscription + where + F: FnMut(OnFlagsReady, &mut App) + 'static; + fn observe_flag<T: FeatureFlag, F>(&mut self, callback: F) -> Subscription where F: FnMut(bool, &mut App) + 'static; @@ -204,6 +207,21 @@ impl FeatureFlagAppExt for App { .unwrap_or(false) } + fn on_flags_ready<F>(&mut self, mut callback: F) -> Subscription + where + F: FnMut(OnFlagsReady, &mut App) + 'static, + { + self.observe_global::<FeatureFlags>(move |cx| { + let feature_flags = cx.global::<FeatureFlags>(); + callback( + OnFlagsReady { + is_staff: feature_flags.staff, + }, + cx, + ); + }) + } + fn observe_flag<T: FeatureFlag, F>(&mut self, mut callback: F) -> Subscription where F: FnMut(bool, &mut App) + 'static, diff --git a/crates/file_finder/src/file_finder.rs b/crates/file_finder/src/file_finder.rs index a4d61dd56f0b3503b09698aa633cf47bf12389e4..e5ac70bb583be004941eee06476dc9318de1adc4 100644 --- a/crates/file_finder/src/file_finder.rs +++ b/crates/file_finder/src/file_finder.rs @@ -1404,14 +1404,21 @@ impl PickerDelegate for FileFinderDelegate { } else { let path_position = PathWithPosition::parse_str(&raw_query); + #[cfg(windows)] + let raw_query = raw_query.trim().to_owned().replace("/", "\\"); + #[cfg(not(windows))] + let raw_query = raw_query.trim().to_owned(); + + let file_query_end = if path_position.path.to_str().unwrap_or(&raw_query) == raw_query { + None + } else { + // Safe to unwrap as we won't get here when the unwrap in if fails + Some(path_position.path.to_str().unwrap().len()) + }; + let query = FileSearchQuery { - raw_query: raw_query.trim().to_owned(), - file_query_end: if path_position.path.to_str().unwrap_or(raw_query) == raw_query { - None - } else { - // Safe to unwrap as we won't get here when the unwrap in if fails - Some(path_position.path.to_str().unwrap().len()) - }, + raw_query, + file_query_end, path_position, }; diff --git a/crates/fs/src/fake_git_repo.rs b/crates/fs/src/fake_git_repo.rs index 40a292e0401df931cc17d04ed71219917292ab1f..73da63fd47b01c48a61220a19ec1436bdc774c91 100644 --- a/crates/fs/src/fake_git_repo.rs +++ b/crates/fs/src/fake_git_repo.rs @@ -1,7 +1,7 @@ -use crate::FakeFs; +use crate::{FakeFs, Fs}; use anyhow::{Context as _, Result}; use collections::{HashMap, HashSet}; -use futures::future::{self, BoxFuture}; +use futures::future::{self, BoxFuture, join_all}; use git::{ blame::Blame, repository::{ @@ -10,7 +10,7 @@ use git::{ }, status::{FileStatus, GitStatus, StatusCode, TrackedStatus, UnmergedStatus}, }; -use gpui::{AsyncApp, BackgroundExecutor}; +use gpui::{AsyncApp, BackgroundExecutor, SharedString}; use ignore::gitignore::GitignoreBuilder; use rope::Rope; use smol::future::FutureExt as _; @@ -356,13 +356,49 @@ impl GitRepository for FakeGitRepository { fn stage_paths( &self, - _paths: Vec<RepoPath>, + paths: Vec<RepoPath>, _env: Arc<HashMap<String, String>>, ) -> BoxFuture<'_, Result<()>> { - unimplemented!() + Box::pin(async move { + let contents = paths + .into_iter() + .map(|path| { + let abs_path = self.dot_git_path.parent().unwrap().join(&path); + Box::pin(async move { (path.clone(), self.fs.load(&abs_path).await.ok()) }) + }) + .collect::<Vec<_>>(); + let contents = join_all(contents).await; + self.with_state_async(true, move |state| { + for (path, content) in contents { + if let Some(content) = content { + state.index_contents.insert(path, content); + } else { + state.index_contents.remove(&path); + } + } + Ok(()) + }) + .await + }) } fn unstage_paths( + &self, + paths: Vec<RepoPath>, + _env: Arc<HashMap<String, String>>, + ) -> BoxFuture<'_, Result<()>> { + self.with_state_async(true, move |state| { + for path in paths { + match state.head_contents.get(&path) { + Some(content) => state.index_contents.insert(path, content.clone()), + None => state.index_contents.remove(&path), + }; + } + Ok(()) + }) + } + + fn stash_paths( &self, _paths: Vec<RepoPath>, _env: Arc<HashMap<String, String>>, @@ -370,6 +406,10 @@ impl GitRepository for FakeGitRepository { unimplemented!() } + fn stash_pop(&self, _env: Arc<HashMap<String, String>>) -> BoxFuture<'_, Result<()>> { + unimplemented!() + } + fn commit( &self, _message: gpui::SharedString, @@ -451,4 +491,8 @@ impl GitRepository for FakeGitRepository { ) -> BoxFuture<'_, Result<String>> { unimplemented!() } + + fn default_branch(&self) -> BoxFuture<'_, Result<Option<SharedString>>> { + unimplemented!() + } } diff --git a/crates/fuzzy/src/matcher.rs b/crates/fuzzy/src/matcher.rs index aff639053494caaca115156aaef9226028fd6cc6..e649d47dd646b80e312e2465f0929f630fecf81f 100644 --- a/crates/fuzzy/src/matcher.rs +++ b/crates/fuzzy/src/matcher.rs @@ -208,8 +208,15 @@ impl<'a> Matcher<'a> { return 1.0; } - let path_len = prefix.len() + path.len(); + let limit = self.last_positions[query_idx]; + let max_valid_index = (prefix.len() + path_lowercased.len()).saturating_sub(1); + let safe_limit = limit.min(max_valid_index); + + if path_idx > safe_limit { + return 0.0; + } + let path_len = prefix.len() + path.len(); if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] { return memoized; } @@ -218,16 +225,13 @@ impl<'a> Matcher<'a> { let mut best_position = 0; let query_char = self.lowercase_query[query_idx]; - let limit = self.last_positions[query_idx]; - - let max_valid_index = (prefix.len() + path_lowercased.len()).saturating_sub(1); - let safe_limit = limit.min(max_valid_index); let mut last_slash = 0; + for j in path_idx..=safe_limit { let extra_lowercase_chars_count = extra_lowercase_chars .iter() - .take_while(|(i, _)| i < &&j) + .take_while(|&(&i, _)| i < j) .map(|(_, increment)| increment) .sum::<usize>(); let j_regular = j - extra_lowercase_chars_count; @@ -236,10 +240,9 @@ impl<'a> Matcher<'a> { lowercase_prefix[j] } else { let path_index = j - prefix.len(); - if path_index < path_lowercased.len() { - path_lowercased[path_index] - } else { - continue; + match path_lowercased.get(path_index) { + Some(&char) => char, + None => continue, } }; let is_path_sep = path_char == MAIN_SEPARATOR; @@ -255,18 +258,16 @@ impl<'a> Matcher<'a> { #[cfg(target_os = "windows")] let need_to_score = query_char == path_char || (is_path_sep && query_char == '_'); if need_to_score { - let curr = if j_regular < prefix.len() { - prefix[j_regular] - } else { - path[j_regular - prefix.len()] + let curr = match prefix.get(j_regular) { + Some(&curr) => curr, + None => path[j_regular - prefix.len()], }; let mut char_score = 1.0; if j > path_idx { - let last = if j_regular - 1 < prefix.len() { - prefix[j_regular - 1] - } else { - path[j_regular - 1 - prefix.len()] + let last = match prefix.get(j_regular - 1) { + Some(&last) => last, + None => path[j_regular - 1 - prefix.len()], }; if last == MAIN_SEPARATOR { diff --git a/crates/git/src/git.rs b/crates/git/src/git.rs index 92cf58b2adafc692d8407982247d82f03d57fd78..553361e673a0de1842326f14dd2be976a63156eb 100644 --- a/crates/git/src/git.rs +++ b/crates/git/src/git.rs @@ -31,8 +31,10 @@ actions!( git, [ // per-hunk - /// Toggles the staged state of the hunk at cursor. + /// Toggles the staged state of the hunk or status entry at cursor. ToggleStaged, + /// Stage status entries between an anchor entry and the cursor. + StageRange, /// Stages the current hunk and moves to the next one. StageAndNext, /// Unstages the current hunk and moves to the next one. @@ -53,6 +55,10 @@ actions!( StageAll, /// Unstages all changes in the repository. UnstageAll, + /// Stashes all changes in the repository, including untracked files. + StashAll, + /// Pops the most recent stash. + StashPop, /// Restores all tracked files to their last committed state. RestoreTrackedFiles, /// Moves all untracked files to trash. @@ -75,6 +81,8 @@ actions!( Commit, /// Amends the last commit with staged changes. Amend, + /// Enable the --signoff option. + Signoff, /// Cancels the current git operation. Cancel, /// Expands the commit message editor. diff --git a/crates/git/src/repository.rs b/crates/git/src/repository.rs index 2ecd4bb894348cf3fc532a8473e43f0712e61700..518b6c4f4626e3701246cbec10e62f847c1074ce 100644 --- a/crates/git/src/repository.rs +++ b/crates/git/src/repository.rs @@ -96,6 +96,7 @@ impl Upstream { #[derive(Clone, Copy, Default)] pub struct CommitOptions { pub amend: bool, + pub signoff: bool, } #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] @@ -394,6 +395,14 @@ pub trait GitRepository: Send + Sync { env: Arc<HashMap<String, String>>, ) -> BoxFuture<'_, Result<()>>; + fn stash_paths( + &self, + paths: Vec<RepoPath>, + env: Arc<HashMap<String, String>>, + ) -> BoxFuture<'_, Result<()>>; + + fn stash_pop(&self, env: Arc<HashMap<String, String>>) -> BoxFuture<'_, Result<()>>; + fn push( &self, branch_name: String, @@ -454,6 +463,8 @@ pub trait GitRepository: Send + Sync { base_checkpoint: GitRepositoryCheckpoint, target_checkpoint: GitRepositoryCheckpoint, ) -> BoxFuture<'_, Result<String>>; + + fn default_branch(&self) -> BoxFuture<'_, Result<Option<SharedString>>>; } pub enum DiffType { @@ -835,14 +846,12 @@ impl GitRepository for RealGitRepository { .stdin(Stdio::piped()) .stdout(Stdio::piped()) .spawn()?; - child - .stdin - .take() - .unwrap() - .write_all(content.as_bytes()) - .await?; + let mut stdin = child.stdin.take().unwrap(); + stdin.write_all(content.as_bytes()).await?; + stdin.flush().await?; + drop(stdin); let output = child.output().await?.stdout; - let sha = String::from_utf8(output)?; + let sha = str::from_utf8(&output)?.trim(); log::debug!("indexing SHA: {sha}, path {path:?}"); @@ -860,6 +869,7 @@ impl GitRepository for RealGitRepository { String::from_utf8_lossy(&output.stderr) ); } else { + log::debug!("removing path {path:?} from the index"); let output = new_smol_command(&git_binary_path) .current_dir(&working_directory) .envs(env.iter()) @@ -910,6 +920,7 @@ impl GitRepository for RealGitRepository { for rev in &revs { write!(&mut stdin, "{rev}\n")?; } + stdin.flush()?; drop(stdin); let output = process.wait_with_output()?; @@ -1188,6 +1199,55 @@ impl GitRepository for RealGitRepository { .boxed() } + fn stash_paths( + &self, + paths: Vec<RepoPath>, + env: Arc<HashMap<String, String>>, + ) -> BoxFuture<'_, Result<()>> { + let working_directory = self.working_directory(); + self.executor + .spawn(async move { + let mut cmd = new_smol_command("git"); + cmd.current_dir(&working_directory?) + .envs(env.iter()) + .args(["stash", "push", "--quiet"]) + .arg("--include-untracked"); + + cmd.args(paths.iter().map(|p| p.as_ref())); + + let output = cmd.output().await?; + + anyhow::ensure!( + output.status.success(), + "Failed to stash:\n{}", + String::from_utf8_lossy(&output.stderr) + ); + Ok(()) + }) + .boxed() + } + + fn stash_pop(&self, env: Arc<HashMap<String, String>>) -> BoxFuture<'_, Result<()>> { + let working_directory = self.working_directory(); + self.executor + .spawn(async move { + let mut cmd = new_smol_command("git"); + cmd.current_dir(&working_directory?) + .envs(env.iter()) + .args(["stash", "pop"]); + + let output = cmd.output().await?; + + anyhow::ensure!( + output.status.success(), + "Failed to stash pop:\n{}", + String::from_utf8_lossy(&output.stderr) + ); + Ok(()) + }) + .boxed() + } + fn commit( &self, message: SharedString, @@ -1209,6 +1269,10 @@ impl GitRepository for RealGitRepository { cmd.arg("--amend"); } + if options.signoff { + cmd.arg("--signoff"); + } + if let Some((name, email)) = name_and_email { cmd.arg("--author").arg(&format!("{name} <{email}>")); } @@ -1545,6 +1609,37 @@ impl GitRepository for RealGitRepository { }) .boxed() } + + fn default_branch(&self) -> BoxFuture<'_, Result<Option<SharedString>>> { + let working_directory = self.working_directory(); + let git_binary_path = self.git_binary_path.clone(); + + let executor = self.executor.clone(); + self.executor + .spawn(async move { + let working_directory = working_directory?; + let git = GitBinary::new(git_binary_path, working_directory, executor); + + if let Ok(output) = git + .run(&["symbolic-ref", "refs/remotes/upstream/HEAD"]) + .await + { + let output = output + .strip_prefix("refs/remotes/upstream/") + .map(|s| SharedString::from(s.to_owned())); + return Ok(output); + } + + let output = git + .run(&["symbolic-ref", "refs/remotes/origin/HEAD"]) + .await?; + + Ok(output + .strip_prefix("refs/remotes/origin/") + .map(|s| SharedString::from(s.to_owned()))) + }) + .boxed() + } } fn git_status_args(path_prefixes: &[RepoPath]) -> Vec<OsString> { diff --git a/crates/git_hosting_providers/src/providers/github.rs b/crates/git_hosting_providers/src/providers/github.rs index 649b2f30aeef92be46317a0039c24738d1981bd5..30f8d058a7c46798209685930518f4b040dbe714 100644 --- a/crates/git_hosting_providers/src/providers/github.rs +++ b/crates/git_hosting_providers/src/providers/github.rs @@ -159,7 +159,11 @@ impl GitHostingProvider for Github { } let mut path_segments = url.path_segments()?; - let owner = path_segments.next()?; + let mut owner = path_segments.next()?; + if owner.is_empty() { + owner = path_segments.next()?; + } + let repo = path_segments.next()?.trim_end_matches(".git"); Some(ParsedGitRemote { @@ -244,6 +248,22 @@ mod tests { use super::*; + #[test] + fn test_remote_url_with_root_slash() { + let remote_url = "git@github.com:/zed-industries/zed"; + let parsed_remote = Github::public_instance() + .parse_remote_url(remote_url) + .unwrap(); + + assert_eq!( + parsed_remote, + ParsedGitRemote { + owner: "zed-industries".into(), + repo: "zed".into(), + } + ); + } + #[test] fn test_invalid_self_hosted_remote_url() { let remote_url = "git@github.com:zed-industries/zed.git"; diff --git a/crates/git_ui/Cargo.toml b/crates/git_ui/Cargo.toml index 6e04dcb656636e9323b6171c6b95012154f52a52..35f7a603544ae72134a2c6c1b08dcb8a0119b79b 100644 --- a/crates/git_ui/Cargo.toml +++ b/crates/git_ui/Cargo.toml @@ -23,6 +23,7 @@ askpass.workspace = true buffer_diff.workspace = true call.workspace = true chrono.workspace = true +cloud_llm_client.workspace = true collections.workspace = true command_palette_hooks.workspace = true component.workspace = true @@ -61,7 +62,6 @@ watch.workspace = true workspace-hack.workspace = true workspace.workspace = true zed_actions.workspace = true -zed_llm_client.workspace = true [target.'cfg(windows)'.dependencies] windows.workspace = true @@ -70,6 +70,7 @@ windows.workspace = true ctor.workspace = true editor = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] } +indoc.workspace = true pretty_assertions.workspace = true project = { workspace = true, features = ["test-support"] } settings = { workspace = true, features = ["test-support"] } diff --git a/crates/git_ui/src/branch_picker.rs b/crates/git_ui/src/branch_picker.rs index 9eac3ce5aff6dd440fd18fde3ea70042e71a4ce7..b74fa649b04ddc2ee5643965f3e0cdf0eb27e235 100644 --- a/crates/git_ui/src/branch_picker.rs +++ b/crates/git_ui/src/branch_picker.rs @@ -13,7 +13,7 @@ use project::git_store::Repository; use std::sync::Arc; use time::OffsetDateTime; use time_format::format_local_timestamp; -use ui::{HighlightedLabel, ListItem, ListItemSpacing, prelude::*}; +use ui::{HighlightedLabel, ListItem, ListItemSpacing, Tooltip, prelude::*}; use util::ResultExt; use workspace::notifications::DetachAndPromptErr; use workspace::{ModalView, Workspace}; @@ -90,11 +90,21 @@ impl BranchList { let all_branches_request = repository .clone() .map(|repository| repository.update(cx, |repository, _| repository.branches())); + let default_branch_request = repository + .clone() + .map(|repository| repository.update(cx, |repository, _| repository.default_branch())); cx.spawn_in(window, async move |this, cx| { let mut all_branches = all_branches_request .context("No active repository")? .await??; + let default_branch = default_branch_request + .context("No active repository")? + .await + .map(Result::ok) + .ok() + .flatten() + .flatten(); let all_branches = cx .background_spawn(async move { @@ -124,6 +134,7 @@ impl BranchList { this.update_in(cx, |this, window, cx| { this.picker.update(cx, |picker, cx| { + picker.delegate.default_branch = default_branch; picker.delegate.all_branches = Some(all_branches); picker.refresh(window, cx); }) @@ -169,6 +180,7 @@ impl Focusable for BranchList { impl Render for BranchList { fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { v_flex() + .key_context("GitBranchSelector") .w(self.width) .on_modifiers_changed(cx.listener(Self::handle_modifiers_changed)) .child(self.picker.clone()) @@ -192,6 +204,7 @@ struct BranchEntry { pub struct BranchListDelegate { matches: Vec<BranchEntry>, all_branches: Option<Vec<Branch>>, + default_branch: Option<SharedString>, repo: Option<Entity<Repository>>, style: BranchListStyle, selected_index: usize, @@ -206,6 +219,7 @@ impl BranchListDelegate { repo, style, all_branches: None, + default_branch: None, selected_index: 0, last_query: Default::default(), modifiers: Default::default(), @@ -214,6 +228,7 @@ impl BranchListDelegate { fn create_branch( &self, + from_branch: Option<SharedString>, new_branch_name: SharedString, window: &mut Window, cx: &mut Context<Picker<Self>>, @@ -223,6 +238,11 @@ impl BranchListDelegate { }; let new_branch_name = new_branch_name.to_string().replace(' ', "-"); cx.spawn(async move |_, cx| { + if let Some(based_branch) = from_branch { + repo.update(cx, |repo, _| repo.change_branch(based_branch.to_string()))? + .await??; + } + repo.update(cx, |repo, _| { repo.create_branch(new_branch_name.to_string()) })? @@ -353,12 +373,22 @@ impl PickerDelegate for BranchListDelegate { }) } - fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) { + fn confirm(&mut self, secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) { let Some(entry) = self.matches.get(self.selected_index()) else { return; }; if entry.is_new { - self.create_branch(entry.branch.name().to_owned().into(), window, cx); + let from_branch = if secondary { + self.default_branch.clone() + } else { + None + }; + self.create_branch( + from_branch, + entry.branch.name().to_owned().into(), + window, + cx, + ); return; } @@ -439,6 +469,28 @@ impl PickerDelegate for BranchListDelegate { }) .unwrap_or_else(|| (None, None)); + let icon = if let Some(default_branch) = self.default_branch.clone() + && entry.is_new + { + Some( + IconButton::new("branch-from-default", IconName::GitBranchSmall) + .on_click(cx.listener(move |this, _, window, cx| { + this.delegate.set_selected_index(ix, window, cx); + this.delegate.confirm(true, window, cx); + })) + .tooltip(move |window, cx| { + Tooltip::for_action( + format!("Create branch based off default: {default_branch}"), + &menu::SecondaryConfirm, + window, + cx, + ) + }), + ) + } else { + None + }; + let branch_name = if entry.is_new { h_flex() .gap_1() @@ -504,7 +556,8 @@ impl PickerDelegate for BranchListDelegate { .color(Color::Muted) })) }), - ), + ) + .end_slot::<IconButton>(icon), ) } diff --git a/crates/git_ui/src/commit_modal.rs b/crates/git_ui/src/commit_modal.rs index 15d0bec3139ca004e11982bc06b6f458b406be36..5dfa800ae59e5a5b5cc804a417d0d13400a13ba9 100644 --- a/crates/git_ui/src/commit_modal.rs +++ b/crates/git_ui/src/commit_modal.rs @@ -1,8 +1,10 @@ use crate::branch_picker::{self, BranchList}; use crate::git_panel::{GitPanel, commit_message_editor}; use git::repository::CommitOptions; -use git::{Amend, Commit, GenerateCommitMessage}; -use panel::{panel_button, panel_editor_style, panel_filled_button}; +use git::{Amend, Commit, GenerateCommitMessage, Signoff}; +use panel::{panel_button, panel_editor_style}; +use project::DisableAiSettings; +use settings::Settings; use ui::{ ContextMenu, KeybindingHint, PopoverMenu, PopoverMenuHandle, SplitButton, Tooltip, prelude::*, }; @@ -273,14 +275,53 @@ impl CommitModal { .child(Icon::new(IconName::ChevronDownSmall).size(IconSize::XSmall)), ), ) - .menu(move |window, cx| { - Some(ContextMenu::build(window, cx, |context_menu, _, _| { - context_menu - .when_some(keybinding_target.clone(), |el, keybinding_target| { - el.context(keybinding_target.clone()) - }) - .action("Amend", Amend.boxed_clone()) - })) + .menu({ + let git_panel_entity = self.git_panel.clone(); + move |window, cx| { + let git_panel = git_panel_entity.read(cx); + let amend_enabled = git_panel.amend_pending(); + let signoff_enabled = git_panel.signoff_enabled(); + let has_previous_commit = git_panel.head_commit(cx).is_some(); + + Some(ContextMenu::build(window, cx, |context_menu, _, _| { + context_menu + .when_some(keybinding_target.clone(), |el, keybinding_target| { + el.context(keybinding_target.clone()) + }) + .when(has_previous_commit, |this| { + this.toggleable_entry( + "Amend", + amend_enabled, + IconPosition::Start, + Some(Box::new(Amend)), + { + let git_panel = git_panel_entity.downgrade(); + move |_, cx| { + git_panel + .update(cx, |git_panel, cx| { + git_panel.toggle_amend_pending(cx); + }) + .ok(); + } + }, + ) + }) + .toggleable_entry( + "Signoff", + signoff_enabled, + IconPosition::Start, + Some(Box::new(Signoff)), + { + let git_panel = git_panel_entity.clone(); + move |window, cx| { + git_panel.update(cx, |git_panel, cx| { + git_panel.toggle_signoff_enabled(&Signoff, window, cx); + }) + } + }, + ) + })) + } }) .with_handle(self.commit_menu_handle.clone()) .anchor(Corner::TopRight) @@ -295,7 +336,7 @@ impl CommitModal { generate_commit_message, active_repo, is_amend_pending, - has_previous_commit, + is_signoff_enabled, ) = self.git_panel.update(cx, |git_panel, cx| { let (can_commit, tooltip) = git_panel.configure_commit_button(cx); let title = git_panel.commit_button_title(); @@ -303,10 +344,7 @@ impl CommitModal { let generate_commit_message = git_panel.render_generate_commit_message_button(cx); let active_repo = git_panel.active_repository.clone(); let is_amend_pending = git_panel.amend_pending(); - let has_previous_commit = active_repo - .as_ref() - .and_then(|repo| repo.read(cx).head_commit.as_ref()) - .is_some(); + let is_signoff_enabled = git_panel.signoff_enabled(); ( can_commit, tooltip, @@ -315,7 +353,7 @@ impl CommitModal { generate_commit_message, active_repo, is_amend_pending, - has_previous_commit, + is_signoff_enabled, ) }); @@ -396,126 +434,59 @@ impl CommitModal { .px_1() .gap_4() .children(close_kb_hint) - .when(is_amend_pending, |this| { - let focus_handle = focus_handle.clone(); - this.child( - panel_filled_button(commit_label) - .tooltip(move |window, cx| { - if can_commit { - Tooltip::for_action_in( - tooltip, - &Amend, - &focus_handle, - window, - cx, - ) - } else { - Tooltip::simple(tooltip, cx) - } - }) - .disabled(!can_commit) - .on_click(cx.listener(move |this, _: &ClickEvent, window, cx| { - telemetry::event!("Git Amended", source = "Git Modal"); - this.git_panel.update(cx, |git_panel, cx| { - git_panel.set_amend_pending(false, cx); - git_panel.commit_changes( - CommitOptions { amend: true }, - window, - cx, - ); - }); - cx.emit(DismissEvent); - })), + .child(SplitButton::new( + ui::ButtonLike::new_rounded_left(ElementId::Name( + format!("split-button-left-{}", commit_label).into(), + )) + .layer(ui::ElevationIndex::ModalSurface) + .size(ui::ButtonSize::Compact) + .child( + div() + .child(Label::new(commit_label).size(LabelSize::Small)) + .mr_0p5(), ) - }) - .when(!is_amend_pending, |this| { - this.when(has_previous_commit, |this| { - this.child(SplitButton::new( - ui::ButtonLike::new_rounded_left(ElementId::Name( - format!("split-button-left-{}", commit_label).into(), - )) - .layer(ui::ElevationIndex::ModalSurface) - .size(ui::ButtonSize::Compact) - .child( - div() - .child(Label::new(commit_label).size(LabelSize::Small)) - .mr_0p5(), + .on_click(cx.listener(move |this, _: &ClickEvent, window, cx| { + telemetry::event!("Git Committed", source = "Git Modal"); + this.git_panel.update(cx, |git_panel, cx| { + git_panel.commit_changes( + CommitOptions { + amend: is_amend_pending, + signoff: is_signoff_enabled, + }, + window, + cx, ) - .on_click(cx.listener(move |this, _: &ClickEvent, window, cx| { - telemetry::event!("Git Committed", source = "Git Modal"); - this.git_panel.update(cx, |git_panel, cx| { - git_panel.commit_changes( - CommitOptions { amend: false }, - window, - cx, - ) - }); - cx.emit(DismissEvent); - })) - .disabled(!can_commit) - .tooltip({ - let focus_handle = focus_handle.clone(); - move |window, cx| { - if can_commit { - Tooltip::with_meta_in( - tooltip, - Some(&git::Commit), - "git commit", - &focus_handle.clone(), - window, - cx, - ) - } else { - Tooltip::simple(tooltip, cx) - } - } - }), - self.render_git_commit_menu( - ElementId::Name( - format!("split-button-right-{}", commit_label).into(), - ), - Some(focus_handle.clone()), - ) - .into_any_element(), - )) - }) - .when(!has_previous_commit, |this| { - this.child( - panel_filled_button(commit_label) - .tooltip(move |window, cx| { - if can_commit { - Tooltip::with_meta_in( - tooltip, - Some(&git::Commit), - "git commit", - &focus_handle, - window, - cx, - ) - } else { - Tooltip::simple(tooltip, cx) - } - }) - .disabled(!can_commit) - .on_click(cx.listener( - move |this, _: &ClickEvent, window, cx| { - telemetry::event!( - "Git Committed", - source = "Git Modal" - ); - this.git_panel.update(cx, |git_panel, cx| { - git_panel.commit_changes( - CommitOptions { amend: false }, - window, - cx, - ) - }); - cx.emit(DismissEvent); - }, - )), - ) - }) - }), + }); + cx.emit(DismissEvent); + })) + .disabled(!can_commit) + .tooltip({ + let focus_handle = focus_handle.clone(); + move |window, cx| { + if can_commit { + Tooltip::with_meta_in( + tooltip, + Some(&git::Commit), + format!( + "git commit{}{}", + if is_amend_pending { " --amend" } else { "" }, + if is_signoff_enabled { " --signoff" } else { "" } + ), + &focus_handle.clone(), + window, + cx, + ) + } else { + Tooltip::simple(tooltip, cx) + } + } + }), + self.render_git_commit_menu( + ElementId::Name(format!("split-button-right-{}", commit_label).into()), + Some(focus_handle.clone()), + ) + .into_any_element(), + )), ) } @@ -534,7 +505,14 @@ impl CommitModal { } telemetry::event!("Git Committed", source = "Git Modal"); self.git_panel.update(cx, |git_panel, cx| { - git_panel.commit_changes(CommitOptions { amend: false }, window, cx) + git_panel.commit_changes( + CommitOptions { + amend: false, + signoff: git_panel.signoff_enabled(), + }, + window, + cx, + ) }); cx.emit(DismissEvent); } @@ -559,7 +537,14 @@ impl CommitModal { telemetry::event!("Git Amended", source = "Git Modal"); self.git_panel.update(cx, |git_panel, cx| { git_panel.set_amend_pending(false, cx); - git_panel.commit_changes(CommitOptions { amend: true }, window, cx); + git_panel.commit_changes( + CommitOptions { + amend: true, + signoff: git_panel.signoff_enabled(), + }, + window, + cx, + ); }); cx.emit(DismissEvent); } @@ -588,11 +573,13 @@ impl Render for CommitModal { .on_action(cx.listener(Self::dismiss)) .on_action(cx.listener(Self::commit)) .on_action(cx.listener(Self::amend)) - .on_action(cx.listener(|this, _: &GenerateCommitMessage, _, cx| { - this.git_panel.update(cx, |panel, cx| { - panel.generate_commit_message(cx); - }) - })) + .when(!DisableAiSettings::get_global(cx).disable_ai, |this| { + this.on_action(cx.listener(|this, _: &GenerateCommitMessage, _, cx| { + this.git_panel.update(cx, |panel, cx| { + panel.generate_commit_message(cx); + }) + })) + }) .on_action( cx.listener(|this, _: &zed_actions::git::Branch, window, cx| { this.toggle_branch_selector(window, cx); diff --git a/crates/git_ui/src/conflict_view.rs b/crates/git_ui/src/conflict_view.rs index 8eadf70830fbdf75e9077f3859d443f0aec12849..0bbb9411be9ef8a8e6b73d11cc4d01126570741f 100644 --- a/crates/git_ui/src/conflict_view.rs +++ b/crates/git_ui/src/conflict_view.rs @@ -11,10 +11,7 @@ use gpui::{ use language::{Anchor, Buffer, BufferId}; use project::{ConflictRegion, ConflictSet, ConflictSetUpdate, ProjectItem as _}; use std::{ops::Range, sync::Arc}; -use ui::{ - ActiveTheme, AnyElement, Element as _, StatefulInteractiveElement, Styled, - StyledTypography as _, Window, div, h_flex, rems, -}; +use ui::{ActiveTheme, Element as _, Styled, Window, prelude::*}; use util::{ResultExt as _, debug_panic, maybe}; pub(crate) struct ConflictAddon { @@ -300,7 +297,6 @@ fn conflicts_updated( move |cx| render_conflict_buttons(&conflict, excerpt_id, editor_handle.clone(), cx) }), priority: 0, - render_in_minimap: true, }) } let new_block_ids = editor.insert_blocks(blocks, None, cx); @@ -391,20 +387,15 @@ fn render_conflict_buttons( cx: &mut BlockContext, ) -> AnyElement { h_flex() + .id(cx.block_id) .h(cx.line_height) - .items_end() .ml(cx.margins.gutter.width) - .id(cx.block_id) - .gap_0p5() + .items_end() + .gap_1() + .bg(cx.theme().colors().editor_background) .child( - div() - .id("ours") - .px_1() - .child("Take Ours") - .rounded_t(rems(0.2)) - .text_ui_sm(cx) - .hover(|this| this.bg(cx.theme().colors().element_background)) - .cursor_pointer() + Button::new("head", "Use HEAD") + .label_size(LabelSize::Small) .on_click({ let editor = editor.clone(); let conflict = conflict.clone(); @@ -423,14 +414,8 @@ fn render_conflict_buttons( }), ) .child( - div() - .id("theirs") - .px_1() - .child("Take Theirs") - .rounded_t(rems(0.2)) - .text_ui_sm(cx) - .hover(|this| this.bg(cx.theme().colors().element_background)) - .cursor_pointer() + Button::new("origin", "Use Origin") + .label_size(LabelSize::Small) .on_click({ let editor = editor.clone(); let conflict = conflict.clone(); @@ -449,14 +434,8 @@ fn render_conflict_buttons( }), ) .child( - div() - .id("both") - .px_1() - .child("Take Both") - .rounded_t(rems(0.2)) - .text_ui_sm(cx) - .hover(|this| this.bg(cx.theme().colors().element_background)) - .cursor_pointer() + Button::new("both", "Use Both") + .label_size(LabelSize::Small) .on_click({ let editor = editor.clone(); let conflict = conflict.clone(); diff --git a/crates/git_ui/src/diff_view.rs b/crates/git_ui/src/file_diff_view.rs similarity index 89% rename from crates/git_ui/src/diff_view.rs rename to crates/git_ui/src/file_diff_view.rs index 9e03dd5f38c016873e34ab38eb1f8dfd717c886a..2f8a744ed893761f6491f23a31e19bfb55a4db62 100644 --- a/crates/git_ui/src/diff_view.rs +++ b/crates/git_ui/src/file_diff_view.rs @@ -1,4 +1,4 @@ -//! DiffView provides a UI for displaying differences between two buffers. +//! FileDiffView provides a UI for displaying differences between two buffers. use anyhow::Result; use buffer_diff::{BufferDiff, BufferDiffSnapshot}; @@ -25,7 +25,7 @@ use workspace::{ searchable::SearchableItemHandle, }; -pub struct DiffView { +pub struct FileDiffView { editor: Entity<Editor>, old_buffer: Entity<Buffer>, new_buffer: Entity<Buffer>, @@ -35,7 +35,7 @@ pub struct DiffView { const RECALCULATE_DIFF_DEBOUNCE: Duration = Duration::from_millis(250); -impl DiffView { +impl FileDiffView { pub fn open( old_path: PathBuf, new_path: PathBuf, @@ -57,7 +57,7 @@ impl DiffView { workspace.update_in(cx, |workspace, window, cx| { let diff_view = cx.new(|cx| { - DiffView::new( + FileDiffView::new( old_buffer, new_buffer, buffer_diff, @@ -190,15 +190,15 @@ async fn build_buffer_diff( }) } -impl EventEmitter<EditorEvent> for DiffView {} +impl EventEmitter<EditorEvent> for FileDiffView {} -impl Focusable for DiffView { +impl Focusable for FileDiffView { fn focus_handle(&self, cx: &App) -> FocusHandle { self.editor.focus_handle(cx) } } -impl Item for DiffView { +impl Item for FileDiffView { type Event = EditorEvent; fn tab_icon(&self, _window: &Window, _cx: &App) -> Option<Icon> { @@ -216,48 +216,37 @@ impl Item for DiffView { } fn tab_content_text(&self, _detail: usize, cx: &App) -> SharedString { - let old_filename = self - .old_buffer - .read(cx) - .file() - .and_then(|file| { - Some( - file.full_path(cx) - .file_name()? - .to_string_lossy() - .to_string(), - ) - }) - .unwrap_or_else(|| "untitled".into()); - let new_filename = self - .new_buffer - .read(cx) - .file() - .and_then(|file| { - Some( - file.full_path(cx) - .file_name()? - .to_string_lossy() - .to_string(), - ) - }) - .unwrap_or_else(|| "untitled".into()); + let title_text = |buffer: &Entity<Buffer>| { + buffer + .read(cx) + .file() + .and_then(|file| { + Some( + file.full_path(cx) + .file_name()? + .to_string_lossy() + .to_string(), + ) + }) + .unwrap_or_else(|| "untitled".into()) + }; + let old_filename = title_text(&self.old_buffer); + let new_filename = title_text(&self.new_buffer); + format!("{old_filename} ↔ {new_filename}").into() } fn tab_tooltip_text(&self, cx: &App) -> Option<ui::SharedString> { - let old_path = self - .old_buffer - .read(cx) - .file() - .map(|file| file.full_path(cx).compact().to_string_lossy().to_string()) - .unwrap_or_else(|| "untitled".into()); - let new_path = self - .new_buffer - .read(cx) - .file() - .map(|file| file.full_path(cx).compact().to_string_lossy().to_string()) - .unwrap_or_else(|| "untitled".into()); + let path = |buffer: &Entity<Buffer>| { + buffer + .read(cx) + .file() + .map(|file| file.full_path(cx).compact().to_string_lossy().to_string()) + .unwrap_or_else(|| "untitled".into()) + }; + let old_path = path(&self.old_buffer); + let new_path = path(&self.new_buffer); + Some(format!("{old_path} ↔ {new_path}").into()) } @@ -363,7 +352,7 @@ impl Item for DiffView { } } -impl Render for DiffView { +impl Render for FileDiffView { fn render(&mut self, _: &mut Window, _: &mut Context<Self>) -> impl IntoElement { self.editor.clone() } @@ -407,16 +396,16 @@ mod tests { ) .await; - let project = Project::test(fs.clone(), ["/test".as_ref()], cx).await; + let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await; let (workspace, mut cx) = cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); let diff_view = workspace .update_in(cx, |workspace, window, cx| { - DiffView::open( - PathBuf::from(path!("/test/old_file.txt")), - PathBuf::from(path!("/test/new_file.txt")), + FileDiffView::open( + path!("/test/old_file.txt").into(), + path!("/test/new_file.txt").into(), workspace, window, cx, @@ -510,6 +499,21 @@ mod tests { ", ), ); + + diff_view.read_with(cx, |diff_view, cx| { + assert_eq!( + diff_view.tab_content_text(0, cx), + "old_file.txt ↔ new_file.txt" + ); + assert_eq!( + diff_view.tab_tooltip_text(cx).unwrap(), + format!( + "{} ↔ {}", + path!("test/old_file.txt"), + path!("test/new_file.txt") + ) + ); + }) } #[gpui::test] @@ -533,7 +537,7 @@ mod tests { let diff_view = workspace .update_in(cx, |workspace, window, cx| { - DiffView::open( + FileDiffView::open( PathBuf::from(path!("/test/old_file.txt")), PathBuf::from(path!("/test/new_file.txt")), workspace, diff --git a/crates/git_ui/src/git_panel.rs b/crates/git_ui/src/git_panel.rs index c50e2f8912ef5b4570a7141378f55701151f3f71..44222b829917cc7d3160dd618dbb6bd16ad1fe98 100644 --- a/crates/git_ui/src/git_panel.rs +++ b/crates/git_ui/src/git_panel.rs @@ -25,15 +25,17 @@ use git::repository::{ UpstreamTrackingStatus, get_git_committer, }; use git::status::StageStatus; -use git::{Amend, ToggleStaged, repository::RepoPath, status::FileStatus}; -use git::{ExpandCommitEditor, RestoreTrackedFiles, StageAll, TrashUntrackedFiles, UnstageAll}; +use git::{Amend, Signoff, ToggleStaged, repository::RepoPath, status::FileStatus}; +use git::{ + ExpandCommitEditor, RestoreTrackedFiles, StageAll, StashAll, StashPop, TrashUntrackedFiles, + UnstageAll, +}; use gpui::{ Action, Animation, AnimationExt as _, AsyncApp, AsyncWindowContext, Axis, ClickEvent, Corner, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, KeyContext, - ListHorizontalSizingBehavior, ListSizingBehavior, Modifiers, ModifiersChangedEvent, - MouseButton, MouseDownEvent, Point, PromptLevel, ScrollStrategy, Subscription, Task, - Transformation, UniformListScrollHandle, WeakEntity, actions, anchored, deferred, percentage, - uniform_list, + ListHorizontalSizingBehavior, ListSizingBehavior, MouseButton, MouseDownEvent, Point, + PromptLevel, ScrollStrategy, Subscription, Task, Transformation, UniformListScrollHandle, + WeakEntity, actions, anchored, deferred, percentage, uniform_list, }; use itertools::Itertools; use language::{Buffer, File}; @@ -48,13 +50,12 @@ use panel::{ PanelHeader, panel_button, panel_editor_container, panel_editor_style, panel_filled_button, panel_icon_button, }; -use project::git_store::RepositoryEvent; use project::{ - Fs, Project, ProjectPath, - git_store::{GitStoreEvent, Repository}, + DisableAiSettings, Fs, Project, ProjectPath, + git_store::{GitStoreEvent, Repository, RepositoryEvent, RepositoryId}, }; use serde::{Deserialize, Serialize}; -use settings::{Settings as _, SettingsStore}; +use settings::{Settings, SettingsStore}; use std::future::Future; use std::ops::Range; use std::path::{Path, PathBuf}; @@ -62,17 +63,18 @@ use std::{collections::HashSet, sync::Arc, time::Duration, usize}; use strum::{IntoEnumIterator, VariantNames}; use time::OffsetDateTime; use ui::{ - Checkbox, ContextMenu, ElevationIndex, PopoverMenu, Scrollbar, ScrollbarState, SplitButton, - Tooltip, prelude::*, + Checkbox, ContextMenu, ElevationIndex, IconPosition, Label, LabelSize, PopoverMenu, Scrollbar, + ScrollbarState, SplitButton, Tooltip, prelude::*, }; use util::{ResultExt, TryFutureExt, maybe}; +use workspace::SERIALIZATION_THROTTLE_TIME; +use cloud_llm_client::CompletionIntent; use workspace::{ Workspace, dock::{DockPosition, Panel, PanelEvent}, notifications::{DetachAndPromptErr, ErrorMessagePrompt, NotificationId}, }; -use zed_llm_client::CompletionIntent; actions!( git_panel, @@ -139,6 +141,13 @@ fn git_panel_context_menu( UnstageAll.boxed_clone(), ) .separator() + .action_disabled_when( + !(state.has_new_changes || state.has_tracked_changes), + "Stash All", + StashAll.boxed_clone(), + ) + .action("Stash Pop", StashPop.boxed_clone()) + .separator() .action("Open Diff", project_diff::Diff.boxed_clone()) .separator() .action_disabled_when( @@ -175,6 +184,10 @@ pub enum Event { #[derive(Serialize, Deserialize)] struct SerializedGitPanel { width: Option<Pixels>, + #[serde(default)] + amend_pending: bool, + #[serde(default)] + signoff_enabled: bool, } #[derive(Debug, PartialEq, Eq, Clone, Copy)] @@ -212,14 +225,14 @@ impl GitHeaderEntry { #[derive(Debug, PartialEq, Eq, Clone)] enum GitListEntry { - GitStatusEntry(GitStatusEntry), + Status(GitStatusEntry), Header(GitHeaderEntry), } impl GitListEntry { fn status_entry(&self) -> Option<&GitStatusEntry> { match self { - GitListEntry::GitStatusEntry(entry) => Some(entry), + GitListEntry::Status(entry) => Some(entry), _ => None, } } @@ -323,7 +336,6 @@ pub struct GitPanel { pub(crate) commit_editor: Entity<Editor>, conflicted_count: usize, conflicted_staged_count: usize, - current_modifiers: Modifiers, add_coauthors: bool, generate_commit_message_task: Option<Task<Option<()>>>, entries: Vec<GitListEntry>, @@ -339,7 +351,8 @@ pub struct GitPanel { pending: Vec<PendingOperation>, pending_commit: Option<Task<()>>, amend_pending: bool, - pending_serialization: Task<Option<()>>, + signoff_enabled: bool, + pending_serialization: Task<()>, pub(crate) project: Entity<Project>, scroll_handle: UniformListScrollHandle, max_width_item_index: Option<usize>, @@ -355,9 +368,16 @@ pub struct GitPanel { show_placeholders: bool, local_committer: Option<GitCommitter>, local_committer_task: Option<Task<()>>, + bulk_staging: Option<BulkStaging>, _settings_subscription: Subscription, } +#[derive(Clone, Debug, PartialEq, Eq)] +struct BulkStaging { + repo_id: RepositoryId, + anchor: RepoPath, +} + const MAX_PANEL_EDITOR_LINES: usize = 6; pub(crate) fn commit_message_editor( @@ -368,6 +388,9 @@ pub(crate) fn commit_message_editor( window: &mut Window, cx: &mut Context<Editor>, ) -> Editor { + project.update(cx, |this, cx| { + this.mark_buffer_as_non_searchable(commit_message_buffer.read(cx).remote_id(), cx); + }); let buffer = cx.new(|cx| MultiBuffer::singleton(commit_message_buffer, cx)); let max_lines = if in_panel { MAX_PANEL_EDITOR_LINES } else { 18 }; let mut commit_editor = Editor::new( @@ -453,9 +476,14 @@ impl GitPanel { }; let mut assistant_enabled = AgentSettings::get_global(cx).enabled; + let mut was_ai_disabled = DisableAiSettings::get_global(cx).disable_ai; let _settings_subscription = cx.observe_global::<SettingsStore>(move |_, cx| { - if assistant_enabled != AgentSettings::get_global(cx).enabled { + let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai; + if assistant_enabled != AgentSettings::get_global(cx).enabled + || was_ai_disabled != is_ai_disabled + { assistant_enabled = AgentSettings::get_global(cx).enabled; + was_ai_disabled = is_ai_disabled; cx.notify(); } }); @@ -497,7 +525,6 @@ impl GitPanel { commit_editor, conflicted_count: 0, conflicted_staged_count: 0, - current_modifiers: window.modifiers(), add_coauthors: true, generate_commit_message_task: None, entries: Vec::new(), @@ -508,7 +535,8 @@ impl GitPanel { pending: Vec::new(), pending_commit: None, amend_pending: false, - pending_serialization: Task::ready(None), + signoff_enabled: false, + pending_serialization: Task::ready(()), single_staged_entry: None, single_tracked_entry: None, project, @@ -529,6 +557,7 @@ impl GitPanel { entry_count: 0, horizontal_scrollbar, vertical_scrollbar, + bulk_staging: None, _settings_subscription, }; @@ -685,20 +714,54 @@ impl GitPanel { cx.notify(); } + fn serialization_key(workspace: &Workspace) -> Option<String> { + workspace + .database_id() + .map(|id| i64::from(id).to_string()) + .or(workspace.session_id()) + .map(|id| format!("{}-{:?}", GIT_PANEL_KEY, id)) + } + fn serialize(&mut self, cx: &mut Context<Self>) { let width = self.width; - self.pending_serialization = cx.background_spawn( - async move { - KEY_VALUE_STORE - .write_kvp( - GIT_PANEL_KEY.into(), - serde_json::to_string(&SerializedGitPanel { width })?, - ) - .await?; - anyhow::Ok(()) - } - .log_err(), - ); + let amend_pending = self.amend_pending; + let signoff_enabled = self.signoff_enabled; + + self.pending_serialization = cx.spawn(async move |git_panel, cx| { + cx.background_executor() + .timer(SERIALIZATION_THROTTLE_TIME) + .await; + let Some(serialization_key) = git_panel + .update(cx, |git_panel, cx| { + git_panel + .workspace + .read_with(cx, |workspace, _| Self::serialization_key(workspace)) + .ok() + .flatten() + }) + .ok() + .flatten() + else { + return; + }; + cx.background_spawn( + async move { + KEY_VALUE_STORE + .write_kvp( + serialization_key, + serde_json::to_string(&SerializedGitPanel { + width, + amend_pending, + signoff_enabled, + })?, + ) + .await?; + anyhow::Ok(()) + } + .log_err(), + ) + .await; + }); } pub(crate) fn set_modal_open(&mut self, open: bool, cx: &mut Context<Self>) { @@ -735,16 +798,6 @@ impl GitPanel { } } - fn handle_modifiers_changed( - &mut self, - event: &ModifiersChangedEvent, - _: &mut Window, - cx: &mut Context<Self>, - ) { - self.current_modifiers = event.modifiers; - cx.notify(); - } - fn scroll_to_selected_entry(&mut self, cx: &mut Context<Self>) { if let Some(selected_entry) = self.selected_entry { self.scroll_handle @@ -1265,10 +1318,18 @@ impl GitPanel { return; }; let (stage, repo_paths) = match entry { - GitListEntry::GitStatusEntry(status_entry) => { + GitListEntry::Status(status_entry) => { if status_entry.status.staging().is_fully_staged() { + if let Some(op) = self.bulk_staging.clone() + && op.anchor == status_entry.repo_path + { + self.bulk_staging = None; + } + (false, vec![status_entry.clone()]) } else { + self.set_bulk_staging_anchor(status_entry.repo_path.clone(), cx); + (true, vec![status_entry.clone()]) } } @@ -1362,6 +1423,52 @@ impl GitPanel { self.tracked_staged_count + self.new_staged_count + self.conflicted_staged_count } + pub fn stash_pop(&mut self, _: &StashPop, _window: &mut Window, cx: &mut Context<Self>) { + let Some(active_repository) = self.active_repository.clone() else { + return; + }; + + cx.spawn({ + async move |this, cx| { + let stash_task = active_repository + .update(cx, |repo, cx| repo.stash_pop(cx))? + .await; + this.update(cx, |this, cx| { + stash_task + .map_err(|e| { + this.show_error_toast("stash pop", e, cx); + }) + .ok(); + cx.notify(); + }) + } + }) + .detach(); + } + + pub fn stash_all(&mut self, _: &StashAll, _window: &mut Window, cx: &mut Context<Self>) { + let Some(active_repository) = self.active_repository.clone() else { + return; + }; + + cx.spawn({ + async move |this, cx| { + let stash_task = active_repository + .update(cx, |repo, cx| repo.stash_all(cx))? + .await; + this.update(cx, |this, cx| { + stash_task + .map_err(|e| { + this.show_error_toast("stash", e, cx); + }) + .ok(); + cx.notify(); + }) + } + }) + .detach(); + } + pub fn commit_message_buffer(&self, cx: &App) -> Entity<Buffer> { self.commit_editor .read(cx) @@ -1383,6 +1490,13 @@ impl GitPanel { } } + fn stage_range(&mut self, _: &git::StageRange, _window: &mut Window, cx: &mut Context<Self>) { + let Some(index) = self.selected_entry else { + return; + }; + self.stage_bulk(index, cx); + } + fn stage_selected(&mut self, _: &git::StageFile, _window: &mut Window, cx: &mut Context<Self>) { let Some(selected_entry) = self.get_selected_entry() else { return; @@ -1422,7 +1536,14 @@ impl GitPanel { .contains_focused(window, cx) { telemetry::event!("Git Committed", source = "Git Panel"); - self.commit_changes(CommitOptions { amend: false }, window, cx) + self.commit_changes( + CommitOptions { + amend: false, + signoff: self.signoff_enabled, + }, + window, + cx, + ) } else { cx.propagate(); } @@ -1434,19 +1555,21 @@ impl GitPanel { .focus_handle(cx) .contains_focused(window, cx) { - if self - .active_repository - .as_ref() - .and_then(|repo| repo.read(cx).head_commit.as_ref()) - .is_some() - { + if self.head_commit(cx).is_some() { if !self.amend_pending { self.set_amend_pending(true, cx); self.load_last_commit_message_if_empty(cx); } else { telemetry::event!("Git Amended", source = "Git Panel"); self.set_amend_pending(false, cx); - self.commit_changes(CommitOptions { amend: true }, window, cx); + self.commit_changes( + CommitOptions { + amend: true, + signoff: self.signoff_enabled, + }, + window, + cx, + ); } } } else { @@ -1454,21 +1577,21 @@ impl GitPanel { } } + pub fn head_commit(&self, cx: &App) -> Option<CommitDetails> { + self.active_repository + .as_ref() + .and_then(|repo| repo.read(cx).head_commit.as_ref()) + .cloned() + } + pub fn load_last_commit_message_if_empty(&mut self, cx: &mut Context<Self>) { if !self.commit_editor.read(cx).is_empty(cx) { return; } - let Some(active_repository) = self.active_repository.as_ref() else { - return; - }; - let Some(recent_sha) = active_repository - .read(cx) - .head_commit - .as_ref() - .map(|commit| commit.sha.to_string()) - else { + let Some(head_commit) = self.head_commit(cx) else { return; }; + let recent_sha = head_commit.sha.to_string(); let detail_task = self.load_commit_details(recent_sha, cx); cx.spawn(async move |this, cx| { if let Ok(message) = detail_task.await.map(|detail| detail.message) { @@ -1485,12 +1608,6 @@ impl GitPanel { .detach(); } - fn cancel(&mut self, _: &git::Cancel, _: &mut Window, cx: &mut Context<Self>) { - if self.amend_pending { - self.set_amend_pending(false, cx); - } - } - fn custom_or_suggested_commit_message( &self, window: &mut Window, @@ -1752,7 +1869,7 @@ impl GitPanel { /// Generates a commit message using an LLM. pub fn generate_commit_message(&mut self, cx: &mut Context<Self>) { - if !self.can_commit() { + if !self.can_commit() || DisableAiSettings::get_global(cx).disable_ai { return; } @@ -2297,7 +2414,7 @@ impl GitPanel { .committer_name .clone() .or_else(|| participant.user.name.clone()) - .unwrap_or_else(|| participant.user.github_login.clone()); + .unwrap_or_else(|| participant.user.github_login.clone().to_string()); new_co_authors.push((name.clone(), email.clone())) } } @@ -2317,7 +2434,7 @@ impl GitPanel { .name .clone() .or_else(|| user.name.clone()) - .unwrap_or_else(|| user.github_login.clone()); + .unwrap_or_else(|| user.github_login.clone().to_string()); Some((name, email)) } @@ -2449,6 +2566,11 @@ impl GitPanel { } fn update_visible_entries(&mut self, cx: &mut Context<Self>) { + let bulk_staging = self.bulk_staging.take(); + let last_staged_path_prev_index = bulk_staging + .as_ref() + .and_then(|op| self.entry_by_path(&op.anchor, cx)); + self.entries.clear(); self.single_staged_entry.take(); self.single_tracked_entry.take(); @@ -2465,7 +2587,7 @@ impl GitPanel { let mut changed_entries = Vec::new(); let mut new_entries = Vec::new(); let mut conflict_entries = Vec::new(); - let mut last_staged = None; + let mut single_staged_entry = None; let mut staged_count = 0; let mut max_width_item: Option<(RepoPath, usize)> = None; @@ -2503,7 +2625,7 @@ impl GitPanel { if staging.has_staged() { staged_count += 1; - last_staged = Some(entry.clone()); + single_staged_entry = Some(entry.clone()); } let width_estimate = Self::item_width_estimate( @@ -2534,27 +2656,27 @@ impl GitPanel { let mut pending_staged_count = 0; let mut last_pending_staged = None; - let mut pending_status_for_last_staged = None; + let mut pending_status_for_single_staged = None; for pending in self.pending.iter() { if pending.target_status == TargetStatus::Staged { pending_staged_count += pending.entries.len(); last_pending_staged = pending.entries.iter().next().cloned(); } - if let Some(last_staged) = &last_staged { + if let Some(single_staged) = &single_staged_entry { if pending .entries .iter() - .any(|entry| entry.repo_path == last_staged.repo_path) + .any(|entry| entry.repo_path == single_staged.repo_path) { - pending_status_for_last_staged = Some(pending.target_status); + pending_status_for_single_staged = Some(pending.target_status); } } } if conflict_entries.len() == 0 && staged_count == 1 && pending_staged_count == 0 { - match pending_status_for_last_staged { + match pending_status_for_single_staged { Some(TargetStatus::Staged) | None => { - self.single_staged_entry = last_staged; + self.single_staged_entry = single_staged_entry; } _ => {} } @@ -2570,11 +2692,8 @@ impl GitPanel { self.entries.push(GitListEntry::Header(GitHeaderEntry { header: Section::Conflict, })); - self.entries.extend( - conflict_entries - .into_iter() - .map(GitListEntry::GitStatusEntry), - ); + self.entries + .extend(conflict_entries.into_iter().map(GitListEntry::Status)); } if changed_entries.len() > 0 { @@ -2583,31 +2702,39 @@ impl GitPanel { header: Section::Tracked, })); } - self.entries.extend( - changed_entries - .into_iter() - .map(GitListEntry::GitStatusEntry), - ); + self.entries + .extend(changed_entries.into_iter().map(GitListEntry::Status)); } if new_entries.len() > 0 { self.entries.push(GitListEntry::Header(GitHeaderEntry { header: Section::New, })); self.entries - .extend(new_entries.into_iter().map(GitListEntry::GitStatusEntry)); + .extend(new_entries.into_iter().map(GitListEntry::Status)); } if let Some((repo_path, _)) = max_width_item { self.max_width_item_index = self.entries.iter().position(|entry| match entry { - GitListEntry::GitStatusEntry(git_status_entry) => { - git_status_entry.repo_path == repo_path - } + GitListEntry::Status(git_status_entry) => git_status_entry.repo_path == repo_path, GitListEntry::Header(_) => false, }); } self.update_counts(repo); + let bulk_staging_anchor_new_index = bulk_staging + .as_ref() + .filter(|op| op.repo_id == repo.id) + .and_then(|op| self.entry_by_path(&op.anchor, cx)); + if bulk_staging_anchor_new_index == last_staged_path_prev_index + && let Some(index) = bulk_staging_anchor_new_index + && let Some(entry) = self.entries.get(index) + && let Some(entry) = entry.status_entry() + && self.entry_staging(entry) == StageStatus::Staged + { + self.bulk_staging = bulk_staging; + } + self.select_first_entry_if_none(cx); let suggested_commit_message = self.suggest_commit_message(cx); @@ -2772,7 +2899,9 @@ impl GitPanel { let status_toast = StatusToast::new(message, cx, move |this, _cx| { use remote_output::SuccessStyle::*; match style { - Toast { .. } => this, + Toast { .. } => { + this.icon(ToastIcon::new(IconName::GitBranchSmall).color(Color::Muted)) + } ToastWithLog { output } => this .icon(ToastIcon::new(IconName::GitBranchSmall).color(Color::Muted)) .action("View Log", move |window, cx| { @@ -2785,9 +2914,9 @@ impl GitPanel { }) .ok(); }), - PushPrLink { link } => this + PushPrLink { text, link } => this .icon(ToastIcon::new(IconName::GitBranchSmall).color(Color::Muted)) - .action("Open Pull Request", move |_, cx| cx.open_url(&link)), + .action(text, move |_, cx| cx.open_url(&link)), } }); workspace.toggle_status_toast(status_toast, cx) @@ -2983,14 +3112,45 @@ impl GitPanel { .child(Icon::new(IconName::ChevronDownSmall).size(IconSize::XSmall)), ), ) - .menu(move |window, cx| { - Some(ContextMenu::build(window, cx, |context_menu, _, _| { - context_menu - .when_some(keybinding_target.clone(), |el, keybinding_target| { - el.context(keybinding_target.clone()) - }) - .action("Amend", Amend.boxed_clone()) - })) + .menu({ + let git_panel = cx.entity(); + let has_previous_commit = self.head_commit(cx).is_some(); + let amend = self.amend_pending(); + let signoff = self.signoff_enabled; + + move |window, cx| { + Some(ContextMenu::build(window, cx, |context_menu, _, _| { + context_menu + .when_some(keybinding_target.clone(), |el, keybinding_target| { + el.context(keybinding_target.clone()) + }) + .when(has_previous_commit, |this| { + this.toggleable_entry( + "Amend", + amend, + IconPosition::Start, + Some(Box::new(Amend)), + { + let git_panel = git_panel.downgrade(); + move |_, cx| { + git_panel + .update(cx, |git_panel, cx| { + git_panel.toggle_amend_pending(cx); + }) + .ok(); + } + }, + ) + }) + .toggleable_entry( + "Signoff", + signoff, + IconPosition::Start, + Some(Box::new(Signoff)), + move |window, cx| window.dispatch_action(Box::new(Signoff), cx), + ) + })) + } }) .anchor(Corner::TopRight) } @@ -3167,7 +3327,6 @@ impl GitPanel { let editor_is_long = self.commit_editor.update(cx, |editor, cx| { editor.max_point(cx).row().0 >= MAX_PANEL_EDITOR_LINES as u32 }); - let has_previous_commit = head_commit.is_some(); let footer = v_flex() .child(PanelRepoFooter::new( @@ -3211,7 +3370,7 @@ impl GitPanel { h_flex() .gap_0p5() .children(enable_coauthors) - .child(self.render_commit_button(has_previous_commit, cx)), + .child(self.render_commit_button(cx)), ), ) .child( @@ -3260,14 +3419,12 @@ impl GitPanel { Some(footer) } - fn render_commit_button( - &self, - has_previous_commit: bool, - cx: &mut Context<Self>, - ) -> impl IntoElement { + fn render_commit_button(&self, cx: &mut Context<Self>) -> impl IntoElement { let (can_commit, tooltip) = self.configure_commit_button(cx); let title = self.commit_button_title(); let commit_tooltip_focus_handle = self.commit_editor.focus_handle(cx); + let amend = self.amend_pending(); + let signoff = self.signoff_enabled; div() .id("commit-wrapper") @@ -3276,164 +3433,87 @@ impl GitPanel { *hovered && !this.has_staged_changes() && !this.has_unstaged_conflicts(); cx.notify() })) - .when(self.amend_pending, { - |this| { - this.h_flex() - .gap_1() - .child( - panel_filled_button("Cancel") - .tooltip({ - let handle = commit_tooltip_focus_handle.clone(); - move |window, cx| { - Tooltip::for_action_in( - "Cancel amend", - &git::Cancel, - &handle, - window, - cx, - ) - } - }) - .on_click(move |_, window, cx| { - window.dispatch_action(Box::new(git::Cancel), cx); - }), - ) - .child( - panel_filled_button(title) - .tooltip({ - let handle = commit_tooltip_focus_handle.clone(); - move |window, cx| { - if can_commit { - Tooltip::for_action_in( - tooltip, &Amend, &handle, window, cx, - ) - } else { - Tooltip::simple(tooltip, cx) - } - } - }) - .disabled(!can_commit || self.modal_open) - .on_click({ - let git_panel = cx.weak_entity(); - move |_, window, cx| { - telemetry::event!("Git Amended", source = "Git Panel"); - git_panel - .update(cx, |git_panel, cx| { - git_panel.set_amend_pending(false, cx); - git_panel.commit_changes( - CommitOptions { amend: true }, - window, - cx, - ); - }) - .ok(); - } - }), - ) - } - }) - .when(!self.amend_pending, |this| { - this.when(has_previous_commit, |this| { - this.child(SplitButton::new( - ui::ButtonLike::new_rounded_left(ElementId::Name( - format!("split-button-left-{}", title).into(), - )) - .layer(ui::ElevationIndex::ModalSurface) - .size(ui::ButtonSize::Compact) - .child( - div() - .child(Label::new(title).size(LabelSize::Small)) - .mr_0p5(), - ) - .on_click({ - let git_panel = cx.weak_entity(); - move |_, window, cx| { - telemetry::event!("Git Committed", source = "Git Panel"); - git_panel - .update(cx, |git_panel, cx| { - git_panel.commit_changes( - CommitOptions { amend: false }, - window, - cx, - ); - }) - .ok(); - } - }) - .disabled(!can_commit || self.modal_open) - .tooltip({ - let handle = commit_tooltip_focus_handle.clone(); - move |window, cx| { - if can_commit { - Tooltip::with_meta_in( - tooltip, - Some(&git::Commit), - "git commit", - &handle.clone(), - window, - cx, - ) - } else { - Tooltip::simple(tooltip, cx) - } - } - }), - self.render_git_commit_menu( - ElementId::Name(format!("split-button-right-{}", title).into()), - Some(commit_tooltip_focus_handle.clone()), - cx, - ) - .into_any_element(), - )) - }) - .when(!has_previous_commit, |this| { - this.child( - panel_filled_button(title) - .tooltip(move |window, cx| { - if can_commit { - Tooltip::with_meta_in( - tooltip, - Some(&git::Commit), - "git commit", - &commit_tooltip_focus_handle, - window, - cx, - ) - } else { - Tooltip::simple(tooltip, cx) - } + .child(SplitButton::new( + ui::ButtonLike::new_rounded_left(ElementId::Name( + format!("split-button-left-{}", title).into(), + )) + .layer(ui::ElevationIndex::ModalSurface) + .size(ui::ButtonSize::Compact) + .child( + div() + .child(Label::new(title).size(LabelSize::Small)) + .mr_0p5(), + ) + .on_click({ + let git_panel = cx.weak_entity(); + move |_, window, cx| { + telemetry::event!("Git Committed", source = "Git Panel"); + git_panel + .update(cx, |git_panel, cx| { + git_panel.set_amend_pending(false, cx); + git_panel.commit_changes( + CommitOptions { amend, signoff }, + window, + cx, + ); }) - .disabled(!can_commit || self.modal_open) - .on_click({ - let git_panel = cx.weak_entity(); - move |_, window, cx| { - telemetry::event!("Git Committed", source = "Git Panel"); - git_panel - .update(cx, |git_panel, cx| { - git_panel.commit_changes( - CommitOptions { amend: false }, - window, - cx, - ); - }) - .ok(); - } - }), - ) + .ok(); + } }) - }) + .disabled(!can_commit || self.modal_open) + .tooltip({ + let handle = commit_tooltip_focus_handle.clone(); + move |window, cx| { + if can_commit { + Tooltip::with_meta_in( + tooltip, + Some(&git::Commit), + format!( + "git commit{}{}", + if amend { " --amend" } else { "" }, + if signoff { " --signoff" } else { "" } + ), + &handle.clone(), + window, + cx, + ) + } else { + Tooltip::simple(tooltip, cx) + } + } + }), + self.render_git_commit_menu( + ElementId::Name(format!("split-button-right-{}", title).into()), + Some(commit_tooltip_focus_handle.clone()), + cx, + ) + .into_any_element(), + )) } fn render_pending_amend(&self, cx: &mut Context<Self>) -> impl IntoElement { - div() - .p_2() + h_flex() + .py_1p5() + .px_2() + .gap_1p5() + .justify_between() .border_t_1() - .border_color(cx.theme().colors().border) + .border_color(cx.theme().colors().border.opacity(0.8)) .child( - Label::new( - "This will update your most recent commit. Cancel to make a new one instead.", - ) - .size(LabelSize::Small), + div() + .flex_grow() + .overflow_hidden() + .max_w(relative(0.85)) + .child( + Label::new("This will update your most recent commit.") + .size(LabelSize::Small) + .truncate(), + ), + ) + .child( + panel_button("Cancel") + .size(ButtonSize::Default) + .on_click(cx.listener(|this, _, _, cx| this.set_amend_pending(false, cx))), ) } @@ -3743,7 +3823,7 @@ impl GitPanel { for ix in range { match &this.entries.get(ix) { - Some(GitListEntry::GitStatusEntry(entry)) => { + Some(GitListEntry::Status(entry)) => { items.push(this.render_entry( ix, entry, @@ -4000,8 +4080,6 @@ impl GitPanel { let marked = self.marked_entries.contains(&ix); let status_style = GitPanelSettings::get_global(cx).status_style; let status = entry.status; - let modifiers = self.current_modifiers; - let shift_held = modifiers.shift; let has_conflict = status.is_conflicted(); let is_modified = status.is_modified(); @@ -4120,12 +4198,6 @@ impl GitPanel { cx.stop_propagation(); }, ) - // .on_secondary_mouse_down(cx.listener( - // move |this, event: &MouseDownEvent, window, cx| { - // this.deploy_entry_context_menu(event.position, ix, window, cx); - // cx.stop_propagation(); - // }, - // )) .child( div() .id(checkbox_wrapper_id) @@ -4137,46 +4209,35 @@ impl GitPanel { .disabled(!has_write_access) .fill() .elevation(ElevationIndex::Surface) - .on_click({ + .on_click_ext({ let entry = entry.clone(); - cx.listener(move |this, _, window, cx| { - if !has_write_access { - return; - } - this.toggle_staged_for_entry( - &GitListEntry::GitStatusEntry(entry.clone()), - window, - cx, - ); - cx.stop_propagation(); - }) + let this = cx.weak_entity(); + move |_, click, window, cx| { + this.update(cx, |this, cx| { + if !has_write_access { + return; + } + if click.modifiers().shift { + this.stage_bulk(ix, cx); + } else { + this.toggle_staged_for_entry( + &GitListEntry::Status(entry.clone()), + window, + cx, + ); + } + cx.stop_propagation(); + }) + .ok(); + } }) .tooltip(move |window, cx| { let is_staged = entry_staging.is_fully_staged(); let action = if is_staged { "Unstage" } else { "Stage" }; - let tooltip_name = if shift_held { - format!("{} section", action) - } else { - action.to_string() - }; - - let meta = if shift_held { - format!( - "Release shift to {} single entry", - action.to_lowercase() - ) - } else { - format!("Shift click to {} section", action.to_lowercase()) - }; + let tooltip_name = action.to_string(); - Tooltip::with_meta( - tooltip_name, - Some(&ToggleStaged), - meta, - window, - cx, - ) + Tooltip::for_action(tooltip_name, &ToggleStaged, window, cx) }), ), ) @@ -4214,20 +4275,50 @@ impl GitPanel { pub fn set_amend_pending(&mut self, value: bool, cx: &mut Context<Self>) { self.amend_pending = value; + self.serialize(cx); + cx.notify(); + } + + pub fn signoff_enabled(&self) -> bool { + self.signoff_enabled + } + + pub fn set_signoff_enabled(&mut self, value: bool, cx: &mut Context<Self>) { + self.signoff_enabled = value; + self.serialize(cx); cx.notify(); } + pub fn toggle_signoff_enabled( + &mut self, + _: &Signoff, + _window: &mut Window, + cx: &mut Context<Self>, + ) { + self.set_signoff_enabled(!self.signoff_enabled, cx); + } + pub async fn load( workspace: WeakEntity<Workspace>, mut cx: AsyncWindowContext, ) -> anyhow::Result<Entity<Self>> { - let serialized_panel = cx - .background_spawn(async move { KEY_VALUE_STORE.read_kvp(&GIT_PANEL_KEY) }) - .await - .context("loading git panel") - .log_err() + let serialized_panel = match workspace + .read_with(&cx, |workspace, _| Self::serialization_key(workspace)) + .ok() .flatten() - .and_then(|panel| serde_json::from_str::<SerializedGitPanel>(&panel).log_err()); + { + Some(serialization_key) => cx + .background_spawn(async move { KEY_VALUE_STORE.read_kvp(&serialization_key) }) + .await + .context("loading git panel") + .log_err() + .flatten() + .map(|panel| serde_json::from_str::<SerializedGitPanel>(&panel)) + .transpose() + .log_err() + .flatten(), + None => None, + }; workspace.update_in(&mut cx, |workspace, window, cx| { let panel = GitPanel::new(workspace, window, cx); @@ -4235,6 +4326,8 @@ impl GitPanel { if let Some(serialized_panel) = serialized_panel { panel.update(cx, |panel, cx| { panel.width = serialized_panel.width; + panel.amend_pending = serialized_panel.amend_pending; + panel.signoff_enabled = serialized_panel.signoff_enabled; cx.notify(); }) } @@ -4242,11 +4335,55 @@ impl GitPanel { panel }) } + + fn stage_bulk(&mut self, mut index: usize, cx: &mut Context<'_, Self>) { + let Some(op) = self.bulk_staging.as_ref() else { + return; + }; + let Some(mut anchor_index) = self.entry_by_path(&op.anchor, cx) else { + return; + }; + if let Some(entry) = self.entries.get(index) + && let Some(entry) = entry.status_entry() + { + self.set_bulk_staging_anchor(entry.repo_path.clone(), cx); + } + if index < anchor_index { + std::mem::swap(&mut index, &mut anchor_index); + } + let entries = self + .entries + .get(anchor_index..=index) + .unwrap_or_default() + .iter() + .filter_map(|entry| entry.status_entry().cloned()) + .collect::<Vec<_>>(); + self.change_file_stage(true, entries, cx); + } + + fn set_bulk_staging_anchor(&mut self, path: RepoPath, cx: &mut Context<'_, GitPanel>) { + let Some(repo) = self.active_repository.as_ref() else { + return; + }; + self.bulk_staging = Some(BulkStaging { + repo_id: repo.read(cx).id, + anchor: path, + }); + } + + pub(crate) fn toggle_amend_pending(&mut self, cx: &mut Context<Self>) { + self.set_amend_pending(!self.amend_pending, cx); + if self.amend_pending { + self.load_last_commit_message_if_empty(cx); + } + } } fn current_language_model(cx: &Context<'_, GitPanel>) -> Option<Arc<dyn LanguageModel>> { - agent_settings::AgentSettings::get_global(cx) - .enabled + let is_enabled = agent_settings::AgentSettings::get_global(cx).enabled + && !DisableAiSettings::get_global(cx).disable_ai; + + is_enabled .then(|| { let ConfiguredModel { provider, model } = LanguageModelRegistry::read_global(cx).commit_message_model()?; @@ -4279,12 +4416,12 @@ impl Render for GitPanel { .id("git_panel") .key_context(self.dispatch_context(window, cx)) .track_focus(&self.focus_handle) - .on_modifiers_changed(cx.listener(Self::handle_modifiers_changed)) .when(has_write_access && !project.is_read_only(cx), |this| { this.on_action(cx.listener(Self::toggle_staged_for_selected)) + .on_action(cx.listener(Self::stage_range)) .on_action(cx.listener(GitPanel::commit)) .on_action(cx.listener(GitPanel::amend)) - .on_action(cx.listener(GitPanel::cancel)) + .on_action(cx.listener(GitPanel::toggle_signoff_enabled)) .on_action(cx.listener(Self::stage_all)) .on_action(cx.listener(Self::unstage_all)) .on_action(cx.listener(Self::stage_selected)) @@ -4293,6 +4430,8 @@ impl Render for GitPanel { .on_action(cx.listener(Self::revert_selected)) .on_action(cx.listener(Self::clean_all)) .on_action(cx.listener(Self::generate_commit_message_action)) + .on_action(cx.listener(Self::stash_all)) + .on_action(cx.listener(Self::stash_pop)) }) .on_action(cx.listener(Self::select_first)) .on_action(cx.listener(Self::select_next)) @@ -4953,7 +5092,7 @@ impl Component for PanelRepoFooter { #[cfg(test)] mod tests { - use git::status::StatusCode; + use git::status::{StatusCode, UnmergedStatus, UnmergedStatusCode}; use gpui::{TestAppContext, VisualTestContext}; use project::{FakeFs, WorktreeSettings}; use serde_json::json; @@ -5052,13 +5191,13 @@ mod tests { GitListEntry::Header(GitHeaderEntry { header: Section::Tracked }), - GitListEntry::GitStatusEntry(GitStatusEntry { + GitListEntry::Status(GitStatusEntry { abs_path: path!("/root/zed/crates/gpui/gpui.rs").into(), repo_path: "crates/gpui/gpui.rs".into(), status: StatusCode::Modified.worktree(), staging: StageStatus::Unstaged, }), - GitListEntry::GitStatusEntry(GitStatusEntry { + GitListEntry::Status(GitStatusEntry { abs_path: path!("/root/zed/crates/util/util.rs").into(), repo_path: "crates/util/util.rs".into(), status: StatusCode::Modified.worktree(), @@ -5067,54 +5206,6 @@ mod tests { ], ); - // TODO(cole) restore this once repository deduplication is implemented properly. - //cx.update_window_entity(&panel, |panel, window, cx| { - // panel.select_last(&Default::default(), window, cx); - // assert_eq!(panel.selected_entry, Some(2)); - // panel.open_diff(&Default::default(), window, cx); - //}); - //cx.run_until_parked(); - - //let worktree_roots = workspace.update(cx, |workspace, cx| { - // workspace - // .worktrees(cx) - // .map(|worktree| worktree.read(cx).abs_path()) - // .collect::<Vec<_>>() - //}); - //pretty_assertions::assert_eq!( - // worktree_roots, - // vec![ - // Path::new(path!("/root/zed/crates/gpui")).into(), - // Path::new(path!("/root/zed/crates/util/util.rs")).into(), - // ] - //); - - //project.update(cx, |project, cx| { - // let git_store = project.git_store().read(cx); - // // The repo that comes from the single-file worktree can't be selected through the UI. - // let filtered_entries = filtered_repository_entries(git_store, cx) - // .iter() - // .map(|repo| repo.read(cx).worktree_abs_path.clone()) - // .collect::<Vec<_>>(); - // assert_eq!( - // filtered_entries, - // [Path::new(path!("/root/zed/crates/gpui")).into()] - // ); - // // But we can select it artificially here. - // let repo_from_single_file_worktree = git_store - // .repositories() - // .values() - // .find(|repo| { - // repo.read(cx).worktree_abs_path.as_ref() - // == Path::new(path!("/root/zed/crates/util/util.rs")) - // }) - // .unwrap() - // .clone(); - - // // Paths still make sense when we somehow activate a repo that comes from a single-file worktree. - // repo_from_single_file_worktree.update(cx, |repo, cx| repo.set_as_active_repository(cx)); - //}); - let handle = cx.update_window_entity(&panel, |panel, _, _| { std::mem::replace(&mut panel.update_visible_entries_task, Task::ready(())) }); @@ -5127,13 +5218,13 @@ mod tests { GitListEntry::Header(GitHeaderEntry { header: Section::Tracked }), - GitListEntry::GitStatusEntry(GitStatusEntry { + GitListEntry::Status(GitStatusEntry { abs_path: path!("/root/zed/crates/gpui/gpui.rs").into(), repo_path: "crates/gpui/gpui.rs".into(), status: StatusCode::Modified.worktree(), staging: StageStatus::Unstaged, }), - GitListEntry::GitStatusEntry(GitStatusEntry { + GitListEntry::Status(GitStatusEntry { abs_path: path!("/root/zed/crates/util/util.rs").into(), repo_path: "crates/util/util.rs".into(), status: StatusCode::Modified.worktree(), @@ -5142,4 +5233,196 @@ mod tests { ], ); } + + #[gpui::test] + async fn test_bulk_staging(cx: &mut TestAppContext) { + use GitListEntry::*; + + init_test(cx); + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree( + "/root", + json!({ + "project": { + ".git": {}, + "src": { + "main.rs": "fn main() {}", + "lib.rs": "pub fn hello() {}", + "utils.rs": "pub fn util() {}" + }, + "tests": { + "test.rs": "fn test() {}" + }, + "new_file.txt": "new content", + "another_new.rs": "// new file", + "conflict.txt": "conflicted content" + } + }), + ) + .await; + + fs.set_status_for_repo( + Path::new(path!("/root/project/.git")), + &[ + (Path::new("src/main.rs"), StatusCode::Modified.worktree()), + (Path::new("src/lib.rs"), StatusCode::Modified.worktree()), + (Path::new("tests/test.rs"), StatusCode::Modified.worktree()), + (Path::new("new_file.txt"), FileStatus::Untracked), + (Path::new("another_new.rs"), FileStatus::Untracked), + (Path::new("src/utils.rs"), FileStatus::Untracked), + ( + Path::new("conflict.txt"), + UnmergedStatus { + first_head: UnmergedStatusCode::Updated, + second_head: UnmergedStatusCode::Updated, + } + .into(), + ), + ], + ); + + let project = Project::test(fs.clone(), [Path::new(path!("/root/project"))], cx).await; + let workspace = + cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let cx = &mut VisualTestContext::from_window(*workspace, cx); + + cx.read(|cx| { + project + .read(cx) + .worktrees(cx) + .nth(0) + .unwrap() + .read(cx) + .as_local() + .unwrap() + .scan_complete() + }) + .await; + + cx.executor().run_until_parked(); + + let panel = workspace.update(cx, GitPanel::new).unwrap(); + + let handle = cx.update_window_entity(&panel, |panel, _, _| { + std::mem::replace(&mut panel.update_visible_entries_task, Task::ready(())) + }); + cx.executor().advance_clock(2 * UPDATE_DEBOUNCE); + handle.await; + + let entries = panel.read_with(cx, |panel, _| panel.entries.clone()); + #[rustfmt::skip] + pretty_assertions::assert_matches!( + entries.as_slice(), + &[ + Header(GitHeaderEntry { header: Section::Conflict }), + Status(GitStatusEntry { staging: StageStatus::Unstaged, .. }), + Header(GitHeaderEntry { header: Section::Tracked }), + Status(GitStatusEntry { staging: StageStatus::Unstaged, .. }), + Status(GitStatusEntry { staging: StageStatus::Unstaged, .. }), + Status(GitStatusEntry { staging: StageStatus::Unstaged, .. }), + Header(GitHeaderEntry { header: Section::New }), + Status(GitStatusEntry { staging: StageStatus::Unstaged, .. }), + Status(GitStatusEntry { staging: StageStatus::Unstaged, .. }), + Status(GitStatusEntry { staging: StageStatus::Unstaged, .. }), + ], + ); + + let second_status_entry = entries[3].clone(); + panel.update_in(cx, |panel, window, cx| { + panel.toggle_staged_for_entry(&second_status_entry, window, cx); + }); + + panel.update_in(cx, |panel, window, cx| { + panel.selected_entry = Some(7); + panel.stage_range(&git::StageRange, window, cx); + }); + + cx.read(|cx| { + project + .read(cx) + .worktrees(cx) + .nth(0) + .unwrap() + .read(cx) + .as_local() + .unwrap() + .scan_complete() + }) + .await; + + cx.executor().run_until_parked(); + + let handle = cx.update_window_entity(&panel, |panel, _, _| { + std::mem::replace(&mut panel.update_visible_entries_task, Task::ready(())) + }); + cx.executor().advance_clock(2 * UPDATE_DEBOUNCE); + handle.await; + + let entries = panel.read_with(cx, |panel, _| panel.entries.clone()); + #[rustfmt::skip] + pretty_assertions::assert_matches!( + entries.as_slice(), + &[ + Header(GitHeaderEntry { header: Section::Conflict }), + Status(GitStatusEntry { staging: StageStatus::Unstaged, .. }), + Header(GitHeaderEntry { header: Section::Tracked }), + Status(GitStatusEntry { staging: StageStatus::Staged, .. }), + Status(GitStatusEntry { staging: StageStatus::Staged, .. }), + Status(GitStatusEntry { staging: StageStatus::Staged, .. }), + Header(GitHeaderEntry { header: Section::New }), + Status(GitStatusEntry { staging: StageStatus::Staged, .. }), + Status(GitStatusEntry { staging: StageStatus::Unstaged, .. }), + Status(GitStatusEntry { staging: StageStatus::Unstaged, .. }), + ], + ); + + let third_status_entry = entries[4].clone(); + panel.update_in(cx, |panel, window, cx| { + panel.toggle_staged_for_entry(&third_status_entry, window, cx); + }); + + panel.update_in(cx, |panel, window, cx| { + panel.selected_entry = Some(9); + panel.stage_range(&git::StageRange, window, cx); + }); + + cx.read(|cx| { + project + .read(cx) + .worktrees(cx) + .nth(0) + .unwrap() + .read(cx) + .as_local() + .unwrap() + .scan_complete() + }) + .await; + + cx.executor().run_until_parked(); + + let handle = cx.update_window_entity(&panel, |panel, _, _| { + std::mem::replace(&mut panel.update_visible_entries_task, Task::ready(())) + }); + cx.executor().advance_clock(2 * UPDATE_DEBOUNCE); + handle.await; + + let entries = panel.read_with(cx, |panel, _| panel.entries.clone()); + #[rustfmt::skip] + pretty_assertions::assert_matches!( + entries.as_slice(), + &[ + Header(GitHeaderEntry { header: Section::Conflict }), + Status(GitStatusEntry { staging: StageStatus::Unstaged, .. }), + Header(GitHeaderEntry { header: Section::Tracked }), + Status(GitStatusEntry { staging: StageStatus::Staged, .. }), + Status(GitStatusEntry { staging: StageStatus::Unstaged, .. }), + Status(GitStatusEntry { staging: StageStatus::Staged, .. }), + Header(GitHeaderEntry { header: Section::New }), + Status(GitStatusEntry { staging: StageStatus::Staged, .. }), + Status(GitStatusEntry { staging: StageStatus::Staged, .. }), + Status(GitStatusEntry { staging: StageStatus::Staged, .. }), + ], + ); + } } diff --git a/crates/git_ui/src/git_ui.rs b/crates/git_ui/src/git_ui.rs index a9ccaf716074783b2bf3a5e4d969c0702320557e..0163175eda0e8135e9aa77d80e8eb4bce0184cf0 100644 --- a/crates/git_ui/src/git_ui.rs +++ b/crates/git_ui/src/git_ui.rs @@ -3,7 +3,7 @@ use std::any::Any; use ::settings::Settings; use command_palette_hooks::CommandPaletteFilter; use commit_modal::CommitModal; -use editor::Editor; +use editor::{Editor, actions::DiffClipboardWithSelectionData}; mod blame_ui; use git::{ repository::{Branch, Upstream, UpstreamTracking, UpstreamTrackingStatus}, @@ -15,6 +15,9 @@ use onboarding::GitOnboardingModal; use project_diff::ProjectDiff; use ui::prelude::*; use workspace::Workspace; +use zed_actions; + +use crate::text_diff_view::TextDiffView; mod askpass_modal; pub mod branch_picker; @@ -22,7 +25,7 @@ mod commit_modal; pub mod commit_tooltip; mod commit_view; mod conflict_view; -pub mod diff_view; +pub mod file_diff_view; pub mod git_panel; mod git_panel_settings; pub mod onboarding; @@ -30,6 +33,7 @@ pub mod picker_prompt; pub mod project_diff; pub(crate) mod remote_output; pub mod repository_selector; +pub mod text_diff_view; actions!( git, @@ -110,6 +114,22 @@ pub fn init(cx: &mut App) { }); }); } + workspace.register_action(|workspace, action: &git::StashAll, window, cx| { + let Some(panel) = workspace.panel::<git_panel::GitPanel>(cx) else { + return; + }; + panel.update(cx, |panel, cx| { + panel.stash_all(action, window, cx); + }); + }); + workspace.register_action(|workspace, action: &git::StashPop, window, cx| { + let Some(panel) = workspace.panel::<git_panel::GitPanel>(cx) else { + return; + }; + panel.update(cx, |panel, cx| { + panel.stash_pop(action, window, cx); + }); + }); workspace.register_action(|workspace, action: &git::StageAll, window, cx| { let Some(panel) = workspace.panel::<git_panel::GitPanel>(cx) else { return; @@ -152,6 +172,13 @@ pub fn init(cx: &mut App) { workspace.register_action(|workspace, _: &git::OpenModifiedFiles, window, cx| { open_modified_files(workspace, window, cx); }); + workspace.register_action( + |workspace, action: &DiffClipboardWithSelectionData, window, cx| { + if let Some(task) = TextDiffView::open(action, workspace, window, cx) { + task.detach(); + }; + }, + ); }) .detach(); } @@ -501,7 +528,7 @@ mod remote_button { ) .into_any_element(); - SplitButton { left, right } + SplitButton::new(left, right) } } diff --git a/crates/git_ui/src/remote_output.rs b/crates/git_ui/src/remote_output.rs index 03fbf4f917ad2f9b0741fab061ee31998146c633..8437bf0d0d37c2b2767624110fed056bbae25d05 100644 --- a/crates/git_ui/src/remote_output.rs +++ b/crates/git_ui/src/remote_output.rs @@ -24,7 +24,7 @@ impl RemoteAction { pub enum SuccessStyle { Toast, ToastWithLog { output: RemoteCommandOutput }, - PushPrLink { link: String }, + PushPrLink { text: String, link: String }, } pub struct SuccessMessage { @@ -37,7 +37,7 @@ pub fn format_output(action: &RemoteAction, output: RemoteCommandOutput) -> Succ RemoteAction::Fetch(remote) => { if output.stderr.is_empty() { SuccessMessage { - message: "Already up to date".into(), + message: "Fetch: Already up to date".into(), style: SuccessStyle::Toast, } } else { @@ -68,10 +68,9 @@ pub fn format_output(action: &RemoteAction, output: RemoteCommandOutput) -> Succ Ok(files_changed) }; - - if output.stderr.starts_with("Everything up to date") { + if output.stdout.ends_with("Already up to date.\n") { SuccessMessage { - message: output.stderr.trim().to_owned(), + message: "Pull: Already up to date".into(), style: SuccessStyle::Toast, } } else if output.stdout.starts_with("Updating") { @@ -119,48 +118,42 @@ pub fn format_output(action: &RemoteAction, output: RemoteCommandOutput) -> Succ } } RemoteAction::Push(branch_name, remote_ref) => { - if output.stderr.contains("* [new branch]") { + let message = if output.stderr.ends_with("Everything up-to-date\n") { + "Push: Everything is up-to-date".to_string() + } else { + format!("Pushed {} to {}", branch_name, remote_ref.name) + }; + + let style = if output.stderr.ends_with("Everything up-to-date\n") { + Some(SuccessStyle::Toast) + } else if output.stderr.contains("\nremote: ") { let pr_hints = [ - // GitHub - "Create a pull request", - // Bitbucket - "Create pull request", - // GitLab - "create a merge request", + ("Create a pull request", "Create Pull Request"), // GitHub + ("Create pull request", "Create Pull Request"), // Bitbucket + ("create a merge request", "Create Merge Request"), // GitLab + ("View merge request", "View Merge Request"), // GitLab ]; - let style = if pr_hints + pr_hints .iter() - .any(|indicator| output.stderr.contains(indicator)) - { - let finder = LinkFinder::new(); - let first_link = finder - .links(&output.stderr) - .filter(|link| *link.kind() == LinkKind::Url) - .map(|link| link.start()..link.end()) - .next(); - if let Some(link) = first_link { - let link = output.stderr[link].to_string(); - SuccessStyle::PushPrLink { link } - } else { - SuccessStyle::ToastWithLog { output } - } - } else { - SuccessStyle::ToastWithLog { output } - }; - SuccessMessage { - message: format!("Published {} to {}", branch_name, remote_ref.name), - style, - } - } else if output.stderr.starts_with("Everything up to date") { - SuccessMessage { - message: output.stderr.trim().to_owned(), - style: SuccessStyle::Toast, - } + .find(|(indicator, _)| output.stderr.contains(indicator)) + .and_then(|(_, mapped)| { + let finder = LinkFinder::new(); + finder + .links(&output.stderr) + .filter(|link| *link.kind() == LinkKind::Url) + .map(|link| link.start()..link.end()) + .next() + .map(|link| SuccessStyle::PushPrLink { + text: mapped.to_string(), + link: output.stderr[link].to_string(), + }) + }) } else { - SuccessMessage { - message: format!("Pushed {} to {}", branch_name, remote_ref.name), - style: SuccessStyle::ToastWithLog { output }, - } + None + }; + SuccessMessage { + message, + style: style.unwrap_or(SuccessStyle::ToastWithLog { output }), } } } @@ -169,6 +162,7 @@ pub fn format_output(action: &RemoteAction, output: RemoteCommandOutput) -> Succ #[cfg(test)] mod tests { use super::*; + use indoc::indoc; #[test] fn test_push_new_branch_pull_request() { @@ -181,8 +175,7 @@ mod tests { let output = RemoteCommandOutput { stdout: String::new(), - stderr: String::from( - " + stderr: indoc! { " Total 0 (delta 0), reused 0 (delta 0), pack-reused 0 (from 0) remote: remote: Create a pull request for 'test' on GitHub by visiting: @@ -190,13 +183,14 @@ mod tests { remote: To example.com:test/test.git * [new branch] test -> test - ", - ), + "} + .to_string(), }; let msg = format_output(&action, output); - if let SuccessStyle::PushPrLink { link } = &msg.style { + if let SuccessStyle::PushPrLink { text: hint, link } = &msg.style { + assert_eq!(hint, "Create Pull Request"); assert_eq!(link, "https://example.com/test/test/pull/new/test"); } else { panic!("Expected PushPrLink variant"); @@ -214,7 +208,7 @@ mod tests { let output = RemoteCommandOutput { stdout: String::new(), - stderr: String::from(" + stderr: indoc! {" Total 0 (delta 0), reused 0 (delta 0), pack-reused 0 (from 0) remote: remote: To create a merge request for test, visit: @@ -222,12 +216,14 @@ mod tests { remote: To example.com:test/test.git * [new branch] test -> test - "), - }; + "} + .to_string() + }; let msg = format_output(&action, output); - if let SuccessStyle::PushPrLink { link } = &msg.style { + if let SuccessStyle::PushPrLink { text, link } = &msg.style { + assert_eq!(text, "Create Merge Request"); assert_eq!( link, "https://example.com/test/test/-/merge_requests/new?merge_request%5Bsource_branch%5D=test" @@ -237,6 +233,39 @@ mod tests { } } + #[test] + fn test_push_branch_existing_merge_request() { + let action = RemoteAction::Push( + SharedString::new("test_branch"), + Remote { + name: SharedString::new("test_remote"), + }, + ); + + let output = RemoteCommandOutput { + stdout: String::new(), + stderr: indoc! {" + Total 0 (delta 0), reused 0 (delta 0), pack-reused 0 (from 0) + remote: + remote: View merge request for test: + remote: https://example.com/test/test/-/merge_requests/99999 + remote: + To example.com:test/test.git + + 80bd3c83be...e03d499d2e test -> test + "} + .to_string(), + }; + + let msg = format_output(&action, output); + + if let SuccessStyle::PushPrLink { text, link } = &msg.style { + assert_eq!(text, "View Merge Request"); + assert_eq!(link, "https://example.com/test/test/-/merge_requests/99999"); + } else { + panic!("Expected PushPrLink variant"); + } + } + #[test] fn test_push_new_branch_no_link() { let action = RemoteAction::Push( @@ -248,12 +277,12 @@ mod tests { let output = RemoteCommandOutput { stdout: String::new(), - stderr: String::from( - " + stderr: indoc! { " To http://example.com/test/test.git * [new branch] test -> test ", - ), + } + .to_string(), }; let msg = format_output(&action, output); @@ -261,10 +290,7 @@ mod tests { if let SuccessStyle::ToastWithLog { output } = &msg.style { assert_eq!( output.stderr, - " - To http://example.com/test/test.git - * [new branch] test -> test - " + "To http://example.com/test/test.git\n * [new branch] test -> test\n" ); } else { panic!("Expected ToastWithLog variant"); diff --git a/crates/git_ui/src/repository_selector.rs b/crates/git_ui/src/repository_selector.rs index b5865e9a8578e24dffb129eb373b718219344e1c..db080ab0b4974dfc3ef83ffb3a0ec71481c683bc 100644 --- a/crates/git_ui/src/repository_selector.rs +++ b/crates/git_ui/src/repository_selector.rs @@ -109,7 +109,10 @@ impl Focusable for RepositorySelector { impl Render for RepositorySelector { fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement { - div().w(self.width).child(self.picker.clone()) + div() + .key_context("GitRepositorySelector") + .w(self.width) + .child(self.picker.clone()) } } diff --git a/crates/git_ui/src/text_diff_view.rs b/crates/git_ui/src/text_diff_view.rs new file mode 100644 index 0000000000000000000000000000000000000000..005c1e18b40727f42df81437c7038f4e5a7ef905 --- /dev/null +++ b/crates/git_ui/src/text_diff_view.rs @@ -0,0 +1,740 @@ +//! TextDiffView currently provides a UI for displaying differences between the clipboard and selected text. + +use anyhow::Result; +use buffer_diff::{BufferDiff, BufferDiffSnapshot}; +use editor::{Editor, EditorEvent, MultiBuffer, ToPoint, actions::DiffClipboardWithSelectionData}; +use futures::{FutureExt, select_biased}; +use gpui::{ + AnyElement, AnyView, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, + FocusHandle, Focusable, IntoElement, Render, Task, Window, +}; +use language::{self, Buffer, Point}; +use project::Project; +use std::{ + any::{Any, TypeId}, + cmp, + ops::Range, + pin::pin, + sync::Arc, + time::Duration, +}; +use ui::{Color, Icon, IconName, Label, LabelCommon as _, SharedString}; +use util::paths::PathExt; + +use workspace::{ + Item, ItemHandle as _, ItemNavHistory, ToolbarItemLocation, Workspace, + item::{BreadcrumbText, ItemEvent, SaveOptions, TabContentParams}, + searchable::SearchableItemHandle, +}; + +pub struct TextDiffView { + diff_editor: Entity<Editor>, + title: SharedString, + path: Option<SharedString>, + buffer_changes_tx: watch::Sender<()>, + _recalculate_diff_task: Task<Result<()>>, +} + +const RECALCULATE_DIFF_DEBOUNCE: Duration = Duration::from_millis(250); + +impl TextDiffView { + pub fn open( + diff_data: &DiffClipboardWithSelectionData, + workspace: &Workspace, + window: &mut Window, + cx: &mut App, + ) -> Option<Task<Result<Entity<Self>>>> { + let source_editor = diff_data.editor.clone(); + + let selection_data = source_editor.update(cx, |editor, cx| { + let multibuffer = editor.buffer().read(cx); + let source_buffer = multibuffer.as_singleton()?.clone(); + let selections = editor.selections.all::<Point>(cx); + let buffer_snapshot = source_buffer.read(cx); + let first_selection = selections.first()?; + let max_point = buffer_snapshot.max_point(); + + if first_selection.is_empty() { + let full_range = Point::new(0, 0)..max_point; + return Some((source_buffer, full_range)); + } + + let start = first_selection.start; + let end = first_selection.end; + let expanded_start = Point::new(start.row, 0); + + let expanded_end = if end.column > 0 { + let next_row = end.row + 1; + cmp::min(max_point, Point::new(next_row, 0)) + } else { + end + }; + Some((source_buffer, expanded_start..expanded_end)) + }); + + let Some((source_buffer, expanded_selection_range)) = selection_data else { + log::warn!("There should always be at least one selection in Zed. This is a bug."); + return None; + }; + + source_editor.update(cx, |source_editor, cx| { + source_editor.change_selections(Default::default(), window, cx, |s| { + s.select_ranges(vec![ + expanded_selection_range.start..expanded_selection_range.end, + ]); + }) + }); + + let source_buffer_snapshot = source_buffer.read(cx).snapshot(); + let mut clipboard_text = diff_data.clipboard_text.clone(); + + if !clipboard_text.ends_with("\n") { + clipboard_text.push_str("\n"); + } + + let workspace = workspace.weak_handle(); + let diff_buffer = cx.new(|cx| BufferDiff::new(&source_buffer_snapshot.text, cx)); + let clipboard_buffer = build_clipboard_buffer( + clipboard_text, + &source_buffer, + expanded_selection_range.clone(), + cx, + ); + + let task = window.spawn(cx, async move |cx| { + let project = workspace.update(cx, |workspace, _| workspace.project().clone())?; + + update_diff_buffer(&diff_buffer, &source_buffer, &clipboard_buffer, cx).await?; + + workspace.update_in(cx, |workspace, window, cx| { + let diff_view = cx.new(|cx| { + TextDiffView::new( + clipboard_buffer, + source_editor, + source_buffer, + expanded_selection_range, + diff_buffer, + project, + window, + cx, + ) + }); + + let pane = workspace.active_pane(); + pane.update(cx, |pane, cx| { + pane.add_item(Box::new(diff_view.clone()), true, true, None, window, cx); + }); + + diff_view + }) + }); + + Some(task) + } + + pub fn new( + clipboard_buffer: Entity<Buffer>, + source_editor: Entity<Editor>, + source_buffer: Entity<Buffer>, + source_range: Range<Point>, + diff_buffer: Entity<BufferDiff>, + project: Entity<Project>, + window: &mut Window, + cx: &mut Context<Self>, + ) -> Self { + let multibuffer = cx.new(|cx| { + let mut multibuffer = MultiBuffer::new(language::Capability::ReadWrite); + + multibuffer.push_excerpts( + source_buffer.clone(), + [editor::ExcerptRange::new(source_range)], + cx, + ); + + multibuffer.add_diff(diff_buffer.clone(), cx); + multibuffer + }); + let diff_editor = cx.new(|cx| { + let mut editor = Editor::for_multibuffer(multibuffer, Some(project), window, cx); + editor.start_temporary_diff_override(); + editor.disable_diagnostics(cx); + editor.set_expand_all_diff_hunks(cx); + editor.set_render_diff_hunk_controls( + Arc::new(|_, _, _, _, _, _, _, _| gpui::Empty.into_any_element()), + cx, + ); + editor + }); + + let (buffer_changes_tx, mut buffer_changes_rx) = watch::channel(()); + + cx.subscribe(&source_buffer, move |this, _, event, _| match event { + language::BufferEvent::Edited + | language::BufferEvent::LanguageChanged + | language::BufferEvent::Reparsed => { + this.buffer_changes_tx.send(()).ok(); + } + _ => {} + }) + .detach(); + + let editor = source_editor.read(cx); + let title = editor.buffer().read(cx).title(cx).to_string(); + let selection_location_text = selection_location_text(editor, cx); + let selection_location_title = selection_location_text + .as_ref() + .map(|text| format!("{} @ {}", title, text)) + .unwrap_or(title); + + let path = editor + .buffer() + .read(cx) + .as_singleton() + .and_then(|b| { + b.read(cx) + .file() + .map(|f| f.full_path(cx).compact().to_string_lossy().to_string()) + }) + .unwrap_or("untitled".into()); + + let selection_location_path = selection_location_text + .map(|text| format!("{} @ {}", path, text)) + .unwrap_or(path); + + Self { + diff_editor, + title: format!("Clipboard ↔ {selection_location_title}").into(), + path: Some(format!("Clipboard ↔ {selection_location_path}").into()), + buffer_changes_tx, + _recalculate_diff_task: cx.spawn(async move |_, cx| { + while let Ok(_) = buffer_changes_rx.recv().await { + loop { + let mut timer = cx + .background_executor() + .timer(RECALCULATE_DIFF_DEBOUNCE) + .fuse(); + let mut recv = pin!(buffer_changes_rx.recv().fuse()); + select_biased! { + _ = timer => break, + _ = recv => continue, + } + } + + log::trace!("start recalculating"); + update_diff_buffer(&diff_buffer, &source_buffer, &clipboard_buffer, cx).await?; + log::trace!("finish recalculating"); + } + Ok(()) + }), + } + } +} + +fn build_clipboard_buffer( + text: String, + source_buffer: &Entity<Buffer>, + replacement_range: Range<Point>, + cx: &mut App, +) -> Entity<Buffer> { + let source_buffer_snapshot = source_buffer.read(cx).snapshot(); + cx.new(|cx| { + let mut buffer = language::Buffer::local(source_buffer_snapshot.text(), cx); + let language = source_buffer.read(cx).language().cloned(); + buffer.set_language(language, cx); + + let range_start = source_buffer_snapshot.point_to_offset(replacement_range.start); + let range_end = source_buffer_snapshot.point_to_offset(replacement_range.end); + buffer.edit([(range_start..range_end, text)], None, cx); + + buffer + }) +} + +async fn update_diff_buffer( + diff: &Entity<BufferDiff>, + source_buffer: &Entity<Buffer>, + clipboard_buffer: &Entity<Buffer>, + cx: &mut AsyncApp, +) -> Result<()> { + let source_buffer_snapshot = source_buffer.read_with(cx, |buffer, _| buffer.snapshot())?; + + let base_buffer_snapshot = clipboard_buffer.read_with(cx, |buffer, _| buffer.snapshot())?; + let base_text = base_buffer_snapshot.text().to_string(); + + let diff_snapshot = cx + .update(|cx| { + BufferDiffSnapshot::new_with_base_buffer( + source_buffer_snapshot.text.clone(), + Some(Arc::new(base_text)), + base_buffer_snapshot, + cx, + ) + })? + .await; + + diff.update(cx, |diff, cx| { + diff.set_snapshot(diff_snapshot, &source_buffer_snapshot.text, cx); + })?; + Ok(()) +} + +impl EventEmitter<EditorEvent> for TextDiffView {} + +impl Focusable for TextDiffView { + fn focus_handle(&self, cx: &App) -> FocusHandle { + self.diff_editor.focus_handle(cx) + } +} + +impl Item for TextDiffView { + type Event = EditorEvent; + + fn tab_icon(&self, _window: &Window, _cx: &App) -> Option<Icon> { + Some(Icon::new(IconName::Diff).color(Color::Muted)) + } + + fn tab_content(&self, params: TabContentParams, _window: &Window, cx: &App) -> AnyElement { + Label::new(self.tab_content_text(params.detail.unwrap_or_default(), cx)) + .color(if params.selected { + Color::Default + } else { + Color::Muted + }) + .into_any_element() + } + + fn tab_content_text(&self, _detail: usize, _: &App) -> SharedString { + self.title.clone() + } + + fn tab_tooltip_text(&self, _: &App) -> Option<SharedString> { + self.path.clone() + } + + fn to_item_events(event: &EditorEvent, f: impl FnMut(ItemEvent)) { + Editor::to_item_events(event, f) + } + + fn telemetry_event_text(&self) -> Option<&'static str> { + Some("Selection Diff View Opened") + } + + fn deactivated(&mut self, window: &mut Window, cx: &mut Context<Self>) { + self.diff_editor + .update(cx, |editor, cx| editor.deactivated(window, cx)); + } + + fn is_singleton(&self, _: &App) -> bool { + false + } + + fn act_as_type<'a>( + &'a self, + type_id: TypeId, + self_handle: &'a Entity<Self>, + _: &'a App, + ) -> Option<AnyView> { + if type_id == TypeId::of::<Self>() { + Some(self_handle.to_any()) + } else if type_id == TypeId::of::<Editor>() { + Some(self.diff_editor.to_any()) + } else { + None + } + } + + fn as_searchable(&self, _: &Entity<Self>) -> Option<Box<dyn SearchableItemHandle>> { + Some(Box::new(self.diff_editor.clone())) + } + + fn for_each_project_item( + &self, + cx: &App, + f: &mut dyn FnMut(gpui::EntityId, &dyn project::ProjectItem), + ) { + self.diff_editor.for_each_project_item(cx, f) + } + + fn set_nav_history( + &mut self, + nav_history: ItemNavHistory, + _: &mut Window, + cx: &mut Context<Self>, + ) { + self.diff_editor.update(cx, |editor, _| { + editor.set_nav_history(Some(nav_history)); + }); + } + + fn navigate( + &mut self, + data: Box<dyn Any>, + window: &mut Window, + cx: &mut Context<Self>, + ) -> bool { + self.diff_editor + .update(cx, |editor, cx| editor.navigate(data, window, cx)) + } + + fn breadcrumb_location(&self, _: &App) -> ToolbarItemLocation { + ToolbarItemLocation::PrimaryLeft + } + + fn breadcrumbs(&self, theme: &theme::Theme, cx: &App) -> Option<Vec<BreadcrumbText>> { + self.diff_editor.breadcrumbs(theme, cx) + } + + fn added_to_workspace( + &mut self, + workspace: &mut Workspace, + window: &mut Window, + cx: &mut Context<Self>, + ) { + self.diff_editor.update(cx, |editor, cx| { + editor.added_to_workspace(workspace, window, cx) + }); + } + + fn can_save(&self, cx: &App) -> bool { + // The editor handles the new buffer, so delegate to it + self.diff_editor.read(cx).can_save(cx) + } + + fn save( + &mut self, + options: SaveOptions, + project: Entity<Project>, + window: &mut Window, + cx: &mut Context<Self>, + ) -> Task<Result<()>> { + // Delegate saving to the editor, which manages the new buffer + self.diff_editor + .update(cx, |editor, cx| editor.save(options, project, window, cx)) + } +} + +pub fn selection_location_text(editor: &Editor, cx: &App) -> Option<String> { + let buffer = editor.buffer().read(cx); + let buffer_snapshot = buffer.snapshot(cx); + let first_selection = editor.selections.disjoint.first()?; + + let selection_start = first_selection.start.to_point(&buffer_snapshot); + let selection_end = first_selection.end.to_point(&buffer_snapshot); + + let start_row = selection_start.row; + let start_column = selection_start.column; + let end_row = selection_end.row; + let end_column = selection_end.column; + + let range_text = if start_row == end_row { + format!("L{}:{}-{}", start_row + 1, start_column + 1, end_column + 1) + } else { + format!( + "L{}:{}-L{}:{}", + start_row + 1, + start_column + 1, + end_row + 1, + end_column + 1 + ) + }; + + Some(range_text) +} + +impl Render for TextDiffView { + fn render(&mut self, _: &mut Window, _: &mut Context<Self>) -> impl IntoElement { + self.diff_editor.clone() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use editor::test::editor_test_context::assert_state_with_diff; + use gpui::{TestAppContext, VisualContext}; + use project::{FakeFs, Project}; + use serde_json::json; + use settings::{Settings, SettingsStore}; + use unindent::unindent; + use util::{path, test::marked_text_ranges}; + + fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + workspace::init_settings(cx); + editor::init_settings(cx); + theme::ThemeSettings::register(cx) + }); + } + + #[gpui::test] + async fn test_diffing_clipboard_against_empty_selection_uses_full_buffer_selection( + cx: &mut TestAppContext, + ) { + base_test( + path!("/test"), + path!("/test/text.txt"), + "def process_incoming_inventory(items, warehouse_id):\n pass\n", + "def process_outgoing_inventory(items, warehouse_id):\n passˇ\n", + &unindent( + " + - def process_incoming_inventory(items, warehouse_id): + + ˇdef process_outgoing_inventory(items, warehouse_id): + pass + ", + ), + "Clipboard ↔ text.txt @ L1:1-L3:1", + &format!("Clipboard ↔ {} @ L1:1-L3:1", path!("test/text.txt")), + cx, + ) + .await; + } + + #[gpui::test] + async fn test_diffing_clipboard_against_multiline_selection_expands_to_full_lines( + cx: &mut TestAppContext, + ) { + base_test( + path!("/test"), + path!("/test/text.txt"), + "def process_incoming_inventory(items, warehouse_id):\n pass\n", + "«def process_outgoing_inventory(items, warehouse_id):\n passˇ»\n", + &unindent( + " + - def process_incoming_inventory(items, warehouse_id): + + ˇdef process_outgoing_inventory(items, warehouse_id): + pass + ", + ), + "Clipboard ↔ text.txt @ L1:1-L3:1", + &format!("Clipboard ↔ {} @ L1:1-L3:1", path!("test/text.txt")), + cx, + ) + .await; + } + + #[gpui::test] + async fn test_diffing_clipboard_against_single_line_selection(cx: &mut TestAppContext) { + base_test( + path!("/test"), + path!("/test/text.txt"), + "a", + "«bbˇ»", + &unindent( + " + - a + + ˇbb", + ), + "Clipboard ↔ text.txt @ L1:1-3", + &format!("Clipboard ↔ {} @ L1:1-3", path!("test/text.txt")), + cx, + ) + .await; + } + + #[gpui::test] + async fn test_diffing_clipboard_with_leading_whitespace_against_line(cx: &mut TestAppContext) { + base_test( + path!("/test"), + path!("/test/text.txt"), + " a", + "«bbˇ»", + &unindent( + " + - a + + ˇbb", + ), + "Clipboard ↔ text.txt @ L1:1-3", + &format!("Clipboard ↔ {} @ L1:1-3", path!("test/text.txt")), + cx, + ) + .await; + } + + #[gpui::test] + async fn test_diffing_clipboard_against_line_with_leading_whitespace(cx: &mut TestAppContext) { + base_test( + path!("/test"), + path!("/test/text.txt"), + "a", + " «bbˇ»", + &unindent( + " + - a + + ˇ bb", + ), + "Clipboard ↔ text.txt @ L1:1-7", + &format!("Clipboard ↔ {} @ L1:1-7", path!("test/text.txt")), + cx, + ) + .await; + } + + #[gpui::test] + async fn test_diffing_clipboard_against_line_with_leading_whitespace_included_in_selection( + cx: &mut TestAppContext, + ) { + base_test( + path!("/test"), + path!("/test/text.txt"), + "a", + "« bbˇ»", + &unindent( + " + - a + + ˇ bb", + ), + "Clipboard ↔ text.txt @ L1:1-7", + &format!("Clipboard ↔ {} @ L1:1-7", path!("test/text.txt")), + cx, + ) + .await; + } + + #[gpui::test] + async fn test_diffing_clipboard_with_leading_whitespace_against_line_with_leading_whitespace( + cx: &mut TestAppContext, + ) { + base_test( + path!("/test"), + path!("/test/text.txt"), + " a", + " «bbˇ»", + &unindent( + " + - a + + ˇ bb", + ), + "Clipboard ↔ text.txt @ L1:1-7", + &format!("Clipboard ↔ {} @ L1:1-7", path!("test/text.txt")), + cx, + ) + .await; + } + + #[gpui::test] + async fn test_diffing_clipboard_with_leading_whitespace_against_line_with_leading_whitespace_included_in_selection( + cx: &mut TestAppContext, + ) { + base_test( + path!("/test"), + path!("/test/text.txt"), + " a", + "« bbˇ»", + &unindent( + " + - a + + ˇ bb", + ), + "Clipboard ↔ text.txt @ L1:1-7", + &format!("Clipboard ↔ {} @ L1:1-7", path!("test/text.txt")), + cx, + ) + .await; + } + + #[gpui::test] + async fn test_diffing_clipboard_against_partial_selection_expands_to_include_trailing_characters( + cx: &mut TestAppContext, + ) { + base_test( + path!("/test"), + path!("/test/text.txt"), + "a", + "«bˇ»b", + &unindent( + " + - a + + ˇbb", + ), + "Clipboard ↔ text.txt @ L1:1-3", + &format!("Clipboard ↔ {} @ L1:1-3", path!("test/text.txt")), + cx, + ) + .await; + } + + async fn base_test( + project_root: &str, + file_path: &str, + clipboard_text: &str, + editor_text: &str, + expected_diff: &str, + expected_tab_title: &str, + expected_tab_tooltip: &str, + cx: &mut TestAppContext, + ) { + init_test(cx); + + let file_name = std::path::Path::new(file_path) + .file_name() + .unwrap() + .to_str() + .unwrap(); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + project_root, + json!({ + file_name: editor_text + }), + ) + .await; + + let project = Project::test(fs, [project_root.as_ref()], cx).await; + + let (workspace, mut cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + + let buffer = project + .update(cx, |project, cx| project.open_local_buffer(file_path, cx)) + .await + .unwrap(); + + let editor = cx.new_window_entity(|window, cx| { + let mut editor = Editor::for_buffer(buffer, None, window, cx); + let (unmarked_text, selection_ranges) = marked_text_ranges(editor_text, false); + editor.set_text(unmarked_text, window, cx); + editor.change_selections(Default::default(), window, cx, |s| { + s.select_ranges(selection_ranges) + }); + + editor + }); + + let diff_view = workspace + .update_in(cx, |workspace, window, cx| { + TextDiffView::open( + &DiffClipboardWithSelectionData { + clipboard_text: clipboard_text.to_string(), + editor, + }, + workspace, + window, + cx, + ) + }) + .unwrap() + .await + .unwrap(); + + cx.executor().run_until_parked(); + + assert_state_with_diff( + &diff_view.read_with(cx, |diff_view, _| diff_view.diff_editor.clone()), + &mut cx, + expected_diff, + ); + + diff_view.read_with(cx, |diff_view, cx| { + assert_eq!(diff_view.tab_content_text(0, cx), expected_tab_title); + assert_eq!( + diff_view.tab_tooltip_text(cx).unwrap(), + expected_tab_tooltip + ); + }); + } +} diff --git a/crates/go_to_line/src/cursor_position.rs b/crates/go_to_line/src/cursor_position.rs index 322a791b13b93b73f204ff99cf4134ec363600ca..29064eb29cb986187b9d86046fd3d78cd2f63451 100644 --- a/crates/go_to_line/src/cursor_position.rs +++ b/crates/go_to_line/src/cursor_position.rs @@ -308,10 +308,14 @@ impl Settings for LineIndicatorFormat { type FileContent = Option<LineIndicatorFormatContent>; fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> anyhow::Result<Self> { - let format = [sources.release_channel, sources.user] - .into_iter() - .find_map(|value| value.copied().flatten()) - .unwrap_or(sources.default.ok_or_else(Self::missing_default)?); + let format = [ + sources.release_channel, + sources.operating_system, + sources.user, + ] + .into_iter() + .find_map(|value| value.copied().flatten()) + .unwrap_or(sources.default.ok_or_else(Self::missing_default)?); Ok(format.0) } diff --git a/crates/gpui/Cargo.toml b/crates/gpui/Cargo.toml index 80e424029a55338b6cd8ae00b684037a1c6f2a44..1d50b56ea595d15c02f517a14c3758dfad6a2661 100644 --- a/crates/gpui/Cargo.toml +++ b/crates/gpui/Cargo.toml @@ -122,7 +122,7 @@ smallvec.workspace = true smol.workspace = true strum.workspace = true sum_tree.workspace = true -taffy = "0.4.3" +taffy = "=0.9.0" thiserror.workspace = true util.workspace = true uuid.workspace = true @@ -217,10 +217,6 @@ xim = { git = "https://github.com/XDeme1/xim-rs", rev = "d50d461764c2213655cd9cf x11-clipboard = { version = "0.9.3", optional = true } [target.'cfg(target_os = "windows")'.dependencies] -blade-util.workspace = true -bytemuck = "1" -blade-graphics.workspace = true -blade-macros.workspace = true flume = "0.11" rand.workspace = true windows.workspace = true @@ -241,7 +237,6 @@ util = { workspace = true, features = ["test-support"] } [target.'cfg(target_os = "windows")'.build-dependencies] embed-resource = "3.0" -naga.workspace = true [target.'cfg(target_os = "macos")'.build-dependencies] bindgen = "0.71" @@ -288,6 +283,10 @@ path = "examples/shadow.rs" name = "svg" path = "examples/svg/svg.rs" +[[example]] +name = "tab_stop" +path = "examples/tab_stop.rs" + [[example]] name = "text" path = "examples/text.rs" @@ -296,6 +295,10 @@ path = "examples/text.rs" name = "text_wrapper" path = "examples/text_wrapper.rs" +[[example]] +name = "tree" +path = "examples/tree.rs" + [[example]] name = "uniform_list" path = "examples/uniform_list.rs" diff --git a/crates/gpui/build.rs b/crates/gpui/build.rs index b9496cc01426485cbef625c7e697bbf6082d1a67..93a1c15c41dd173a35ffc0adf06af6c449809890 100644 --- a/crates/gpui/build.rs +++ b/crates/gpui/build.rs @@ -9,7 +9,10 @@ fn main() { let target = env::var("CARGO_CFG_TARGET_OS"); println!("cargo::rustc-check-cfg=cfg(gles)"); - #[cfg(any(not(target_os = "macos"), feature = "macos-blade"))] + #[cfg(any( + not(any(target_os = "macos", target_os = "windows")), + all(target_os = "macos", feature = "macos-blade") + ))] check_wgsl_shaders(); match target.as_deref() { @@ -17,21 +20,18 @@ fn main() { #[cfg(target_os = "macos")] macos::build(); } - #[cfg(all(target_os = "windows", feature = "windows-manifest"))] Ok("windows") => { - let manifest = std::path::Path::new("resources/windows/gpui.manifest.xml"); - let rc_file = std::path::Path::new("resources/windows/gpui.rc"); - println!("cargo:rerun-if-changed={}", manifest.display()); - println!("cargo:rerun-if-changed={}", rc_file.display()); - embed_resource::compile(rc_file, embed_resource::NONE) - .manifest_required() - .unwrap(); + #[cfg(target_os = "windows")] + windows::build(); } _ => (), }; } -#[allow(dead_code)] +#[cfg(any( + not(any(target_os = "macos", target_os = "windows")), + all(target_os = "macos", feature = "macos-blade") +))] fn check_wgsl_shaders() { use std::path::PathBuf; use std::process; @@ -126,8 +126,9 @@ mod macos { "ContentMask".into(), "Uniforms".into(), "AtlasTile".into(), - "PathInputIndex".into(), + "PathRasterizationInputIndex".into(), "PathVertex_ScaledPixels".into(), + "PathRasterizationVertex".into(), "ShadowInputIndex".into(), "Shadow".into(), "QuadInputIndex".into(), @@ -242,3 +243,215 @@ mod macos { } } } + +#[cfg(target_os = "windows")] +mod windows { + use std::{ + fs, + io::Write, + path::{Path, PathBuf}, + process::{self, Command}, + }; + + pub(super) fn build() { + // Compile HLSL shaders + #[cfg(not(debug_assertions))] + compile_shaders(); + + // Embed the Windows manifest and resource file + #[cfg(feature = "windows-manifest")] + embed_resource(); + } + + #[cfg(feature = "windows-manifest")] + fn embed_resource() { + let manifest = std::path::Path::new("resources/windows/gpui.manifest.xml"); + let rc_file = std::path::Path::new("resources/windows/gpui.rc"); + println!("cargo:rerun-if-changed={}", manifest.display()); + println!("cargo:rerun-if-changed={}", rc_file.display()); + embed_resource::compile(rc_file, embed_resource::NONE) + .manifest_required() + .unwrap(); + } + + /// You can set the `GPUI_FXC_PATH` environment variable to specify the path to the fxc.exe compiler. + fn compile_shaders() { + let shader_path = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()) + .join("src/platform/windows/shaders.hlsl"); + let out_dir = std::env::var("OUT_DIR").unwrap(); + + println!("cargo:rerun-if-changed={}", shader_path.display()); + + // Check if fxc.exe is available + let fxc_path = find_fxc_compiler(); + + // Define all modules + let modules = [ + "quad", + "shadow", + "path_rasterization", + "path_sprite", + "underline", + "monochrome_sprite", + "polychrome_sprite", + ]; + + let rust_binding_path = format!("{}/shaders_bytes.rs", out_dir); + if Path::new(&rust_binding_path).exists() { + fs::remove_file(&rust_binding_path) + .expect("Failed to remove existing Rust binding file"); + } + for module in modules { + compile_shader_for_module( + module, + &out_dir, + &fxc_path, + shader_path.to_str().unwrap(), + &rust_binding_path, + ); + } + + { + let shader_path = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()) + .join("src/platform/windows/color_text_raster.hlsl"); + compile_shader_for_module( + "emoji_rasterization", + &out_dir, + &fxc_path, + shader_path.to_str().unwrap(), + &rust_binding_path, + ); + } + } + + /// You can set the `GPUI_FXC_PATH` environment variable to specify the path to the fxc.exe compiler. + fn find_fxc_compiler() -> String { + // Check environment variable + if let Ok(path) = std::env::var("GPUI_FXC_PATH") { + if Path::new(&path).exists() { + return path; + } + } + + // Try to find in PATH + // NOTE: This has to be `where.exe` on Windows, not `where`, it must be ended with `.exe` + if let Ok(output) = std::process::Command::new("where.exe") + .arg("fxc.exe") + .output() + { + if output.status.success() { + let path = String::from_utf8_lossy(&output.stdout); + return path.trim().to_string(); + } + } + + // Check the default path + if Path::new(r"C:\Program Files (x86)\Windows Kits\10\bin\10.0.26100.0\x64\fxc.exe") + .exists() + { + return r"C:\Program Files (x86)\Windows Kits\10\bin\10.0.26100.0\x64\fxc.exe" + .to_string(); + } + + panic!("Failed to find fxc.exe"); + } + + fn compile_shader_for_module( + module: &str, + out_dir: &str, + fxc_path: &str, + shader_path: &str, + rust_binding_path: &str, + ) { + // Compile vertex shader + let output_file = format!("{}/{}_vs.h", out_dir, module); + let const_name = format!("{}_VERTEX_BYTES", module.to_uppercase()); + compile_shader_impl( + fxc_path, + &format!("{module}_vertex"), + &output_file, + &const_name, + shader_path, + "vs_4_1", + ); + generate_rust_binding(&const_name, &output_file, &rust_binding_path); + + // Compile fragment shader + let output_file = format!("{}/{}_ps.h", out_dir, module); + let const_name = format!("{}_FRAGMENT_BYTES", module.to_uppercase()); + compile_shader_impl( + fxc_path, + &format!("{module}_fragment"), + &output_file, + &const_name, + shader_path, + "ps_4_1", + ); + generate_rust_binding(&const_name, &output_file, &rust_binding_path); + } + + fn compile_shader_impl( + fxc_path: &str, + entry_point: &str, + output_path: &str, + var_name: &str, + shader_path: &str, + target: &str, + ) { + let output = Command::new(fxc_path) + .args([ + "/T", + target, + "/E", + entry_point, + "/Fh", + output_path, + "/Vn", + var_name, + "/O3", + shader_path, + ]) + .output(); + + match output { + Ok(result) => { + if result.status.success() { + return; + } + eprintln!( + "Shader compilation failed for {}:\n{}", + entry_point, + String::from_utf8_lossy(&result.stderr) + ); + process::exit(1); + } + Err(e) => { + eprintln!("Failed to run fxc for {}: {}", entry_point, e); + process::exit(1); + } + } + } + + fn generate_rust_binding(const_name: &str, head_file: &str, output_path: &str) { + let header_content = fs::read_to_string(head_file).expect("Failed to read header file"); + let const_definition = { + let global_var_start = header_content.find("const BYTE").unwrap(); + let global_var = &header_content[global_var_start..]; + let equal = global_var.find('=').unwrap(); + global_var[equal + 1..].trim() + }; + let rust_binding = format!( + "const {}: &[u8] = &{}\n", + const_name, + const_definition.replace('{', "[").replace('}', "]") + ); + let mut options = fs::OpenOptions::new() + .create(true) + .append(true) + .open(output_path) + .expect("Failed to open Rust binding file"); + options + .write_all(rust_binding.as_bytes()) + .expect("Failed to write Rust binding file"); + } +} diff --git a/crates/gpui/examples/painting.rs b/crates/gpui/examples/painting.rs index 9ab58cffc9d417d181634e9958bb64ea5dace478..668aed23772d32a84a81cc0648d6b60dd05e21cf 100644 --- a/crates/gpui/examples/painting.rs +++ b/crates/gpui/examples/painting.rs @@ -1,15 +1,12 @@ use gpui::{ Application, Background, Bounds, ColorSpace, Context, MouseDownEvent, Path, PathBuilder, - PathStyle, Pixels, Point, Render, SharedString, StrokeOptions, Window, WindowBounds, - WindowOptions, canvas, div, linear_color_stop, linear_gradient, point, prelude::*, px, rgb, - size, + PathStyle, Pixels, Point, Render, SharedString, StrokeOptions, Window, WindowOptions, canvas, + div, linear_color_stop, linear_gradient, point, prelude::*, px, quad, rgb, size, }; -const DEFAULT_WINDOW_WIDTH: Pixels = px(1024.0); -const DEFAULT_WINDOW_HEIGHT: Pixels = px(768.0); - struct PaintingViewer { default_lines: Vec<(Path<Pixels>, Background)>, + background_quads: Vec<(Bounds<Pixels>, Background)>, lines: Vec<Vec<Point<Pixels>>>, start: Point<Pixels>, dashed: bool, @@ -20,12 +17,148 @@ impl PaintingViewer { fn new(_window: &mut Window, _cx: &mut Context<Self>) -> Self { let mut lines = vec![]; + // Black squares beneath transparent paths. + let background_quads = vec![ + ( + Bounds { + origin: point(px(70.), px(70.)), + size: size(px(40.), px(40.)), + }, + gpui::black().into(), + ), + ( + Bounds { + origin: point(px(170.), px(70.)), + size: size(px(40.), px(40.)), + }, + gpui::black().into(), + ), + ( + Bounds { + origin: point(px(270.), px(70.)), + size: size(px(40.), px(40.)), + }, + gpui::black().into(), + ), + ( + Bounds { + origin: point(px(370.), px(70.)), + size: size(px(40.), px(40.)), + }, + gpui::black().into(), + ), + ( + Bounds { + origin: point(px(450.), px(50.)), + size: size(px(80.), px(80.)), + }, + gpui::black().into(), + ), + ]; + + // 50% opaque red path that extends across black quad. + let mut builder = PathBuilder::fill(); + builder.move_to(point(px(50.), px(50.))); + builder.line_to(point(px(130.), px(50.))); + builder.line_to(point(px(130.), px(130.))); + builder.line_to(point(px(50.), px(130.))); + builder.close(); + let path = builder.build().unwrap(); + let mut red = rgb(0xFF0000); + red.a = 0.5; + lines.push((path, red.into())); + + // 50% opaque blue path that extends across black quad. + let mut builder = PathBuilder::fill(); + builder.move_to(point(px(150.), px(50.))); + builder.line_to(point(px(230.), px(50.))); + builder.line_to(point(px(230.), px(130.))); + builder.line_to(point(px(150.), px(130.))); + builder.close(); + let path = builder.build().unwrap(); + let mut blue = rgb(0x0000FF); + blue.a = 0.5; + lines.push((path, blue.into())); + + // 50% opaque green path that extends across black quad. + let mut builder = PathBuilder::fill(); + builder.move_to(point(px(250.), px(50.))); + builder.line_to(point(px(330.), px(50.))); + builder.line_to(point(px(330.), px(130.))); + builder.line_to(point(px(250.), px(130.))); + builder.close(); + let path = builder.build().unwrap(); + let mut green = rgb(0x00FF00); + green.a = 0.5; + lines.push((path, green.into())); + + // 50% opaque black path that extends across black quad. + let mut builder = PathBuilder::fill(); + builder.move_to(point(px(350.), px(50.))); + builder.line_to(point(px(430.), px(50.))); + builder.line_to(point(px(430.), px(130.))); + builder.line_to(point(px(350.), px(130.))); + builder.close(); + let path = builder.build().unwrap(); + let mut black = rgb(0x000000); + black.a = 0.5; + lines.push((path, black.into())); + + // Two 50% opaque red circles overlapping - center should be darker red + let mut builder = PathBuilder::fill(); + let center = point(px(530.), px(85.)); + let radius = px(30.); + builder.move_to(point(center.x + radius, center.y)); + builder.arc_to( + point(radius, radius), + px(0.), + false, + false, + point(center.x - radius, center.y), + ); + builder.arc_to( + point(radius, radius), + px(0.), + false, + false, + point(center.x + radius, center.y), + ); + builder.close(); + let path = builder.build().unwrap(); + let mut red1 = rgb(0xFF0000); + red1.a = 0.5; + lines.push((path, red1.into())); + + let mut builder = PathBuilder::fill(); + let center = point(px(570.), px(85.)); + let radius = px(30.); + builder.move_to(point(center.x + radius, center.y)); + builder.arc_to( + point(radius, radius), + px(0.), + false, + false, + point(center.x - radius, center.y), + ); + builder.arc_to( + point(radius, radius), + px(0.), + false, + false, + point(center.x + radius, center.y), + ); + builder.close(); + let path = builder.build().unwrap(); + let mut red2 = rgb(0xFF0000); + red2.a = 0.5; + lines.push((path, red2.into())); + // draw a Rust logo let mut builder = lyon::path::Path::svg_builder(); lyon::extra::rust_logo::build_logo_path(&mut builder); // move down the Path let mut builder: PathBuilder = builder.into(); - builder.translate(point(px(10.), px(100.))); + builder.translate(point(px(10.), px(200.))); builder.scale(0.9); let path = builder.build().unwrap(); lines.push((path, gpui::black().into())); @@ -34,10 +167,10 @@ impl PaintingViewer { let mut builder = PathBuilder::fill(); builder.add_polygon( &[ - point(px(150.), px(200.)), - point(px(200.), px(125.)), - point(px(200.), px(175.)), - point(px(250.), px(100.)), + point(px(150.), px(300.)), + point(px(200.), px(225.)), + point(px(200.), px(275.)), + point(px(250.), px(200.)), ], false, ); @@ -46,17 +179,17 @@ impl PaintingViewer { // draw a ⭐ let mut builder = PathBuilder::fill(); - builder.move_to(point(px(350.), px(100.))); - builder.line_to(point(px(370.), px(160.))); - builder.line_to(point(px(430.), px(160.))); - builder.line_to(point(px(380.), px(200.))); - builder.line_to(point(px(400.), px(260.))); - builder.line_to(point(px(350.), px(220.))); - builder.line_to(point(px(300.), px(260.))); - builder.line_to(point(px(320.), px(200.))); - builder.line_to(point(px(270.), px(160.))); - builder.line_to(point(px(330.), px(160.))); - builder.line_to(point(px(350.), px(100.))); + builder.move_to(point(px(350.), px(200.))); + builder.line_to(point(px(370.), px(260.))); + builder.line_to(point(px(430.), px(260.))); + builder.line_to(point(px(380.), px(300.))); + builder.line_to(point(px(400.), px(360.))); + builder.line_to(point(px(350.), px(320.))); + builder.line_to(point(px(300.), px(360.))); + builder.line_to(point(px(320.), px(300.))); + builder.line_to(point(px(270.), px(260.))); + builder.line_to(point(px(330.), px(260.))); + builder.line_to(point(px(350.), px(200.))); let path = builder.build().unwrap(); lines.push(( path, @@ -70,7 +203,7 @@ impl PaintingViewer { // draw linear gradient let square_bounds = Bounds { - origin: point(px(450.), px(100.)), + origin: point(px(450.), px(200.)), size: size(px(200.), px(80.)), }; let height = square_bounds.size.height; @@ -100,31 +233,31 @@ impl PaintingViewer { // draw a pie chart let center = point(px(96.), px(96.)); - let pie_center = point(px(775.), px(155.)); + let pie_center = point(px(775.), px(255.)); let segments = [ ( - point(px(871.), px(155.)), - point(px(747.), px(63.)), + point(px(871.), px(255.)), + point(px(747.), px(163.)), rgb(0x1374e9), ), ( - point(px(747.), px(63.)), - point(px(679.), px(163.)), + point(px(747.), px(163.)), + point(px(679.), px(263.)), rgb(0xe13527), ), ( - point(px(679.), px(163.)), - point(px(754.), px(249.)), + point(px(679.), px(263.)), + point(px(754.), px(349.)), rgb(0x0751ce), ), ( - point(px(754.), px(249.)), - point(px(854.), px(210.)), + point(px(754.), px(349.)), + point(px(854.), px(310.)), rgb(0x209742), ), ( - point(px(854.), px(210.)), - point(px(871.), px(155.)), + point(px(854.), px(310.)), + point(px(871.), px(255.)), rgb(0xfbc10a), ), ]; @@ -144,16 +277,19 @@ impl PaintingViewer { .with_line_width(1.) .with_line_join(lyon::path::LineJoin::Bevel); let mut builder = PathBuilder::stroke(px(1.)).with_style(PathStyle::Stroke(options)); - builder.move_to(point(px(40.), px(320.))); + builder.move_to(point(px(40.), px(420.))); for i in 1..50 { builder.line_to(point( px(40.0 + i as f32 * 10.0), - px(320.0 + (i as f32 * 10.0).sin() * 40.0), + px(420.0 + (i as f32 * 10.0).sin() * 40.0), )); } + let path = builder.build().unwrap(); + lines.push((path, gpui::green().into())); Self { default_lines: lines.clone(), + background_quads, lines: vec![], start: point(px(0.), px(0.)), dashed: false, @@ -185,13 +321,10 @@ fn button( } impl Render for PaintingViewer { - fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { - window.request_animation_frame(); - + fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { let default_lines = self.default_lines.clone(); + let background_quads = self.background_quads.clone(); let lines = self.lines.clone(); - let window_size = window.bounds().size; - let scale = window_size.width / DEFAULT_WINDOW_WIDTH; let dashed = self.dashed; div() @@ -227,8 +360,21 @@ impl Render for PaintingViewer { canvas( move |_, _, _| {}, move |_, _, window, _| { + // First draw background quads + for (bounds, color) in background_quads.iter() { + window.paint_quad(quad( + *bounds, + px(0.), + *color, + px(0.), + gpui::transparent_black(), + Default::default(), + )); + } + + // Then draw the default paths on top for (path, color) in default_lines { - window.paint_path(path.clone().scale(scale), color); + window.paint_path(path, color); } for points in lines { @@ -304,16 +450,15 @@ fn main() { cx.open_window( WindowOptions { focus: true, - window_bounds: Some(WindowBounds::Windowed(Bounds::centered( - None, - size(DEFAULT_WINDOW_WIDTH, DEFAULT_WINDOW_HEIGHT), - cx, - ))), ..Default::default() }, |window, cx| cx.new(|cx| PaintingViewer::new(window, cx)), ) .unwrap(); + cx.on_window_closed(|cx| { + cx.quit(); + }) + .detach(); cx.activate(true); }); } diff --git a/crates/gpui/examples/paths_bench.rs b/crates/gpui/examples/paths_bench.rs new file mode 100644 index 0000000000000000000000000000000000000000..a801889ae869ea7c08dce1362036b1d29c4daf36 --- /dev/null +++ b/crates/gpui/examples/paths_bench.rs @@ -0,0 +1,92 @@ +use gpui::{ + Application, Background, Bounds, ColorSpace, Context, Path, PathBuilder, Pixels, Render, + TitlebarOptions, Window, WindowBounds, WindowOptions, canvas, div, linear_color_stop, + linear_gradient, point, prelude::*, px, rgb, size, +}; + +const DEFAULT_WINDOW_WIDTH: Pixels = px(1024.0); +const DEFAULT_WINDOW_HEIGHT: Pixels = px(768.0); + +struct PaintingViewer { + default_lines: Vec<(Path<Pixels>, Background)>, + _painting: bool, +} + +impl PaintingViewer { + fn new(_window: &mut Window, _cx: &mut Context<Self>) -> Self { + let mut lines = vec![]; + + // draw a lightening bolt ⚡ + for _ in 0..2000 { + // draw a ⭐ + let mut builder = PathBuilder::fill(); + builder.move_to(point(px(350.), px(100.))); + builder.line_to(point(px(370.), px(160.))); + builder.line_to(point(px(430.), px(160.))); + builder.line_to(point(px(380.), px(200.))); + builder.line_to(point(px(400.), px(260.))); + builder.line_to(point(px(350.), px(220.))); + builder.line_to(point(px(300.), px(260.))); + builder.line_to(point(px(320.), px(200.))); + builder.line_to(point(px(270.), px(160.))); + builder.line_to(point(px(330.), px(160.))); + builder.line_to(point(px(350.), px(100.))); + let path = builder.build().unwrap(); + lines.push(( + path, + linear_gradient( + 180., + linear_color_stop(rgb(0xFACC15), 0.7), + linear_color_stop(rgb(0xD56D0C), 1.), + ) + .color_space(ColorSpace::Oklab), + )); + } + + Self { + default_lines: lines, + _painting: false, + } + } +} + +impl Render for PaintingViewer { + fn render(&mut self, window: &mut Window, _: &mut Context<Self>) -> impl IntoElement { + window.request_animation_frame(); + let lines = self.default_lines.clone(); + div().size_full().child( + canvas( + move |_, _, _| {}, + move |_, _, window, _| { + for (path, color) in lines { + window.paint_path(path, color); + } + }, + ) + .size_full(), + ) + } +} + +fn main() { + Application::new().run(|cx| { + cx.open_window( + WindowOptions { + titlebar: Some(TitlebarOptions { + title: Some("Vulkan".into()), + ..Default::default() + }), + focus: true, + window_bounds: Some(WindowBounds::Windowed(Bounds::centered( + None, + size(DEFAULT_WINDOW_WIDTH, DEFAULT_WINDOW_HEIGHT), + cx, + ))), + ..Default::default() + }, + |window, cx| cx.new(|cx| PaintingViewer::new(window, cx)), + ) + .unwrap(); + cx.activate(true); + }); +} diff --git a/crates/gpui/examples/set_menus.rs b/crates/gpui/examples/set_menus.rs index 2b302f78f273449b9afeac8cb272a1a8148aaf56..f53fff7c7f7dfca1d2e44faf39347d1716ddad1c 100644 --- a/crates/gpui/examples/set_menus.rs +++ b/crates/gpui/examples/set_menus.rs @@ -34,7 +34,7 @@ fn main() { }); } -// Associate actions using the `actions!` macro (or `impl_actions!` macro) +// Associate actions using the `actions!` macro (or `Action` derive macro) actions!(set_menus, [Quit]); // Define the quit function that is registered with the App diff --git a/crates/gpui/examples/tab_stop.rs b/crates/gpui/examples/tab_stop.rs new file mode 100644 index 0000000000000000000000000000000000000000..8dbcbeccb7351fda18e8d36fe38d8f26c4a70cc9 --- /dev/null +++ b/crates/gpui/examples/tab_stop.rs @@ -0,0 +1,155 @@ +use gpui::{ + App, Application, Bounds, Context, Div, ElementId, FocusHandle, KeyBinding, SharedString, + Stateful, Window, WindowBounds, WindowOptions, actions, div, prelude::*, px, size, +}; + +actions!(example, [Tab, TabPrev]); + +struct Example { + focus_handle: FocusHandle, + items: Vec<FocusHandle>, + message: SharedString, +} + +impl Example { + fn new(window: &mut Window, cx: &mut Context<Self>) -> Self { + let items = vec![ + cx.focus_handle().tab_index(1).tab_stop(true), + cx.focus_handle().tab_index(2).tab_stop(true), + cx.focus_handle().tab_index(3).tab_stop(true), + cx.focus_handle(), + cx.focus_handle().tab_index(2).tab_stop(true), + ]; + + let focus_handle = cx.focus_handle(); + window.focus(&focus_handle); + + Self { + focus_handle, + items, + message: SharedString::from("Press `Tab`, `Shift-Tab` to switch focus."), + } + } + + fn on_tab(&mut self, _: &Tab, window: &mut Window, _: &mut Context<Self>) { + window.focus_next(); + self.message = SharedString::from("You have pressed `Tab`."); + } + + fn on_tab_prev(&mut self, _: &TabPrev, window: &mut Window, _: &mut Context<Self>) { + window.focus_prev(); + self.message = SharedString::from("You have pressed `Shift-Tab`."); + } +} + +impl Render for Example { + fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { + fn tab_stop_style<T: Styled>(this: T) -> T { + this.border_3().border_color(gpui::blue()) + } + + fn button(id: impl Into<ElementId>) -> Stateful<Div> { + div() + .id(id) + .h_10() + .flex_1() + .flex() + .justify_center() + .items_center() + .border_1() + .border_color(gpui::black()) + .bg(gpui::black()) + .text_color(gpui::white()) + .focus(tab_stop_style) + .shadow_sm() + } + + div() + .id("app") + .track_focus(&self.focus_handle) + .on_action(cx.listener(Self::on_tab)) + .on_action(cx.listener(Self::on_tab_prev)) + .size_full() + .flex() + .flex_col() + .p_4() + .gap_3() + .bg(gpui::white()) + .text_color(gpui::black()) + .child(self.message.clone()) + .children( + self.items + .clone() + .into_iter() + .enumerate() + .map(|(ix, item_handle)| { + div() + .id(("item", ix)) + .track_focus(&item_handle) + .h_10() + .w_full() + .flex() + .justify_center() + .items_center() + .border_1() + .border_color(gpui::black()) + .when( + item_handle.tab_stop && item_handle.is_focused(window), + tab_stop_style, + ) + .map(|this| match item_handle.tab_stop { + true => this + .hover(|this| this.bg(gpui::black().opacity(0.1))) + .child(format!("tab_index: {}", item_handle.tab_index)), + false => this.opacity(0.4).child("tab_stop: false"), + }) + }), + ) + .child( + div() + .flex() + .flex_row() + .gap_3() + .items_center() + .child( + button("el1") + .tab_index(4) + .child("Button 1") + .on_click(cx.listener(|this, _, _, cx| { + this.message = "You have clicked Button 1.".into(); + cx.notify(); + })), + ) + .child( + button("el2") + .tab_index(5) + .child("Button 2") + .on_click(cx.listener(|this, _, _, cx| { + this.message = "You have clicked Button 2.".into(); + cx.notify(); + })), + ), + ) + } +} + +fn main() { + Application::new().run(|cx: &mut App| { + cx.bind_keys([ + KeyBinding::new("tab", Tab, None), + KeyBinding::new("shift-tab", TabPrev, None), + ]); + + let bounds = Bounds::centered(None, size(px(800.), px(600.0)), cx); + cx.open_window( + WindowOptions { + window_bounds: Some(WindowBounds::Windowed(bounds)), + ..Default::default() + }, + |window, cx| cx.new(|cx| Example::new(window, cx)), + ) + .unwrap(); + + cx.activate(true); + }); +} diff --git a/crates/gpui/examples/text.rs b/crates/gpui/examples/text.rs index 19214aebdefccba9216e1a6e250244eb231d282a..1166bb279541c80eb8686b59c85724b4068895ed 100644 --- a/crates/gpui/examples/text.rs +++ b/crates/gpui/examples/text.rs @@ -198,7 +198,7 @@ impl RenderOnce for CharacterGrid { "χ", "ψ", "∂", "а", "в", "Ж", "ж", "З", "з", "К", "к", "л", "м", "Н", "н", "Р", "р", "У", "у", "ф", "ч", "ь", "ы", "Э", "э", "Я", "я", "ij", "öẋ", ".,", "⣝⣑", "~", "*", "_", "^", "`", "'", "(", "{", "«", "#", "&", "@", "$", "¢", "%", "|", "?", "¶", "µ", - "❮", "<=", "!=", "==", "--", "++", "=>", "->", + "❮", "<=", "!=", "==", "--", "++", "=>", "->", "🏀", "🎊", "😍", "❤️", "👍", "👎", ]; let columns = 11; diff --git a/crates/gpui/examples/tree.rs b/crates/gpui/examples/tree.rs new file mode 100644 index 0000000000000000000000000000000000000000..1bd45920037839c27ea5773f23daa9dcbbceae0e --- /dev/null +++ b/crates/gpui/examples/tree.rs @@ -0,0 +1,46 @@ +//! Renders a div with deep children hierarchy. This example is useful to exemplify that Zed can +//! handle deep hierarchies (even though it cannot just yet!). +use std::sync::LazyLock; + +use gpui::{ + App, Application, Bounds, Context, Window, WindowBounds, WindowOptions, div, prelude::*, px, + size, +}; + +struct Tree {} + +static DEPTH: LazyLock<u64> = LazyLock::new(|| { + std::env::var("GPUI_TREE_DEPTH") + .ok() + .and_then(|depth| depth.parse().ok()) + .unwrap_or_else(|| 50) +}); + +impl Render for Tree { + fn render(&mut self, _: &mut Window, _: &mut Context<Self>) -> impl IntoElement { + let mut depth = *DEPTH; + static COLORS: [gpui::Hsla; 4] = [gpui::red(), gpui::blue(), gpui::green(), gpui::yellow()]; + let mut colors = COLORS.iter().cycle().copied(); + let mut next_div = || div().p_0p5().bg(colors.next().unwrap()); + let mut innermost_node = next_div(); + while depth > 0 { + innermost_node = next_div().child(innermost_node); + depth -= 1; + } + innermost_node + } +} + +fn main() { + Application::new().run(|cx: &mut App| { + let bounds = Bounds::centered(None, size(px(300.0), px(300.0)), cx); + cx.open_window( + WindowOptions { + window_bounds: Some(WindowBounds::Windowed(bounds)), + ..Default::default() + }, + |_, cx| cx.new(|_| Tree {}), + ) + .unwrap(); + }); +} diff --git a/crates/gpui/examples/window_shadow.rs b/crates/gpui/examples/window_shadow.rs index 06dde911330d0b82ba3584cf5fb8054f57920b93..469017da795d54a36353449f0043483526f63708 100644 --- a/crates/gpui/examples/window_shadow.rs +++ b/crates/gpui/examples/window_shadow.rs @@ -165,8 +165,8 @@ impl Render for WindowShadow { }, ) .on_click(|e, window, _| { - if e.down.button == MouseButton::Right { - window.show_window_menu(e.up.position); + if e.is_right_click() { + window.show_window_menu(e.position()); } }) .text_color(black()) diff --git a/crates/gpui/src/action.rs b/crates/gpui/src/action.rs index e099bfec28c3e3f15267348694f60c961df6f086..b179076cd5f0da826ca0d5da5e2a5a41cbb5e806 100644 --- a/crates/gpui/src/action.rs +++ b/crates/gpui/src/action.rs @@ -403,12 +403,10 @@ impl ActionRegistry { /// Useful for transforming the list of available actions into a /// format suited for static analysis such as in validating keymaps, or /// generating documentation. -pub fn generate_list_of_all_registered_actions() -> Vec<MacroActionData> { - let mut actions = Vec::new(); - for builder in inventory::iter::<MacroActionBuilder> { - actions.push(builder.0()); - } - actions +pub fn generate_list_of_all_registered_actions() -> impl Iterator<Item = MacroActionData> { + inventory::iter::<MacroActionBuilder> + .into_iter() + .map(|builder| builder.0()) } mod no_action { diff --git a/crates/gpui/src/app.rs b/crates/gpui/src/app.rs index 6cecfcc0e42b239dc98db9391650f0618530d52c..ded7bae3164ab5b76568290b7335599dd720c320 100644 --- a/crates/gpui/src/app.rs +++ b/crates/gpui/src/app.rs @@ -448,15 +448,23 @@ impl App { } pub(crate) fn update<R>(&mut self, update: impl FnOnce(&mut Self) -> R) -> R { - self.pending_updates += 1; + self.start_update(); let result = update(self); + self.finish_update(); + result + } + + pub(crate) fn start_update(&mut self) { + self.pending_updates += 1; + } + + pub(crate) fn finish_update(&mut self) { if !self.flushing_effects && self.pending_updates == 1 { self.flushing_effects = true; self.flush_effects(); self.flushing_effects = false; } self.pending_updates -= 1; - result } /// Arrange a callback to be invoked when the given entity calls `notify` on its respective context. @@ -688,7 +696,7 @@ impl App { /// Returns a list of available screen capture sources. pub fn screen_capture_sources( &self, - ) -> oneshot::Receiver<Result<Vec<Box<dyn ScreenCaptureSource>>>> { + ) -> oneshot::Receiver<Result<Vec<Rc<dyn ScreenCaptureSource>>>> { self.platform.screen_capture_sources() } @@ -868,7 +876,6 @@ impl App { loop { self.release_dropped_entities(); self.release_dropped_focus_handles(); - if let Some(effect) = self.pending_effects.pop_front() { match effect { Effect::Notify { emitter } => { @@ -947,8 +954,8 @@ impl App { self.focus_handles .clone() .write() - .retain(|handle_id, count| { - if count.load(SeqCst) == 0 { + .retain(|handle_id, focus| { + if focus.ref_count.load(SeqCst) == 0 { for window_handle in self.windows() { window_handle .update(self, |_, window, _| { @@ -1250,11 +1257,7 @@ impl App { .downcast::<T>() .unwrap() .update(cx, |entity_state, cx| { - if let Some(window) = window { - on_new(entity_state, Some(window), cx); - } else { - on_new(entity_state, None, cx); - } + on_new(entity_state, window.as_deref_mut(), cx) }) }, ), @@ -1367,7 +1370,9 @@ impl App { self.keymap.clone() } - /// Register a global listener for actions invoked via the keyboard. + /// Register a global handler for actions invoked via the keyboard. These handlers are run at + /// the end of the bubble phase for actions, and so will only be invoked if there are no other + /// handlers or if they called `cx.propagate()`. pub fn on_action<A: Action>(&mut self, listener: impl Fn(&A, &mut Self) + 'static) { self.global_action_listeners .entry(TypeId::of::<A>()) @@ -1823,6 +1828,13 @@ impl AppContext for App { }) } + fn as_mut<'a, T>(&'a mut self, handle: &Entity<T>) -> GpuiBorrow<'a, T> + where + T: 'static, + { + GpuiBorrow::new(handle.clone(), self) + } + fn read_entity<T, R>( &self, handle: &Entity<T>, @@ -2011,6 +2023,10 @@ impl HttpClient for NullHttpClient { .boxed() } + fn user_agent(&self) -> Option<&http_client::http::HeaderValue> { + None + } + fn proxy(&self) -> Option<&Url> { None } @@ -2019,3 +2035,79 @@ impl HttpClient for NullHttpClient { type_name::<Self>() } } + +/// A mutable reference to an entity owned by GPUI +pub struct GpuiBorrow<'a, T> { + inner: Option<Lease<T>>, + app: &'a mut App, +} + +impl<'a, T: 'static> GpuiBorrow<'a, T> { + fn new(inner: Entity<T>, app: &'a mut App) -> Self { + app.start_update(); + let lease = app.entities.lease(&inner); + Self { + inner: Some(lease), + app, + } + } +} + +impl<'a, T: 'static> std::borrow::Borrow<T> for GpuiBorrow<'a, T> { + fn borrow(&self) -> &T { + self.inner.as_ref().unwrap().borrow() + } +} + +impl<'a, T: 'static> std::borrow::BorrowMut<T> for GpuiBorrow<'a, T> { + fn borrow_mut(&mut self) -> &mut T { + self.inner.as_mut().unwrap().borrow_mut() + } +} + +impl<'a, T> Drop for GpuiBorrow<'a, T> { + fn drop(&mut self) { + let lease = self.inner.take().unwrap(); + self.app.notify(lease.id); + self.app.entities.end_lease(lease); + self.app.finish_update(); + } +} + +#[cfg(test)] +mod test { + use std::{cell::RefCell, rc::Rc}; + + use crate::{AppContext, TestAppContext}; + + #[test] + fn test_gpui_borrow() { + let cx = TestAppContext::single(); + let observation_count = Rc::new(RefCell::new(0)); + + let state = cx.update(|cx| { + let state = cx.new(|_| false); + cx.observe(&state, { + let observation_count = observation_count.clone(); + move |_, _| { + let mut count = observation_count.borrow_mut(); + *count += 1; + } + }) + .detach(); + + state + }); + + cx.update(|cx| { + // Calling this like this so that we don't clobber the borrow_mut above + *std::borrow::BorrowMut::borrow_mut(&mut state.as_mut(cx)) = true; + }); + + cx.update(|cx| { + state.write(cx, false); + }); + + assert_eq!(*observation_count.borrow(), 2); + } +} diff --git a/crates/gpui/src/app/async_context.rs b/crates/gpui/src/app/async_context.rs index c3b60dd580483771f683b6d76fd76e52b3f531ad..d9d21c024461cab68d62d685a40b61c9c74d46dd 100644 --- a/crates/gpui/src/app/async_context.rs +++ b/crates/gpui/src/app/async_context.rs @@ -3,7 +3,7 @@ use crate::{ Entity, EventEmitter, Focusable, ForegroundExecutor, Global, PromptButton, PromptLevel, Render, Reservation, Result, Subscription, Task, VisualContext, Window, WindowHandle, }; -use anyhow::Context as _; +use anyhow::{Context as _, anyhow}; use derive_more::{Deref, DerefMut}; use futures::channel::oneshot; use std::{future::Future, rc::Weak}; @@ -58,6 +58,15 @@ impl AppContext for AsyncApp { Ok(app.update_entity(handle, update)) } + fn as_mut<'a, T>(&'a mut self, _handle: &Entity<T>) -> Self::Result<super::GpuiBorrow<'a, T>> + where + T: 'static, + { + Err(anyhow!( + "Cannot as_mut with an async context. Try calling update() first" + )) + } + fn read_entity<T, R>( &self, handle: &Entity<T>, @@ -364,6 +373,15 @@ impl AppContext for AsyncWindowContext { .update(self, |_, _, cx| cx.update_entity(handle, update)) } + fn as_mut<'a, T>(&'a mut self, _: &Entity<T>) -> Self::Result<super::GpuiBorrow<'a, T>> + where + T: 'static, + { + Err(anyhow!( + "Cannot use as_mut() from an async context, call `update`" + )) + } + fn read_entity<T, R>( &self, handle: &Entity<T>, diff --git a/crates/gpui/src/app/context.rs b/crates/gpui/src/app/context.rs index 2d90ff35b1b47c44e19de15adad64b7b569a0ec1..392be2ffe9ce4eed9397a11770b7133db145a7a8 100644 --- a/crates/gpui/src/app/context.rs +++ b/crates/gpui/src/app/context.rs @@ -726,6 +726,13 @@ impl<T> AppContext for Context<'_, T> { self.app.update_entity(handle, update) } + fn as_mut<'a, E>(&'a mut self, handle: &Entity<E>) -> Self::Result<super::GpuiBorrow<'a, E>> + where + E: 'static, + { + self.app.as_mut(handle) + } + fn read_entity<U, R>( &self, handle: &Entity<U>, diff --git a/crates/gpui/src/app/entity_map.rs b/crates/gpui/src/app/entity_map.rs index f1aafa55e871567a58fe0696a4e84287e82bd437..fccb417caa70c7526a0f15a307d74caeabcdab77 100644 --- a/crates/gpui/src/app/entity_map.rs +++ b/crates/gpui/src/app/entity_map.rs @@ -1,4 +1,4 @@ -use crate::{App, AppContext, VisualContext, Window, seal::Sealed}; +use crate::{App, AppContext, GpuiBorrow, VisualContext, Window, seal::Sealed}; use anyhow::{Context as _, Result}; use collections::FxHashSet; use derive_more::{Deref, DerefMut}; @@ -105,7 +105,7 @@ impl EntityMap { /// Move an entity to the stack. #[track_caller] - pub fn lease<'a, T>(&mut self, pointer: &'a Entity<T>) -> Lease<'a, T> { + pub fn lease<T>(&mut self, pointer: &Entity<T>) -> Lease<T> { self.assert_valid_context(pointer); let mut accessed_entities = self.accessed_entities.borrow_mut(); accessed_entities.insert(pointer.entity_id); @@ -117,15 +117,14 @@ impl EntityMap { ); Lease { entity, - pointer, + id: pointer.entity_id, entity_type: PhantomData, } } /// Returns an entity after moving it to the stack. pub fn end_lease<T>(&mut self, mut lease: Lease<T>) { - self.entities - .insert(lease.pointer.entity_id, lease.entity.take().unwrap()); + self.entities.insert(lease.id, lease.entity.take().unwrap()); } pub fn read<T: 'static>(&self, entity: &Entity<T>) -> &T { @@ -187,13 +186,13 @@ fn double_lease_panic<T>(operation: &str) -> ! { ) } -pub(crate) struct Lease<'a, T> { +pub(crate) struct Lease<T> { entity: Option<Box<dyn Any>>, - pub pointer: &'a Entity<T>, + pub id: EntityId, entity_type: PhantomData<T>, } -impl<T: 'static> core::ops::Deref for Lease<'_, T> { +impl<T: 'static> core::ops::Deref for Lease<T> { type Target = T; fn deref(&self) -> &Self::Target { @@ -201,13 +200,13 @@ impl<T: 'static> core::ops::Deref for Lease<'_, T> { } } -impl<T: 'static> core::ops::DerefMut for Lease<'_, T> { +impl<T: 'static> core::ops::DerefMut for Lease<T> { fn deref_mut(&mut self) -> &mut Self::Target { self.entity.as_mut().unwrap().downcast_mut().unwrap() } } -impl<T> Drop for Lease<'_, T> { +impl<T> Drop for Lease<T> { fn drop(&mut self) { if self.entity.is_some() && !panicking() { panic!("Leases must be ended with EntityMap::end_lease") @@ -371,7 +370,7 @@ impl std::fmt::Debug for AnyEntity { } } -/// A strong, well typed reference to a struct which is managed +/// A strong, well-typed reference to a struct which is managed /// by GPUI #[derive(Deref, DerefMut)] pub struct Entity<T> { @@ -437,6 +436,19 @@ impl<T: 'static> Entity<T> { cx.update_entity(self, update) } + /// Updates the entity referenced by this handle with the given function. + pub fn as_mut<'a, C: AppContext>(&self, cx: &'a mut C) -> C::Result<GpuiBorrow<'a, T>> { + cx.as_mut(self) + } + + /// Updates the entity referenced by this handle with the given function. + pub fn write<C: AppContext>(&self, cx: &mut C, value: T) -> C::Result<()> { + self.update(cx, |entity, cx| { + *entity = value; + cx.notify(); + }) + } + /// Updates the entity referenced by this handle with the given function if /// the referenced entity still exists, within a visual context that has a window. /// Returns an error if the entity has been released. diff --git a/crates/gpui/src/app/test_context.rs b/crates/gpui/src/app/test_context.rs index dfc7af0d9c02ae08f7ff46b2400e3ebecc48f8ec..35e60326714f049faeaac54e8d979a91f9d97bbc 100644 --- a/crates/gpui/src/app/test_context.rs +++ b/crates/gpui/src/app/test_context.rs @@ -9,6 +9,7 @@ use crate::{ }; use anyhow::{anyhow, bail}; use futures::{Stream, StreamExt, channel::oneshot}; +use rand::{SeedableRng, rngs::StdRng}; use std::{cell::RefCell, future::Future, ops::Deref, rc::Rc, sync::Arc, time::Duration}; /// A TestAppContext is provided to tests created with `#[gpui::test]`, it provides @@ -63,6 +64,13 @@ impl AppContext for TestAppContext { app.update_entity(handle, update) } + fn as_mut<'a, T>(&'a mut self, _: &Entity<T>) -> Self::Result<super::GpuiBorrow<'a, T>> + where + T: 'static, + { + panic!("Cannot use as_mut with a test app context. Try calling update() first") + } + fn read_entity<T, R>( &self, handle: &Entity<T>, @@ -134,6 +142,12 @@ impl TestAppContext { } } + /// Create a single TestAppContext, for non-multi-client tests + pub fn single() -> Self { + let dispatcher = TestDispatcher::new(StdRng::from_entropy()); + Self::build(dispatcher, None) + } + /// The name of the test function that created this `TestAppContext` pub fn test_function_name(&self) -> Option<&'static str> { self.fn_name @@ -914,6 +928,13 @@ impl AppContext for VisualTestContext { self.cx.update_entity(handle, update) } + fn as_mut<'a, T>(&'a mut self, handle: &Entity<T>) -> Self::Result<super::GpuiBorrow<'a, T>> + where + T: 'static, + { + self.cx.as_mut(handle) + } + fn read_entity<T, R>( &self, handle: &Entity<T>, diff --git a/crates/gpui/src/color.rs b/crates/gpui/src/color.rs index 7fc9c24393907d3991edcf9ae82b25eee419e766..639c84c10144310b14a94c2a22b84957b8b09524 100644 --- a/crates/gpui/src/color.rs +++ b/crates/gpui/src/color.rs @@ -12,18 +12,13 @@ use std::{ /// Convert an RGB hex color code number to a color type pub fn rgb(hex: u32) -> Rgba { - let r = ((hex >> 16) & 0xFF) as f32 / 255.0; - let g = ((hex >> 8) & 0xFF) as f32 / 255.0; - let b = (hex & 0xFF) as f32 / 255.0; + let [_, r, g, b] = hex.to_be_bytes().map(|b| (b as f32) / 255.0); Rgba { r, g, b, a: 1.0 } } /// Convert an RGBA hex color code number to [`Rgba`] pub fn rgba(hex: u32) -> Rgba { - let r = ((hex >> 24) & 0xFF) as f32 / 255.0; - let g = ((hex >> 16) & 0xFF) as f32 / 255.0; - let b = ((hex >> 8) & 0xFF) as f32 / 255.0; - let a = (hex & 0xFF) as f32 / 255.0; + let [r, g, b, a] = hex.to_be_bytes().map(|b| (b as f32) / 255.0); Rgba { r, g, b, a } } @@ -40,6 +35,7 @@ pub(crate) fn swap_rgba_pa_to_bgra(color: &mut [u8]) { /// An RGBA color #[derive(PartialEq, Clone, Copy, Default)] +#[repr(C)] pub struct Rgba { /// The red component of the color, in the range 0.0 to 1.0 pub r: f32, @@ -63,14 +59,14 @@ impl Rgba { if other.a >= 1.0 { other } else if other.a <= 0.0 { - return *self; + *self } else { - return Rgba { + Rgba { r: (self.r * (1.0 - other.a)) + (other.r * other.a), g: (self.g * (1.0 - other.a)) + (other.g * other.a), b: (self.b * (1.0 - other.a)) + (other.b * other.a), a: self.a, - }; + } } } } @@ -494,12 +490,12 @@ impl Hsla { if alpha >= 1.0 { other } else if alpha <= 0.0 { - return self; + self } else { let converted_self = Rgba::from(self); let converted_other = Rgba::from(other); let blended_rgb = converted_self.blend(converted_other); - return Hsla::from(blended_rgb); + Hsla::from(blended_rgb) } } diff --git a/crates/gpui/src/element.rs b/crates/gpui/src/element.rs index 2852841b2c2b42ceceddaeebcf0b3abfa2684808..e5f49c7be141a3620e52599bcc2b151acc1f7319 100644 --- a/crates/gpui/src/element.rs +++ b/crates/gpui/src/element.rs @@ -39,7 +39,7 @@ use crate::{ use derive_more::{Deref, DerefMut}; pub(crate) use smallvec::SmallVec; use std::{ - any::Any, + any::{Any, type_name}, fmt::{self, Debug, Display}, mem, panic, }; @@ -220,14 +220,17 @@ impl<C: RenderOnce> Element for Component<C> { window: &mut Window, cx: &mut App, ) -> (LayoutId, Self::RequestLayoutState) { - let mut element = self - .component - .take() - .unwrap() - .render(window, cx) - .into_any_element(); - let layout_id = element.request_layout(window, cx); - (layout_id, element) + window.with_global_id(ElementId::Name(type_name::<C>().into()), |_, window| { + let mut element = self + .component + .take() + .unwrap() + .render(window, cx) + .into_any_element(); + + let layout_id = element.request_layout(window, cx); + (layout_id, element) + }) } fn prepaint( @@ -239,7 +242,9 @@ impl<C: RenderOnce> Element for Component<C> { window: &mut Window, cx: &mut App, ) { - element.prepaint(window, cx); + window.with_global_id(ElementId::Name(type_name::<C>().into()), |_, window| { + element.prepaint(window, cx); + }) } fn paint( @@ -252,7 +257,9 @@ impl<C: RenderOnce> Element for Component<C> { window: &mut Window, cx: &mut App, ) { - element.paint(window, cx); + window.with_global_id(ElementId::Name(type_name::<C>().into()), |_, window| { + element.paint(window, cx); + }) } } diff --git a/crates/gpui/src/elements/animation.rs b/crates/gpui/src/elements/animation.rs index bcdfa3562c747999dde96498e046ce7bd4629ac2..11dd19e260c20e49b87e05137771be73a3f816ea 100644 --- a/crates/gpui/src/elements/animation.rs +++ b/crates/gpui/src/elements/animation.rs @@ -1,4 +1,7 @@ -use std::time::{Duration, Instant}; +use std::{ + rc::Rc, + time::{Duration, Instant}, +}; use crate::{ AnyElement, App, Element, ElementId, GlobalElementId, InspectorElementId, IntoElement, Window, @@ -8,6 +11,7 @@ pub use easing::*; use smallvec::SmallVec; /// An animation that can be applied to an element. +#[derive(Clone)] pub struct Animation { /// The amount of time for which this animation should run pub duration: Duration, @@ -15,7 +19,7 @@ pub struct Animation { pub oneshot: bool, /// A function that takes a delta between 0 and 1 and returns a new delta /// between 0 and 1 based on the given easing function. - pub easing: Box<dyn Fn(f32) -> f32>, + pub easing: Rc<dyn Fn(f32) -> f32>, } impl Animation { @@ -25,7 +29,7 @@ impl Animation { Self { duration, oneshot: true, - easing: Box::new(linear), + easing: Rc::new(linear), } } @@ -39,7 +43,7 @@ impl Animation { /// The easing function will take a time delta between 0 and 1 and return a new delta /// between 0 and 1 pub fn with_easing(mut self, easing: impl Fn(f32) -> f32 + 'static) -> Self { - self.easing = Box::new(easing); + self.easing = Rc::new(easing); self } } diff --git a/crates/gpui/src/elements/div.rs b/crates/gpui/src/elements/div.rs index 6e05b384e15492f6ebd137004f0f13fd4a6d549c..09afbff929b99bb927d365621ea0550c28dcedf8 100644 --- a/crates/gpui/src/elements/div.rs +++ b/crates/gpui/src/elements/div.rs @@ -19,10 +19,10 @@ use crate::{ Action, AnyDrag, AnyElement, AnyTooltip, AnyView, App, Bounds, ClickEvent, DispatchPhase, Element, ElementId, Entity, FocusHandle, Global, GlobalElementId, Hitbox, HitboxBehavior, HitboxId, InspectorElementId, IntoElement, IsZero, KeyContext, KeyDownEvent, KeyUpEvent, - LayoutId, ModifiersChangedEvent, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, - Overflow, ParentElement, Pixels, Point, Render, ScrollWheelEvent, SharedString, Size, Style, - StyleRefinement, Styled, Task, TooltipId, Visibility, Window, WindowControlArea, point, px, - size, + KeyboardButton, KeyboardClickEvent, LayoutId, ModifiersChangedEvent, MouseButton, + MouseClickEvent, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Overflow, ParentElement, Pixels, + Point, Render, ScrollWheelEvent, SharedString, Size, Style, StyleRefinement, Styled, Task, + TooltipId, Visibility, Window, WindowControlArea, point, px, size, }; use collections::HashMap; use refineable::Refineable; @@ -484,10 +484,9 @@ impl Interactivity { where Self: Sized, { - self.click_listeners - .push(Box::new(move |event, window, cx| { - listener(event, window, cx) - })); + self.click_listeners.push(Rc::new(move |event, window, cx| { + listener(event, window, cx) + })); } /// On drag initiation, this callback will be used to create a new view to render the dragged value for a @@ -619,6 +618,13 @@ pub trait InteractiveElement: Sized { self } + /// Set index of the tab stop order. + fn tab_index(mut self, index: isize) -> Self { + self.interactivity().focusable = true; + self.interactivity().tab_index = Some(index); + self + } + /// Set the keymap context for this element. This will be used to determine /// which action to dispatch from the keymap. fn key_context<C, E>(mut self, key_context: C) -> Self @@ -903,7 +909,7 @@ pub trait InteractiveElement: Sized { /// Apply the given style when the given data type is dragged over this element fn drag_over<S: 'static>( mut self, - f: impl 'static + Fn(StyleRefinement, &S, &Window, &App) -> StyleRefinement, + f: impl 'static + Fn(StyleRefinement, &S, &mut Window, &mut App) -> StyleRefinement, ) -> Self { self.interactivity().drag_over_styles.push(( TypeId::of::<S>(), @@ -1149,7 +1155,7 @@ pub(crate) type MouseMoveListener = pub(crate) type ScrollWheelListener = Box<dyn Fn(&ScrollWheelEvent, DispatchPhase, &Hitbox, &mut Window, &mut App) + 'static>; -pub(crate) type ClickListener = Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>; +pub(crate) type ClickListener = Rc<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>; pub(crate) type DragListener = Box<dyn Fn(&dyn Any, Point<Pixels>, &mut Window, &mut App) -> AnyView + 'static>; @@ -1327,7 +1333,6 @@ impl Element for Div { } else if let Some(scroll_handle) = self.interactivity.tracked_scroll_handle.as_ref() { let mut state = scroll_handle.0.borrow_mut(); state.child_bounds = Vec::with_capacity(request_layout.child_layout_ids.len()); - state.bounds = bounds; for child_layout_id in &request_layout.child_layout_ids { let child_bounds = window.layout_bounds(*child_layout_id); child_min = child_min.min(&child_bounds.origin); @@ -1462,6 +1467,7 @@ pub struct Interactivity { pub(crate) tooltip_builder: Option<TooltipBuilder>, pub(crate) window_control: Option<WindowControlArea>, pub(crate) hitbox_behavior: HitboxBehavior, + pub(crate) tab_index: Option<isize>, #[cfg(any(feature = "inspector", debug_assertions))] pub(crate) source_location: Option<&'static core::panic::Location<'static>>, @@ -1521,12 +1527,17 @@ impl Interactivity { // as frames contain an element with this id. if self.focusable && self.tracked_focus_handle.is_none() { if let Some(element_state) = element_state.as_mut() { - self.tracked_focus_handle = Some( - element_state - .focus_handle - .get_or_insert_with(|| cx.focus_handle()) - .clone(), - ); + let mut handle = element_state + .focus_handle + .get_or_insert_with(|| cx.focus_handle()) + .clone() + .tab_stop(false); + + if let Some(index) = self.tab_index { + handle = handle.tab_index(index).tab_stop(true); + } + + self.tracked_focus_handle = Some(handle); } } @@ -1651,6 +1662,11 @@ impl Interactivity { window: &mut Window, _cx: &mut App, ) -> Point<Pixels> { + fn round_to_two_decimals(pixels: Pixels) -> Pixels { + const ROUNDING_FACTOR: f32 = 100.0; + (pixels * ROUNDING_FACTOR).round() / ROUNDING_FACTOR + } + if let Some(scroll_offset) = self.scroll_offset.as_ref() { let mut scroll_to_bottom = false; let mut tracked_scroll_handle = self @@ -1665,8 +1681,16 @@ impl Interactivity { let rem_size = window.rem_size(); let padding = style.padding.to_pixels(bounds.size.into(), rem_size); let padding_size = size(padding.left + padding.right, padding.top + padding.bottom); + // The floating point values produced by Taffy and ours often vary + // slightly after ~5 decimal places. This can lead to cases where after + // subtracting these, the container becomes scrollable for less than + // 0.00000x pixels. As we generally don't benefit from a precision that + // high for the maximum scroll, we round the scroll max to 2 decimal + // places here. let padded_content_size = self.content_size + padding_size; - let scroll_max = (padded_content_size - bounds.size).max(&Size::default()); + let scroll_max = (padded_content_size - bounds.size) + .map(round_to_two_decimals) + .max(&Default::default()); // Clamp scroll offset in case scroll max is smaller now (e.g., if children // were removed or the bounds became larger). let mut scroll_offset = scroll_offset.borrow_mut(); @@ -1679,7 +1703,8 @@ impl Interactivity { } if let Some(mut scroll_handle_state) = tracked_scroll_handle { - scroll_handle_state.padded_content_size = padded_content_size; + scroll_handle_state.max_offset = scroll_max; + scroll_handle_state.bounds = bounds; } *scroll_offset @@ -1729,6 +1754,10 @@ impl Interactivity { return ((), element_state); } + if let Some(focus_handle) = &self.tracked_focus_handle { + window.next_frame.tab_handles.insert(focus_handle); + } + window.with_element_opacity(style.opacity, |window| { style.paint(bounds, window, cx, |window: &mut Window, cx: &mut App| { window.with_text_style(style.text_style().cloned(), |window| { @@ -1920,6 +1949,12 @@ impl Interactivity { window: &mut Window, cx: &mut App, ) { + let is_focused = self + .tracked_focus_handle + .as_ref() + .map(|handle| handle.is_focused(window)) + .unwrap_or(false); + // If this element can be focused, register a mouse down listener // that will automatically transfer focus when hitting the element. // This behavior can be suppressed by using `cx.prevent_default()`. @@ -2083,6 +2118,39 @@ impl Interactivity { } }); + if is_focused { + // Press enter, space to trigger click, when the element is focused. + window.on_key_event({ + let click_listeners = click_listeners.clone(); + let hitbox = hitbox.clone(); + move |event: &KeyUpEvent, phase, window, cx| { + if phase.bubble() && !window.default_prevented() { + let stroke = &event.keystroke; + let keyboard_button = if stroke.key.eq("enter") { + Some(KeyboardButton::Enter) + } else if stroke.key.eq("space") { + Some(KeyboardButton::Space) + } else { + None + }; + + if let Some(button) = keyboard_button + && !stroke.modifiers.modified() + { + let click_event = ClickEvent::Keyboard(KeyboardClickEvent { + button, + bounds: hitbox.bounds, + }); + + for listener in &click_listeners { + listener(&click_event, window, cx); + } + } + } + } + }); + } + window.on_mouse_event({ let mut captured_mouse_down = None; let hitbox = hitbox.clone(); @@ -2108,10 +2176,10 @@ impl Interactivity { // Fire click handlers during the bubble phase. DispatchPhase::Bubble => { if let Some(mouse_down) = captured_mouse_down.take() { - let mouse_click = ClickEvent { + let mouse_click = ClickEvent::Mouse(MouseClickEvent { down: mouse_down, up: event.clone(), - }; + }); for listener in &click_listeners { listener(&mouse_click, window, cx); } @@ -2919,7 +2987,7 @@ impl ScrollAnchor { struct ScrollHandleState { offset: Rc<RefCell<Point<Pixels>>>, bounds: Bounds<Pixels>, - padded_content_size: Size<Pixels>, + max_offset: Size<Pixels>, child_bounds: Vec<Bounds<Pixels>>, scroll_to_bottom: bool, overflow: Point<Overflow>, @@ -2948,6 +3016,11 @@ impl ScrollHandle { *self.0.borrow().offset.borrow() } + /// Get the maximum scroll offset. + pub fn max_offset(&self) -> Size<Pixels> { + self.0.borrow().max_offset + } + /// Get the top child that's scrolled into view. pub fn top_item(&self) -> usize { let state = self.0.borrow(); @@ -2972,21 +3045,11 @@ impl ScrollHandle { self.0.borrow().bounds } - /// Set the bounds into which this child is painted - pub(super) fn set_bounds(&self, bounds: Bounds<Pixels>) { - self.0.borrow_mut().bounds = bounds; - } - /// Get the bounds for a specific child. pub fn bounds_for_item(&self, ix: usize) -> Option<Bounds<Pixels>> { self.0.borrow().child_bounds.get(ix).cloned() } - /// Get the size of the content with padding of the container. - pub fn padded_content_size(&self) -> Size<Pixels> { - self.0.borrow().padded_content_size - } - /// scroll_to_item scrolls the minimal amount to ensure that the child is /// fully visible pub fn scroll_to_item(&self, ix: usize) { diff --git a/crates/gpui/src/elements/list.rs b/crates/gpui/src/elements/list.rs index 35a3b622b2e53028218ce0c42ab0a5ad7f1a4ec3..39f38bdc69d6a5d4c9ce8c7c349707e906124cca 100644 --- a/crates/gpui/src/elements/list.rs +++ b/crates/gpui/src/elements/list.rs @@ -16,12 +16,18 @@ use crate::{ use collections::VecDeque; use refineable::Refineable as _; use std::{cell::RefCell, ops::Range, rc::Rc}; -use sum_tree::{Bias, SumTree}; +use sum_tree::{Bias, Dimensions, SumTree}; + +type RenderItemFn = dyn FnMut(usize, &mut Window, &mut App) -> AnyElement + 'static; /// Construct a new list element -pub fn list(state: ListState) -> List { +pub fn list( + state: ListState, + render_item: impl FnMut(usize, &mut Window, &mut App) -> AnyElement + 'static, +) -> List { List { state, + render_item: Box::new(render_item), style: StyleRefinement::default(), sizing_behavior: ListSizingBehavior::default(), } @@ -30,6 +36,7 @@ pub fn list(state: ListState) -> List { /// A list element pub struct List { state: ListState, + render_item: Box<RenderItemFn>, style: StyleRefinement, sizing_behavior: ListSizingBehavior, } @@ -55,7 +62,6 @@ impl std::fmt::Debug for ListState { struct StateInner { last_layout_bounds: Option<Bounds<Pixels>>, last_padding: Option<Edges<Pixels>>, - render_item: Box<dyn FnMut(usize, &mut Window, &mut App) -> AnyElement>, items: SumTree<ListItem>, logical_scroll_top: Option<ListOffset>, alignment: ListAlignment, @@ -186,19 +192,10 @@ impl ListState { /// above and below the visible area. Elements within this area will /// be measured even though they are not visible. This can help ensure /// that the list doesn't flicker or pop in when scrolling. - pub fn new<R>( - item_count: usize, - alignment: ListAlignment, - overdraw: Pixels, - render_item: R, - ) -> Self - where - R: 'static + FnMut(usize, &mut Window, &mut App) -> AnyElement, - { + pub fn new(item_count: usize, alignment: ListAlignment, overdraw: Pixels) -> Self { let this = Self(Rc::new(RefCell::new(StateInner { last_layout_bounds: None, last_padding: None, - render_item: Box::new(render_item), items: SumTree::default(), logical_scroll_top: None, alignment, @@ -249,8 +246,8 @@ impl ListState { let state = &mut *self.0.borrow_mut(); let mut old_items = state.items.cursor::<Count>(&()); - let mut new_items = old_items.slice(&Count(old_range.start), Bias::Right, &()); - old_items.seek_forward(&Count(old_range.end), Bias::Right, &()); + let mut new_items = old_items.slice(&Count(old_range.start), Bias::Right); + old_items.seek_forward(&Count(old_range.end), Bias::Right); let mut spliced_count = 0; new_items.extend( @@ -260,7 +257,7 @@ impl ListState { }), &(), ); - new_items.append(old_items.suffix(&()), &()); + new_items.append(old_items.suffix(), &()); drop(old_items); state.items = new_items; @@ -300,14 +297,14 @@ impl ListState { let current_offset = self.logical_scroll_top(); let state = &mut *self.0.borrow_mut(); let mut cursor = state.items.cursor::<ListItemSummary>(&()); - cursor.seek(&Count(current_offset.item_ix), Bias::Right, &()); + cursor.seek(&Count(current_offset.item_ix), Bias::Right); let start_pixel_offset = cursor.start().height + current_offset.offset_in_item; let new_pixel_offset = (start_pixel_offset + distance).max(px(0.)); if new_pixel_offset > start_pixel_offset { - cursor.seek_forward(&Height(new_pixel_offset), Bias::Right, &()); + cursor.seek_forward(&Height(new_pixel_offset), Bias::Right); } else { - cursor.seek(&Height(new_pixel_offset), Bias::Right, &()); + cursor.seek(&Height(new_pixel_offset), Bias::Right); } state.logical_scroll_top = Some(ListOffset { @@ -343,11 +340,11 @@ impl ListState { scroll_top.offset_in_item = px(0.); } else { let mut cursor = state.items.cursor::<ListItemSummary>(&()); - cursor.seek(&Count(ix + 1), Bias::Right, &()); + cursor.seek(&Count(ix + 1), Bias::Right); let bottom = cursor.start().height + padding.top; let goal_top = px(0.).max(bottom - height + padding.bottom); - cursor.seek(&Height(goal_top), Bias::Left, &()); + cursor.seek(&Height(goal_top), Bias::Left); let start_ix = cursor.start().count; let start_item_top = cursor.start().height; @@ -371,14 +368,14 @@ impl ListState { return None; } - let mut cursor = state.items.cursor::<(Count, Height)>(&()); - cursor.seek(&Count(scroll_top.item_ix), Bias::Right, &()); + let mut cursor = state.items.cursor::<Dimensions<Count, Height>>(&()); + cursor.seek(&Count(scroll_top.item_ix), Bias::Right); let scroll_top = cursor.start().1.0 + scroll_top.offset_in_item; - cursor.seek_forward(&Count(ix), Bias::Right, &()); + cursor.seek_forward(&Count(ix), Bias::Right); if let Some(&ListItem::Measured { size, .. }) = cursor.item() { - let &(Count(count), Height(top)) = cursor.start(); + let &Dimensions(Count(count), Height(top), _) = cursor.start(); if count == ix { let top = bounds.top() + top - scroll_top; return Some(Bounds::from_corners( @@ -411,9 +408,9 @@ impl ListState { self.0.borrow_mut().set_offset_from_scrollbar(point); } - /// Returns the size of items we have measured. + /// Returns the maximum scroll offset according to the items we have measured. /// This value remains constant while dragging to prevent the scrollbar from moving away unexpectedly. - pub fn content_size_for_scrollbar(&self) -> Size<Pixels> { + pub fn max_offset_for_scrollbar(&self) -> Size<Pixels> { let state = self.0.borrow(); let bounds = state.last_layout_bounds.unwrap_or_default(); @@ -421,7 +418,7 @@ impl ListState { .scrollbar_drag_start_height .unwrap_or_else(|| state.items.summary().height); - Size::new(bounds.size.width, height) + Size::new(Pixels::ZERO, Pixels::ZERO.max(height - bounds.size.height)) } /// Returns the current scroll offset adjusted for the scrollbar @@ -431,7 +428,7 @@ impl ListState { let mut cursor = state.items.cursor::<ListItemSummary>(&()); let summary: ListItemSummary = - cursor.summary(&Count(logical_scroll_top.item_ix), Bias::Right, &()); + cursor.summary(&Count(logical_scroll_top.item_ix), Bias::Right); let content_height = state.items.summary().height; let drag_offset = // if dragging the scrollbar, we want to offset the point if the height changed @@ -450,9 +447,9 @@ impl ListState { impl StateInner { fn visible_range(&self, height: Pixels, scroll_top: &ListOffset) -> Range<usize> { let mut cursor = self.items.cursor::<ListItemSummary>(&()); - cursor.seek(&Count(scroll_top.item_ix), Bias::Right, &()); + cursor.seek(&Count(scroll_top.item_ix), Bias::Right); let start_y = cursor.start().height + scroll_top.offset_in_item; - cursor.seek_forward(&Height(start_y + height), Bias::Left, &()); + cursor.seek_forward(&Height(start_y + height), Bias::Left); scroll_top.item_ix..cursor.start().count + 1 } @@ -482,7 +479,7 @@ impl StateInner { self.logical_scroll_top = None; } else { let mut cursor = self.items.cursor::<ListItemSummary>(&()); - cursor.seek(&Height(new_scroll_top), Bias::Right, &()); + cursor.seek(&Height(new_scroll_top), Bias::Right); let item_ix = cursor.start().count; let offset_in_item = new_scroll_top - cursor.start().height; self.logical_scroll_top = Some(ListOffset { @@ -523,7 +520,7 @@ impl StateInner { fn scroll_top(&self, logical_scroll_top: &ListOffset) -> Pixels { let mut cursor = self.items.cursor::<ListItemSummary>(&()); - cursor.seek(&Count(logical_scroll_top.item_ix), Bias::Right, &()); + cursor.seek(&Count(logical_scroll_top.item_ix), Bias::Right); cursor.start().height + logical_scroll_top.offset_in_item } @@ -532,6 +529,7 @@ impl StateInner { available_width: Option<Pixels>, available_height: Pixels, padding: &Edges<Pixels>, + render_item: &mut RenderItemFn, window: &mut Window, cx: &mut App, ) -> LayoutItemsResponse { @@ -553,7 +551,7 @@ impl StateInner { let mut cursor = old_items.cursor::<Count>(&()); // Render items after the scroll top, including those in the trailing overdraw - cursor.seek(&Count(scroll_top.item_ix), Bias::Right, &()); + cursor.seek(&Count(scroll_top.item_ix), Bias::Right); for (ix, item) in cursor.by_ref().enumerate() { let visible_height = rendered_height - scroll_top.offset_in_item; if visible_height >= available_height + self.overdraw { @@ -566,7 +564,7 @@ impl StateInner { // If we're within the visible area or the height wasn't cached, render and measure the item's element if visible_height < available_height || size.is_none() { let item_index = scroll_top.item_ix + ix; - let mut element = (self.render_item)(item_index, window, cx); + let mut element = render_item(item_index, window, cx); let element_size = element.layout_as_root(available_item_space, window, cx); size = Some(element_size); if visible_height < available_height { @@ -592,16 +590,16 @@ impl StateInner { rendered_height += padding.bottom; // Prepare to start walking upward from the item at the scroll top. - cursor.seek(&Count(scroll_top.item_ix), Bias::Right, &()); + cursor.seek(&Count(scroll_top.item_ix), Bias::Right); // If the rendered items do not fill the visible region, then adjust // the scroll top upward. if rendered_height - scroll_top.offset_in_item < available_height { while rendered_height < available_height { - cursor.prev(&()); + cursor.prev(); if let Some(item) = cursor.item() { let item_index = cursor.start().0; - let mut element = (self.render_item)(item_index, window, cx); + let mut element = render_item(item_index, window, cx); let element_size = element.layout_as_root(available_item_space, window, cx); let focus_handle = item.focus_handle(); rendered_height += element_size.height; @@ -645,12 +643,12 @@ impl StateInner { // Measure items in the leading overdraw let mut leading_overdraw = scroll_top.offset_in_item; while leading_overdraw < self.overdraw { - cursor.prev(&()); + cursor.prev(); if let Some(item) = cursor.item() { let size = if let ListItem::Measured { size, .. } = item { *size } else { - let mut element = (self.render_item)(cursor.start().0, window, cx); + let mut element = render_item(cursor.start().0, window, cx); element.layout_as_root(available_item_space, window, cx) }; @@ -666,10 +664,10 @@ impl StateInner { let measured_range = cursor.start().0..(cursor.start().0 + measured_items.len()); let mut cursor = old_items.cursor::<Count>(&()); - let mut new_items = cursor.slice(&Count(measured_range.start), Bias::Right, &()); + let mut new_items = cursor.slice(&Count(measured_range.start), Bias::Right); new_items.extend(measured_items, &()); - cursor.seek(&Count(measured_range.end), Bias::Right, &()); - new_items.append(cursor.suffix(&()), &()); + cursor.seek(&Count(measured_range.end), Bias::Right); + new_items.append(cursor.suffix(), &()); self.items = new_items; // If none of the visible items are focused, check if an off-screen item is focused @@ -679,11 +677,11 @@ impl StateInner { let mut cursor = self .items .filter::<_, Count>(&(), |summary| summary.has_focus_handles); - cursor.next(&()); + cursor.next(); while let Some(item) = cursor.item() { if item.contains_focused(window, cx) { let item_index = cursor.start().0; - let mut element = (self.render_item)(cursor.start().0, window, cx); + let mut element = render_item(cursor.start().0, window, cx); let size = element.layout_as_root(available_item_space, window, cx); item_layouts.push_back(ItemLayout { index: item_index, @@ -692,7 +690,7 @@ impl StateInner { }); break; } - cursor.next(&()); + cursor.next(); } } @@ -708,6 +706,7 @@ impl StateInner { bounds: Bounds<Pixels>, padding: Edges<Pixels>, autoscroll: bool, + render_item: &mut RenderItemFn, window: &mut Window, cx: &mut App, ) -> Result<LayoutItemsResponse, ListOffset> { @@ -716,6 +715,7 @@ impl StateInner { Some(bounds.size.width), bounds.size.height, &padding, + render_item, window, cx, ); @@ -741,7 +741,7 @@ impl StateInner { }); } else if autoscroll_bounds.bottom() > bounds.bottom() { let mut cursor = self.items.cursor::<Count>(&()); - cursor.seek(&Count(item.index), Bias::Right, &()); + cursor.seek(&Count(item.index), Bias::Right); let mut height = bounds.size.height - padding.top - padding.bottom; // Account for the height of the element down until the autoscroll bottom. @@ -749,12 +749,11 @@ impl StateInner { // Keep decreasing the scroll top until we fill all the available space. while height > Pixels::ZERO { - cursor.prev(&()); + cursor.prev(); let Some(item) = cursor.item() else { break }; let size = item.size().unwrap_or_else(|| { - let mut item = - (self.render_item)(cursor.start().0, window, cx); + let mut item = render_item(cursor.start().0, window, cx); let item_available_size = size( bounds.size.width.into(), AvailableSpace::MinContent, @@ -806,7 +805,7 @@ impl StateInner { self.logical_scroll_top = None; } else { let mut cursor = self.items.cursor::<ListItemSummary>(&()); - cursor.seek(&Height(new_scroll_top), Bias::Right, &()); + cursor.seek(&Height(new_scroll_top), Bias::Right); let item_ix = cursor.start().count; let offset_in_item = new_scroll_top - cursor.start().height; @@ -876,8 +875,14 @@ impl Element for List { window.rem_size(), ); - let layout_response = - state.layout_items(None, available_height, &padding, window, cx); + let layout_response = state.layout_items( + None, + available_height, + &padding, + &mut self.render_item, + window, + cx, + ); let max_element_width = layout_response.max_item_width; let summary = state.items.summary(); @@ -951,15 +956,16 @@ impl Element for List { let padding = style .padding .to_pixels(bounds.size.into(), window.rem_size()); - let layout = match state.prepaint_items(bounds, padding, true, window, cx) { - Ok(layout) => layout, - Err(autoscroll_request) => { - state.logical_scroll_top = Some(autoscroll_request); - state - .prepaint_items(bounds, padding, false, window, cx) - .unwrap() - } - }; + let layout = + match state.prepaint_items(bounds, padding, true, &mut self.render_item, window, cx) { + Ok(layout) => layout, + Err(autoscroll_request) => { + state.logical_scroll_top = Some(autoscroll_request); + state + .prepaint_items(bounds, padding, false, &mut self.render_item, window, cx) + .unwrap() + } + }; state.last_layout_bounds = Some(bounds); state.last_padding = Some(padding); @@ -1108,9 +1114,7 @@ mod test { let cx = cx.add_empty_window(); - let state = ListState::new(5, crate::ListAlignment::Top, px(10.), |_, _, _| { - div().h(px(10.)).w_full().into_any() - }); + let state = ListState::new(5, crate::ListAlignment::Top, px(10.)); // Ensure that the list is scrolled to the top state.scroll_to(gpui::ListOffset { @@ -1121,7 +1125,11 @@ mod test { struct TestView(ListState); impl Render for TestView { fn render(&mut self, _: &mut Window, _: &mut Context<Self>) -> impl IntoElement { - list(self.0.clone()).w_full().h_full() + list(self.0.clone(), |_, _, _| { + div().h(px(10.)).w_full().into_any() + }) + .w_full() + .h_full() } } @@ -1154,14 +1162,16 @@ mod test { let cx = cx.add_empty_window(); - let state = ListState::new(5, crate::ListAlignment::Top, px(10.), |_, _, _| { - div().h(px(20.)).w_full().into_any() - }); + let state = ListState::new(5, crate::ListAlignment::Top, px(10.)); struct TestView(ListState); impl Render for TestView { fn render(&mut self, _: &mut Window, _: &mut Context<Self>) -> impl IntoElement { - list(self.0.clone()).w_full().h_full() + list(self.0.clone(), |_, _, _| { + div().h(px(20.)).w_full().into_any() + }) + .w_full() + .h_full() } } diff --git a/crates/gpui/src/elements/uniform_list.rs b/crates/gpui/src/elements/uniform_list.rs index 52e2015c20f9983e78c126cc920ed115eef0fd7a..cdf90d4eb8934de99a21c65b6c9efa2a2fdde258 100644 --- a/crates/gpui/src/elements/uniform_list.rs +++ b/crates/gpui/src/elements/uniform_list.rs @@ -88,15 +88,24 @@ pub enum ScrollStrategy { /// May not be possible if there's not enough list items above the item scrolled to: /// in this case, the element will be placed at the closest possible position. Center, - /// Scrolls the element to be at the given item index from the top of the viewport. - ToPosition(usize), +} + +#[derive(Clone, Copy, Debug)] +#[allow(missing_docs)] +pub struct DeferredScrollToItem { + /// The item index to scroll to + pub item_index: usize, + /// The scroll strategy to use + pub strategy: ScrollStrategy, + /// The offset in number of items + pub offset: usize, } #[derive(Clone, Debug, Default)] #[allow(missing_docs)] pub struct UniformListScrollState { pub base_handle: ScrollHandle, - pub deferred_scroll_to_item: Option<(usize, ScrollStrategy)>, + pub deferred_scroll_to_item: Option<DeferredScrollToItem>, /// Size of the item, captured during last layout. pub last_item_size: Option<ItemSize>, /// Whether the list was vertically flipped during last layout. @@ -126,7 +135,24 @@ impl UniformListScrollHandle { /// Scroll the list to the given item index. pub fn scroll_to_item(&self, ix: usize, strategy: ScrollStrategy) { - self.0.borrow_mut().deferred_scroll_to_item = Some((ix, strategy)); + self.0.borrow_mut().deferred_scroll_to_item = Some(DeferredScrollToItem { + item_index: ix, + strategy, + offset: 0, + }); + } + + /// Scroll the list to the given item index with an offset. + /// + /// For ScrollStrategy::Top, the item will be placed at the offset position from the top. + /// + /// For ScrollStrategy::Center, the item will be centered between offset and the last visible item. + pub fn scroll_to_item_with_offset(&self, ix: usize, strategy: ScrollStrategy, offset: usize) { + self.0.borrow_mut().deferred_scroll_to_item = Some(DeferredScrollToItem { + item_index: ix, + strategy, + offset, + }); } /// Check if the list is flipped vertically. @@ -139,7 +165,8 @@ impl UniformListScrollHandle { pub fn logical_scroll_top_index(&self) -> usize { let this = self.0.borrow(); this.deferred_scroll_to_item - .map(|(ix, _)| ix) + .as_ref() + .map(|deferred| deferred.item_index) .unwrap_or_else(|| this.base_handle.logical_scroll_top().0) } @@ -295,9 +322,8 @@ impl Element for UniformList { bounds.bottom_right() - point(border.right + padding.right, border.bottom), ); - let y_flipped = if let Some(scroll_handle) = self.scroll_handle.as_mut() { - let mut scroll_state = scroll_handle.0.borrow_mut(); - scroll_state.base_handle.set_bounds(bounds); + let y_flipped = if let Some(scroll_handle) = &self.scroll_handle { + let scroll_state = scroll_handle.0.borrow(); scroll_state.y_flipped } else { false @@ -321,7 +347,8 @@ impl Element for UniformList { scroll_offset.x = Pixels::ZERO; } - if let Some((mut ix, scroll_strategy)) = shared_scroll_to_item { + if let Some(deferred_scroll) = shared_scroll_to_item { + let mut ix = deferred_scroll.item_index; if y_flipped { ix = self.item_count.saturating_sub(ix + 1); } @@ -330,23 +357,28 @@ impl Element for UniformList { let item_top = item_height * ix + padding.top; let item_bottom = item_top + item_height; let scroll_top = -updated_scroll_offset.y; + let offset_pixels = item_height * deferred_scroll.offset; let mut scrolled_to_top = false; - if item_top < scroll_top + padding.top { + + if item_top < scroll_top + padding.top + offset_pixels { scrolled_to_top = true; - updated_scroll_offset.y = -(item_top) + padding.top; + updated_scroll_offset.y = -(item_top) + padding.top + offset_pixels; } else if item_bottom > scroll_top + list_height - padding.bottom { scrolled_to_top = true; updated_scroll_offset.y = -(item_bottom - list_height) - padding.bottom; } - match scroll_strategy { + match deferred_scroll.strategy { ScrollStrategy::Top => {} ScrollStrategy::Center => { if scrolled_to_top { let item_center = item_top + item_height / 2.0; - let target_scroll_top = item_center - list_height / 2.0; - if item_top < scroll_top + let viewport_height = list_height - offset_pixels; + let viewport_center = offset_pixels + viewport_height / 2.0; + let target_scroll_top = item_center - viewport_center; + + if item_top < scroll_top + offset_pixels || item_bottom > scroll_top + list_height { updated_scroll_offset.y = -target_scroll_top @@ -356,15 +388,6 @@ impl Element for UniformList { } } } - ScrollStrategy::ToPosition(sticky_index) => { - let target_y_in_viewport = item_height * sticky_index; - let target_scroll_top = item_top - target_y_in_viewport; - let max_scroll_top = - (content_height - list_height).max(Pixels::ZERO); - let new_scroll_top = - target_scroll_top.clamp(Pixels::ZERO, max_scroll_top); - updated_scroll_offset.y = -new_scroll_top; - } } scroll_offset = *updated_scroll_offset } diff --git a/crates/gpui/src/geometry.rs b/crates/gpui/src/geometry.rs index 74be6344f92a2c478318641be5a78eb7bacfe28e..3d2d9cd9db693f8aa72f314da324ab2cb8e98789 100644 --- a/crates/gpui/src/geometry.rs +++ b/crates/gpui/src/geometry.rs @@ -3522,7 +3522,7 @@ impl Serialize for Length { /// # Returns /// /// A `DefiniteLength` representing the relative length as a fraction of the parent's size. -pub fn relative(fraction: f32) -> DefiniteLength { +pub const fn relative(fraction: f32) -> DefiniteLength { DefiniteLength::Fraction(fraction) } diff --git a/crates/gpui/src/gpui.rs b/crates/gpui/src/gpui.rs index 91461a4d2c8f1bbf1504a36429064a038bedec21..09799eb910f0eeece17fd9975c3c13f6accd2df6 100644 --- a/crates/gpui/src/gpui.rs +++ b/crates/gpui/src/gpui.rs @@ -95,6 +95,7 @@ mod style; mod styled; mod subscription; mod svg_renderer; +mod tab_stop; mod taffy; #[cfg(any(test, feature = "test-support"))] pub mod test; @@ -151,6 +152,7 @@ pub use style::*; pub use styled::*; pub use subscription::*; use svg_renderer::*; +pub(crate) use tab_stop::*; pub use taffy::{AvailableSpace, LayoutId}; #[cfg(any(test, feature = "test-support"))] pub use test::*; @@ -197,6 +199,11 @@ pub trait AppContext { where T: 'static; + /// Update a entity in the app context. + fn as_mut<'a, T>(&'a mut self, handle: &Entity<T>) -> Self::Result<GpuiBorrow<'a, T>> + where + T: 'static; + /// Read a entity from the app context. fn read_entity<T, R>( &self, diff --git a/crates/gpui/src/interactive.rs b/crates/gpui/src/interactive.rs index edd807da11410fa7255cd4613704a9444c197bb0..218ae5fcdfbb60b2dd99c8a656d95c3962edc98c 100644 --- a/crates/gpui/src/interactive.rs +++ b/crates/gpui/src/interactive.rs @@ -1,6 +1,6 @@ use crate::{ - Capslock, Context, Empty, IntoElement, Keystroke, Modifiers, Pixels, Point, Render, Window, - point, seal::Sealed, + Bounds, Capslock, Context, Empty, IntoElement, Keystroke, Modifiers, Pixels, Point, Render, + Window, point, seal::Sealed, }; use smallvec::SmallVec; use std::{any::Any, fmt::Debug, ops::Deref, path::PathBuf}; @@ -141,7 +141,7 @@ impl MouseEvent for MouseUpEvent {} /// A click event, generated when a mouse button is pressed and released. #[derive(Clone, Debug, Default)] -pub struct ClickEvent { +pub struct MouseClickEvent { /// The mouse event when the button was pressed. pub down: MouseDownEvent, @@ -149,18 +149,126 @@ pub struct ClickEvent { pub up: MouseUpEvent, } +/// A click event that was generated by a keyboard button being pressed and released. +#[derive(Clone, Debug, Default)] +pub struct KeyboardClickEvent { + /// The keyboard button that was pressed to trigger the click. + pub button: KeyboardButton, + + /// The bounds of the element that was clicked. + pub bounds: Bounds<Pixels>, +} + +/// A click event, generated when a mouse button or keyboard button is pressed and released. +#[derive(Clone, Debug)] +pub enum ClickEvent { + /// A click event trigger by a mouse button being pressed and released. + Mouse(MouseClickEvent), + /// A click event trigger by a keyboard button being pressed and released. + Keyboard(KeyboardClickEvent), +} + +impl Default for ClickEvent { + fn default() -> Self { + ClickEvent::Keyboard(KeyboardClickEvent::default()) + } +} + impl ClickEvent { - /// Returns the modifiers that were held down during both the - /// mouse down and mouse up events + /// Returns the modifiers that were held during the click event + /// + /// `Keyboard`: The keyboard click events never have modifiers. + /// `Mouse`: Modifiers that were held during the mouse key up event. pub fn modifiers(&self) -> Modifiers { - Modifiers { - control: self.up.modifiers.control && self.down.modifiers.control, - alt: self.up.modifiers.alt && self.down.modifiers.alt, - shift: self.up.modifiers.shift && self.down.modifiers.shift, - platform: self.up.modifiers.platform && self.down.modifiers.platform, - function: self.up.modifiers.function && self.down.modifiers.function, + match self { + // Click events are only generated from keyboard events _without any modifiers_, so we know the modifiers are always Default + ClickEvent::Keyboard(_) => Modifiers::default(), + // Click events on the web only reflect the modifiers for the keyup event, + // tested via observing the behavior of the `ClickEvent.shiftKey` field in Chrome 138 + // under various combinations of modifiers and keyUp / keyDown events. + ClickEvent::Mouse(event) => event.up.modifiers, + } + } + + /// Returns the position of the click event + /// + /// `Keyboard`: The bottom left corner of the clicked hitbox + /// `Mouse`: The position of the mouse when the button was released. + pub fn position(&self) -> Point<Pixels> { + match self { + ClickEvent::Keyboard(event) => event.bounds.bottom_left(), + ClickEvent::Mouse(event) => event.up.position, } } + + /// Returns the mouse position of the click event + /// + /// `Keyboard`: None + /// `Mouse`: The position of the mouse when the button was released. + pub fn mouse_position(&self) -> Option<Point<Pixels>> { + match self { + ClickEvent::Keyboard(_) => None, + ClickEvent::Mouse(event) => Some(event.up.position), + } + } + + /// Returns if this was a right click + /// + /// `Keyboard`: false + /// `Mouse`: Whether the right button was pressed and released + pub fn is_right_click(&self) -> bool { + match self { + ClickEvent::Keyboard(_) => false, + ClickEvent::Mouse(event) => { + event.down.button == MouseButton::Right && event.up.button == MouseButton::Right + } + } + } + + /// Returns whether the click was a standard click + /// + /// `Keyboard`: Always true + /// `Mouse`: Left button pressed and released + pub fn standard_click(&self) -> bool { + match self { + ClickEvent::Keyboard(_) => true, + ClickEvent::Mouse(event) => { + event.down.button == MouseButton::Left && event.up.button == MouseButton::Left + } + } + } + + /// Returns whether the click focused the element + /// + /// `Keyboard`: false, keyboard clicks only work if an element is already focused + /// `Mouse`: Whether this was the first focusing click + pub fn first_focus(&self) -> bool { + match self { + ClickEvent::Keyboard(_) => false, + ClickEvent::Mouse(event) => event.down.first_mouse, + } + } + + /// Returns the click count of the click event + /// + /// `Keyboard`: Always 1 + /// `Mouse`: Count of clicks from MouseUpEvent + pub fn click_count(&self) -> usize { + match self { + ClickEvent::Keyboard(_) => 1, + ClickEvent::Mouse(event) => event.up.click_count, + } + } +} + +/// An enum representing the keyboard button that was pressed for a click event. +#[derive(Hash, PartialEq, Eq, Copy, Clone, Debug, Default)] +pub enum KeyboardButton { + /// Enter key was clicked + #[default] + Enter, + /// Space key was clicked + Space, } /// An enum representing the mouse button that was pressed. diff --git a/crates/gpui/src/key_dispatch.rs b/crates/gpui/src/key_dispatch.rs index a290a132c3b5f9fa42e338c28b86de7ded5b10ac..cc6ebb9b08db114c1aabbb2a58296fc3aa8a9949 100644 --- a/crates/gpui/src/key_dispatch.rs +++ b/crates/gpui/src/key_dispatch.rs @@ -50,8 +50,8 @@ /// KeyBinding::new("cmd-k left", pane::SplitLeft, Some("Pane")) /// use crate::{ - Action, ActionRegistry, App, BindingIndex, DispatchPhase, EntityId, FocusId, KeyBinding, - KeyContext, Keymap, Keystroke, ModifiersChangedEvent, Window, + Action, ActionRegistry, App, DispatchPhase, EntityId, FocusId, KeyBinding, KeyContext, Keymap, + Keystroke, ModifiersChangedEvent, Window, }; use collections::FxHashMap; use smallvec::SmallVec; @@ -406,16 +406,11 @@ impl DispatchTree { // methods, but this can't be done very cleanly since keymap must be borrowed. let keymap = self.keymap.borrow(); keymap - .bindings_for_action_with_indices(action) - .filter(|(binding_index, binding)| { - Self::binding_matches_predicate_and_not_shadowed( - &keymap, - *binding_index, - &binding.keystrokes, - context_stack, - ) + .bindings_for_action(action) + .filter(|binding| { + Self::binding_matches_predicate_and_not_shadowed(&keymap, &binding, context_stack) }) - .map(|(_, binding)| binding.clone()) + .cloned() .collect() } @@ -428,28 +423,22 @@ impl DispatchTree { ) -> Option<KeyBinding> { let keymap = self.keymap.borrow(); keymap - .bindings_for_action_with_indices(action) + .bindings_for_action(action) .rev() - .find_map(|(binding_index, binding)| { - let found = Self::binding_matches_predicate_and_not_shadowed( - &keymap, - binding_index, - &binding.keystrokes, - context_stack, - ); - if found { Some(binding.clone()) } else { None } + .find(|binding| { + Self::binding_matches_predicate_and_not_shadowed(&keymap, &binding, context_stack) }) + .cloned() } fn binding_matches_predicate_and_not_shadowed( keymap: &Keymap, - binding_index: BindingIndex, - keystrokes: &[Keystroke], + binding: &KeyBinding, context_stack: &[KeyContext], ) -> bool { - let (bindings, _) = keymap.bindings_for_input_with_indices(&keystrokes, context_stack); - if let Some((highest_precedence_index, _)) = bindings.iter().next() { - binding_index == *highest_precedence_index + let (bindings, _) = keymap.bindings_for_input(&binding.keystrokes, context_stack); + if let Some(found) = bindings.iter().next() { + found.action.partial_eq(binding.action.as_ref()) } else { false } diff --git a/crates/gpui/src/keymap.rs b/crates/gpui/src/keymap.rs index b5dbab15c77a0cfff96885e5835f602197e408e6..83d7479a04423d249a2be69c69756211eb9d485d 100644 --- a/crates/gpui/src/keymap.rs +++ b/crates/gpui/src/keymap.rs @@ -5,7 +5,7 @@ pub use binding::*; pub use context::*; use crate::{Action, Keystroke, is_no_action}; -use collections::HashMap; +use collections::{HashMap, HashSet}; use smallvec::SmallVec; use std::any::TypeId; @@ -24,7 +24,7 @@ pub struct Keymap { } /// Index of a binding within a keymap. -#[derive(Copy, Clone, Debug, Eq, PartialEq)] +#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] pub struct BindingIndex(usize); impl Keymap { @@ -77,15 +77,6 @@ impl Keymap { &'a self, action: &'a dyn Action, ) -> impl 'a + DoubleEndedIterator<Item = &'a KeyBinding> { - self.bindings_for_action_with_indices(action) - .map(|(_, binding)| binding) - } - - /// Like `bindings_for_action_with_indices`, but also returns the binding indices. - pub fn bindings_for_action_with_indices<'a>( - &'a self, - action: &'a dyn Action, - ) -> impl 'a + DoubleEndedIterator<Item = (BindingIndex, &'a KeyBinding)> { let action_id = action.type_id(); let binding_indices = self .binding_indices_by_action_id @@ -118,7 +109,7 @@ impl Keymap { } } - Some((BindingIndex(*ix), binding)) + Some(binding) }) } @@ -153,107 +144,63 @@ impl Keymap { input: &[Keystroke], context_stack: &[KeyContext], ) -> (SmallVec<[KeyBinding; 1]>, bool) { - let (bindings, pending) = self.bindings_for_input_with_indices(input, context_stack); - let bindings = bindings - .into_iter() - .map(|(_, binding)| binding) - .collect::<SmallVec<[KeyBinding; 1]>>(); - (bindings, pending) - } + let mut matched_bindings = SmallVec::<[(usize, BindingIndex, &KeyBinding); 1]>::new(); + let mut pending_bindings = SmallVec::<[(BindingIndex, &KeyBinding); 1]>::new(); + + for (ix, binding) in self.bindings().enumerate().rev() { + let Some(depth) = self.binding_enabled(binding, &context_stack) else { + continue; + }; + let Some(pending) = binding.match_keystrokes(input) else { + continue; + }; + + if !pending { + matched_bindings.push((depth, BindingIndex(ix), binding)); + } else { + pending_bindings.push((BindingIndex(ix), binding)); + } + } - /// Like `bindings_for_input`, but also returns the binding indices. - pub fn bindings_for_input_with_indices( - &self, - input: &[Keystroke], - context_stack: &[KeyContext], - ) -> (SmallVec<[(BindingIndex, KeyBinding); 1]>, bool) { - let possibilities = self - .bindings() - .enumerate() - .rev() - .filter_map(|(ix, binding)| { - binding - .match_keystrokes(input) - .map(|pending| (BindingIndex(ix), binding, pending)) - }); - - let mut bindings: SmallVec<[(BindingIndex, KeyBinding, usize); 1]> = SmallVec::new(); - - // (pending, is_no_action, depth, keystrokes) - let mut pending_info_opt: Option<(bool, bool, usize, &[Keystroke])> = None; - - 'outer: for (binding_index, binding, pending) in possibilities { - for depth in (0..=context_stack.len()).rev() { - if self.binding_enabled(binding, &context_stack[0..depth]) { - let is_no_action = is_no_action(&*binding.action); - // We only want to consider a binding pending if it has an action - // This, however, means that if we have both a NoAction binding and a binding - // with an action at the same depth, we should still set is_pending to true. - if let Some(pending_info) = pending_info_opt.as_mut() { - let ( - already_pending, - pending_is_no_action, - pending_depth, - pending_keystrokes, - ) = *pending_info; - - // We only want to change the pending status if it's not already pending AND if - // the existing pending status was set by a NoAction binding. This avoids a NoAction - // binding erroneously setting the pending status to true when a binding with an action - // already set it to false - // - // We also want to change the pending status if the keystrokes don't match, - // meaning it's different keystrokes than the NoAction that set pending to false - if pending - && !already_pending - && pending_is_no_action - && (pending_depth == depth - || pending_keystrokes != binding.keystrokes()) - { - pending_info.0 = !is_no_action; - } - } else { - pending_info_opt = Some(( - pending && !is_no_action, - is_no_action, - depth, - binding.keystrokes(), - )); - } + matched_bindings.sort_by(|(depth_a, ix_a, _), (depth_b, ix_b, _)| { + depth_b.cmp(depth_a).then(ix_b.cmp(ix_a)) + }); - if !pending { - bindings.push((binding_index, binding.clone(), depth)); - continue 'outer; - } - } + let mut bindings: SmallVec<[_; 1]> = SmallVec::new(); + let mut first_binding_index = None; + for (_, ix, binding) in matched_bindings { + if is_no_action(&*binding.action) { + break; } + bindings.push(binding.clone()); + first_binding_index.get_or_insert(ix); } - // sort by descending depth - bindings.sort_by(|a, b| a.2.cmp(&b.2).reverse()); - let bindings = bindings - .into_iter() - .map_while(|(binding_index, binding, _)| { - if is_no_action(&*binding.action) { - None - } else { - Some((binding_index, binding)) - } - }) - .collect(); - (bindings, pending_info_opt.unwrap_or_default().0) + let mut pending = HashSet::default(); + for (ix, binding) in pending_bindings.into_iter().rev() { + if let Some(binding_ix) = first_binding_index + && binding_ix > ix + { + continue; + } + if is_no_action(&*binding.action) { + pending.remove(&&binding.keystrokes); + continue; + } + pending.insert(&binding.keystrokes); + } + + (bindings, !pending.is_empty()) } /// Check if the given binding is enabled, given a certain key context. - fn binding_enabled(&self, binding: &KeyBinding, context: &[KeyContext]) -> bool { - // If binding has a context predicate, it must match the current context, + /// Returns the deepest depth at which the binding matches, or None if it doesn't match. + fn binding_enabled(&self, binding: &KeyBinding, contexts: &[KeyContext]) -> Option<usize> { if let Some(predicate) = &binding.context_predicate { - if !predicate.eval(context) { - return false; - } + predicate.depth_of(contexts) + } else { + Some(contexts.len()) } - - true } } @@ -280,18 +227,57 @@ mod tests { keymap.add_bindings(bindings.clone()); // global bindings are enabled in all contexts - assert!(keymap.binding_enabled(&bindings[0], &[])); - assert!(keymap.binding_enabled(&bindings[0], &[KeyContext::parse("terminal").unwrap()])); + assert_eq!(keymap.binding_enabled(&bindings[0], &[]), Some(0)); + assert_eq!( + keymap.binding_enabled(&bindings[0], &[KeyContext::parse("terminal").unwrap()]), + Some(1) + ); // contextual bindings are enabled in contexts that match their predicate - assert!(!keymap.binding_enabled(&bindings[1], &[KeyContext::parse("barf x=y").unwrap()])); - assert!(keymap.binding_enabled(&bindings[1], &[KeyContext::parse("pane x=y").unwrap()])); - - assert!(!keymap.binding_enabled(&bindings[2], &[KeyContext::parse("editor").unwrap()])); - assert!(keymap.binding_enabled( - &bindings[2], - &[KeyContext::parse("editor mode=full").unwrap()] - )); + assert_eq!( + keymap.binding_enabled(&bindings[1], &[KeyContext::parse("barf x=y").unwrap()]), + None + ); + assert_eq!( + keymap.binding_enabled(&bindings[1], &[KeyContext::parse("pane x=y").unwrap()]), + Some(1) + ); + + assert_eq!( + keymap.binding_enabled(&bindings[2], &[KeyContext::parse("editor").unwrap()]), + None + ); + assert_eq!( + keymap.binding_enabled( + &bindings[2], + &[KeyContext::parse("editor mode=full").unwrap()] + ), + Some(1) + ); + } + + #[test] + fn test_depth_precedence() { + let bindings = [ + KeyBinding::new("ctrl-a", ActionBeta {}, Some("pane")), + KeyBinding::new("ctrl-a", ActionGamma {}, Some("editor")), + ]; + + let mut keymap = Keymap::default(); + keymap.add_bindings(bindings.clone()); + + let (result, pending) = keymap.bindings_for_input( + &[Keystroke::parse("ctrl-a").unwrap()], + &[ + KeyContext::parse("pane").unwrap(), + KeyContext::parse("editor").unwrap(), + ], + ); + + assert!(!pending); + assert_eq!(result.len(), 2); + assert!(result[0].action.partial_eq(&ActionGamma {})); + assert!(result[1].action.partial_eq(&ActionBeta {})); } #[test] @@ -445,6 +431,193 @@ mod tests { assert_eq!(space_editor.1, true); } + #[test] + fn test_override_multikey() { + let bindings = [ + KeyBinding::new("ctrl-w left", ActionAlpha {}, Some("editor")), + KeyBinding::new("ctrl-w", NoAction {}, Some("editor")), + ]; + + let mut keymap = Keymap::default(); + keymap.add_bindings(bindings.clone()); + + // Ensure `space` results in pending input on the workspace, but not editor + let (result, pending) = keymap.bindings_for_input( + &[Keystroke::parse("ctrl-w").unwrap()], + &[KeyContext::parse("editor").unwrap()], + ); + assert!(result.is_empty()); + assert_eq!(pending, true); + + let bindings = [ + KeyBinding::new("ctrl-w left", ActionAlpha {}, Some("editor")), + KeyBinding::new("ctrl-w", ActionBeta {}, Some("editor")), + ]; + + let mut keymap = Keymap::default(); + keymap.add_bindings(bindings.clone()); + + // Ensure `space` results in pending input on the workspace, but not editor + let (result, pending) = keymap.bindings_for_input( + &[Keystroke::parse("ctrl-w").unwrap()], + &[KeyContext::parse("editor").unwrap()], + ); + assert_eq!(result.len(), 1); + assert_eq!(pending, false); + } + + #[test] + fn test_simple_disable() { + let bindings = [ + KeyBinding::new("ctrl-x", ActionAlpha {}, Some("editor")), + KeyBinding::new("ctrl-x", NoAction {}, Some("editor")), + ]; + + let mut keymap = Keymap::default(); + keymap.add_bindings(bindings.clone()); + + // Ensure `space` results in pending input on the workspace, but not editor + let (result, pending) = keymap.bindings_for_input( + &[Keystroke::parse("ctrl-x").unwrap()], + &[KeyContext::parse("editor").unwrap()], + ); + assert!(result.is_empty()); + assert_eq!(pending, false); + } + + #[test] + fn test_fail_to_disable() { + // disabled at the wrong level + let bindings = [ + KeyBinding::new("ctrl-x", ActionAlpha {}, Some("editor")), + KeyBinding::new("ctrl-x", NoAction {}, Some("workspace")), + ]; + + let mut keymap = Keymap::default(); + keymap.add_bindings(bindings.clone()); + + // Ensure `space` results in pending input on the workspace, but not editor + let (result, pending) = keymap.bindings_for_input( + &[Keystroke::parse("ctrl-x").unwrap()], + &[ + KeyContext::parse("workspace").unwrap(), + KeyContext::parse("editor").unwrap(), + ], + ); + assert_eq!(result.len(), 1); + assert_eq!(pending, false); + } + + #[test] + fn test_disable_deeper() { + let bindings = [ + KeyBinding::new("ctrl-x", ActionAlpha {}, Some("workspace")), + KeyBinding::new("ctrl-x", NoAction {}, Some("editor")), + ]; + + let mut keymap = Keymap::default(); + keymap.add_bindings(bindings.clone()); + + // Ensure `space` results in pending input on the workspace, but not editor + let (result, pending) = keymap.bindings_for_input( + &[Keystroke::parse("ctrl-x").unwrap()], + &[ + KeyContext::parse("workspace").unwrap(), + KeyContext::parse("editor").unwrap(), + ], + ); + assert_eq!(result.len(), 0); + assert_eq!(pending, false); + } + + #[test] + fn test_pending_match_enabled() { + let bindings = [ + KeyBinding::new("ctrl-x", ActionBeta, Some("vim_mode == normal")), + KeyBinding::new("ctrl-x 0", ActionAlpha, Some("Workspace")), + ]; + let mut keymap = Keymap::default(); + keymap.add_bindings(bindings.clone()); + + let matched = keymap.bindings_for_input( + &[Keystroke::parse("ctrl-x")].map(Result::unwrap), + &[ + KeyContext::parse("Workspace"), + KeyContext::parse("Pane"), + KeyContext::parse("Editor vim_mode=normal"), + ] + .map(Result::unwrap), + ); + assert_eq!(matched.0.len(), 1); + assert!(matched.0[0].action.partial_eq(&ActionBeta)); + assert!(matched.1); + } + + #[test] + fn test_pending_match_enabled_extended() { + let bindings = [ + KeyBinding::new("ctrl-x", ActionBeta, Some("vim_mode == normal")), + KeyBinding::new("ctrl-x 0", NoAction, Some("Workspace")), + ]; + let mut keymap = Keymap::default(); + keymap.add_bindings(bindings.clone()); + + let matched = keymap.bindings_for_input( + &[Keystroke::parse("ctrl-x")].map(Result::unwrap), + &[ + KeyContext::parse("Workspace"), + KeyContext::parse("Pane"), + KeyContext::parse("Editor vim_mode=normal"), + ] + .map(Result::unwrap), + ); + assert_eq!(matched.0.len(), 1); + assert!(matched.0[0].action.partial_eq(&ActionBeta)); + assert!(!matched.1); + let bindings = [ + KeyBinding::new("ctrl-x", ActionBeta, Some("Workspace")), + KeyBinding::new("ctrl-x 0", NoAction, Some("vim_mode == normal")), + ]; + let mut keymap = Keymap::default(); + keymap.add_bindings(bindings.clone()); + + let matched = keymap.bindings_for_input( + &[Keystroke::parse("ctrl-x")].map(Result::unwrap), + &[ + KeyContext::parse("Workspace"), + KeyContext::parse("Pane"), + KeyContext::parse("Editor vim_mode=normal"), + ] + .map(Result::unwrap), + ); + assert_eq!(matched.0.len(), 1); + assert!(matched.0[0].action.partial_eq(&ActionBeta)); + assert!(!matched.1); + } + + #[test] + fn test_overriding_prefix() { + let bindings = [ + KeyBinding::new("ctrl-x 0", ActionAlpha, Some("Workspace")), + KeyBinding::new("ctrl-x", ActionBeta, Some("vim_mode == normal")), + ]; + let mut keymap = Keymap::default(); + keymap.add_bindings(bindings.clone()); + + let matched = keymap.bindings_for_input( + &[Keystroke::parse("ctrl-x")].map(Result::unwrap), + &[ + KeyContext::parse("Workspace"), + KeyContext::parse("Pane"), + KeyContext::parse("Editor vim_mode=normal"), + ] + .map(Result::unwrap), + ); + assert_eq!(matched.0.len(), 1); + assert!(matched.0[0].action.partial_eq(&ActionBeta)); + assert!(!matched.1); + } + #[test] fn test_bindings_for_action() { let bindings = [ diff --git a/crates/gpui/src/keymap/context.rs b/crates/gpui/src/keymap/context.rs index eaad06098218275ab37c9078c358cab019e90761..281035fe97614dd810f1057c8094b2c698984166 100644 --- a/crates/gpui/src/keymap/context.rs +++ b/crates/gpui/src/keymap/context.rs @@ -178,7 +178,7 @@ pub enum KeyBindingContextPredicate { NotEqual(SharedString, SharedString), /// A predicate that will match a given predicate appearing below another predicate. /// in the element tree - Child( + Descendant( Box<KeyBindingContextPredicate>, Box<KeyBindingContextPredicate>, ), @@ -203,7 +203,7 @@ impl fmt::Display for KeyBindingContextPredicate { Self::Equal(left, right) => write!(f, "{} == {}", left, right), Self::NotEqual(left, right) => write!(f, "{} != {}", left, right), Self::Not(pred) => write!(f, "!{}", pred), - Self::Child(parent, child) => write!(f, "{} > {}", parent, child), + Self::Descendant(parent, child) => write!(f, "{} > {}", parent, child), Self::And(left, right) => write!(f, "({} && {})", left, right), Self::Or(left, right) => write!(f, "({} || {})", left, right), } @@ -249,8 +249,25 @@ impl KeyBindingContextPredicate { } } + /// Find the deepest depth at which the predicate matches. + pub fn depth_of(&self, contexts: &[KeyContext]) -> Option<usize> { + for depth in (0..=contexts.len()).rev() { + let context_slice = &contexts[0..depth]; + if self.eval_inner(context_slice, contexts) { + return Some(depth); + } + } + None + } + + /// Eval a predicate against a set of contexts, arranged from lowest to highest. + #[allow(unused)] + pub(crate) fn eval(&self, contexts: &[KeyContext]) -> bool { + self.eval_inner(contexts, contexts) + } + /// Eval a predicate against a set of contexts, arranged from lowest to highest. - pub fn eval(&self, contexts: &[KeyContext]) -> bool { + pub fn eval_inner(&self, contexts: &[KeyContext], all_contexts: &[KeyContext]) -> bool { let Some(context) = contexts.last() else { return false; }; @@ -264,12 +281,38 @@ impl KeyBindingContextPredicate { .get(left) .map(|value| value != right) .unwrap_or(true), - Self::Not(pred) => !pred.eval(contexts), - Self::Child(parent, child) => { - parent.eval(&contexts[..contexts.len() - 1]) && child.eval(contexts) + Self::Not(pred) => { + for i in 0..all_contexts.len() { + if pred.eval_inner(&all_contexts[..=i], all_contexts) { + return false; + } + } + return true; + } + // Workspace > Pane > Editor + // + // Pane > (Pane > Editor) // should match? + // (Pane > Pane) > Editor // should not match? + // Pane > !Workspace <-- should match? + // !Workspace <-- shouldn't match? + Self::Descendant(parent, child) => { + for i in 0..contexts.len() - 1 { + // [Workspace > Pane], [Editor] + if parent.eval_inner(&contexts[..=i], all_contexts) { + if !child.eval_inner(&contexts[i + 1..], &contexts[i + 1..]) { + return false; + } + return true; + } + } + return false; + } + Self::And(left, right) => { + left.eval_inner(contexts, all_contexts) && right.eval_inner(contexts, all_contexts) + } + Self::Or(left, right) => { + left.eval_inner(contexts, all_contexts) || right.eval_inner(contexts, all_contexts) } - Self::And(left, right) => left.eval(contexts) && right.eval(contexts), - Self::Or(left, right) => left.eval(contexts) || right.eval(contexts), } } @@ -285,7 +328,7 @@ impl KeyBindingContextPredicate { } match other { - KeyBindingContextPredicate::Child(_, child) => self.is_superset(child), + KeyBindingContextPredicate::Descendant(_, child) => self.is_superset(child), KeyBindingContextPredicate::And(left, right) => { self.is_superset(left) || self.is_superset(right) } @@ -375,7 +418,7 @@ impl KeyBindingContextPredicate { } fn new_child(self, other: Self) -> Result<Self> { - Ok(Self::Child(Box::new(self), Box::new(other))) + Ok(Self::Descendant(Box::new(self), Box::new(other))) } fn new_eq(self, other: Self) -> Result<Self> { @@ -418,6 +461,8 @@ fn skip_whitespace(source: &str) -> &str { #[cfg(test)] mod tests { + use core::slice; + use super::*; use crate as gpui; use KeyBindingContextPredicate::*; @@ -598,4 +643,122 @@ mod tests { assert_eq!(a.is_superset(&b), result, "({a:?}).is_superset({b:?})"); } } + + #[test] + fn test_child_operator() { + let predicate = KeyBindingContextPredicate::parse("parent > child").unwrap(); + + let parent_context = KeyContext::try_from("parent").unwrap(); + let child_context = KeyContext::try_from("child").unwrap(); + + let contexts = vec![parent_context.clone(), child_context.clone()]; + assert!(predicate.eval(&contexts)); + + let grandparent_context = KeyContext::try_from("grandparent").unwrap(); + + let contexts = vec![ + grandparent_context, + parent_context.clone(), + child_context.clone(), + ]; + assert!(predicate.eval(&contexts)); + + let other_context = KeyContext::try_from("other").unwrap(); + + let contexts = vec![other_context.clone(), child_context.clone()]; + assert!(!predicate.eval(&contexts)); + + let contexts = vec![ + parent_context.clone(), + other_context.clone(), + child_context.clone(), + ]; + assert!(predicate.eval(&contexts)); + + assert!(!predicate.eval(&[])); + assert!(!predicate.eval(slice::from_ref(&child_context))); + assert!(!predicate.eval(&[parent_context])); + + let zany_predicate = KeyBindingContextPredicate::parse("child > child").unwrap(); + assert!(!zany_predicate.eval(slice::from_ref(&child_context))); + assert!(zany_predicate.eval(&[child_context.clone(), child_context.clone()])); + } + + #[test] + fn test_not_operator() { + let not_predicate = KeyBindingContextPredicate::parse("!editor").unwrap(); + let editor_context = KeyContext::try_from("editor").unwrap(); + let workspace_context = KeyContext::try_from("workspace").unwrap(); + let parent_context = KeyContext::try_from("parent").unwrap(); + let child_context = KeyContext::try_from("child").unwrap(); + + assert!(not_predicate.eval(slice::from_ref(&workspace_context))); + assert!(!not_predicate.eval(slice::from_ref(&editor_context))); + assert!(!not_predicate.eval(&[editor_context.clone(), workspace_context.clone()])); + assert!(!not_predicate.eval(&[workspace_context.clone(), editor_context.clone()])); + + let complex_not = KeyBindingContextPredicate::parse("!editor && workspace").unwrap(); + assert!(complex_not.eval(slice::from_ref(&workspace_context))); + assert!(!complex_not.eval(&[editor_context.clone(), workspace_context.clone()])); + + let not_mode_predicate = KeyBindingContextPredicate::parse("!(mode == full)").unwrap(); + let mut mode_context = KeyContext::default(); + mode_context.set("mode", "full"); + assert!(!not_mode_predicate.eval(&[mode_context.clone()])); + + let mut other_mode_context = KeyContext::default(); + other_mode_context.set("mode", "partial"); + assert!(not_mode_predicate.eval(&[other_mode_context])); + + let not_descendant = KeyBindingContextPredicate::parse("!(parent > child)").unwrap(); + assert!(not_descendant.eval(slice::from_ref(&parent_context))); + assert!(not_descendant.eval(slice::from_ref(&child_context))); + assert!(!not_descendant.eval(&[parent_context.clone(), child_context.clone()])); + + let not_descendant = KeyBindingContextPredicate::parse("parent > !child").unwrap(); + assert!(!not_descendant.eval(slice::from_ref(&parent_context))); + assert!(!not_descendant.eval(slice::from_ref(&child_context))); + assert!(!not_descendant.eval(&[parent_context.clone(), child_context.clone()])); + + let double_not = KeyBindingContextPredicate::parse("!!editor").unwrap(); + assert!(double_not.eval(slice::from_ref(&editor_context))); + assert!(!double_not.eval(slice::from_ref(&workspace_context))); + + // Test complex descendant cases + let workspace_context = KeyContext::try_from("Workspace").unwrap(); + let pane_context = KeyContext::try_from("Pane").unwrap(); + let editor_context = KeyContext::try_from("Editor").unwrap(); + + // Workspace > Pane > Editor + let workspace_pane_editor = vec![ + workspace_context.clone(), + pane_context.clone(), + editor_context.clone(), + ]; + + // Pane > (Pane > Editor) - should not match + let pane_pane_editor = KeyBindingContextPredicate::parse("Pane > (Pane > Editor)").unwrap(); + assert!(!pane_pane_editor.eval(&workspace_pane_editor)); + + let workspace_pane_editor_predicate = + KeyBindingContextPredicate::parse("Workspace > Pane > Editor").unwrap(); + assert!(workspace_pane_editor_predicate.eval(&workspace_pane_editor)); + + // (Pane > Pane) > Editor - should not match + let pane_pane_then_editor = + KeyBindingContextPredicate::parse("(Pane > Pane) > Editor").unwrap(); + assert!(!pane_pane_then_editor.eval(&workspace_pane_editor)); + + // Pane > !Workspace - should match + let pane_not_workspace = KeyBindingContextPredicate::parse("Pane > !Workspace").unwrap(); + assert!(pane_not_workspace.eval(&[pane_context.clone(), editor_context.clone()])); + assert!(!pane_not_workspace.eval(&[pane_context.clone(), workspace_context.clone()])); + + // !Workspace - shouldn't match when Workspace is in the context + let not_workspace = KeyBindingContextPredicate::parse("!Workspace").unwrap(); + assert!(!not_workspace.eval(slice::from_ref(&workspace_context))); + assert!(not_workspace.eval(slice::from_ref(&pane_context))); + assert!(not_workspace.eval(slice::from_ref(&editor_context))); + assert!(!not_workspace.eval(&workspace_pane_editor)); + } } diff --git a/crates/gpui/src/path_builder.rs b/crates/gpui/src/path_builder.rs index 13c168b0bb90f7d209ce02cbab798faf48ae1d2f..6c8cfddd523c4d56c81ebcbbf1437b5cc418d73c 100644 --- a/crates/gpui/src/path_builder.rs +++ b/crates/gpui/src/path_builder.rs @@ -336,7 +336,10 @@ impl PathBuilder { let v1 = buf.vertices[i1]; let v2 = buf.vertices[i2]; - path.push_triangle((v0.into(), v1.into(), v2.into())); + path.push_triangle( + (v0.into(), v1.into(), v2.into()), + (point(0., 1.), point(0., 1.), point(0., 1.)), + ); } path diff --git a/crates/gpui/src/platform.rs b/crates/gpui/src/platform.rs index 0250e59a9bbc363a61377dc8c0ab01bccd820df3..b495d70dfdd3594a27ed3c1793e7e0ac4e7e0b4a 100644 --- a/crates/gpui/src/platform.rs +++ b/crates/gpui/src/platform.rs @@ -13,8 +13,7 @@ mod mac; any(target_os = "linux", target_os = "freebsd"), any(feature = "x11", feature = "wayland") ), - target_os = "windows", - feature = "macos-blade" + all(target_os = "macos", feature = "macos-blade") ))] mod blade; @@ -85,7 +84,7 @@ pub(crate) use test::*; pub(crate) use windows::*; #[cfg(any(test, feature = "test-support"))] -pub use test::{TestDispatcher, TestScreenCaptureSource}; +pub use test::{TestDispatcher, TestScreenCaptureSource, TestScreenCaptureStream}; /// Returns a background executor for the current platform. pub fn background_executor() -> BackgroundExecutor { @@ -189,13 +188,12 @@ pub(crate) trait Platform: 'static { false } #[cfg(feature = "screen-capture")] - fn screen_capture_sources( - &self, - ) -> oneshot::Receiver<Result<Vec<Box<dyn ScreenCaptureSource>>>>; + fn screen_capture_sources(&self) + -> oneshot::Receiver<Result<Vec<Rc<dyn ScreenCaptureSource>>>>; #[cfg(not(feature = "screen-capture"))] fn screen_capture_sources( &self, - ) -> oneshot::Receiver<anyhow::Result<Vec<Box<dyn ScreenCaptureSource>>>> { + ) -> oneshot::Receiver<anyhow::Result<Vec<Rc<dyn ScreenCaptureSource>>>> { let (sources_tx, sources_rx) = oneshot::channel(); sources_tx .send(Err(anyhow::anyhow!( @@ -293,10 +291,23 @@ pub trait PlatformDisplay: Send + Sync + Debug { } } +/// Metadata for a given [ScreenCaptureSource] +#[derive(Clone)] +pub struct SourceMetadata { + /// Opaque identifier of this screen. + pub id: u64, + /// Human-readable label for this source. + pub label: Option<SharedString>, + /// Whether this source is the main display. + pub is_main: Option<bool>, + /// Video resolution of this source. + pub resolution: Size<DevicePixels>, +} + /// A source of on-screen video content that can be captured. pub trait ScreenCaptureSource { - /// Returns the video resolution of this source. - fn resolution(&self) -> Result<Size<DevicePixels>>; + /// Returns metadata for this source. + fn metadata(&self) -> Result<SourceMetadata>; /// Start capture video from this source, invoking the given callback /// with each frame. @@ -308,7 +319,10 @@ pub trait ScreenCaptureSource { } /// A video stream captured from a screen. -pub trait ScreenCaptureStream {} +pub trait ScreenCaptureStream { + /// Returns metadata for this source. + fn metadata(&self) -> Result<SourceMetadata>; +} /// A frame of video captured from a screen. pub struct ScreenCaptureFrame(pub PlatformScreenCaptureFrame); @@ -433,6 +447,8 @@ impl Tiling { #[derive(Debug, Copy, Clone, Eq, PartialEq, Default)] pub(crate) struct RequestFrameOptions { pub(crate) require_presentation: bool, + /// Force refresh of all rendering states when true + pub(crate) force_render: bool, } pub(crate) trait PlatformWindow: HasWindowHandle + HasDisplayHandle { diff --git a/crates/gpui/src/platform/blade/blade_atlas.rs b/crates/gpui/src/platform/blade/blade_atlas.rs index 0b119c39101ff36199d41d7905fb6e9e25db4a68..74500ebf8324e4747122ac425388bc122953185e 100644 --- a/crates/gpui/src/platform/blade/blade_atlas.rs +++ b/crates/gpui/src/platform/blade/blade_atlas.rs @@ -38,8 +38,6 @@ impl BladeAtlasState { } pub struct BladeTextureInfo { - #[allow(dead_code)] - pub size: gpu::Extent, pub raw_view: gpu::TextureView, } @@ -63,15 +61,6 @@ impl BladeAtlas { self.0.lock().destroy(); } - #[allow(dead_code)] - pub(crate) fn clear_textures(&self, texture_kind: AtlasTextureKind) { - let mut lock = self.0.lock(); - let textures = &mut lock.storage[texture_kind]; - for texture in textures.iter_mut() { - texture.clear(); - } - } - pub fn before_frame(&self, gpu_encoder: &mut gpu::CommandEncoder) { let mut lock = self.0.lock(); lock.flush(gpu_encoder); @@ -85,13 +74,7 @@ impl BladeAtlas { pub fn get_texture_info(&self, id: AtlasTextureId) -> BladeTextureInfo { let lock = self.0.lock(); let texture = &lock.storage[id]; - let size = texture.allocator.size(); BladeTextureInfo { - size: gpu::Extent { - width: size.width as u32, - height: size.height as u32, - depth: 1, - }, raw_view: texture.raw_view, } } @@ -334,10 +317,6 @@ struct BladeAtlasTexture { } impl BladeAtlasTexture { - fn clear(&mut self) { - self.allocator.clear(); - } - fn allocate(&mut self, size: Size<DevicePixels>) -> Option<AtlasTile> { let allocation = self.allocator.allocate(size.into())?; let tile = AtlasTile { diff --git a/crates/gpui/src/platform/blade/blade_renderer.rs b/crates/gpui/src/platform/blade/blade_renderer.rs index 1b9f111b0d44f2182e5e76b17228b41b66baa32b..46d3c16c72a9c10c0e686aff425fcc236c253ce7 100644 --- a/crates/gpui/src/platform/blade/blade_renderer.rs +++ b/crates/gpui/src/platform/blade/blade_renderer.rs @@ -3,15 +3,15 @@ use super::{BladeAtlas, BladeContext}; use crate::{ - Background, Bounds, ContentMask, DevicePixels, GpuSpecs, MonochromeSprite, PathVertex, - PolychromeSprite, PrimitiveBatch, Quad, ScaledPixels, Scene, Shadow, Size, Underline, + Background, Bounds, DevicePixels, GpuSpecs, MonochromeSprite, Path, Point, PolychromeSprite, + PrimitiveBatch, Quad, ScaledPixels, Scene, Shadow, Size, Underline, }; -use blade_graphics::{self as gpu}; +use blade_graphics as gpu; use blade_util::{BufferBelt, BufferBeltDescriptor}; use bytemuck::{Pod, Zeroable}; #[cfg(target_os = "macos")] use media::core_video::CVMetalTextureCache; -use std::{mem, sync::Arc}; +use std::sync::Arc; const MAX_FRAME_TIME_MS: u32 = 10000; @@ -61,9 +61,16 @@ struct ShaderShadowsData { } #[derive(blade_macros::ShaderData)] -struct ShaderPathsData { +struct ShaderPathRasterizationData { globals: GlobalParams, b_path_vertices: gpu::BufferPiece, +} + +#[derive(blade_macros::ShaderData)] +struct ShaderPathsData { + globals: GlobalParams, + t_sprite: gpu::TextureView, + s_sprite: gpu::Sampler, b_path_sprites: gpu::BufferPiece, } @@ -102,28 +109,21 @@ struct ShaderSurfacesData { #[repr(C)] struct PathSprite { bounds: Bounds<ScaledPixels>, - color: Background, } -/// Argument buffer layout for `draw_indirect` commands. +#[derive(Clone, Debug)] #[repr(C)] -#[derive(Copy, Clone, Debug, Default, Pod, Zeroable)] -pub struct DrawIndirectArgs { - /// The number of vertices to draw. - pub vertex_count: u32, - /// The number of instances to draw. - pub instance_count: u32, - /// The Index of the first vertex to draw. - pub first_vertex: u32, - /// The instance ID of the first instance to draw. - /// - /// Has to be 0, unless [`Features::INDIRECT_FIRST_INSTANCE`](crate::Features::INDIRECT_FIRST_INSTANCE) is enabled. - pub first_instance: u32, +struct PathRasterizationVertex { + xy_position: Point<ScaledPixels>, + st_position: Point<f32>, + color: Background, + bounds: Bounds<ScaledPixels>, } struct BladePipelines { quads: gpu::RenderPipeline, shadows: gpu::RenderPipeline, + path_rasterization: gpu::RenderPipeline, paths: gpu::RenderPipeline, underlines: gpu::RenderPipeline, mono_sprites: gpu::RenderPipeline, @@ -132,7 +132,7 @@ struct BladePipelines { } impl BladePipelines { - fn new(gpu: &gpu::Context, surface_info: gpu::SurfaceInfo, sample_count: u32) -> Self { + fn new(gpu: &gpu::Context, surface_info: gpu::SurfaceInfo, path_sample_count: u32) -> Self { use gpu::ShaderData as _; log::info!( @@ -146,10 +146,7 @@ impl BladePipelines { shader.check_struct_size::<SurfaceParams>(); shader.check_struct_size::<Quad>(); shader.check_struct_size::<Shadow>(); - assert_eq!( - mem::size_of::<PathVertex<ScaledPixels>>(), - shader.get_struct_size("PathVertex") as usize, - ); + shader.check_struct_size::<PathRasterizationVertex>(); shader.check_struct_size::<PathSprite>(); shader.check_struct_size::<Underline>(); shader.check_struct_size::<MonochromeSprite>(); @@ -180,10 +177,7 @@ impl BladePipelines { depth_stencil: None, fragment: Some(shader.at("fs_quad")), color_targets, - multisample_state: gpu::MultisampleState { - sample_count, - ..Default::default() - }, + multisample_state: gpu::MultisampleState::default(), }), shadows: gpu.create_render_pipeline(gpu::RenderPipelineDesc { name: "shadows", @@ -197,8 +191,33 @@ impl BladePipelines { depth_stencil: None, fragment: Some(shader.at("fs_shadow")), color_targets, + multisample_state: gpu::MultisampleState::default(), + }), + path_rasterization: gpu.create_render_pipeline(gpu::RenderPipelineDesc { + name: "path_rasterization", + data_layouts: &[&ShaderPathRasterizationData::layout()], + vertex: shader.at("vs_path_rasterization"), + vertex_fetches: &[], + primitive: gpu::PrimitiveState { + topology: gpu::PrimitiveTopology::TriangleList, + ..Default::default() + }, + depth_stencil: None, + fragment: Some(shader.at("fs_path_rasterization")), + // The original implementation was using ADDITIVE blende mode, + // I don't know why + // color_targets: &[gpu::ColorTargetState { + // format: PATH_TEXTURE_FORMAT, + // blend: Some(gpu::BlendState::ADDITIVE), + // write_mask: gpu::ColorWrites::default(), + // }], + color_targets: &[gpu::ColorTargetState { + format: surface_info.format, + blend: Some(gpu::BlendState::PREMULTIPLIED_ALPHA_BLENDING), + write_mask: gpu::ColorWrites::default(), + }], multisample_state: gpu::MultisampleState { - sample_count, + sample_count: path_sample_count, ..Default::default() }, }), @@ -208,16 +227,20 @@ impl BladePipelines { vertex: shader.at("vs_path"), vertex_fetches: &[], primitive: gpu::PrimitiveState { - topology: gpu::PrimitiveTopology::TriangleList, + topology: gpu::PrimitiveTopology::TriangleStrip, ..Default::default() }, depth_stencil: None, fragment: Some(shader.at("fs_path")), - color_targets, - multisample_state: gpu::MultisampleState { - sample_count, - ..Default::default() - }, + color_targets: &[gpu::ColorTargetState { + format: surface_info.format, + blend: Some(gpu::BlendState { + color: gpu::BlendComponent::OVER, + alpha: gpu::BlendComponent::ADDITIVE, + }), + write_mask: gpu::ColorWrites::default(), + }], + multisample_state: gpu::MultisampleState::default(), }), underlines: gpu.create_render_pipeline(gpu::RenderPipelineDesc { name: "underlines", @@ -231,10 +254,7 @@ impl BladePipelines { depth_stencil: None, fragment: Some(shader.at("fs_underline")), color_targets, - multisample_state: gpu::MultisampleState { - sample_count, - ..Default::default() - }, + multisample_state: gpu::MultisampleState::default(), }), mono_sprites: gpu.create_render_pipeline(gpu::RenderPipelineDesc { name: "mono-sprites", @@ -248,10 +268,7 @@ impl BladePipelines { depth_stencil: None, fragment: Some(shader.at("fs_mono_sprite")), color_targets, - multisample_state: gpu::MultisampleState { - sample_count, - ..Default::default() - }, + multisample_state: gpu::MultisampleState::default(), }), poly_sprites: gpu.create_render_pipeline(gpu::RenderPipelineDesc { name: "poly-sprites", @@ -265,10 +282,7 @@ impl BladePipelines { depth_stencil: None, fragment: Some(shader.at("fs_poly_sprite")), color_targets, - multisample_state: gpu::MultisampleState { - sample_count, - ..Default::default() - }, + multisample_state: gpu::MultisampleState::default(), }), surfaces: gpu.create_render_pipeline(gpu::RenderPipelineDesc { name: "surfaces", @@ -282,10 +296,7 @@ impl BladePipelines { depth_stencil: None, fragment: Some(shader.at("fs_surface")), color_targets, - multisample_state: gpu::MultisampleState { - sample_count, - ..Default::default() - }, + multisample_state: gpu::MultisampleState::default(), }), } } @@ -293,6 +304,7 @@ impl BladePipelines { fn destroy(&mut self, gpu: &gpu::Context) { gpu.destroy_render_pipeline(&mut self.quads); gpu.destroy_render_pipeline(&mut self.shadows); + gpu.destroy_render_pipeline(&mut self.path_rasterization); gpu.destroy_render_pipeline(&mut self.paths); gpu.destroy_render_pipeline(&mut self.underlines); gpu.destroy_render_pipeline(&mut self.mono_sprites); @@ -322,9 +334,11 @@ pub struct BladeRenderer { atlas_sampler: gpu::Sampler, #[cfg(target_os = "macos")] core_video_texture_cache: CVMetalTextureCache, - sample_count: u32, - texture_msaa: Option<gpu::Texture>, - texture_view_msaa: Option<gpu::TextureView>, + path_sample_count: u32, + path_intermediate_texture: gpu::Texture, + path_intermediate_texture_view: gpu::TextureView, + path_intermediate_msaa_texture: Option<gpu::Texture>, + path_intermediate_msaa_texture_view: Option<gpu::TextureView>, } impl BladeRenderer { @@ -333,18 +347,6 @@ impl BladeRenderer { window: &I, config: BladeSurfaceConfig, ) -> anyhow::Result<Self> { - // workaround for https://github.com/zed-industries/zed/issues/26143 - let sample_count = std::env::var("ZED_SAMPLE_COUNT") - .ok() - .or_else(|| std::env::var("ZED_PATH_SAMPLE_COUNT").ok()) - .and_then(|v| v.parse().ok()) - .or_else(|| { - [4, 2, 1] - .into_iter() - .find(|count| context.gpu.supports_texture_sample_count(*count)) - }) - .unwrap_or(1); - let surface_config = gpu::SurfaceConfig { size: config.size, usage: gpu::TextureUsage::TARGET, @@ -358,21 +360,21 @@ impl BladeRenderer { .create_surface_configured(window, surface_config) .map_err(|err| anyhow::anyhow!("Failed to create surface: {err:?}"))?; - let (texture_msaa, texture_view_msaa) = create_msaa_texture_if_needed( - &context.gpu, - surface.info().format, - config.size.width, - config.size.height, - sample_count, - ) - .unzip(); - let command_encoder = context.gpu.create_command_encoder(gpu::CommandEncoderDesc { name: "main", buffer_count: 2, }); - - let pipelines = BladePipelines::new(&context.gpu, surface.info(), sample_count); + // workaround for https://github.com/zed-industries/zed/issues/26143 + let path_sample_count = std::env::var("ZED_PATH_SAMPLE_COUNT") + .ok() + .and_then(|v| v.parse().ok()) + .or_else(|| { + [4, 2, 1] + .into_iter() + .find(|count| context.gpu.supports_texture_sample_count(*count)) + }) + .unwrap_or(1); + let pipelines = BladePipelines::new(&context.gpu, surface.info(), path_sample_count); let instance_belt = BufferBelt::new(BufferBeltDescriptor { memory: gpu::Memory::Shared, min_chunk_size: 0x1000, @@ -380,12 +382,29 @@ impl BladeRenderer { }); let atlas = Arc::new(BladeAtlas::new(&context.gpu)); let atlas_sampler = context.gpu.create_sampler(gpu::SamplerDesc { - name: "atlas", + name: "path rasterization sampler", mag_filter: gpu::FilterMode::Linear, min_filter: gpu::FilterMode::Linear, ..Default::default() }); + let (path_intermediate_texture, path_intermediate_texture_view) = + create_path_intermediate_texture( + &context.gpu, + surface.info().format, + config.size.width, + config.size.height, + ); + let (path_intermediate_msaa_texture, path_intermediate_msaa_texture_view) = + create_msaa_texture_if_needed( + &context.gpu, + surface.info().format, + config.size.width, + config.size.height, + path_sample_count, + ) + .unzip(); + #[cfg(target_os = "macos")] let core_video_texture_cache = unsafe { CVMetalTextureCache::new( @@ -406,9 +425,11 @@ impl BladeRenderer { atlas_sampler, #[cfg(target_os = "macos")] core_video_texture_cache, - sample_count, - texture_msaa, - texture_view_msaa, + path_sample_count, + path_intermediate_texture, + path_intermediate_texture_view, + path_intermediate_msaa_texture, + path_intermediate_msaa_texture_view, }) } @@ -461,24 +482,35 @@ impl BladeRenderer { self.surface_config.size = gpu_size; self.gpu .reconfigure_surface(&mut self.surface, self.surface_config); - - if let Some(texture_msaa) = self.texture_msaa { - self.gpu.destroy_texture(texture_msaa); + self.gpu.destroy_texture(self.path_intermediate_texture); + self.gpu + .destroy_texture_view(self.path_intermediate_texture_view); + if let Some(msaa_texture) = self.path_intermediate_msaa_texture { + self.gpu.destroy_texture(msaa_texture); } - if let Some(texture_view_msaa) = self.texture_view_msaa { - self.gpu.destroy_texture_view(texture_view_msaa); + if let Some(msaa_view) = self.path_intermediate_msaa_texture_view { + self.gpu.destroy_texture_view(msaa_view); } - - let (texture_msaa, texture_view_msaa) = create_msaa_texture_if_needed( - &self.gpu, - self.surface.info().format, - gpu_size.width, - gpu_size.height, - self.sample_count, - ) - .unzip(); - self.texture_msaa = texture_msaa; - self.texture_view_msaa = texture_view_msaa; + let (path_intermediate_texture, path_intermediate_texture_view) = + create_path_intermediate_texture( + &self.gpu, + self.surface.info().format, + gpu_size.width, + gpu_size.height, + ); + self.path_intermediate_texture = path_intermediate_texture; + self.path_intermediate_texture_view = path_intermediate_texture_view; + let (path_intermediate_msaa_texture, path_intermediate_msaa_texture_view) = + create_msaa_texture_if_needed( + &self.gpu, + self.surface.info().format, + gpu_size.width, + gpu_size.height, + self.path_sample_count, + ) + .unzip(); + self.path_intermediate_msaa_texture = path_intermediate_msaa_texture; + self.path_intermediate_msaa_texture_view = path_intermediate_msaa_texture_view; } } @@ -489,7 +521,8 @@ impl BladeRenderer { self.gpu .reconfigure_surface(&mut self.surface, self.surface_config); self.pipelines.destroy(&self.gpu); - self.pipelines = BladePipelines::new(&self.gpu, self.surface.info(), self.sample_count); + self.pipelines = + BladePipelines::new(&self.gpu, self.surface.info(), self.path_sample_count); } } @@ -527,6 +560,67 @@ impl BladeRenderer { objc2::rc::Retained::as_ptr(&self.surface.metal_layer()) as *mut _ } + #[profiling::function] + fn draw_paths_to_intermediate( + &mut self, + paths: &[Path<ScaledPixels>], + width: f32, + height: f32, + ) { + self.command_encoder + .init_texture(self.path_intermediate_texture); + if let Some(msaa_texture) = self.path_intermediate_msaa_texture { + self.command_encoder.init_texture(msaa_texture); + } + + let target = if let Some(msaa_view) = self.path_intermediate_msaa_texture_view { + gpu::RenderTarget { + view: msaa_view, + init_op: gpu::InitOp::Clear(gpu::TextureColor::TransparentBlack), + finish_op: gpu::FinishOp::ResolveTo(self.path_intermediate_texture_view), + } + } else { + gpu::RenderTarget { + view: self.path_intermediate_texture_view, + init_op: gpu::InitOp::Clear(gpu::TextureColor::TransparentBlack), + finish_op: gpu::FinishOp::Store, + } + }; + if let mut pass = self.command_encoder.render( + "rasterize paths", + gpu::RenderTargetSet { + colors: &[target], + depth_stencil: None, + }, + ) { + let globals = GlobalParams { + viewport_size: [width, height], + premultiplied_alpha: 0, + pad: 0, + }; + let mut encoder = pass.with(&self.pipelines.path_rasterization); + + let mut vertices = Vec::new(); + for path in paths { + vertices.extend(path.vertices.iter().map(|v| PathRasterizationVertex { + xy_position: v.xy_position, + st_position: v.st_position, + color: path.color, + bounds: path.clipped_bounds(), + })); + } + let vertex_buf = unsafe { self.instance_belt.alloc_typed(&vertices, &self.gpu) }; + encoder.bind( + 0, + &ShaderPathRasterizationData { + globals, + b_path_vertices: vertex_buf, + }, + ); + encoder.draw(0, vertices.len() as u32, 0, 1); + } + } + pub fn destroy(&mut self) { self.wait_for_gpu(); self.atlas.destroy(); @@ -535,11 +629,14 @@ impl BladeRenderer { self.gpu.destroy_command_encoder(&mut self.command_encoder); self.pipelines.destroy(&self.gpu); self.gpu.destroy_surface(&mut self.surface); - if let Some(texture_msaa) = self.texture_msaa { - self.gpu.destroy_texture(texture_msaa); + self.gpu.destroy_texture(self.path_intermediate_texture); + self.gpu + .destroy_texture_view(self.path_intermediate_texture_view); + if let Some(msaa_texture) = self.path_intermediate_msaa_texture { + self.gpu.destroy_texture(msaa_texture); } - if let Some(texture_view_msaa) = self.texture_view_msaa { - self.gpu.destroy_texture_view(texture_view_msaa); + if let Some(msaa_view) = self.path_intermediate_msaa_texture_view { + self.gpu.destroy_texture_view(msaa_view); } } @@ -551,10 +648,6 @@ impl BladeRenderer { profiling::scope!("acquire frame"); self.surface.acquire_frame() }; - let frame_view = frame.texture_view(); - if let Some(texture_msaa) = self.texture_msaa { - self.command_encoder.init_texture(texture_msaa); - } self.command_encoder.init_texture(frame.texture()); let globals = GlobalParams { @@ -569,253 +662,245 @@ impl BladeRenderer { pad: 0, }; - let target = if let Some(texture_view_msaa) = self.texture_view_msaa { - gpu::RenderTarget { - view: texture_view_msaa, - init_op: gpu::InitOp::Clear(gpu::TextureColor::TransparentBlack), - finish_op: gpu::FinishOp::ResolveTo(frame_view), - } - } else { - gpu::RenderTarget { - view: frame_view, - init_op: gpu::InitOp::Clear(gpu::TextureColor::TransparentBlack), - finish_op: gpu::FinishOp::Store, - } - }; - - // draw to the target texture - if let mut pass = self.command_encoder.render( + let mut pass = self.command_encoder.render( "main", gpu::RenderTargetSet { - colors: &[target], + colors: &[gpu::RenderTarget { + view: frame.texture_view(), + init_op: gpu::InitOp::Clear(gpu::TextureColor::TransparentBlack), + finish_op: gpu::FinishOp::Store, + }], depth_stencil: None, }, - ) { - profiling::scope!("render pass"); - for batch in scene.batches() { - match batch { - PrimitiveBatch::Quads(quads) => { - let instance_buf = - unsafe { self.instance_belt.alloc_typed(quads, &self.gpu) }; - let mut encoder = pass.with(&self.pipelines.quads); - encoder.bind( - 0, - &ShaderQuadsData { - globals, - b_quads: instance_buf, - }, - ); - encoder.draw(0, 4, 0, quads.len() as u32); - } - PrimitiveBatch::Shadows(shadows) => { - let instance_buf = - unsafe { self.instance_belt.alloc_typed(shadows, &self.gpu) }; - let mut encoder = pass.with(&self.pipelines.shadows); - encoder.bind( - 0, - &ShaderShadowsData { - globals, - b_shadows: instance_buf, - }, - ); - encoder.draw(0, 4, 0, shadows.len() as u32); - } - PrimitiveBatch::Paths(paths) => { - let mut encoder = pass.with(&self.pipelines.paths); - - let mut vertices = Vec::new(); - let mut sprites = Vec::with_capacity(paths.len()); - let mut draw_indirect_commands = Vec::with_capacity(paths.len()); - let mut first_vertex = 0; - - for (i, path) in paths.iter().enumerate() { - draw_indirect_commands.push(DrawIndirectArgs { - vertex_count: path.vertices.len() as u32, - instance_count: 1, - first_vertex, - first_instance: i as u32, - }); - first_vertex += path.vertices.len() as u32; - - vertices.extend(path.vertices.iter().map(|v| PathVertex { - xy_position: v.xy_position, - content_mask: ContentMask { - bounds: path.content_mask.bounds, - }, - })); + ); - sprites.push(PathSprite { - bounds: path.bounds, - color: path.color, - }); + profiling::scope!("render pass"); + for batch in scene.batches() { + match batch { + PrimitiveBatch::Quads(quads) => { + let instance_buf = unsafe { self.instance_belt.alloc_typed(quads, &self.gpu) }; + let mut encoder = pass.with(&self.pipelines.quads); + encoder.bind( + 0, + &ShaderQuadsData { + globals, + b_quads: instance_buf, + }, + ); + encoder.draw(0, 4, 0, quads.len() as u32); + } + PrimitiveBatch::Shadows(shadows) => { + let instance_buf = + unsafe { self.instance_belt.alloc_typed(shadows, &self.gpu) }; + let mut encoder = pass.with(&self.pipelines.shadows); + encoder.bind( + 0, + &ShaderShadowsData { + globals, + b_shadows: instance_buf, + }, + ); + encoder.draw(0, 4, 0, shadows.len() as u32); + } + PrimitiveBatch::Paths(paths) => { + let Some(first_path) = paths.first() else { + continue; + }; + drop(pass); + self.draw_paths_to_intermediate( + paths, + self.surface_config.size.width as f32, + self.surface_config.size.height as f32, + ); + pass = self.command_encoder.render( + "main", + gpu::RenderTargetSet { + colors: &[gpu::RenderTarget { + view: frame.texture_view(), + init_op: gpu::InitOp::Load, + finish_op: gpu::FinishOp::Store, + }], + depth_stencil: None, + }, + ); + let mut encoder = pass.with(&self.pipelines.paths); + // When copying paths from the intermediate texture to the drawable, + // each pixel must only be copied once, in case of transparent paths. + // + // If all paths have the same draw order, then their bounds are all + // disjoint, so we can copy each path's bounds individually. If this + // batch combines different draw orders, we perform a single copy + // for a minimal spanning rect. + let sprites = if paths.last().unwrap().order == first_path.order { + paths + .iter() + .map(|path| PathSprite { + bounds: path.clipped_bounds(), + }) + .collect() + } else { + let mut bounds = first_path.clipped_bounds(); + for path in paths.iter().skip(1) { + bounds = bounds.union(&path.clipped_bounds()); } - - let b_path_vertices = - unsafe { self.instance_belt.alloc_typed(&vertices, &self.gpu) }; - let instance_buf = - unsafe { self.instance_belt.alloc_typed(&sprites, &self.gpu) }; - let indirect_buf = unsafe { - self.instance_belt - .alloc_typed(&draw_indirect_commands, &self.gpu) + vec![PathSprite { bounds }] + }; + let instance_buf = + unsafe { self.instance_belt.alloc_typed(&sprites, &self.gpu) }; + encoder.bind( + 0, + &ShaderPathsData { + globals, + t_sprite: self.path_intermediate_texture_view, + s_sprite: self.atlas_sampler, + b_path_sprites: instance_buf, + }, + ); + encoder.draw(0, 4, 0, sprites.len() as u32); + } + PrimitiveBatch::Underlines(underlines) => { + let instance_buf = + unsafe { self.instance_belt.alloc_typed(underlines, &self.gpu) }; + let mut encoder = pass.with(&self.pipelines.underlines); + encoder.bind( + 0, + &ShaderUnderlinesData { + globals, + b_underlines: instance_buf, + }, + ); + encoder.draw(0, 4, 0, underlines.len() as u32); + } + PrimitiveBatch::MonochromeSprites { + texture_id, + sprites, + } => { + let tex_info = self.atlas.get_texture_info(texture_id); + let instance_buf = + unsafe { self.instance_belt.alloc_typed(sprites, &self.gpu) }; + let mut encoder = pass.with(&self.pipelines.mono_sprites); + encoder.bind( + 0, + &ShaderMonoSpritesData { + globals, + t_sprite: tex_info.raw_view, + s_sprite: self.atlas_sampler, + b_mono_sprites: instance_buf, + }, + ); + encoder.draw(0, 4, 0, sprites.len() as u32); + } + PrimitiveBatch::PolychromeSprites { + texture_id, + sprites, + } => { + let tex_info = self.atlas.get_texture_info(texture_id); + let instance_buf = + unsafe { self.instance_belt.alloc_typed(sprites, &self.gpu) }; + let mut encoder = pass.with(&self.pipelines.poly_sprites); + encoder.bind( + 0, + &ShaderPolySpritesData { + globals, + t_sprite: tex_info.raw_view, + s_sprite: self.atlas_sampler, + b_poly_sprites: instance_buf, + }, + ); + encoder.draw(0, 4, 0, sprites.len() as u32); + } + PrimitiveBatch::Surfaces(surfaces) => { + let mut _encoder = pass.with(&self.pipelines.surfaces); + + for surface in surfaces { + #[cfg(not(target_os = "macos"))] + { + let _ = surface; + continue; }; - encoder.bind( - 0, - &ShaderPathsData { - globals, - b_path_vertices, - b_path_sprites: instance_buf, - }, - ); - - for i in 0..paths.len() { - encoder.draw_indirect(indirect_buf.buffer.at(indirect_buf.offset - + (i * mem::size_of::<DrawIndirectArgs>()) as u64)); - } - } - PrimitiveBatch::Underlines(underlines) => { - let instance_buf = - unsafe { self.instance_belt.alloc_typed(underlines, &self.gpu) }; - let mut encoder = pass.with(&self.pipelines.underlines); - encoder.bind( - 0, - &ShaderUnderlinesData { - globals, - b_underlines: instance_buf, - }, - ); - encoder.draw(0, 4, 0, underlines.len() as u32); - } - PrimitiveBatch::MonochromeSprites { - texture_id, - sprites, - } => { - let tex_info = self.atlas.get_texture_info(texture_id); - let instance_buf = - unsafe { self.instance_belt.alloc_typed(sprites, &self.gpu) }; - let mut encoder = pass.with(&self.pipelines.mono_sprites); - encoder.bind( - 0, - &ShaderMonoSpritesData { - globals, - t_sprite: tex_info.raw_view, - s_sprite: self.atlas_sampler, - b_mono_sprites: instance_buf, - }, - ); - encoder.draw(0, 4, 0, sprites.len() as u32); - } - PrimitiveBatch::PolychromeSprites { - texture_id, - sprites, - } => { - let tex_info = self.atlas.get_texture_info(texture_id); - let instance_buf = - unsafe { self.instance_belt.alloc_typed(sprites, &self.gpu) }; - let mut encoder = pass.with(&self.pipelines.poly_sprites); - encoder.bind( - 0, - &ShaderPolySpritesData { - globals, - t_sprite: tex_info.raw_view, - s_sprite: self.atlas_sampler, - b_poly_sprites: instance_buf, - }, - ); - encoder.draw(0, 4, 0, sprites.len() as u32); - } - PrimitiveBatch::Surfaces(surfaces) => { - let mut _encoder = pass.with(&self.pipelines.surfaces); - - for surface in surfaces { - #[cfg(not(target_os = "macos"))] - { - let _ = surface; - continue; - }; + #[cfg(target_os = "macos")] + { + let (t_y, t_cb_cr) = unsafe { + use core_foundation::base::TCFType as _; + use std::ptr; - #[cfg(target_os = "macos")] - { - let (t_y, t_cb_cr) = unsafe { - use core_foundation::base::TCFType as _; - use std::ptr; - - assert_eq!( + assert_eq!( surface.image_buffer.get_pixel_format(), core_video::pixel_buffer::kCVPixelFormatType_420YpCbCr8BiPlanarFullRange ); - let y_texture = self - .core_video_texture_cache - .create_texture_from_image( - surface.image_buffer.as_concrete_TypeRef(), - ptr::null(), - metal::MTLPixelFormat::R8Unorm, - surface.image_buffer.get_width_of_plane(0), - surface.image_buffer.get_height_of_plane(0), - 0, - ) - .unwrap(); - let cb_cr_texture = self - .core_video_texture_cache - .create_texture_from_image( - surface.image_buffer.as_concrete_TypeRef(), - ptr::null(), - metal::MTLPixelFormat::RG8Unorm, - surface.image_buffer.get_width_of_plane(1), - surface.image_buffer.get_height_of_plane(1), - 1, - ) - .unwrap(); - ( - gpu::TextureView::from_metal_texture( - &objc2::rc::Retained::retain( - foreign_types::ForeignTypeRef::as_ptr( - y_texture.as_texture_ref(), - ) - as *mut objc2::runtime::ProtocolObject< - dyn objc2_metal::MTLTexture, - >, + let y_texture = self + .core_video_texture_cache + .create_texture_from_image( + surface.image_buffer.as_concrete_TypeRef(), + ptr::null(), + metal::MTLPixelFormat::R8Unorm, + surface.image_buffer.get_width_of_plane(0), + surface.image_buffer.get_height_of_plane(0), + 0, + ) + .unwrap(); + let cb_cr_texture = self + .core_video_texture_cache + .create_texture_from_image( + surface.image_buffer.as_concrete_TypeRef(), + ptr::null(), + metal::MTLPixelFormat::RG8Unorm, + surface.image_buffer.get_width_of_plane(1), + surface.image_buffer.get_height_of_plane(1), + 1, + ) + .unwrap(); + ( + gpu::TextureView::from_metal_texture( + &objc2::rc::Retained::retain( + foreign_types::ForeignTypeRef::as_ptr( + y_texture.as_texture_ref(), ) - .unwrap(), - gpu::TexelAspects::COLOR, - ), - gpu::TextureView::from_metal_texture( - &objc2::rc::Retained::retain( - foreign_types::ForeignTypeRef::as_ptr( - cb_cr_texture.as_texture_ref(), - ) - as *mut objc2::runtime::ProtocolObject< - dyn objc2_metal::MTLTexture, - >, + as *mut objc2::runtime::ProtocolObject< + dyn objc2_metal::MTLTexture, + >, + ) + .unwrap(), + gpu::TexelAspects::COLOR, + ), + gpu::TextureView::from_metal_texture( + &objc2::rc::Retained::retain( + foreign_types::ForeignTypeRef::as_ptr( + cb_cr_texture.as_texture_ref(), ) - .unwrap(), - gpu::TexelAspects::COLOR, - ), - ) - }; - - _encoder.bind( - 0, - &ShaderSurfacesData { - globals, - surface_locals: SurfaceParams { - bounds: surface.bounds.into(), - content_mask: surface.content_mask.bounds.into(), - }, - t_y, - t_cb_cr, - s_surface: self.atlas_sampler, + as *mut objc2::runtime::ProtocolObject< + dyn objc2_metal::MTLTexture, + >, + ) + .unwrap(), + gpu::TexelAspects::COLOR, + ), + ) + }; + + _encoder.bind( + 0, + &ShaderSurfacesData { + globals, + surface_locals: SurfaceParams { + bounds: surface.bounds.into(), + content_mask: surface.content_mask.bounds.into(), }, - ); + t_y, + t_cb_cr, + s_surface: self.atlas_sampler, + }, + ); - _encoder.draw(0, 4, 0, 1); - } + _encoder.draw(0, 4, 0, 1); } } } } } + drop(pass); self.command_encoder.present(frame); let sync_point = self.gpu.submit(&mut self.command_encoder); @@ -829,6 +914,39 @@ impl BladeRenderer { } } +fn create_path_intermediate_texture( + gpu: &gpu::Context, + format: gpu::TextureFormat, + width: u32, + height: u32, +) -> (gpu::Texture, gpu::TextureView) { + let texture = gpu.create_texture(gpu::TextureDesc { + name: "path intermediate", + format, + size: gpu::Extent { + width, + height, + depth: 1, + }, + array_layer_count: 1, + mip_level_count: 1, + sample_count: 1, + dimension: gpu::TextureDimension::D2, + usage: gpu::TextureUsage::COPY | gpu::TextureUsage::RESOURCE | gpu::TextureUsage::TARGET, + external: None, + }); + let texture_view = gpu.create_texture_view( + texture, + gpu::TextureViewDesc { + name: "path intermediate view", + format, + dimension: gpu::ViewDimension::D2, + subresources: &Default::default(), + }, + ); + (texture, texture_view) +} + fn create_msaa_texture_if_needed( gpu: &gpu::Context, format: gpu::TextureFormat, @@ -839,9 +957,8 @@ fn create_msaa_texture_if_needed( if sample_count <= 1 { return None; } - let texture_msaa = gpu.create_texture(gpu::TextureDesc { - name: "msaa", + name: "path intermediate msaa", format, size: gpu::Extent { width, @@ -858,7 +975,7 @@ fn create_msaa_texture_if_needed( let texture_view_msaa = gpu.create_texture_view( texture_msaa, gpu::TextureViewDesc { - name: "msaa view", + name: "path intermediate msaa view", format, dimension: gpu::ViewDimension::D2, subresources: &Default::default(), diff --git a/crates/gpui/src/platform/blade/shaders.wgsl b/crates/gpui/src/platform/blade/shaders.wgsl index 00c9d07af7d670a8bc00f4374c143bb28ff2b6d6..b1ffb1812effa24e674e0238fcc6ddef7dc0f882 100644 --- a/crates/gpui/src/platform/blade/shaders.wgsl +++ b/crates/gpui/src/platform/blade/shaders.wgsl @@ -922,62 +922,103 @@ fn fs_shadow(input: ShadowVarying) -> @location(0) vec4<f32> { return blend_color(input.color, alpha); } -// --- paths --- // +// --- path rasterization --- // -struct PathVertex { +struct PathRasterizationVertex { xy_position: vec2<f32>, - content_mask: Bounds, + st_position: vec2<f32>, + color: Background, + bounds: Bounds, +} + +var<storage, read> b_path_vertices: array<PathRasterizationVertex>; + +struct PathRasterizationVarying { + @builtin(position) position: vec4<f32>, + @location(0) st_position: vec2<f32>, + @location(1) vertex_id: u32, + //TODO: use `clip_distance` once Naga supports it + @location(3) clip_distances: vec4<f32>, } +@vertex +fn vs_path_rasterization(@builtin(vertex_index) vertex_id: u32) -> PathRasterizationVarying { + let v = b_path_vertices[vertex_id]; + + var out = PathRasterizationVarying(); + out.position = to_device_position_impl(v.xy_position); + out.st_position = v.st_position; + out.vertex_id = vertex_id; + out.clip_distances = distance_from_clip_rect_impl(v.xy_position, v.bounds); + return out; +} + +@fragment +fn fs_path_rasterization(input: PathRasterizationVarying) -> @location(0) vec4<f32> { + let dx = dpdx(input.st_position); + let dy = dpdy(input.st_position); + if (any(input.clip_distances < vec4<f32>(0.0))) { + return vec4<f32>(0.0); + } + + let v = b_path_vertices[input.vertex_id]; + let background = v.color; + let bounds = v.bounds; + + var alpha: f32; + if (length(vec2<f32>(dx.x, dy.x)) < 0.001) { + // If the gradient is too small, return a solid color. + alpha = 1.0; + } else { + let gradient = 2.0 * input.st_position.xx * vec2<f32>(dx.x, dy.x) - vec2<f32>(dx.y, dy.y); + let f = input.st_position.x * input.st_position.x - input.st_position.y; + let distance = f / length(gradient); + alpha = saturate(0.5 - distance); + } + let gradient_color = prepare_gradient_color( + background.tag, + background.color_space, + background.solid, + background.colors, + ); + let color = gradient_color(background, input.position.xy, bounds, + gradient_color.solid, gradient_color.color0, gradient_color.color1); + return vec4<f32>(color.rgb * color.a * alpha, color.a * alpha); +} + +// --- paths --- // + struct PathSprite { bounds: Bounds, - color: Background, } -var<storage, read> b_path_vertices: array<PathVertex>; var<storage, read> b_path_sprites: array<PathSprite>; struct PathVarying { @builtin(position) position: vec4<f32>, - @location(0) clip_distances: vec4<f32>, - @location(1) @interpolate(flat) instance_id: u32, - @location(2) @interpolate(flat) color_solid: vec4<f32>, - @location(3) @interpolate(flat) color0: vec4<f32>, - @location(4) @interpolate(flat) color1: vec4<f32>, + @location(0) texture_coords: vec2<f32>, } @vertex fn vs_path(@builtin(vertex_index) vertex_id: u32, @builtin(instance_index) instance_id: u32) -> PathVarying { - let v = b_path_vertices[vertex_id]; + let unit_vertex = vec2<f32>(f32(vertex_id & 1u), 0.5 * f32(vertex_id & 2u)); let sprite = b_path_sprites[instance_id]; + // Don't apply content mask because it was already accounted for when rasterizing the path. + let device_position = to_device_position(unit_vertex, sprite.bounds); + // For screen-space intermediate texture, convert screen position to texture coordinates + let screen_position = sprite.bounds.origin + unit_vertex * sprite.bounds.size; + let texture_coords = screen_position / globals.viewport_size; var out = PathVarying(); - out.position = to_device_position_impl(v.xy_position); - out.clip_distances = distance_from_clip_rect_impl(v.xy_position, v.content_mask); - out.instance_id = instance_id; + out.position = device_position; + out.texture_coords = texture_coords; - let gradient = prepare_gradient_color( - sprite.color.tag, - sprite.color.color_space, - sprite.color.solid, - sprite.color.colors - ); - out.color_solid = gradient.solid; - out.color0 = gradient.color0; - out.color1 = gradient.color1; return out; } @fragment fn fs_path(input: PathVarying) -> @location(0) vec4<f32> { - if any(input.clip_distances < vec4<f32>(0.0)) { - return vec4<f32>(0.0); - } - - let sprite = b_path_sprites[input.instance_id]; - let background = sprite.color; - let color = gradient_color(background, input.position.xy, sprite.bounds, - input.color_solid, input.color0, input.color1); - return blend_color(color, 1.0); + let sample = textureSample(t_sprite, s_sprite, input.texture_coords); + return sample; } // --- underlines --- // diff --git a/crates/gpui/src/platform/keystroke.rs b/crates/gpui/src/platform/keystroke.rs index 40387a820230cfc0f73f90643c082619ceaa595a..24601eefd6de450622247caaca5ff680c60a3257 100644 --- a/crates/gpui/src/platform/keystroke.rs +++ b/crates/gpui/src/platform/keystroke.rs @@ -13,6 +13,9 @@ pub struct Keystroke { /// key is the character printed on the key that was pressed /// e.g. for option-s, key is "s" + /// On layouts that do not have ascii keys (e.g. Thai) + /// this will be the ASCII-equivalent character (q instead of ๆ), + /// and the typed character will be present in key_char. pub key: String, /// key_char is the character that could have been typed when @@ -531,11 +534,62 @@ impl Modifiers { /// Checks if this [`Modifiers`] is a subset of another [`Modifiers`]. pub fn is_subset_of(&self, other: &Modifiers) -> bool { - (other.control || !self.control) - && (other.alt || !self.alt) - && (other.shift || !self.shift) - && (other.platform || !self.platform) - && (other.function || !self.function) + (*other & *self) == *self + } +} + +impl std::ops::BitOr for Modifiers { + type Output = Self; + + fn bitor(mut self, other: Self) -> Self::Output { + self |= other; + self + } +} + +impl std::ops::BitOrAssign for Modifiers { + fn bitor_assign(&mut self, other: Self) { + self.control |= other.control; + self.alt |= other.alt; + self.shift |= other.shift; + self.platform |= other.platform; + self.function |= other.function; + } +} + +impl std::ops::BitXor for Modifiers { + type Output = Self; + fn bitxor(mut self, rhs: Self) -> Self::Output { + self ^= rhs; + self + } +} + +impl std::ops::BitXorAssign for Modifiers { + fn bitxor_assign(&mut self, other: Self) { + self.control ^= other.control; + self.alt ^= other.alt; + self.shift ^= other.shift; + self.platform ^= other.platform; + self.function ^= other.function; + } +} + +impl std::ops::BitAnd for Modifiers { + type Output = Self; + fn bitand(mut self, rhs: Self) -> Self::Output { + self &= rhs; + self + } +} + +impl std::ops::BitAndAssign for Modifiers { + fn bitand_assign(&mut self, other: Self) { + self.control &= other.control; + self.alt &= other.alt; + self.shift &= other.shift; + self.platform &= other.platform; + self.function &= other.function; } } diff --git a/crates/gpui/src/platform/linux/headless/client.rs b/crates/gpui/src/platform/linux/headless/client.rs index 663a740389e68c0505a4b3f1f55a3b4681aacfa6..da54db371033bac53e2ac3324306fa86eb57fb57 100644 --- a/crates/gpui/src/platform/linux/headless/client.rs +++ b/crates/gpui/src/platform/linux/headless/client.rs @@ -73,7 +73,7 @@ impl LinuxClient for HeadlessClient { #[cfg(feature = "screen-capture")] fn screen_capture_sources( &self, - ) -> futures::channel::oneshot::Receiver<anyhow::Result<Vec<Box<dyn crate::ScreenCaptureSource>>>> + ) -> futures::channel::oneshot::Receiver<anyhow::Result<Vec<Rc<dyn crate::ScreenCaptureSource>>>> { let (mut tx, rx) = futures::channel::oneshot::channel(); tx.send(Err(anyhow::anyhow!( diff --git a/crates/gpui/src/platform/linux/platform.rs b/crates/gpui/src/platform/linux/platform.rs index af53899b437c244fd06d43b7963920c9596b94a0..fe6a36baa854856eb961a020ab35a7bd0195d465 100644 --- a/crates/gpui/src/platform/linux/platform.rs +++ b/crates/gpui/src/platform/linux/platform.rs @@ -56,7 +56,7 @@ pub trait LinuxClient { #[cfg(feature = "screen-capture")] fn screen_capture_sources( &self, - ) -> oneshot::Receiver<Result<Vec<Box<dyn crate::ScreenCaptureSource>>>>; + ) -> oneshot::Receiver<Result<Vec<Rc<dyn crate::ScreenCaptureSource>>>>; fn open_window( &self, @@ -245,7 +245,7 @@ impl<P: LinuxClient + 'static> Platform for P { #[cfg(feature = "screen-capture")] fn screen_capture_sources( &self, - ) -> oneshot::Receiver<Result<Vec<Box<dyn crate::ScreenCaptureSource>>>> { + ) -> oneshot::Receiver<Result<Vec<Rc<dyn crate::ScreenCaptureSource>>>> { self.screen_capture_sources() } @@ -706,6 +706,60 @@ pub(super) fn log_cursor_icon_warning(message: impl std::fmt::Display) { } } +#[cfg(any(feature = "wayland", feature = "x11"))] +fn guess_ascii(keycode: Keycode, shift: bool) -> Option<char> { + let c = match (keycode.raw(), shift) { + (24, _) => 'q', + (25, _) => 'w', + (26, _) => 'e', + (27, _) => 'r', + (28, _) => 't', + (29, _) => 'y', + (30, _) => 'u', + (31, _) => 'i', + (32, _) => 'o', + (33, _) => 'p', + (34, false) => '[', + (34, true) => '{', + (35, false) => ']', + (35, true) => '}', + (38, _) => 'a', + (39, _) => 's', + (40, _) => 'd', + (41, _) => 'f', + (42, _) => 'g', + (43, _) => 'h', + (44, _) => 'j', + (45, _) => 'k', + (46, _) => 'l', + (47, false) => ';', + (47, true) => ':', + (48, false) => '\'', + (48, true) => '"', + (49, false) => '`', + (49, true) => '~', + (51, false) => '\\', + (51, true) => '|', + (52, _) => 'z', + (53, _) => 'x', + (54, _) => 'c', + (55, _) => 'v', + (56, _) => 'b', + (57, _) => 'n', + (58, _) => 'm', + (59, false) => ',', + (59, true) => '>', + (60, false) => '.', + (60, true) => '<', + (61, false) => '/', + (61, true) => '?', + + _ => return None, + }; + + Some(c) +} + #[cfg(any(feature = "wayland", feature = "x11"))] impl crate::Keystroke { pub(super) fn from_xkb( @@ -768,11 +822,43 @@ impl crate::Keystroke { Keysym::underscore => "_".to_owned(), Keysym::equal => "=".to_owned(), Keysym::plus => "+".to_owned(), + Keysym::space => "space".to_owned(), + Keysym::BackSpace => "backspace".to_owned(), + Keysym::Tab => "tab".to_owned(), + Keysym::Delete => "delete".to_owned(), + Keysym::Escape => "escape".to_owned(), + + Keysym::Left => "left".to_owned(), + Keysym::Right => "right".to_owned(), + Keysym::Up => "up".to_owned(), + Keysym::Down => "down".to_owned(), + Keysym::Home => "home".to_owned(), + Keysym::End => "end".to_owned(), _ => { let name = xkb::keysym_get_name(key_sym).to_lowercase(); if key_sym.is_keypad_key() { name.replace("kp_", "") + } else if let Some(key) = key_utf8.chars().next() + && key_utf8.len() == 1 + && key.is_ascii() + { + if key.is_ascii_graphic() { + key_utf8.to_lowercase() + // map ctrl-a to `a` + // ctrl-0..9 may emit control codes like ctrl-[, but + // we don't want to map them to `[` + } else if key_utf32 <= 0x1f + && !name.chars().next().is_some_and(|c| c.is_ascii_digit()) + { + ((key_utf32 as u8 + 0x40) as char) + .to_ascii_lowercase() + .to_string() + } else { + name + } + } else if let Some(key_en) = guess_ascii(keycode, modifiers.shift) { + String::from(key_en) } else { name } diff --git a/crates/gpui/src/platform/linux/wayland/client.rs b/crates/gpui/src/platform/linux/wayland/client.rs index 4f9b0896153667c2efa0ad4fddb76e8b1ec44451..af7c545397f12eed62d746481fb364612c4b055f 100644 --- a/crates/gpui/src/platform/linux/wayland/client.rs +++ b/crates/gpui/src/platform/linux/wayland/client.rs @@ -671,7 +671,7 @@ impl LinuxClient for WaylandClient { #[cfg(feature = "screen-capture")] fn screen_capture_sources( &self, - ) -> futures::channel::oneshot::Receiver<anyhow::Result<Vec<Box<dyn crate::ScreenCaptureSource>>>> + ) -> futures::channel::oneshot::Receiver<anyhow::Result<Vec<Rc<dyn crate::ScreenCaptureSource>>>> { // todo! Try window resizing as that may have unexpected results. crate::platform::scap_screen_capture::start_scap_default_target_source( diff --git a/crates/gpui/src/platform/linux/wayland/window.rs b/crates/gpui/src/platform/linux/wayland/window.rs index 36e070b0b0fc03d1dd6cd3402eedd228dbc909e3..2b2207e22c86fc25e6387581bb92b9c304f4bc9d 100644 --- a/crates/gpui/src/platform/linux/wayland/window.rs +++ b/crates/gpui/src/platform/linux/wayland/window.rs @@ -76,6 +76,7 @@ struct InProgressConfigure { size: Option<Size<Pixels>>, fullscreen: bool, maximized: bool, + resizing: bool, tiling: Tiling, } @@ -107,9 +108,10 @@ pub struct WaylandWindowState { active: bool, hovered: bool, in_progress_configure: Option<InProgressConfigure>, + resize_throttle: bool, in_progress_window_controls: Option<WindowControls>, window_controls: WindowControls, - inset: Option<Pixels>, + client_inset: Option<Pixels>, } #[derive(Clone)] @@ -176,6 +178,7 @@ impl WaylandWindowState { tiling: Tiling::default(), window_bounds: options.bounds, in_progress_configure: None, + resize_throttle: false, client, appearance, handle, @@ -183,7 +186,7 @@ impl WaylandWindowState { hovered: false, in_progress_window_controls: None, window_controls: WindowControls::default(), - inset: None, + client_inset: None, }) } @@ -208,6 +211,13 @@ impl WaylandWindowState { self.display = current_output; scale } + + pub fn inset(&self) -> Pixels { + match self.decorations { + WindowDecorations::Server => px(0.0), + WindowDecorations::Client => self.client_inset.unwrap_or(px(0.0)), + } + } } pub(crate) struct WaylandWindow(pub WaylandWindowStatePtr); @@ -335,6 +345,7 @@ impl WaylandWindowStatePtr { pub fn frame(&self) { let mut state = self.state.borrow_mut(); state.surface.frame(&state.globals.qh, state.surface.id()); + state.resize_throttle = false; drop(state); let mut cb = self.callbacks.borrow_mut(); @@ -366,11 +377,17 @@ impl WaylandWindowStatePtr { state.fullscreen = configure.fullscreen; state.maximized = configure.maximized; state.tiling = configure.tiling; + // Limit interactive resizes to once per vblank + if configure.resizing && state.resize_throttle { + return; + } else if configure.resizing { + state.resize_throttle = true; + } if !configure.fullscreen && !configure.maximized { configure.size = if got_unmaximized { Some(state.window_bounds.size) } else { - compute_outer_size(state.inset, configure.size, state.tiling) + compute_outer_size(state.inset(), configure.size, state.tiling) }; if let Some(size) = configure.size { state.window_bounds = Bounds { @@ -390,7 +407,7 @@ impl WaylandWindowStatePtr { let window_geometry = inset_by_tiling( state.bounds.map_origin(|_| px(0.0)), - state.inset.unwrap_or(px(0.0)), + state.inset(), state.tiling, ) .map(|v| v.0 as i32) @@ -472,6 +489,7 @@ impl WaylandWindowStatePtr { let mut tiling = Tiling::default(); let mut fullscreen = false; let mut maximized = false; + let mut resizing = false; for state in states { match state { @@ -481,6 +499,7 @@ impl WaylandWindowStatePtr { xdg_toplevel::State::Fullscreen => { fullscreen = true; } + xdg_toplevel::State::Resizing => resizing = true, xdg_toplevel::State::TiledTop => { tiling.top = true; } @@ -508,6 +527,7 @@ impl WaylandWindowStatePtr { size, fullscreen, maximized, + resizing, tiling, }); @@ -805,7 +825,7 @@ impl PlatformWindow for WaylandWindow { } else if state.maximized { WindowBounds::Maximized(state.window_bounds) } else { - let inset = state.inset.unwrap_or(px(0.)); + let inset = state.inset(); drop(state); WindowBounds::Windowed(self.bounds().inset(inset)) } @@ -1060,8 +1080,8 @@ impl PlatformWindow for WaylandWindow { fn set_client_inset(&self, inset: Pixels) { let mut state = self.borrow_mut(); - if Some(inset) != state.inset { - state.inset = Some(inset); + if Some(inset) != state.client_inset { + state.client_inset = Some(inset); update_window(state); } } @@ -1081,9 +1101,7 @@ fn update_window(mut state: RefMut<WaylandWindowState>) { state.renderer.update_transparency(!opaque); let mut opaque_area = state.window_bounds.map(|v| v.0 as i32); - if let Some(inset) = state.inset { - opaque_area.inset(inset.0 as i32); - } + opaque_area.inset(state.inset().0 as i32); let region = state .globals @@ -1156,12 +1174,10 @@ impl ResizeEdge { /// updating to account for the client decorations. But that's not the area we want to render /// to, due to our intrusize CSD. So, here we calculate the 'actual' size, by adding back in the insets fn compute_outer_size( - inset: Option<Pixels>, + inset: Pixels, new_size: Option<Size<Pixels>>, tiling: Tiling, ) -> Option<Size<Pixels>> { - let Some(inset) = inset else { return new_size }; - new_size.map(|mut new_size| { if !tiling.top { new_size.height += inset; diff --git a/crates/gpui/src/platform/linux/x11/client.rs b/crates/gpui/src/platform/linux/x11/client.rs index 430ce9260b87ae1c4c7c64b463e647a4c6e6c90a..573e4addf75b90d50e7f453555507462280fb3d4 100644 --- a/crates/gpui/src/platform/linux/x11/client.rs +++ b/crates/gpui/src/platform/linux/x11/client.rs @@ -1,23 +1,22 @@ use crate::{Capslock, xcb_flush}; -use core::str; -use std::{ - cell::RefCell, - collections::{BTreeMap, HashSet}, - ops::Deref, - path::PathBuf, - rc::{Rc, Weak}, - time::{Duration, Instant}, -}; - use anyhow::{Context as _, anyhow}; use calloop::{ EventLoop, LoopHandle, RegistrationToken, generic::{FdWrapper, Generic}, }; use collections::HashMap; +use core::str; use http_client::Url; use log::Level; use smallvec::SmallVec; +use std::{ + cell::RefCell, + collections::{BTreeMap, HashSet}, + ops::Deref, + path::PathBuf, + rc::{Rc, Weak}, + time::{Duration, Instant}, +}; use util::ResultExt; use x11rb::{ @@ -38,7 +37,7 @@ use x11rb::{ }; use xim::{AttributeName, Client, InputStyle, x11rb::X11rbClient}; use xkbc::x11::ffi::{XKB_X11_MIN_MAJOR_XKB_VERSION, XKB_X11_MIN_MINOR_XKB_VERSION}; -use xkbcommon::xkb::{self as xkbc, LayoutIndex, ModMask, STATE_LAYOUT_EFFECTIVE}; +use xkbcommon::xkb::{self as xkbc, STATE_LAYOUT_EFFECTIVE}; use super::{ ButtonOrScroll, ScrollDirection, X11Display, X11WindowStatePtr, XcbAtoms, XimCallbackEvent, @@ -77,6 +76,8 @@ pub(crate) const XINPUT_ALL_DEVICES: xinput::DeviceId = 0; /// terminology is both archaic and unclear. pub(crate) const XINPUT_ALL_DEVICE_GROUPS: xinput::DeviceId = 1; +const GPUI_X11_SCALE_FACTOR_ENV: &str = "GPUI_X11_SCALE_FACTOR"; + pub(crate) struct WindowRef { window: X11WindowStatePtr, refresh_state: Option<RefreshState>, @@ -139,13 +140,6 @@ impl From<xim::ClientError> for EventHandlerError { } } -#[derive(Debug, Default, Clone)] -struct XKBStateNotiy { - depressed_layout: LayoutIndex, - latched_layout: LayoutIndex, - locked_layout: LayoutIndex, -} - #[derive(Debug, Default)] pub struct Xdnd { other_window: xproto::Window, @@ -198,7 +192,6 @@ pub struct X11ClientState { pub(crate) mouse_focused_window: Option<xproto::Window>, pub(crate) keyboard_focused_window: Option<xproto::Window>, pub(crate) xkb: xkbc::State, - previous_xkb_state: XKBStateNotiy, keyboard_layout: LinuxKeyboardLayout, pub(crate) ximc: Option<X11rbClient<Rc<XCBConnection>>>, pub(crate) xim_handler: Option<XimHandler>, @@ -424,12 +417,7 @@ impl X11Client { let resource_database = x11rb::resource_manager::new_from_default(&xcb_connection) .context("Failed to create resource database")?; - let scale_factor = resource_database - .get_value("Xft.dpi", "Xft.dpi") - .ok() - .flatten() - .map(|dpi: f32| dpi / 96.0) - .unwrap_or(1.0); + let scale_factor = get_scale_factor(&xcb_connection, &resource_database, x_root_index); let cursor_handle = cursor::Handle::new(&xcb_connection, x_root_index, &resource_database) .context("Failed to initialize cursor theme handler")? .reply() @@ -510,7 +498,6 @@ impl X11Client { mouse_focused_window: None, keyboard_focused_window: None, xkb: xkb_state, - previous_xkb_state: XKBStateNotiy::default(), keyboard_layout, ximc, xim_handler, @@ -962,14 +949,6 @@ impl X11Client { state.xkb_device_id, ) }; - let depressed_layout = xkb_state.serialize_layout(xkbc::STATE_LAYOUT_DEPRESSED); - let latched_layout = xkb_state.serialize_layout(xkbc::STATE_LAYOUT_LATCHED); - let locked_layout = xkb_state.serialize_layout(xkbc::ffi::XKB_STATE_LAYOUT_LOCKED); - state.previous_xkb_state = XKBStateNotiy { - depressed_layout, - latched_layout, - locked_layout, - }; state.xkb = xkb_state; drop(state); self.handle_keyboard_layout_change(); @@ -986,12 +965,6 @@ impl X11Client { event.latched_group as u32, event.locked_group.into(), ); - state.previous_xkb_state = XKBStateNotiy { - depressed_layout: event.base_group as u32, - latched_layout: event.latched_group as u32, - locked_layout: event.locked_group.into(), - }; - let modifiers = Modifiers::from_xkb(&state.xkb); let capslock = Capslock::from_xkb(&state.xkb); if state.last_modifiers_changed_event == modifiers @@ -1028,20 +1001,16 @@ impl X11Client { state.pre_key_char_down.take(); let keystroke = { let code = event.detail.into(); - let xkb_state = state.previous_xkb_state.clone(); - state.xkb.update_mask( - event.state.bits() as ModMask, - 0, - 0, - xkb_state.depressed_layout, - xkb_state.latched_layout, - xkb_state.locked_layout, - ); let mut keystroke = crate::Keystroke::from_xkb(&state.xkb, modifiers, code); let keysym = state.xkb.key_get_one_sym(code); + if keysym.is_modifier_key() { return Some(()); } + + // should be called after key_get_one_sym + state.xkb.update_key(code, xkbc::KeyDirection::Down); + if let Some(mut compose_state) = state.compose_state.take() { compose_state.feed(keysym); match compose_state.status() { @@ -1096,20 +1065,16 @@ impl X11Client { let keystroke = { let code = event.detail.into(); - let xkb_state = state.previous_xkb_state.clone(); - state.xkb.update_mask( - event.state.bits() as ModMask, - 0, - 0, - xkb_state.depressed_layout, - xkb_state.latched_layout, - xkb_state.locked_layout, - ); let keystroke = crate::Keystroke::from_xkb(&state.xkb, modifiers, code); let keysym = state.xkb.key_get_one_sym(code); + if keysym.is_modifier_key() { return Some(()); } + + // should be called after key_get_one_sym + state.xkb.update_key(code, xkbc::KeyDirection::Up); + keystroke }; drop(state); @@ -1485,7 +1450,7 @@ impl LinuxClient for X11Client { #[cfg(feature = "screen-capture")] fn screen_capture_sources( &self, - ) -> futures::channel::oneshot::Receiver<anyhow::Result<Vec<Box<dyn crate::ScreenCaptureSource>>>> + ) -> futures::channel::oneshot::Receiver<anyhow::Result<Vec<Rc<dyn crate::ScreenCaptureSource>>>> { crate::platform::scap_screen_capture::scap_screen_sources( &self.0.borrow().common.foreground_executor, @@ -1830,6 +1795,7 @@ impl X11ClientState { drop(state); window.refresh(RequestFrameOptions { require_presentation: expose_event_received, + force_render: false, }); } xcb_connection @@ -2272,3 +2238,253 @@ fn create_invisible_cursor( xcb_flush(connection); Ok(cursor) } + +enum DpiMode { + Randr, + Scale(f32), + NotSet, +} + +fn get_scale_factor( + connection: &XCBConnection, + resource_database: &Database, + screen_index: usize, +) -> f32 { + let env_dpi = std::env::var(GPUI_X11_SCALE_FACTOR_ENV) + .ok() + .map(|var| { + if var.to_lowercase() == "randr" { + DpiMode::Randr + } else if let Ok(scale) = var.parse::<f32>() { + if valid_scale_factor(scale) { + DpiMode::Scale(scale) + } else { + panic!( + "`{}` must be a positive normal number or `randr`. Got `{}`", + GPUI_X11_SCALE_FACTOR_ENV, var + ); + } + } else if var.is_empty() { + DpiMode::NotSet + } else { + panic!( + "`{}` must be a positive number or `randr`. Got `{}`", + GPUI_X11_SCALE_FACTOR_ENV, var + ); + } + }) + .unwrap_or(DpiMode::NotSet); + + match env_dpi { + DpiMode::Scale(scale) => { + log::info!( + "Using scale factor from {}: {}", + GPUI_X11_SCALE_FACTOR_ENV, + scale + ); + return scale; + } + DpiMode::Randr => { + if let Some(scale) = get_randr_scale_factor(connection, screen_index) { + log::info!( + "Using RandR scale factor from {}=randr: {}", + GPUI_X11_SCALE_FACTOR_ENV, + scale + ); + return scale; + } + log::warn!("Failed to calculate RandR scale factor, falling back to default"); + return 1.0; + } + DpiMode::NotSet => {} + } + + // TODO: Use scale factor from XSettings here + + if let Some(dpi) = resource_database + .get_value::<f32>("Xft.dpi", "Xft.dpi") + .ok() + .flatten() + { + let scale = dpi / 96.0; // base dpi + log::info!("Using scale factor from Xft.dpi: {}", scale); + return scale; + } + + if let Some(scale) = get_randr_scale_factor(connection, screen_index) { + log::info!("Using RandR scale factor: {}", scale); + return scale; + } + + log::info!("Using default scale factor: 1.0"); + 1.0 +} + +fn get_randr_scale_factor(connection: &XCBConnection, screen_index: usize) -> Option<f32> { + let root = connection.setup().roots.get(screen_index)?.root; + + let version_cookie = connection.randr_query_version(1, 6).ok()?; + let version_reply = version_cookie.reply().ok()?; + if version_reply.major_version < 1 + || (version_reply.major_version == 1 && version_reply.minor_version < 5) + { + return legacy_get_randr_scale_factor(connection, root); // for randr <1.5 + } + + let monitors_cookie = connection.randr_get_monitors(root, true).ok()?; // true for active only + let monitors_reply = monitors_cookie.reply().ok()?; + + let mut fallback_scale: Option<f32> = None; + for monitor in monitors_reply.monitors { + if monitor.width_in_millimeters == 0 || monitor.height_in_millimeters == 0 { + continue; + } + let scale_factor = get_dpi_factor( + (monitor.width as u32, monitor.height as u32), + ( + monitor.width_in_millimeters as u64, + monitor.height_in_millimeters as u64, + ), + ); + if monitor.primary { + return Some(scale_factor); + } else if fallback_scale.is_none() { + fallback_scale = Some(scale_factor); + } + } + + fallback_scale +} + +fn legacy_get_randr_scale_factor(connection: &XCBConnection, root: u32) -> Option<f32> { + let primary_cookie = connection.randr_get_output_primary(root).ok()?; + let primary_reply = primary_cookie.reply().ok()?; + let primary_output = primary_reply.output; + + let primary_output_cookie = connection + .randr_get_output_info(primary_output, x11rb::CURRENT_TIME) + .ok()?; + let primary_output_info = primary_output_cookie.reply().ok()?; + + // try primary + if primary_output_info.connection == randr::Connection::CONNECTED + && primary_output_info.mm_width > 0 + && primary_output_info.mm_height > 0 + && primary_output_info.crtc != 0 + { + let crtc_cookie = connection + .randr_get_crtc_info(primary_output_info.crtc, x11rb::CURRENT_TIME) + .ok()?; + let crtc_info = crtc_cookie.reply().ok()?; + + if crtc_info.width > 0 && crtc_info.height > 0 { + let scale_factor = get_dpi_factor( + (crtc_info.width as u32, crtc_info.height as u32), + ( + primary_output_info.mm_width as u64, + primary_output_info.mm_height as u64, + ), + ); + return Some(scale_factor); + } + } + + // fallback: full scan + let resources_cookie = connection.randr_get_screen_resources_current(root).ok()?; + let screen_resources = resources_cookie.reply().ok()?; + + let mut crtc_cookies = Vec::with_capacity(screen_resources.crtcs.len()); + for &crtc in &screen_resources.crtcs { + if let Ok(cookie) = connection.randr_get_crtc_info(crtc, x11rb::CURRENT_TIME) { + crtc_cookies.push((crtc, cookie)); + } + } + + let mut crtc_infos: HashMap<randr::Crtc, randr::GetCrtcInfoReply> = HashMap::default(); + let mut valid_outputs: HashSet<randr::Output> = HashSet::new(); + for (crtc, cookie) in crtc_cookies { + if let Ok(reply) = cookie.reply() { + if reply.width > 0 && reply.height > 0 && !reply.outputs.is_empty() { + crtc_infos.insert(crtc, reply.clone()); + valid_outputs.extend(&reply.outputs); + } + } + } + + if valid_outputs.is_empty() { + return None; + } + + let mut output_cookies = Vec::with_capacity(valid_outputs.len()); + for &output in &valid_outputs { + if let Ok(cookie) = connection.randr_get_output_info(output, x11rb::CURRENT_TIME) { + output_cookies.push((output, cookie)); + } + } + let mut output_infos: HashMap<randr::Output, randr::GetOutputInfoReply> = HashMap::default(); + for (output, cookie) in output_cookies { + if let Ok(reply) = cookie.reply() { + output_infos.insert(output, reply); + } + } + + let mut fallback_scale: Option<f32> = None; + for crtc_info in crtc_infos.values() { + for &output in &crtc_info.outputs { + if let Some(output_info) = output_infos.get(&output) { + if output_info.connection != randr::Connection::CONNECTED { + continue; + } + + if output_info.mm_width == 0 || output_info.mm_height == 0 { + continue; + } + + let scale_factor = get_dpi_factor( + (crtc_info.width as u32, crtc_info.height as u32), + (output_info.mm_width as u64, output_info.mm_height as u64), + ); + + if output != primary_output && fallback_scale.is_none() { + fallback_scale = Some(scale_factor); + } + } + } + } + + fallback_scale +} + +fn get_dpi_factor((width_px, height_px): (u32, u32), (width_mm, height_mm): (u64, u64)) -> f32 { + let ppmm = ((width_px as f64 * height_px as f64) / (width_mm as f64 * height_mm as f64)).sqrt(); // pixels per mm + + const MM_PER_INCH: f64 = 25.4; + const BASE_DPI: f64 = 96.0; + const QUANTIZE_STEP: f64 = 12.0; // e.g. 1.25 = 15/12, 1.5 = 18/12, 1.75 = 21/12, 2.0 = 24/12 + const MIN_SCALE: f64 = 1.0; + const MAX_SCALE: f64 = 20.0; + + let dpi_factor = + ((ppmm * (QUANTIZE_STEP * MM_PER_INCH / BASE_DPI)).round() / QUANTIZE_STEP).max(MIN_SCALE); + + let validated_factor = if dpi_factor <= MAX_SCALE { + dpi_factor + } else { + MIN_SCALE + }; + + if valid_scale_factor(validated_factor as f32) { + validated_factor as f32 + } else { + log::warn!( + "Calculated DPI factor {} is invalid, using 1.0", + validated_factor + ); + 1.0 + } +} + +#[inline] +fn valid_scale_factor(scale_factor: f32) -> bool { + scale_factor.is_sign_positive() && scale_factor.is_normal() +} diff --git a/crates/gpui/src/platform/mac/metal_atlas.rs b/crates/gpui/src/platform/mac/metal_atlas.rs index 0c8e1d37032f48994bbf41fb44a77efe991e47bf..5d2d8e63e06a1ea6251c1fd2edf461eeeedec612 100644 --- a/crates/gpui/src/platform/mac/metal_atlas.rs +++ b/crates/gpui/src/platform/mac/metal_atlas.rs @@ -25,27 +25,6 @@ impl MetalAtlas { pub(crate) fn metal_texture(&self, id: AtlasTextureId) -> metal::Texture { self.0.lock().texture(id).metal_texture.clone() } - - #[allow(dead_code)] - pub(crate) fn allocate( - &self, - size: Size<DevicePixels>, - texture_kind: AtlasTextureKind, - ) -> Option<AtlasTile> { - self.0.lock().allocate(size, texture_kind) - } - - #[allow(dead_code)] - pub(crate) fn clear_textures(&self, texture_kind: AtlasTextureKind) { - let mut lock = self.0.lock(); - let textures = match texture_kind { - AtlasTextureKind::Monochrome => &mut lock.monochrome_textures, - AtlasTextureKind::Polychrome => &mut lock.polychrome_textures, - }; - for texture in textures.iter_mut() { - texture.clear(); - } - } } struct MetalAtlasState { @@ -212,10 +191,6 @@ struct MetalAtlasTexture { } impl MetalAtlasTexture { - fn clear(&mut self) { - self.allocator.clear(); - } - fn allocate(&mut self, size: Size<DevicePixels>) -> Option<AtlasTile> { let allocation = self.allocator.allocate(size.into())?; let tile = AtlasTile { diff --git a/crates/gpui/src/platform/mac/metal_renderer.rs b/crates/gpui/src/platform/mac/metal_renderer.rs index 8936cf242cf3d997495d486d471a14285ae7caa0..629654014d5a15632c5992d9347cab3ee1fd28d9 100644 --- a/crates/gpui/src/platform/mac/metal_renderer.rs +++ b/crates/gpui/src/platform/mac/metal_renderer.rs @@ -1,7 +1,7 @@ use super::metal_atlas::MetalAtlas; use crate::{ AtlasTextureId, Background, Bounds, ContentMask, DevicePixels, MonochromeSprite, PaintSurface, - Path, PathVertex, PolychromeSprite, PrimitiveBatch, Quad, ScaledPixels, Scene, Shadow, Size, + Path, Point, PolychromeSprite, PrimitiveBatch, Quad, ScaledPixels, Scene, Shadow, Size, Surface, Underline, point, size, }; use anyhow::Result; @@ -11,6 +11,7 @@ use cocoa::{ foundation::{NSSize, NSUInteger}, quartzcore::AutoresizingMask, }; + use core_foundation::base::TCFType; use core_video::{ metal_texture::CVMetalTextureGetTexture, metal_texture_cache::CVMetalTextureCache, @@ -18,11 +19,12 @@ use core_video::{ }; use foreign_types::{ForeignType, ForeignTypeRef}; use metal::{ - CAMetalLayer, CommandQueue, MTLDrawPrimitivesIndirectArguments, MTLPixelFormat, - MTLResourceOptions, NSRange, + CAMetalLayer, CommandQueue, MTLPixelFormat, MTLResourceOptions, NSRange, + RenderPassColorAttachmentDescriptorRef, }; use objc::{self, msg_send, sel, sel_impl}; use parking_lot::Mutex; + use std::{cell::Cell, ffi::c_void, mem, ptr, sync::Arc}; // Exported to metal @@ -32,6 +34,9 @@ pub(crate) type PointF = crate::Point<f32>; const SHADERS_METALLIB: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/shaders.metallib")); #[cfg(feature = "runtime_shaders")] const SHADERS_SOURCE_FILE: &str = include_str!(concat!(env!("OUT_DIR"), "/stitched_shaders.metal")); +// Use 4x MSAA, all devices support it. +// https://developer.apple.com/documentation/metal/mtldevice/1433355-supportstexturesamplecount +const PATH_SAMPLE_COUNT: u32 = 4; pub type Context = Arc<Mutex<InstanceBufferPool>>; pub type Renderer = MetalRenderer; @@ -96,7 +101,8 @@ pub(crate) struct MetalRenderer { layer: metal::MetalLayer, presents_with_transaction: bool, command_queue: CommandQueue, - path_pipeline_state: metal::RenderPipelineState, + paths_rasterization_pipeline_state: metal::RenderPipelineState, + path_sprites_pipeline_state: metal::RenderPipelineState, shadows_pipeline_state: metal::RenderPipelineState, quads_pipeline_state: metal::RenderPipelineState, underlines_pipeline_state: metal::RenderPipelineState, @@ -108,8 +114,17 @@ pub(crate) struct MetalRenderer { instance_buffer_pool: Arc<Mutex<InstanceBufferPool>>, sprite_atlas: Arc<MetalAtlas>, core_video_texture_cache: core_video::metal_texture_cache::CVMetalTextureCache, - sample_count: u64, - msaa_texture: Option<metal::Texture>, + path_intermediate_texture: Option<metal::Texture>, + path_intermediate_msaa_texture: Option<metal::Texture>, + path_sample_count: u32, +} + +#[repr(C)] +pub struct PathRasterizationVertex { + pub xy_position: Point<ScaledPixels>, + pub st_position: Point<f32>, + pub color: Background, + pub bounds: Bounds<ScaledPixels>, } impl MetalRenderer { @@ -168,19 +183,22 @@ impl MetalRenderer { MTLResourceOptions::StorageModeManaged, ); - let sample_count = [4, 2, 1] - .into_iter() - .find(|count| device.supports_texture_sample_count(*count)) - .unwrap_or(1); - - let path_pipeline_state = build_pipeline_state( + let paths_rasterization_pipeline_state = build_path_rasterization_pipeline_state( &device, &library, - "paths", - "path_vertex", - "path_fragment", + "paths_rasterization", + "path_rasterization_vertex", + "path_rasterization_fragment", + MTLPixelFormat::BGRA8Unorm, + PATH_SAMPLE_COUNT, + ); + let path_sprites_pipeline_state = build_path_sprite_pipeline_state( + &device, + &library, + "path_sprites", + "path_sprite_vertex", + "path_sprite_fragment", MTLPixelFormat::BGRA8Unorm, - sample_count, ); let shadows_pipeline_state = build_pipeline_state( &device, @@ -189,7 +207,6 @@ impl MetalRenderer { "shadow_vertex", "shadow_fragment", MTLPixelFormat::BGRA8Unorm, - sample_count, ); let quads_pipeline_state = build_pipeline_state( &device, @@ -198,7 +215,6 @@ impl MetalRenderer { "quad_vertex", "quad_fragment", MTLPixelFormat::BGRA8Unorm, - sample_count, ); let underlines_pipeline_state = build_pipeline_state( &device, @@ -207,7 +223,6 @@ impl MetalRenderer { "underline_vertex", "underline_fragment", MTLPixelFormat::BGRA8Unorm, - sample_count, ); let monochrome_sprites_pipeline_state = build_pipeline_state( &device, @@ -216,7 +231,6 @@ impl MetalRenderer { "monochrome_sprite_vertex", "monochrome_sprite_fragment", MTLPixelFormat::BGRA8Unorm, - sample_count, ); let polychrome_sprites_pipeline_state = build_pipeline_state( &device, @@ -225,7 +239,6 @@ impl MetalRenderer { "polychrome_sprite_vertex", "polychrome_sprite_fragment", MTLPixelFormat::BGRA8Unorm, - sample_count, ); let surfaces_pipeline_state = build_pipeline_state( &device, @@ -234,21 +247,20 @@ impl MetalRenderer { "surface_vertex", "surface_fragment", MTLPixelFormat::BGRA8Unorm, - sample_count, ); let command_queue = device.new_command_queue(); let sprite_atlas = Arc::new(MetalAtlas::new(device.clone())); let core_video_texture_cache = CVMetalTextureCache::new(None, device.clone(), None).unwrap(); - let msaa_texture = create_msaa_texture(&device, &layer, sample_count); Self { device, layer, presents_with_transaction: false, command_queue, - path_pipeline_state, + paths_rasterization_pipeline_state, + path_sprites_pipeline_state, shadows_pipeline_state, quads_pipeline_state, underlines_pipeline_state, @@ -259,8 +271,9 @@ impl MetalRenderer { instance_buffer_pool, sprite_atlas, core_video_texture_cache, - sample_count, - msaa_texture, + path_intermediate_texture: None, + path_intermediate_msaa_texture: None, + path_sample_count: PATH_SAMPLE_COUNT, } } @@ -293,8 +306,31 @@ impl MetalRenderer { setDrawableSize: size ]; } + let device_pixels_size = Size { + width: DevicePixels(size.width as i32), + height: DevicePixels(size.height as i32), + }; + self.update_path_intermediate_textures(device_pixels_size); + } - self.msaa_texture = create_msaa_texture(&self.device, &self.layer, self.sample_count); + fn update_path_intermediate_textures(&mut self, size: Size<DevicePixels>) { + let texture_descriptor = metal::TextureDescriptor::new(); + texture_descriptor.set_width(size.width.0 as u64); + texture_descriptor.set_height(size.height.0 as u64); + texture_descriptor.set_pixel_format(metal::MTLPixelFormat::BGRA8Unorm); + texture_descriptor + .set_usage(metal::MTLTextureUsage::RenderTarget | metal::MTLTextureUsage::ShaderRead); + self.path_intermediate_texture = Some(self.device.new_texture(&texture_descriptor)); + + if self.path_sample_count > 1 { + let mut msaa_descriptor = texture_descriptor.clone(); + msaa_descriptor.set_texture_type(metal::MTLTextureType::D2Multisample); + msaa_descriptor.set_storage_mode(metal::MTLStorageMode::Private); + msaa_descriptor.set_sample_count(self.path_sample_count as _); + self.path_intermediate_msaa_texture = Some(self.device.new_texture(&msaa_descriptor)); + } else { + self.path_intermediate_msaa_texture = None; + } } pub fn update_transparency(&self, _transparent: bool) { @@ -380,36 +416,18 @@ impl MetalRenderer { ) -> Result<metal::CommandBuffer> { let command_queue = self.command_queue.clone(); let command_buffer = command_queue.new_command_buffer(); - let mut instance_offset = 0; - let render_pass_descriptor = metal::RenderPassDescriptor::new(); - let color_attachment = render_pass_descriptor - .color_attachments() - .object_at(0) - .unwrap(); - - if let Some(msaa_texture_ref) = self.msaa_texture.as_deref() { - color_attachment.set_texture(Some(msaa_texture_ref)); - color_attachment.set_load_action(metal::MTLLoadAction::Clear); - color_attachment.set_store_action(metal::MTLStoreAction::MultisampleResolve); - color_attachment.set_resolve_texture(Some(drawable.texture())); - } else { - color_attachment.set_load_action(metal::MTLLoadAction::Clear); - color_attachment.set_texture(Some(drawable.texture())); - color_attachment.set_store_action(metal::MTLStoreAction::Store); - } - let alpha = if self.layer.is_opaque() { 1. } else { 0. }; - color_attachment.set_clear_color(metal::MTLClearColor::new(0., 0., 0., alpha)); - let command_encoder = command_buffer.new_render_command_encoder(render_pass_descriptor); + let mut instance_offset = 0; - command_encoder.set_viewport(metal::MTLViewport { - originX: 0.0, - originY: 0.0, - width: i32::from(viewport_size.width) as f64, - height: i32::from(viewport_size.height) as f64, - znear: 0.0, - zfar: 1.0, - }); + let mut command_encoder = new_command_encoder( + command_buffer, + drawable, + viewport_size, + |color_attachment| { + color_attachment.set_load_action(metal::MTLLoadAction::Clear); + color_attachment.set_clear_color(metal::MTLClearColor::new(0., 0., 0., alpha)); + }, + ); for batch in scene.batches() { let ok = match batch { @@ -418,28 +436,53 @@ impl MetalRenderer { instance_buffer, &mut instance_offset, viewport_size, - command_encoder, + &command_encoder, ), PrimitiveBatch::Quads(quads) => self.draw_quads( quads, instance_buffer, &mut instance_offset, viewport_size, - command_encoder, - ), - PrimitiveBatch::Paths(paths) => self.draw_paths( - paths, - instance_buffer, - &mut instance_offset, - viewport_size, - command_encoder, + &command_encoder, ), + PrimitiveBatch::Paths(paths) => { + command_encoder.end_encoding(); + + let did_draw = self.draw_paths_to_intermediate( + paths, + instance_buffer, + &mut instance_offset, + viewport_size, + command_buffer, + ); + + command_encoder = new_command_encoder( + command_buffer, + drawable, + viewport_size, + |color_attachment| { + color_attachment.set_load_action(metal::MTLLoadAction::Load); + }, + ); + + if did_draw { + self.draw_paths_from_intermediate( + paths, + instance_buffer, + &mut instance_offset, + viewport_size, + &command_encoder, + ) + } else { + false + } + } PrimitiveBatch::Underlines(underlines) => self.draw_underlines( underlines, instance_buffer, &mut instance_offset, viewport_size, - command_encoder, + &command_encoder, ), PrimitiveBatch::MonochromeSprites { texture_id, @@ -450,7 +493,7 @@ impl MetalRenderer { instance_buffer, &mut instance_offset, viewport_size, - command_encoder, + &command_encoder, ), PrimitiveBatch::PolychromeSprites { texture_id, @@ -461,17 +504,16 @@ impl MetalRenderer { instance_buffer, &mut instance_offset, viewport_size, - command_encoder, + &command_encoder, ), PrimitiveBatch::Surfaces(surfaces) => self.draw_surfaces( surfaces, instance_buffer, &mut instance_offset, viewport_size, - command_encoder, + &command_encoder, ), }; - if !ok { command_encoder.end_encoding(); anyhow::bail!( @@ -496,6 +538,92 @@ impl MetalRenderer { Ok(command_buffer.to_owned()) } + fn draw_paths_to_intermediate( + &self, + paths: &[Path<ScaledPixels>], + instance_buffer: &mut InstanceBuffer, + instance_offset: &mut usize, + viewport_size: Size<DevicePixels>, + command_buffer: &metal::CommandBufferRef, + ) -> bool { + if paths.is_empty() { + return true; + } + let Some(intermediate_texture) = &self.path_intermediate_texture else { + return false; + }; + + let render_pass_descriptor = metal::RenderPassDescriptor::new(); + let color_attachment = render_pass_descriptor + .color_attachments() + .object_at(0) + .unwrap(); + color_attachment.set_load_action(metal::MTLLoadAction::Clear); + color_attachment.set_clear_color(metal::MTLClearColor::new(0., 0., 0., 0.)); + + if let Some(msaa_texture) = &self.path_intermediate_msaa_texture { + color_attachment.set_texture(Some(msaa_texture)); + color_attachment.set_resolve_texture(Some(intermediate_texture)); + color_attachment.set_store_action(metal::MTLStoreAction::MultisampleResolve); + } else { + color_attachment.set_texture(Some(intermediate_texture)); + color_attachment.set_store_action(metal::MTLStoreAction::Store); + } + + let command_encoder = command_buffer.new_render_command_encoder(render_pass_descriptor); + command_encoder.set_render_pipeline_state(&self.paths_rasterization_pipeline_state); + + align_offset(instance_offset); + let mut vertices = Vec::new(); + for path in paths { + vertices.extend(path.vertices.iter().map(|v| PathRasterizationVertex { + xy_position: v.xy_position, + st_position: v.st_position, + color: path.color, + bounds: path.bounds.intersect(&path.content_mask.bounds), + })); + } + let vertices_bytes_len = mem::size_of_val(vertices.as_slice()); + let next_offset = *instance_offset + vertices_bytes_len; + if next_offset > instance_buffer.size { + command_encoder.end_encoding(); + return false; + } + command_encoder.set_vertex_buffer( + PathRasterizationInputIndex::Vertices as u64, + Some(&instance_buffer.metal_buffer), + *instance_offset as u64, + ); + command_encoder.set_vertex_bytes( + PathRasterizationInputIndex::ViewportSize as u64, + mem::size_of_val(&viewport_size) as u64, + &viewport_size as *const Size<DevicePixels> as *const _, + ); + command_encoder.set_fragment_buffer( + PathRasterizationInputIndex::Vertices as u64, + Some(&instance_buffer.metal_buffer), + *instance_offset as u64, + ); + let buffer_contents = + unsafe { (instance_buffer.metal_buffer.contents() as *mut u8).add(*instance_offset) }; + unsafe { + ptr::copy_nonoverlapping( + vertices.as_ptr() as *const u8, + buffer_contents, + vertices_bytes_len, + ); + } + command_encoder.draw_primitives( + metal::MTLPrimitiveType::Triangle, + 0, + vertices.len() as u64, + ); + *instance_offset = next_offset; + + command_encoder.end_encoding(); + true + } + fn draw_shadows( &self, shadows: &[Shadow], @@ -618,7 +746,7 @@ impl MetalRenderer { true } - fn draw_paths( + fn draw_paths_from_intermediate( &self, paths: &[Path<ScaledPixels>], instance_buffer: &mut InstanceBuffer, @@ -626,112 +754,85 @@ impl MetalRenderer { viewport_size: Size<DevicePixels>, command_encoder: &metal::RenderCommandEncoderRef, ) -> bool { - if paths.is_empty() { + let Some(ref first_path) = paths.first() else { return true; - } - - command_encoder.set_render_pipeline_state(&self.path_pipeline_state); - - unsafe { - let base_addr = instance_buffer.metal_buffer.contents(); - let mut p = (base_addr as *mut u8).add(*instance_offset); - let mut draw_indirect_commands = Vec::with_capacity(paths.len()); - - // copy vertices - let vertices_offset = (p as usize) - (base_addr as usize); - let mut first_vertex = 0; - for (i, path) in paths.iter().enumerate() { - if (p as usize) - (base_addr as usize) - + (mem::size_of::<PathVertex<ScaledPixels>>() * path.vertices.len()) - > instance_buffer.size - { - return false; - } + }; - for v in &path.vertices { - *(p as *mut PathVertex<ScaledPixels>) = PathVertex { - xy_position: v.xy_position, - content_mask: ContentMask { - bounds: path.content_mask.bounds, - }, - }; - p = p.add(mem::size_of::<PathVertex<ScaledPixels>>()); - } + let Some(ref intermediate_texture) = self.path_intermediate_texture else { + return false; + }; - draw_indirect_commands.push(MTLDrawPrimitivesIndirectArguments { - vertexCount: path.vertices.len() as u32, - instanceCount: 1, - vertexStart: first_vertex, - baseInstance: i as u32, - }); - first_vertex += path.vertices.len() as u32; - } + command_encoder.set_render_pipeline_state(&self.path_sprites_pipeline_state); + command_encoder.set_vertex_buffer( + SpriteInputIndex::Vertices as u64, + Some(&self.unit_vertices), + 0, + ); + command_encoder.set_vertex_bytes( + SpriteInputIndex::ViewportSize as u64, + mem::size_of_val(&viewport_size) as u64, + &viewport_size as *const Size<DevicePixels> as *const _, + ); - // copy sprites - let sprites_offset = (p as u64) - (base_addr as u64); - if (p as usize) - (base_addr as usize) + (mem::size_of::<PathSprite>() * paths.len()) - > instance_buffer.size - { - return false; - } - for path in paths { - *(p as *mut PathSprite) = PathSprite { - bounds: path.bounds, - color: path.color, - }; - p = p.add(mem::size_of::<PathSprite>()); - } + command_encoder.set_fragment_texture( + SpriteInputIndex::AtlasTexture as u64, + Some(intermediate_texture), + ); - // copy indirect commands - let icb_bytes_len = mem::size_of_val(draw_indirect_commands.as_slice()); - let icb_offset = (p as u64) - (base_addr as u64); - if (p as usize) - (base_addr as usize) + icb_bytes_len > instance_buffer.size { - return false; + // When copying paths from the intermediate texture to the drawable, + // each pixel must only be copied once, in case of transparent paths. + // + // If all paths have the same draw order, then their bounds are all + // disjoint, so we can copy each path's bounds individually. If this + // batch combines different draw orders, we perform a single copy + // for a minimal spanning rect. + let sprites; + if paths.last().unwrap().order == first_path.order { + sprites = paths + .iter() + .map(|path| PathSprite { + bounds: path.clipped_bounds(), + }) + .collect(); + } else { + let mut bounds = first_path.clipped_bounds(); + for path in paths.iter().skip(1) { + bounds = bounds.union(&path.clipped_bounds()); } - ptr::copy_nonoverlapping( - draw_indirect_commands.as_ptr() as *const u8, - p, - icb_bytes_len, - ); - p = p.add(icb_bytes_len); - - // draw path - command_encoder.set_vertex_buffer( - PathInputIndex::Vertices as u64, - Some(&instance_buffer.metal_buffer), - vertices_offset as u64, - ); + sprites = vec![PathSprite { bounds }]; + } - command_encoder.set_vertex_bytes( - PathInputIndex::ViewportSize as u64, - mem::size_of_val(&viewport_size) as u64, - &viewport_size as *const Size<DevicePixels> as *const _, - ); + align_offset(instance_offset); + let sprite_bytes_len = mem::size_of_val(sprites.as_slice()); + let next_offset = *instance_offset + sprite_bytes_len; + if next_offset > instance_buffer.size { + return false; + } - command_encoder.set_vertex_buffer( - PathInputIndex::Sprites as u64, - Some(&instance_buffer.metal_buffer), - sprites_offset, - ); + command_encoder.set_vertex_buffer( + SpriteInputIndex::Sprites as u64, + Some(&instance_buffer.metal_buffer), + *instance_offset as u64, + ); - command_encoder.set_fragment_buffer( - PathInputIndex::Sprites as u64, - Some(&instance_buffer.metal_buffer), - sprites_offset, + let buffer_contents = + unsafe { (instance_buffer.metal_buffer.contents() as *mut u8).add(*instance_offset) }; + unsafe { + ptr::copy_nonoverlapping( + sprites.as_ptr() as *const u8, + buffer_contents, + sprite_bytes_len, ); - - for i in 0..paths.len() { - command_encoder.draw_primitives_indirect( - metal::MTLPrimitiveType::Triangle, - &instance_buffer.metal_buffer, - icb_offset - + (i * std::mem::size_of::<MTLDrawPrimitivesIndirectArguments>()) as u64, - ); - } - - *instance_offset = (p as usize) - (base_addr as usize); } + command_encoder.draw_primitives_instanced( + metal::MTLPrimitiveType::Triangle, + 0, + 6, + sprites.len() as u64, + ); + *instance_offset = next_offset; + true } @@ -1046,6 +1147,33 @@ impl MetalRenderer { } } +fn new_command_encoder<'a>( + command_buffer: &'a metal::CommandBufferRef, + drawable: &'a metal::MetalDrawableRef, + viewport_size: Size<DevicePixels>, + configure_color_attachment: impl Fn(&RenderPassColorAttachmentDescriptorRef), +) -> &'a metal::RenderCommandEncoderRef { + let render_pass_descriptor = metal::RenderPassDescriptor::new(); + let color_attachment = render_pass_descriptor + .color_attachments() + .object_at(0) + .unwrap(); + color_attachment.set_texture(Some(drawable.texture())); + color_attachment.set_store_action(metal::MTLStoreAction::Store); + configure_color_attachment(color_attachment); + + let command_encoder = command_buffer.new_render_command_encoder(render_pass_descriptor); + command_encoder.set_viewport(metal::MTLViewport { + originX: 0.0, + originY: 0.0, + width: i32::from(viewport_size.width) as f64, + height: i32::from(viewport_size.height) as f64, + znear: 0.0, + zfar: 1.0, + }); + command_encoder +} + fn build_pipeline_state( device: &metal::DeviceRef, library: &metal::LibraryRef, @@ -1053,7 +1181,6 @@ fn build_pipeline_state( vertex_fn_name: &str, fragment_fn_name: &str, pixel_format: metal::MTLPixelFormat, - sample_count: u64, ) -> metal::RenderPipelineState { let vertex_fn = library .get_function(vertex_fn_name, None) @@ -1066,7 +1193,6 @@ fn build_pipeline_state( descriptor.set_label(label); descriptor.set_vertex_function(Some(vertex_fn.as_ref())); descriptor.set_fragment_function(Some(fragment_fn.as_ref())); - descriptor.set_sample_count(sample_count); let color_attachment = descriptor.color_attachments().object_at(0).unwrap(); color_attachment.set_pixel_format(pixel_format); color_attachment.set_blending_enabled(true); @@ -1082,43 +1208,82 @@ fn build_pipeline_state( .expect("could not create render pipeline state") } -// Align to multiples of 256 make Metal happy. -fn align_offset(offset: &mut usize) { - *offset = (*offset).div_ceil(256) * 256; -} +fn build_path_sprite_pipeline_state( + device: &metal::DeviceRef, + library: &metal::LibraryRef, + label: &str, + vertex_fn_name: &str, + fragment_fn_name: &str, + pixel_format: metal::MTLPixelFormat, +) -> metal::RenderPipelineState { + let vertex_fn = library + .get_function(vertex_fn_name, None) + .expect("error locating vertex function"); + let fragment_fn = library + .get_function(fragment_fn_name, None) + .expect("error locating fragment function"); -fn create_msaa_texture( - device: &metal::Device, - layer: &metal::MetalLayer, - sample_count: u64, -) -> Option<metal::Texture> { - let viewport_size = layer.drawable_size(); - let width = viewport_size.width.ceil() as u64; - let height = viewport_size.height.ceil() as u64; - - if width == 0 || height == 0 { - return None; - } + let descriptor = metal::RenderPipelineDescriptor::new(); + descriptor.set_label(label); + descriptor.set_vertex_function(Some(vertex_fn.as_ref())); + descriptor.set_fragment_function(Some(fragment_fn.as_ref())); + let color_attachment = descriptor.color_attachments().object_at(0).unwrap(); + color_attachment.set_pixel_format(pixel_format); + color_attachment.set_blending_enabled(true); + color_attachment.set_rgb_blend_operation(metal::MTLBlendOperation::Add); + color_attachment.set_alpha_blend_operation(metal::MTLBlendOperation::Add); + color_attachment.set_source_rgb_blend_factor(metal::MTLBlendFactor::One); + color_attachment.set_source_alpha_blend_factor(metal::MTLBlendFactor::One); + color_attachment.set_destination_rgb_blend_factor(metal::MTLBlendFactor::OneMinusSourceAlpha); + color_attachment.set_destination_alpha_blend_factor(metal::MTLBlendFactor::One); - if sample_count <= 1 { - return None; - } + device + .new_render_pipeline_state(&descriptor) + .expect("could not create render pipeline state") +} - let texture_descriptor = metal::TextureDescriptor::new(); - texture_descriptor.set_texture_type(metal::MTLTextureType::D2Multisample); +fn build_path_rasterization_pipeline_state( + device: &metal::DeviceRef, + library: &metal::LibraryRef, + label: &str, + vertex_fn_name: &str, + fragment_fn_name: &str, + pixel_format: metal::MTLPixelFormat, + path_sample_count: u32, +) -> metal::RenderPipelineState { + let vertex_fn = library + .get_function(vertex_fn_name, None) + .expect("error locating vertex function"); + let fragment_fn = library + .get_function(fragment_fn_name, None) + .expect("error locating fragment function"); - // MTLStorageMode default is `shared` only for Apple silicon GPUs. Use `private` for Apple and Intel GPUs both. - // Reference: https://developer.apple.com/documentation/metal/choosing-a-resource-storage-mode-for-apple-gpus - texture_descriptor.set_storage_mode(metal::MTLStorageMode::Private); + let descriptor = metal::RenderPipelineDescriptor::new(); + descriptor.set_label(label); + descriptor.set_vertex_function(Some(vertex_fn.as_ref())); + descriptor.set_fragment_function(Some(fragment_fn.as_ref())); + if path_sample_count > 1 { + descriptor.set_raster_sample_count(path_sample_count as _); + descriptor.set_alpha_to_coverage_enabled(false); + } + let color_attachment = descriptor.color_attachments().object_at(0).unwrap(); + color_attachment.set_pixel_format(pixel_format); + color_attachment.set_blending_enabled(true); + color_attachment.set_rgb_blend_operation(metal::MTLBlendOperation::Add); + color_attachment.set_alpha_blend_operation(metal::MTLBlendOperation::Add); + color_attachment.set_source_rgb_blend_factor(metal::MTLBlendFactor::One); + color_attachment.set_source_alpha_blend_factor(metal::MTLBlendFactor::One); + color_attachment.set_destination_rgb_blend_factor(metal::MTLBlendFactor::OneMinusSourceAlpha); + color_attachment.set_destination_alpha_blend_factor(metal::MTLBlendFactor::OneMinusSourceAlpha); - texture_descriptor.set_width(width); - texture_descriptor.set_height(height); - texture_descriptor.set_pixel_format(layer.pixel_format()); - texture_descriptor.set_usage(metal::MTLTextureUsage::RenderTarget); - texture_descriptor.set_sample_count(sample_count); + device + .new_render_pipeline_state(&descriptor) + .expect("could not create render pipeline state") +} - let metal_texture = device.new_texture(&texture_descriptor); - Some(metal_texture) +// Align to multiples of 256 make Metal happy. +fn align_offset(offset: &mut usize) { + *offset = (*offset).div_ceil(256) * 256; } #[repr(C)] @@ -1162,17 +1327,15 @@ enum SurfaceInputIndex { } #[repr(C)] -enum PathInputIndex { +enum PathRasterizationInputIndex { Vertices = 0, ViewportSize = 1, - Sprites = 2, } #[derive(Clone, Debug, Eq, PartialEq)] #[repr(C)] pub struct PathSprite { pub bounds: Bounds<ScaledPixels>, - pub color: Background, } #[derive(Clone, Debug, Eq, PartialEq)] diff --git a/crates/gpui/src/platform/mac/platform.rs b/crates/gpui/src/platform/mac/platform.rs index d9bb665469002bd89e248a8593f56b12cfebcca1..1d2146cf73562beed6c26754396dc2c4c0c915f9 100644 --- a/crates/gpui/src/platform/mac/platform.rs +++ b/crates/gpui/src/platform/mac/platform.rs @@ -583,7 +583,7 @@ impl Platform for MacPlatform { #[cfg(feature = "screen-capture")] fn screen_capture_sources( &self, - ) -> oneshot::Receiver<Result<Vec<Box<dyn crate::ScreenCaptureSource>>>> { + ) -> oneshot::Receiver<Result<Vec<Rc<dyn crate::ScreenCaptureSource>>>> { super::screen_capture::get_sources() } diff --git a/crates/gpui/src/platform/mac/screen_capture.rs b/crates/gpui/src/platform/mac/screen_capture.rs index af5e02fc06cbd6a82c4502a5f20e54237b5dc64d..4d4ffa6896520e465dfeb7b1ccc06e1149f9e25d 100644 --- a/crates/gpui/src/platform/mac/screen_capture.rs +++ b/crates/gpui/src/platform/mac/screen_capture.rs @@ -1,5 +1,5 @@ use crate::{ - DevicePixels, ForegroundExecutor, Size, + DevicePixels, ForegroundExecutor, SharedString, SourceMetadata, platform::{ScreenCaptureFrame, ScreenCaptureSource, ScreenCaptureStream}, size, }; @@ -7,8 +7,9 @@ use anyhow::{Result, anyhow}; use block::ConcreteBlock; use cocoa::{ base::{YES, id, nil}, - foundation::NSArray, + foundation::{NSArray, NSString}, }; +use collections::HashMap; use core_foundation::base::TCFType; use core_graphics::display::{ CGDirectDisplayID, CGDisplayCopyDisplayMode, CGDisplayModeGetPixelHeight, @@ -32,11 +33,13 @@ use super::NSStringExt; #[derive(Clone)] pub struct MacScreenCaptureSource { sc_display: id, + meta: Option<ScreenMeta>, } pub struct MacScreenCaptureStream { sc_stream: id, sc_stream_output: id, + meta: SourceMetadata, } static mut DELEGATE_CLASS: *const Class = ptr::null(); @@ -47,19 +50,31 @@ const FRAME_CALLBACK_IVAR: &str = "frame_callback"; const SCStreamOutputTypeScreen: NSInteger = 0; impl ScreenCaptureSource for MacScreenCaptureSource { - fn resolution(&self) -> Result<Size<DevicePixels>> { - unsafe { + fn metadata(&self) -> Result<SourceMetadata> { + let (display_id, size) = unsafe { let display_id: CGDirectDisplayID = msg_send![self.sc_display, displayID]; let display_mode_ref = CGDisplayCopyDisplayMode(display_id); let width = CGDisplayModeGetPixelWidth(display_mode_ref); let height = CGDisplayModeGetPixelHeight(display_mode_ref); CGDisplayModeRelease(display_mode_ref); - Ok(size( - DevicePixels(width as i32), - DevicePixels(height as i32), - )) - } + ( + display_id, + size(DevicePixels(width as i32), DevicePixels(height as i32)), + ) + }; + let (label, is_main) = self + .meta + .clone() + .map(|meta| (meta.label, meta.is_main)) + .unzip(); + + Ok(SourceMetadata { + id: display_id as u64, + label, + is_main, + resolution: size, + }) } fn stream( @@ -89,9 +104,9 @@ impl ScreenCaptureSource for MacScreenCaptureSource { Box::into_raw(Box::new(frame_callback)) as *mut c_void, ); - let resolution = self.resolution().unwrap(); - let _: id = msg_send![configuration, setWidth: resolution.width.0 as i64]; - let _: id = msg_send![configuration, setHeight: resolution.height.0 as i64]; + let meta = self.metadata().unwrap(); + let _: id = msg_send![configuration, setWidth: meta.resolution.width.0 as i64]; + let _: id = msg_send![configuration, setHeight: meta.resolution.height.0 as i64]; let stream: id = msg_send![stream, initWithFilter:filter configuration:configuration delegate:delegate]; let (mut tx, rx) = oneshot::channel(); @@ -110,6 +125,7 @@ impl ScreenCaptureSource for MacScreenCaptureSource { move |error: id| { let result = if error == nil { let stream = MacScreenCaptureStream { + meta: meta.clone(), sc_stream: stream, sc_stream_output: output, }; @@ -138,7 +154,11 @@ impl Drop for MacScreenCaptureSource { } } -impl ScreenCaptureStream for MacScreenCaptureStream {} +impl ScreenCaptureStream for MacScreenCaptureStream { + fn metadata(&self) -> Result<SourceMetadata> { + Ok(self.meta.clone()) + } +} impl Drop for MacScreenCaptureStream { fn drop(&mut self) { @@ -164,24 +184,74 @@ impl Drop for MacScreenCaptureStream { } } -pub(crate) fn get_sources() -> oneshot::Receiver<Result<Vec<Box<dyn ScreenCaptureSource>>>> { +#[derive(Clone)] +struct ScreenMeta { + label: SharedString, + // Is this the screen with menu bar? + is_main: bool, +} + +unsafe fn screen_id_to_human_label() -> HashMap<CGDirectDisplayID, ScreenMeta> { + let screens: id = msg_send![class!(NSScreen), screens]; + let count: usize = msg_send![screens, count]; + let mut map = HashMap::default(); + let screen_number_key = unsafe { NSString::alloc(nil).init_str("NSScreenNumber") }; + for i in 0..count { + let screen: id = msg_send![screens, objectAtIndex: i]; + let device_desc: id = msg_send![screen, deviceDescription]; + if device_desc == nil { + continue; + } + + let nsnumber: id = msg_send![device_desc, objectForKey: screen_number_key]; + if nsnumber == nil { + continue; + } + + let screen_id: u32 = msg_send![nsnumber, unsignedIntValue]; + + let name: id = msg_send![screen, localizedName]; + if name != nil { + let cstr: *const std::os::raw::c_char = msg_send![name, UTF8String]; + let rust_str = unsafe { + std::ffi::CStr::from_ptr(cstr) + .to_string_lossy() + .into_owned() + }; + map.insert( + screen_id, + ScreenMeta { + label: rust_str.into(), + is_main: i == 0, + }, + ); + } + } + map +} + +pub(crate) fn get_sources() -> oneshot::Receiver<Result<Vec<Rc<dyn ScreenCaptureSource>>>> { unsafe { let (mut tx, rx) = oneshot::channel(); let tx = Rc::new(RefCell::new(Some(tx))); - + let screen_id_to_label = screen_id_to_human_label(); let block = ConcreteBlock::new(move |shareable_content: id, error: id| { let Some(mut tx) = tx.borrow_mut().take() else { return; }; + let result = if error == nil { let displays: id = msg_send![shareable_content, displays]; let mut result = Vec::new(); for i in 0..displays.count() { let display = displays.objectAtIndex(i); + let id: CGDirectDisplayID = msg_send![display, displayID]; + let meta = screen_id_to_label.get(&id).cloned(); let source = MacScreenCaptureSource { sc_display: msg_send![display, retain], + meta, }; - result.push(Box::new(source) as Box<dyn ScreenCaptureSource>); + result.push(Rc::new(source) as Rc<dyn ScreenCaptureSource>); } Ok(result) } else { diff --git a/crates/gpui/src/platform/mac/shaders.metal b/crates/gpui/src/platform/mac/shaders.metal index 5f0dc3323d4b4cec77a8c25fc9b008ea9da0a578..f9d5bdbf4c4ae1fa6ce098463ce63701a7019bbc 100644 --- a/crates/gpui/src/platform/mac/shaders.metal +++ b/crates/gpui/src/platform/mac/shaders.metal @@ -698,63 +698,120 @@ fragment float4 polychrome_sprite_fragment( return color; } -struct PathVertexOutput { +struct PathRasterizationVertexOutput { float4 position [[position]]; - uint sprite_id [[flat]]; - float4 solid_color [[flat]]; - float4 color0 [[flat]]; - float4 color1 [[flat]]; - float4 clip_distance; + float2 st_position; + uint vertex_id [[flat]]; + float clip_rect_distance [[clip_distance]][4]; }; -vertex PathVertexOutput path_vertex( - uint vertex_id [[vertex_id]], - constant PathVertex_ScaledPixels *vertices [[buffer(PathInputIndex_Vertices)]], - uint sprite_id [[instance_id]], - constant PathSprite *sprites [[buffer(PathInputIndex_Sprites)]], - constant Size_DevicePixels *input_viewport_size [[buffer(PathInputIndex_ViewportSize)]]) { - PathVertex_ScaledPixels v = vertices[vertex_id]; +struct PathRasterizationFragmentInput { + float4 position [[position]]; + float2 st_position; + uint vertex_id [[flat]]; +}; + +vertex PathRasterizationVertexOutput path_rasterization_vertex( + uint vertex_id [[vertex_id]], + constant PathRasterizationVertex *vertices [[buffer(PathRasterizationInputIndex_Vertices)]], + constant Size_DevicePixels *atlas_size [[buffer(PathRasterizationInputIndex_ViewportSize)]] +) { + PathRasterizationVertex v = vertices[vertex_id]; float2 vertex_position = float2(v.xy_position.x, v.xy_position.y); - float2 viewport_size = float2((float)input_viewport_size->width, - (float)input_viewport_size->height); - PathSprite sprite = sprites[sprite_id]; - float4 device_position = float4(vertex_position / viewport_size * float2(2., -2.) + float2(-1., 1.), 0., 1.); + float4 position = float4( + vertex_position * float2(2. / atlas_size->width, -2. / atlas_size->height) + float2(-1., 1.), + 0., + 1. + ); + return PathRasterizationVertexOutput{ + position, + float2(v.st_position.x, v.st_position.y), + vertex_id, + { + v.xy_position.x - v.bounds.origin.x, + v.bounds.origin.x + v.bounds.size.width - v.xy_position.x, + v.xy_position.y - v.bounds.origin.y, + v.bounds.origin.y + v.bounds.size.height - v.xy_position.y + } + }; +} - GradientColor gradient = prepare_fill_color( - sprite.color.tag, - sprite.color.color_space, - sprite.color.solid, - sprite.color.colors[0].color, - sprite.color.colors[1].color +fragment float4 path_rasterization_fragment( + PathRasterizationFragmentInput input [[stage_in]], + constant PathRasterizationVertex *vertices [[buffer(PathRasterizationInputIndex_Vertices)]] +) { + float2 dx = dfdx(input.st_position); + float2 dy = dfdy(input.st_position); + + PathRasterizationVertex v = vertices[input.vertex_id]; + Background background = v.color; + Bounds_ScaledPixels path_bounds = v.bounds; + float alpha; + if (length(float2(dx.x, dy.x)) < 0.001) { + alpha = 1.0; + } else { + float2 gradient = float2( + (2. * input.st_position.x) * dx.x - dx.y, + (2. * input.st_position.x) * dy.x - dy.y + ); + float f = (input.st_position.x * input.st_position.x) - input.st_position.y; + float distance = f / length(gradient); + alpha = saturate(0.5 - distance); + } + + GradientColor gradient_color = prepare_fill_color( + background.tag, + background.color_space, + background.solid, + background.colors[0].color, + background.colors[1].color + ); + + float4 color = fill_color( + background, + input.position.xy, + path_bounds, + gradient_color.solid, + gradient_color.color0, + gradient_color.color1 ); + return float4(color.rgb * color.a * alpha, alpha * color.a); +} + +struct PathSpriteVertexOutput { + float4 position [[position]]; + float2 texture_coords; +}; + +vertex PathSpriteVertexOutput path_sprite_vertex( + uint unit_vertex_id [[vertex_id]], + uint sprite_id [[instance_id]], + constant float2 *unit_vertices [[buffer(SpriteInputIndex_Vertices)]], + constant PathSprite *sprites [[buffer(SpriteInputIndex_Sprites)]], + constant Size_DevicePixels *viewport_size [[buffer(SpriteInputIndex_ViewportSize)]] +) { + float2 unit_vertex = unit_vertices[unit_vertex_id]; + PathSprite sprite = sprites[sprite_id]; + // Don't apply content mask because it was already accounted for when + // rasterizing the path. + float4 device_position = + to_device_position(unit_vertex, sprite.bounds, viewport_size); - return PathVertexOutput{ + float2 screen_position = float2(sprite.bounds.origin.x, sprite.bounds.origin.y) + unit_vertex * float2(sprite.bounds.size.width, sprite.bounds.size.height); + float2 texture_coords = screen_position / float2(viewport_size->width, viewport_size->height); + + return PathSpriteVertexOutput{ device_position, - sprite_id, - gradient.solid, - gradient.color0, - gradient.color1, - {v.xy_position.x - v.content_mask.bounds.origin.x, - v.content_mask.bounds.origin.x + v.content_mask.bounds.size.width - - v.xy_position.x, - v.xy_position.y - v.content_mask.bounds.origin.y, - v.content_mask.bounds.origin.y + v.content_mask.bounds.size.height - - v.xy_position.y} + texture_coords }; } -fragment float4 path_fragment( - PathVertexOutput input [[stage_in]], - constant PathSprite *sprites [[buffer(PathInputIndex_Sprites)]]) { - if (any(input.clip_distance < float4(0.0))) { - return float4(0.0); - } - - PathSprite sprite = sprites[input.sprite_id]; - Background background = sprite.color; - float4 color = fill_color(background, input.position.xy, sprite.bounds, - input.solid_color, input.color0, input.color1); - return color; +fragment float4 path_sprite_fragment( + PathSpriteVertexOutput input [[stage_in]], + texture2d<float> intermediate_texture [[texture(SpriteInputIndex_AtlasTexture)]] +) { + constexpr sampler intermediate_texture_sampler(mag_filter::linear, min_filter::linear); + return intermediate_texture.sample(intermediate_texture_sampler, input.texture_coords); } struct SurfaceVertexOutput { diff --git a/crates/gpui/src/platform/scap_screen_capture.rs b/crates/gpui/src/platform/scap_screen_capture.rs index c5e2267a37c794aeab70bc06d88d849b64be1c6f..32041b655fdc20b046717291c623dcb5c4d5146c 100644 --- a/crates/gpui/src/platform/scap_screen_capture.rs +++ b/crates/gpui/src/platform/scap_screen_capture.rs @@ -1,10 +1,12 @@ //! Screen capture for Linux and Windows use crate::{ DevicePixels, ForegroundExecutor, ScreenCaptureFrame, ScreenCaptureSource, ScreenCaptureStream, - Size, size, + Size, SourceMetadata, size, }; use anyhow::{Context as _, Result, anyhow}; use futures::channel::oneshot; +use scap::Target; +use std::rc::Rc; use std::sync::Arc; use std::sync::atomic::{self, AtomicBool}; @@ -15,7 +17,7 @@ use std::sync::atomic::{self, AtomicBool}; #[allow(dead_code)] pub(crate) fn scap_screen_sources( foreground_executor: &ForegroundExecutor, -) -> oneshot::Receiver<Result<Vec<Box<dyn ScreenCaptureSource>>>> { +) -> oneshot::Receiver<Result<Vec<Rc<dyn ScreenCaptureSource>>>> { let (sources_tx, sources_rx) = oneshot::channel(); get_screen_targets(sources_tx); to_dyn_screen_capture_sources(sources_rx, foreground_executor) @@ -29,14 +31,14 @@ pub(crate) fn scap_screen_sources( #[allow(dead_code)] pub(crate) fn start_scap_default_target_source( foreground_executor: &ForegroundExecutor, -) -> oneshot::Receiver<Result<Vec<Box<dyn ScreenCaptureSource>>>> { +) -> oneshot::Receiver<Result<Vec<Rc<dyn ScreenCaptureSource>>>> { let (sources_tx, sources_rx) = oneshot::channel(); start_default_target_screen_capture(sources_tx); to_dyn_screen_capture_sources(sources_rx, foreground_executor) } struct ScapCaptureSource { - target: scap::Target, + target: scap::Display, size: Size<DevicePixels>, } @@ -52,7 +54,7 @@ fn get_screen_targets(sources_tx: oneshot::Sender<Result<Vec<ScapCaptureSource>> } }; let sources = targets - .iter() + .into_iter() .filter_map(|target| match target { scap::Target::Display(display) => { let size = Size { @@ -60,7 +62,7 @@ fn get_screen_targets(sources_tx: oneshot::Sender<Result<Vec<ScapCaptureSource>> height: DevicePixels(display.height as i32), }; Some(ScapCaptureSource { - target: target.clone(), + target: display, size, }) } @@ -72,8 +74,13 @@ fn get_screen_targets(sources_tx: oneshot::Sender<Result<Vec<ScapCaptureSource>> } impl ScreenCaptureSource for ScapCaptureSource { - fn resolution(&self) -> Result<Size<DevicePixels>> { - Ok(self.size) + fn metadata(&self) -> Result<SourceMetadata> { + Ok(SourceMetadata { + resolution: self.size, + label: Some(self.target.title.clone().into()), + is_main: None, + id: self.target.id as u64, + }) } fn stream( @@ -85,13 +92,15 @@ impl ScreenCaptureSource for ScapCaptureSource { let target = self.target.clone(); // Due to use of blocking APIs, a dedicated thread is used. - std::thread::spawn(move || match new_scap_capturer(Some(target)) { - Ok(mut capturer) => { - capturer.start_capture(); - run_capture(capturer, frame_callback, stream_tx); - } - Err(e) => { - stream_tx.send(Err(e)).ok(); + std::thread::spawn(move || { + match new_scap_capturer(Some(scap::Target::Display(target.clone()))) { + Ok(mut capturer) => { + capturer.start_capture(); + run_capture(capturer, target.clone(), frame_callback, stream_tx); + } + Err(e) => { + stream_tx.send(Err(e)).ok(); + } } }); @@ -107,6 +116,7 @@ struct ScapDefaultTargetCaptureSource { // Callback for frames. Box<dyn Fn(ScreenCaptureFrame) + Send>, )>, + target: scap::Display, size: Size<DevicePixels>, } @@ -123,33 +133,48 @@ fn start_default_target_screen_capture( .get_next_frame() .context("Failed to get first frame of screenshare to get the size.")?; let size = frame_size(&first_frame); - Ok((capturer, size)) + let target = capturer + .target() + .context("Unable to determine the target display.")?; + let target = target.clone(); + Ok((capturer, size, target)) }); match start_result { - Err(e) => { - sources_tx.send(Err(e)).ok(); - } - Ok((capturer, size)) => { + Ok((capturer, size, Target::Display(display))) => { let (stream_call_tx, stream_rx) = std::sync::mpsc::sync_channel(1); sources_tx .send(Ok(vec![ScapDefaultTargetCaptureSource { stream_call_tx, size, + target: display.clone(), }])) .ok(); let Ok((stream_tx, frame_callback)) = stream_rx.recv() else { return; }; - run_capture(capturer, frame_callback, stream_tx); + run_capture(capturer, display, frame_callback, stream_tx); + } + Err(e) => { + sources_tx.send(Err(e)).ok(); + } + _ => { + sources_tx + .send(Err(anyhow!("The screen capture source is not a display"))) + .ok(); } } }); } impl ScreenCaptureSource for ScapDefaultTargetCaptureSource { - fn resolution(&self) -> Result<Size<DevicePixels>> { - Ok(self.size) + fn metadata(&self) -> Result<SourceMetadata> { + Ok(SourceMetadata { + resolution: self.size, + label: None, + is_main: None, + id: self.target.id as u64, + }) } fn stream( @@ -189,12 +214,19 @@ fn new_scap_capturer(target: Option<scap::Target>) -> Result<scap::capturer::Cap fn run_capture( mut capturer: scap::capturer::Capturer, + display: scap::Display, frame_callback: Box<dyn Fn(ScreenCaptureFrame) + Send>, stream_tx: oneshot::Sender<Result<ScapStream>>, ) { let cancel_stream = Arc::new(AtomicBool::new(false)); + let size = Size { + width: DevicePixels(display.width as i32), + height: DevicePixels(display.height as i32), + }; let stream_send_result = stream_tx.send(Ok(ScapStream { cancel_stream: cancel_stream.clone(), + display, + size, })); if let Err(_) = stream_send_result { return; @@ -213,9 +245,20 @@ fn run_capture( struct ScapStream { cancel_stream: Arc<AtomicBool>, + display: scap::Display, + size: Size<DevicePixels>, } -impl ScreenCaptureStream for ScapStream {} +impl ScreenCaptureStream for ScapStream { + fn metadata(&self) -> Result<SourceMetadata> { + Ok(SourceMetadata { + resolution: self.size, + label: Some(self.display.title.clone().into()), + is_main: None, + id: self.display.id as u64, + }) + } +} impl Drop for ScapStream { fn drop(&mut self) { @@ -237,12 +280,12 @@ fn frame_size(frame: &scap::frame::Frame) -> Size<DevicePixels> { } /// This is used by `get_screen_targets` and `start_default_target_screen_capture` to turn their -/// results into `Box<dyn ScreenCaptureSource>`. They need to `Send` their capture source, and so -/// the capture source structs are used as `Box<dyn ScreenCaptureSource>` is not `Send`. +/// results into `Rc<dyn ScreenCaptureSource>`. They need to `Send` their capture source, and so +/// the capture source structs are used as `Rc<dyn ScreenCaptureSource>` is not `Send`. fn to_dyn_screen_capture_sources<T: ScreenCaptureSource + 'static>( sources_rx: oneshot::Receiver<Result<Vec<T>>>, foreground_executor: &ForegroundExecutor, -) -> oneshot::Receiver<Result<Vec<Box<dyn ScreenCaptureSource>>>> { +) -> oneshot::Receiver<Result<Vec<Rc<dyn ScreenCaptureSource>>>> { let (dyn_sources_tx, dyn_sources_rx) = oneshot::channel(); foreground_executor .spawn(async move { @@ -250,7 +293,7 @@ fn to_dyn_screen_capture_sources<T: ScreenCaptureSource + 'static>( Ok(Ok(results)) => dyn_sources_tx .send(Ok(results .into_iter() - .map(|source| Box::new(source) as Box<dyn ScreenCaptureSource>) + .map(|source| Rc::new(source) as Rc<dyn ScreenCaptureSource>) .collect::<Vec<_>>())) .ok(), Ok(Err(err)) => dyn_sources_tx.send(Err(err)).ok(), diff --git a/crates/gpui/src/platform/test.rs b/crates/gpui/src/platform/test.rs index e4173b7c6ba2011bdea80514ac12fa862cade1e4..9227df5b63314b44a3c641835d00ba340aa909e8 100644 --- a/crates/gpui/src/platform/test.rs +++ b/crates/gpui/src/platform/test.rs @@ -8,4 +8,4 @@ pub(crate) use display::*; pub(crate) use platform::*; pub(crate) use window::*; -pub use platform::TestScreenCaptureSource; +pub use platform::{TestScreenCaptureSource, TestScreenCaptureStream}; diff --git a/crates/gpui/src/platform/test/platform.rs b/crates/gpui/src/platform/test/platform.rs index bef05399e52a6eb1a05552bb8693f5850274e98a..a26b65576cc49e290494762eed597d5bd8d0af26 100644 --- a/crates/gpui/src/platform/test/platform.rs +++ b/crates/gpui/src/platform/test/platform.rs @@ -2,7 +2,7 @@ use crate::{ AnyWindowHandle, BackgroundExecutor, ClipboardItem, CursorStyle, DevicePixels, ForegroundExecutor, Keymap, NoopTextSystem, Platform, PlatformDisplay, PlatformKeyboardLayout, PlatformTextSystem, PromptButton, ScreenCaptureFrame, ScreenCaptureSource, ScreenCaptureStream, - Size, Task, TestDisplay, TestWindow, WindowAppearance, WindowParams, size, + SourceMetadata, Task, TestDisplay, TestWindow, WindowAppearance, WindowParams, size, }; use anyhow::Result; use collections::VecDeque; @@ -44,11 +44,17 @@ pub(crate) struct TestPlatform { /// A fake screen capture source, used for testing. pub struct TestScreenCaptureSource {} +/// A fake screen capture stream, used for testing. pub struct TestScreenCaptureStream {} impl ScreenCaptureSource for TestScreenCaptureSource { - fn resolution(&self) -> Result<Size<DevicePixels>> { - Ok(size(DevicePixels(1), DevicePixels(1))) + fn metadata(&self) -> Result<SourceMetadata> { + Ok(SourceMetadata { + id: 0, + is_main: None, + label: None, + resolution: size(DevicePixels(1), DevicePixels(1)), + }) } fn stream( @@ -64,7 +70,11 @@ impl ScreenCaptureSource for TestScreenCaptureSource { } } -impl ScreenCaptureStream for TestScreenCaptureStream {} +impl ScreenCaptureStream for TestScreenCaptureStream { + fn metadata(&self) -> Result<SourceMetadata> { + TestScreenCaptureSource {}.metadata() + } +} struct TestPrompt { msg: String, @@ -271,13 +281,13 @@ impl Platform for TestPlatform { #[cfg(feature = "screen-capture")] fn screen_capture_sources( &self, - ) -> oneshot::Receiver<Result<Vec<Box<dyn ScreenCaptureSource>>>> { + ) -> oneshot::Receiver<Result<Vec<Rc<dyn ScreenCaptureSource>>>> { let (mut tx, rx) = oneshot::channel(); tx.send(Ok(self .screen_capture_sources .borrow() .iter() - .map(|source| Box::new(source.clone()) as Box<dyn ScreenCaptureSource>) + .map(|source| Rc::new(source.clone()) as Rc<dyn ScreenCaptureSource>) .collect())) .ok(); rx diff --git a/crates/gpui/src/platform/test/window.rs b/crates/gpui/src/platform/test/window.rs index 65ee10a13ffa8a73f377ae4ac4e5f7a4381519ec..e15bd7aeecec5932eb6386bd47d168eda906dd63 100644 --- a/crates/gpui/src/platform/test/window.rs +++ b/crates/gpui/src/platform/test/window.rs @@ -341,7 +341,7 @@ impl PlatformAtlas for TestAtlas { crate::AtlasTile { texture_id: AtlasTextureId { index: texture_id, - kind: crate::AtlasTextureKind::Polychrome, + kind: crate::AtlasTextureKind::Monochrome, }, tile_id: TileId(tile_id), padding: 0, diff --git a/crates/gpui/src/platform/windows.rs b/crates/gpui/src/platform/windows.rs index 4bdf42080d9c6becd339a134c4a3dfc1ac3502e2..5268d3ccba217c996f1ab2f3f664bac2c1032627 100644 --- a/crates/gpui/src/platform/windows.rs +++ b/crates/gpui/src/platform/windows.rs @@ -1,6 +1,8 @@ mod clipboard; mod destination_list; mod direct_write; +mod directx_atlas; +mod directx_renderer; mod dispatcher; mod display; mod events; @@ -14,6 +16,8 @@ mod wrapper; pub(crate) use clipboard::*; pub(crate) use destination_list::*; pub(crate) use direct_write::*; +pub(crate) use directx_atlas::*; +pub(crate) use directx_renderer::*; pub(crate) use dispatcher::*; pub(crate) use display::*; pub(crate) use events::*; diff --git a/crates/gpui/src/platform/windows/color_text_raster.hlsl b/crates/gpui/src/platform/windows/color_text_raster.hlsl new file mode 100644 index 0000000000000000000000000000000000000000..ccc5fa26f00d57f2b69e85965a66b6ecea98a833 --- /dev/null +++ b/crates/gpui/src/platform/windows/color_text_raster.hlsl @@ -0,0 +1,39 @@ +struct RasterVertexOutput { + float4 position : SV_Position; + float2 texcoord : TEXCOORD0; +}; + +RasterVertexOutput emoji_rasterization_vertex(uint vertexID : SV_VERTEXID) +{ + RasterVertexOutput output; + output.texcoord = float2((vertexID << 1) & 2, vertexID & 2); + output.position = float4(output.texcoord * 2.0f - 1.0f, 0.0f, 1.0f); + output.position.y = -output.position.y; + + return output; +} + +struct PixelInput { + float4 position: SV_Position; + float2 texcoord : TEXCOORD0; +}; + +struct Bounds { + int2 origin; + int2 size; +}; + +Texture2D<float4> t_layer : register(t0); +SamplerState s_layer : register(s0); + +cbuffer GlyphLayerTextureParams : register(b0) { + Bounds bounds; + float4 run_color; +}; + +float4 emoji_rasterization_fragment(PixelInput input): SV_Target { + float3 sampled = t_layer.Sample(s_layer, input.texcoord.xy).rgb; + float alpha = (sampled.r + sampled.g + sampled.b) / 3; + + return float4(run_color.rgb, alpha); +} diff --git a/crates/gpui/src/platform/windows/direct_write.rs b/crates/gpui/src/platform/windows/direct_write.rs index ada306c15c187e7014812c3026d12e966a563c80..587cb7b4a6c18aca2ee4f8a10e453eeca3ade037 100644 --- a/crates/gpui/src/platform/windows/direct_write.rs +++ b/crates/gpui/src/platform/windows/direct_write.rs @@ -10,10 +10,11 @@ use windows::{ Foundation::*, Globalization::GetUserDefaultLocaleName, Graphics::{ - Direct2D::{Common::*, *}, + Direct3D::D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP, + Direct3D11::*, DirectWrite::*, Dxgi::Common::*, - Gdi::LOGFONTW, + Gdi::{IsRectEmpty, LOGFONTW}, Imaging::*, }, System::SystemServices::LOCALE_NAME_MAX_LENGTH, @@ -40,16 +41,21 @@ struct DirectWriteComponent { locale: String, factory: IDWriteFactory5, bitmap_factory: AgileReference<IWICImagingFactory>, - d2d1_factory: ID2D1Factory, in_memory_loader: IDWriteInMemoryFontFileLoader, builder: IDWriteFontSetBuilder1, text_renderer: Arc<TextRendererWrapper>, - render_context: GlyphRenderContext, + + render_params: IDWriteRenderingParams3, + gpu_state: GPUState, } -struct GlyphRenderContext { - params: IDWriteRenderingParams3, - dc_target: ID2D1DeviceContext4, +struct GPUState { + device: ID3D11Device, + device_context: ID3D11DeviceContext, + sampler: [Option<ID3D11SamplerState>; 1], + blend_state: ID3D11BlendState, + vertex_shader: ID3D11VertexShader, + pixel_shader: ID3D11PixelShader, } struct DirectWriteState { @@ -70,12 +76,11 @@ struct FontIdentifier { } impl DirectWriteComponent { - pub fn new(bitmap_factory: &IWICImagingFactory) -> Result<Self> { + pub fn new(bitmap_factory: &IWICImagingFactory, gpu_context: &DirectXDevices) -> Result<Self> { + // todo: ideally this would not be a large unsafe block but smaller isolated ones for easier auditing unsafe { let factory: IDWriteFactory5 = DWriteCreateFactory(DWRITE_FACTORY_TYPE_SHARED)?; let bitmap_factory = AgileReference::new(bitmap_factory)?; - let d2d1_factory: ID2D1Factory = - D2D1CreateFactory(D2D1_FACTORY_TYPE_MULTI_THREADED, None)?; // The `IDWriteInMemoryFontFileLoader` here is supported starting from // Windows 10 Creators Update, which consequently requires the entire // `DirectWriteTextSystem` to run on `win10 1703`+. @@ -86,60 +91,132 @@ impl DirectWriteComponent { GetUserDefaultLocaleName(&mut locale_vec); let locale = String::from_utf16_lossy(&locale_vec); let text_renderer = Arc::new(TextRendererWrapper::new(&locale)); - let render_context = GlyphRenderContext::new(&factory, &d2d1_factory)?; + + let render_params = { + let default_params: IDWriteRenderingParams3 = + factory.CreateRenderingParams()?.cast()?; + let gamma = default_params.GetGamma(); + let enhanced_contrast = default_params.GetEnhancedContrast(); + let gray_contrast = default_params.GetGrayscaleEnhancedContrast(); + let cleartype_level = default_params.GetClearTypeLevel(); + let grid_fit_mode = default_params.GetGridFitMode(); + + factory.CreateCustomRenderingParams( + gamma, + enhanced_contrast, + gray_contrast, + cleartype_level, + DWRITE_PIXEL_GEOMETRY_RGB, + DWRITE_RENDERING_MODE1_NATURAL_SYMMETRIC, + grid_fit_mode, + )? + }; + + let gpu_state = GPUState::new(gpu_context)?; Ok(DirectWriteComponent { locale, factory, bitmap_factory, - d2d1_factory, in_memory_loader, builder, text_renderer, - render_context, + render_params, + gpu_state, }) } } } -impl GlyphRenderContext { - pub fn new(factory: &IDWriteFactory5, d2d1_factory: &ID2D1Factory) -> Result<Self> { - unsafe { - let default_params: IDWriteRenderingParams3 = - factory.CreateRenderingParams()?.cast()?; - let gamma = default_params.GetGamma(); - let enhanced_contrast = default_params.GetEnhancedContrast(); - let gray_contrast = default_params.GetGrayscaleEnhancedContrast(); - let cleartype_level = default_params.GetClearTypeLevel(); - let grid_fit_mode = default_params.GetGridFitMode(); - - let params = factory.CreateCustomRenderingParams( - gamma, - enhanced_contrast, - gray_contrast, - cleartype_level, - DWRITE_PIXEL_GEOMETRY_RGB, - DWRITE_RENDERING_MODE1_NATURAL_SYMMETRIC, - grid_fit_mode, - )?; - let dc_target = { - let target = d2d1_factory.CreateDCRenderTarget(&get_render_target_property( - DXGI_FORMAT_B8G8R8A8_UNORM, - D2D1_ALPHA_MODE_PREMULTIPLIED, - ))?; - let target = target.cast::<ID2D1DeviceContext4>()?; - target.SetTextRenderingParams(¶ms); - target +impl GPUState { + fn new(gpu_context: &DirectXDevices) -> Result<Self> { + let device = gpu_context.device.clone(); + let device_context = gpu_context.device_context.clone(); + + let blend_state = { + let mut blend_state = None; + let desc = D3D11_BLEND_DESC { + AlphaToCoverageEnable: false.into(), + IndependentBlendEnable: false.into(), + RenderTarget: [ + D3D11_RENDER_TARGET_BLEND_DESC { + BlendEnable: true.into(), + SrcBlend: D3D11_BLEND_SRC_ALPHA, + DestBlend: D3D11_BLEND_INV_SRC_ALPHA, + BlendOp: D3D11_BLEND_OP_ADD, + SrcBlendAlpha: D3D11_BLEND_SRC_ALPHA, + DestBlendAlpha: D3D11_BLEND_INV_SRC_ALPHA, + BlendOpAlpha: D3D11_BLEND_OP_ADD, + RenderTargetWriteMask: D3D11_COLOR_WRITE_ENABLE_ALL.0 as u8, + }, + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + ], }; + unsafe { device.CreateBlendState(&desc, Some(&mut blend_state)) }?; + blend_state.unwrap() + }; - Ok(Self { params, dc_target }) - } + let sampler = { + let mut sampler = None; + let desc = D3D11_SAMPLER_DESC { + Filter: D3D11_FILTER_MIN_MAG_MIP_POINT, + AddressU: D3D11_TEXTURE_ADDRESS_BORDER, + AddressV: D3D11_TEXTURE_ADDRESS_BORDER, + AddressW: D3D11_TEXTURE_ADDRESS_BORDER, + MipLODBias: 0.0, + MaxAnisotropy: 1, + ComparisonFunc: D3D11_COMPARISON_ALWAYS, + BorderColor: [0.0, 0.0, 0.0, 0.0], + MinLOD: 0.0, + MaxLOD: 0.0, + }; + unsafe { device.CreateSamplerState(&desc, Some(&mut sampler)) }?; + [sampler] + }; + + let vertex_shader = { + let source = shader_resources::RawShaderBytes::new( + shader_resources::ShaderModule::EmojiRasterization, + shader_resources::ShaderTarget::Vertex, + )?; + let mut shader = None; + unsafe { device.CreateVertexShader(source.as_bytes(), None, Some(&mut shader)) }?; + shader.unwrap() + }; + + let pixel_shader = { + let source = shader_resources::RawShaderBytes::new( + shader_resources::ShaderModule::EmojiRasterization, + shader_resources::ShaderTarget::Fragment, + )?; + let mut shader = None; + unsafe { device.CreatePixelShader(source.as_bytes(), None, Some(&mut shader)) }?; + shader.unwrap() + }; + + Ok(Self { + device, + device_context, + sampler, + blend_state, + vertex_shader, + pixel_shader, + }) } } impl DirectWriteTextSystem { - pub(crate) fn new(bitmap_factory: &IWICImagingFactory) -> Result<Self> { - let components = DirectWriteComponent::new(bitmap_factory)?; + pub(crate) fn new( + gpu_context: &DirectXDevices, + bitmap_factory: &IWICImagingFactory, + ) -> Result<Self> { + let components = DirectWriteComponent::new(bitmap_factory, gpu_context)?; let system_font_collection = unsafe { let mut result = std::mem::zeroed(); components @@ -648,15 +725,13 @@ impl DirectWriteState { } } - fn raster_bounds(&self, params: &RenderGlyphParams) -> Result<Bounds<DevicePixels>> { - let render_target = &self.components.render_context.dc_target; - unsafe { - render_target.SetUnitMode(D2D1_UNIT_MODE_DIPS); - render_target.SetDpi(96.0 * params.scale_factor, 96.0 * params.scale_factor); - } + fn create_glyph_run_analysis( + &self, + params: &RenderGlyphParams, + ) -> Result<IDWriteGlyphRunAnalysis> { let font = &self.fonts[params.font_id.0]; let glyph_id = [params.glyph_id.0 as u16]; - let advance = [0.0f32]; + let advance = [0.0]; let offset = [DWRITE_GLYPH_OFFSET::default()]; let glyph_run = DWRITE_GLYPH_RUN { fontFace: unsafe { std::mem::transmute_copy(&font.font_face) }, @@ -668,44 +743,87 @@ impl DirectWriteState { isSideways: BOOL(0), bidiLevel: 0, }; - let bounds = unsafe { - render_target.GetGlyphRunWorldBounds( - Vector2 { X: 0.0, Y: 0.0 }, - &glyph_run, - DWRITE_MEASURING_MODE_NATURAL, - )? + let transform = DWRITE_MATRIX { + m11: params.scale_factor, + m12: 0.0, + m21: 0.0, + m22: params.scale_factor, + dx: 0.0, + dy: 0.0, }; - // todo(windows) - // This is a walkaround, deleted when figured out. - let y_offset; - let extra_height; - if params.is_emoji { - y_offset = 0; - extra_height = 0; - } else { - // make some room for scaler. - y_offset = -1; - extra_height = 2; + let subpixel_shift = params + .subpixel_variant + .map(|v| v as f32 / SUBPIXEL_VARIANTS as f32); + let baseline_origin_x = subpixel_shift.x / params.scale_factor; + let baseline_origin_y = subpixel_shift.y / params.scale_factor; + + let mut rendering_mode = DWRITE_RENDERING_MODE1::default(); + let mut grid_fit_mode = DWRITE_GRID_FIT_MODE::default(); + unsafe { + font.font_face.GetRecommendedRenderingMode( + params.font_size.0, + // The dpi here seems that it has the same effect with `Some(&transform)` + 1.0, + 1.0, + Some(&transform), + false, + DWRITE_OUTLINE_THRESHOLD_ANTIALIASED, + DWRITE_MEASURING_MODE_NATURAL, + &self.components.render_params, + &mut rendering_mode, + &mut grid_fit_mode, + )?; } - if bounds.right < bounds.left { - Ok(Bounds { - origin: point(0.into(), 0.into()), - size: size(0.into(), 0.into()), - }) - } else { + let glyph_analysis = unsafe { + self.components.factory.CreateGlyphRunAnalysis( + &glyph_run, + Some(&transform), + rendering_mode, + DWRITE_MEASURING_MODE_NATURAL, + grid_fit_mode, + // We're using cleartype not grayscale for monochrome is because it provides better quality + DWRITE_TEXT_ANTIALIAS_MODE_CLEARTYPE, + baseline_origin_x, + baseline_origin_y, + ) + }?; + Ok(glyph_analysis) + } + + fn raster_bounds(&self, params: &RenderGlyphParams) -> Result<Bounds<DevicePixels>> { + let glyph_analysis = self.create_glyph_run_analysis(params)?; + + let bounds = unsafe { glyph_analysis.GetAlphaTextureBounds(DWRITE_TEXTURE_CLEARTYPE_3x1)? }; + // Some glyphs cannot be drawn with ClearType, such as bitmap fonts. In that case + // GetAlphaTextureBounds() supposedly returns an empty RECT, but I haven't tested that yet. + if !unsafe { IsRectEmpty(&bounds) }.as_bool() { Ok(Bounds { - origin: point( - ((bounds.left * params.scale_factor).ceil() as i32).into(), - ((bounds.top * params.scale_factor).ceil() as i32 + y_offset).into(), - ), + origin: point(bounds.left.into(), bounds.top.into()), size: size( - (((bounds.right - bounds.left) * params.scale_factor).ceil() as i32).into(), - (((bounds.bottom - bounds.top) * params.scale_factor).ceil() as i32 - + extra_height) - .into(), + (bounds.right - bounds.left).into(), + (bounds.bottom - bounds.top).into(), ), }) + } else { + // If it's empty, retry with grayscale AA. + let bounds = + unsafe { glyph_analysis.GetAlphaTextureBounds(DWRITE_TEXTURE_ALIASED_1x1)? }; + + if bounds.right < bounds.left { + Ok(Bounds { + origin: point(0.into(), 0.into()), + size: size(0.into(), 0.into()), + }) + } else { + Ok(Bounds { + origin: point(bounds.left.into(), bounds.top.into()), + size: size( + (bounds.right - bounds.left).into(), + (bounds.bottom - bounds.top).into(), + ), + }) + } } } @@ -731,7 +849,95 @@ impl DirectWriteState { anyhow::bail!("glyph bounds are empty"); } - let font_info = &self.fonts[params.font_id.0]; + let bitmap_data = if params.is_emoji { + if let Ok(color) = self.rasterize_color(¶ms, glyph_bounds) { + color + } else { + let monochrome = self.rasterize_monochrome(params, glyph_bounds)?; + monochrome + .into_iter() + .flat_map(|pixel| [0, 0, 0, pixel]) + .collect::<Vec<_>>() + } + } else { + self.rasterize_monochrome(params, glyph_bounds)? + }; + + Ok((glyph_bounds.size, bitmap_data)) + } + + fn rasterize_monochrome( + &self, + params: &RenderGlyphParams, + glyph_bounds: Bounds<DevicePixels>, + ) -> Result<Vec<u8>> { + let mut bitmap_data = + vec![0u8; glyph_bounds.size.width.0 as usize * glyph_bounds.size.height.0 as usize * 3]; + + let glyph_analysis = self.create_glyph_run_analysis(params)?; + unsafe { + glyph_analysis.CreateAlphaTexture( + // We're using cleartype not grayscale for monochrome is because it provides better quality + DWRITE_TEXTURE_CLEARTYPE_3x1, + &RECT { + left: glyph_bounds.origin.x.0, + top: glyph_bounds.origin.y.0, + right: glyph_bounds.size.width.0 + glyph_bounds.origin.x.0, + bottom: glyph_bounds.size.height.0 + glyph_bounds.origin.y.0, + }, + &mut bitmap_data, + )?; + } + + let bitmap_factory = self.components.bitmap_factory.resolve()?; + let bitmap = unsafe { + bitmap_factory.CreateBitmapFromMemory( + glyph_bounds.size.width.0 as u32, + glyph_bounds.size.height.0 as u32, + &GUID_WICPixelFormat24bppRGB, + glyph_bounds.size.width.0 as u32 * 3, + &bitmap_data, + ) + }?; + + let grayscale_bitmap = + unsafe { WICConvertBitmapSource(&GUID_WICPixelFormat8bppGray, &bitmap) }?; + + let mut bitmap_data = + vec![0u8; glyph_bounds.size.width.0 as usize * glyph_bounds.size.height.0 as usize]; + unsafe { + grayscale_bitmap.CopyPixels( + std::ptr::null() as _, + glyph_bounds.size.width.0 as u32, + &mut bitmap_data, + ) + }?; + + Ok(bitmap_data) + } + + fn rasterize_color( + &self, + params: &RenderGlyphParams, + glyph_bounds: Bounds<DevicePixels>, + ) -> Result<Vec<u8>> { + let bitmap_size = glyph_bounds.size; + let subpixel_shift = params + .subpixel_variant + .map(|v| v as f32 / SUBPIXEL_VARIANTS as f32); + let baseline_origin_x = subpixel_shift.x / params.scale_factor; + let baseline_origin_y = subpixel_shift.y / params.scale_factor; + + let transform = DWRITE_MATRIX { + m11: params.scale_factor, + m12: 0.0, + m21: 0.0, + m22: params.scale_factor, + dx: 0.0, + dy: 0.0, + }; + + let font = &self.fonts[params.font_id.0]; let glyph_id = [params.glyph_id.0 as u16]; let advance = [glyph_bounds.size.width.0 as f32]; let offset = [DWRITE_GLYPH_OFFSET { @@ -739,7 +945,7 @@ impl DirectWriteState { ascenderOffset: glyph_bounds.origin.y.0 as f32 / params.scale_factor, }]; let glyph_run = DWRITE_GLYPH_RUN { - fontFace: unsafe { std::mem::transmute_copy(&font_info.font_face) }, + fontFace: unsafe { std::mem::transmute_copy(&font.font_face) }, fontEmSize: params.font_size.0, glyphCount: 1, glyphIndices: glyph_id.as_ptr(), @@ -749,160 +955,254 @@ impl DirectWriteState { bidiLevel: 0, }; - // Add an extra pixel when the subpixel variant isn't zero to make room for anti-aliasing. - let mut bitmap_size = glyph_bounds.size; - if params.subpixel_variant.x > 0 { - bitmap_size.width += DevicePixels(1); - } - if params.subpixel_variant.y > 0 { - bitmap_size.height += DevicePixels(1); - } - let bitmap_size = bitmap_size; - - let total_bytes; - let bitmap_format; - let render_target_property; - let bitmap_width; - let bitmap_height; - let bitmap_stride; - let bitmap_dpi; - if params.is_emoji { - total_bytes = bitmap_size.height.0 as usize * bitmap_size.width.0 as usize * 4; - bitmap_format = &GUID_WICPixelFormat32bppPBGRA; - render_target_property = get_render_target_property( - DXGI_FORMAT_B8G8R8A8_UNORM, - D2D1_ALPHA_MODE_PREMULTIPLIED, - ); - bitmap_width = bitmap_size.width.0 as u32; - bitmap_height = bitmap_size.height.0 as u32; - bitmap_stride = bitmap_size.width.0 as u32 * 4; - bitmap_dpi = 96.0; - } else { - total_bytes = bitmap_size.height.0 as usize * bitmap_size.width.0 as usize; - bitmap_format = &GUID_WICPixelFormat8bppAlpha; - render_target_property = - get_render_target_property(DXGI_FORMAT_A8_UNORM, D2D1_ALPHA_MODE_STRAIGHT); - bitmap_width = bitmap_size.width.0 as u32 * 2; - bitmap_height = bitmap_size.height.0 as u32 * 2; - bitmap_stride = bitmap_size.width.0 as u32; - bitmap_dpi = 192.0; - } + // todo: support formats other than COLR + let color_enumerator = unsafe { + self.components.factory.TranslateColorGlyphRun( + Vector2::new(baseline_origin_x, baseline_origin_y), + &glyph_run, + None, + DWRITE_GLYPH_IMAGE_FORMATS_COLR, + DWRITE_MEASURING_MODE_NATURAL, + Some(&transform), + 0, + ) + }?; + + let mut glyph_layers = Vec::new(); + loop { + let color_run = unsafe { color_enumerator.GetCurrentRun() }?; + let color_run = unsafe { &*color_run }; + let image_format = color_run.glyphImageFormat & !DWRITE_GLYPH_IMAGE_FORMATS_TRUETYPE; + if image_format == DWRITE_GLYPH_IMAGE_FORMATS_COLR { + let color_analysis = unsafe { + self.components.factory.CreateGlyphRunAnalysis( + &color_run.Base.glyphRun as *const _, + Some(&transform), + DWRITE_RENDERING_MODE1_NATURAL_SYMMETRIC, + DWRITE_MEASURING_MODE_NATURAL, + DWRITE_GRID_FIT_MODE_DEFAULT, + DWRITE_TEXT_ANTIALIAS_MODE_CLEARTYPE, + baseline_origin_x, + baseline_origin_y, + ) + }?; - let bitmap_factory = self.components.bitmap_factory.resolve()?; - unsafe { - let bitmap = bitmap_factory.CreateBitmap( - bitmap_width, - bitmap_height, - bitmap_format, - WICBitmapCacheOnLoad, - )?; - let render_target = self - .components - .d2d1_factory - .CreateWicBitmapRenderTarget(&bitmap, &render_target_property)?; - let brush = render_target.CreateSolidColorBrush(&BRUSH_COLOR, None)?; - let subpixel_shift = params - .subpixel_variant - .map(|v| v as f32 / SUBPIXEL_VARIANTS as f32); - let baseline_origin = Vector2 { - X: subpixel_shift.x / params.scale_factor, - Y: subpixel_shift.y / params.scale_factor, - }; + let color_bounds = + unsafe { color_analysis.GetAlphaTextureBounds(DWRITE_TEXTURE_CLEARTYPE_3x1) }?; - // This `cast()` action here should never fail since we are running on Win10+, and - // ID2D1DeviceContext4 requires Win8+ - let render_target = render_target.cast::<ID2D1DeviceContext4>().unwrap(); - render_target.SetUnitMode(D2D1_UNIT_MODE_DIPS); - render_target.SetDpi( - bitmap_dpi * params.scale_factor, - bitmap_dpi * params.scale_factor, - ); - render_target.SetTextRenderingParams(&self.components.render_context.params); - render_target.BeginDraw(); - - if params.is_emoji { - // WARN: only DWRITE_GLYPH_IMAGE_FORMATS_COLR has been tested - let enumerator = self.components.factory.TranslateColorGlyphRun( - baseline_origin, - &glyph_run as _, - None, - DWRITE_GLYPH_IMAGE_FORMATS_COLR - | DWRITE_GLYPH_IMAGE_FORMATS_SVG - | DWRITE_GLYPH_IMAGE_FORMATS_PNG - | DWRITE_GLYPH_IMAGE_FORMATS_JPEG - | DWRITE_GLYPH_IMAGE_FORMATS_PREMULTIPLIED_B8G8R8A8, - DWRITE_MEASURING_MODE_NATURAL, - None, - 0, - )?; - while enumerator.MoveNext().is_ok() { - let Ok(color_glyph) = enumerator.GetCurrentRun() else { - break; + let color_size = size( + color_bounds.right - color_bounds.left, + color_bounds.bottom - color_bounds.top, + ); + if color_size.width > 0 && color_size.height > 0 { + let mut alpha_data = + vec![0u8; (color_size.width * color_size.height * 3) as usize]; + unsafe { + color_analysis.CreateAlphaTexture( + DWRITE_TEXTURE_CLEARTYPE_3x1, + &color_bounds, + &mut alpha_data, + ) + }?; + + let run_color = { + let run_color = color_run.Base.runColor; + Rgba { + r: run_color.r, + g: run_color.g, + b: run_color.b, + a: run_color.a, + } }; - let color_glyph = &*color_glyph; - let brush_color = translate_color(&color_glyph.Base.runColor); - brush.SetColor(&brush_color); - match color_glyph.glyphImageFormat { - DWRITE_GLYPH_IMAGE_FORMATS_PNG - | DWRITE_GLYPH_IMAGE_FORMATS_JPEG - | DWRITE_GLYPH_IMAGE_FORMATS_PREMULTIPLIED_B8G8R8A8 => render_target - .DrawColorBitmapGlyphRun( - color_glyph.glyphImageFormat, - baseline_origin, - &color_glyph.Base.glyphRun, - color_glyph.measuringMode, - D2D1_COLOR_BITMAP_GLYPH_SNAP_OPTION_DEFAULT, - ), - DWRITE_GLYPH_IMAGE_FORMATS_SVG => render_target.DrawSvgGlyphRun( - baseline_origin, - &color_glyph.Base.glyphRun, - &brush, - None, - color_glyph.Base.paletteIndex as u32, - color_glyph.measuringMode, - ), - _ => render_target.DrawGlyphRun( - baseline_origin, - &color_glyph.Base.glyphRun, - Some(color_glyph.Base.glyphRunDescription as *const _), - &brush, - color_glyph.measuringMode, - ), - } + let bounds = bounds(point(color_bounds.left, color_bounds.top), color_size); + let alpha_data = alpha_data + .chunks_exact(3) + .flat_map(|chunk| [chunk[0], chunk[1], chunk[2], 255]) + .collect::<Vec<_>>(); + glyph_layers.push(GlyphLayerTexture::new( + &self.components.gpu_state, + run_color, + bounds, + &alpha_data, + )?); } - } else { - render_target.DrawGlyphRun( - baseline_origin, - &glyph_run, - None, - &brush, - DWRITE_MEASURING_MODE_NATURAL, - ); } - render_target.EndDraw(None, None)?; - - let mut raw_data = vec![0u8; total_bytes]; - if params.is_emoji { - bitmap.CopyPixels(std::ptr::null() as _, bitmap_stride, &mut raw_data)?; - // Convert from BGRA with premultiplied alpha to BGRA with straight alpha. - for pixel in raw_data.chunks_exact_mut(4) { - let a = pixel[3] as f32 / 255.; - pixel[0] = (pixel[0] as f32 / a) as u8; - pixel[1] = (pixel[1] as f32 / a) as u8; - pixel[2] = (pixel[2] as f32 / a) as u8; - } - } else { - let scaler = bitmap_factory.CreateBitmapScaler()?; - scaler.Initialize( - &bitmap, - bitmap_size.width.0 as u32, - bitmap_size.height.0 as u32, - WICBitmapInterpolationModeHighQualityCubic, - )?; - scaler.CopyPixels(std::ptr::null() as _, bitmap_stride, &mut raw_data)?; + + let has_next = unsafe { color_enumerator.MoveNext() } + .map(|e| e.as_bool()) + .unwrap_or(false); + if !has_next { + break; } - Ok((bitmap_size, raw_data)) } + + let gpu_state = &self.components.gpu_state; + let params_buffer = { + let desc = D3D11_BUFFER_DESC { + ByteWidth: std::mem::size_of::<GlyphLayerTextureParams>() as u32, + Usage: D3D11_USAGE_DYNAMIC, + BindFlags: D3D11_BIND_CONSTANT_BUFFER.0 as u32, + CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32, + MiscFlags: 0, + StructureByteStride: 0, + }; + + let mut buffer = None; + unsafe { + gpu_state + .device + .CreateBuffer(&desc, None, Some(&mut buffer)) + }?; + [buffer] + }; + + let render_target_texture = { + let mut texture = None; + let desc = D3D11_TEXTURE2D_DESC { + Width: bitmap_size.width.0 as u32, + Height: bitmap_size.height.0 as u32, + MipLevels: 1, + ArraySize: 1, + Format: DXGI_FORMAT_B8G8R8A8_UNORM, + SampleDesc: DXGI_SAMPLE_DESC { + Count: 1, + Quality: 0, + }, + Usage: D3D11_USAGE_DEFAULT, + BindFlags: D3D11_BIND_RENDER_TARGET.0 as u32, + CPUAccessFlags: 0, + MiscFlags: 0, + }; + unsafe { + gpu_state + .device + .CreateTexture2D(&desc, None, Some(&mut texture)) + }?; + texture.unwrap() + }; + + let render_target_view = { + let desc = D3D11_RENDER_TARGET_VIEW_DESC { + Format: DXGI_FORMAT_B8G8R8A8_UNORM, + ViewDimension: D3D11_RTV_DIMENSION_TEXTURE2D, + Anonymous: D3D11_RENDER_TARGET_VIEW_DESC_0 { + Texture2D: D3D11_TEX2D_RTV { MipSlice: 0 }, + }, + }; + let mut rtv = None; + unsafe { + gpu_state.device.CreateRenderTargetView( + &render_target_texture, + Some(&desc), + Some(&mut rtv), + ) + }?; + [rtv] + }; + + let staging_texture = { + let mut texture = None; + let desc = D3D11_TEXTURE2D_DESC { + Width: bitmap_size.width.0 as u32, + Height: bitmap_size.height.0 as u32, + MipLevels: 1, + ArraySize: 1, + Format: DXGI_FORMAT_B8G8R8A8_UNORM, + SampleDesc: DXGI_SAMPLE_DESC { + Count: 1, + Quality: 0, + }, + Usage: D3D11_USAGE_STAGING, + BindFlags: 0, + CPUAccessFlags: D3D11_CPU_ACCESS_READ.0 as u32, + MiscFlags: 0, + }; + unsafe { + gpu_state + .device + .CreateTexture2D(&desc, None, Some(&mut texture)) + }?; + texture.unwrap() + }; + + let device_context = &gpu_state.device_context; + unsafe { device_context.IASetPrimitiveTopology(D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP) }; + unsafe { device_context.VSSetShader(&gpu_state.vertex_shader, None) }; + unsafe { device_context.PSSetShader(&gpu_state.pixel_shader, None) }; + unsafe { device_context.VSSetConstantBuffers(0, Some(¶ms_buffer)) }; + unsafe { device_context.PSSetConstantBuffers(0, Some(¶ms_buffer)) }; + unsafe { device_context.OMSetRenderTargets(Some(&render_target_view), None) }; + unsafe { device_context.PSSetSamplers(0, Some(&gpu_state.sampler)) }; + unsafe { device_context.OMSetBlendState(&gpu_state.blend_state, None, 0xffffffff) }; + + for layer in glyph_layers { + let params = GlyphLayerTextureParams { + run_color: layer.run_color, + bounds: layer.bounds, + }; + unsafe { + let mut dest = std::mem::zeroed(); + gpu_state.device_context.Map( + params_buffer[0].as_ref().unwrap(), + 0, + D3D11_MAP_WRITE_DISCARD, + 0, + Some(&mut dest), + )?; + std::ptr::copy_nonoverlapping(¶ms as *const _, dest.pData as *mut _, 1); + gpu_state + .device_context + .Unmap(params_buffer[0].as_ref().unwrap(), 0); + }; + + let texture = [Some(layer.texture_view)]; + unsafe { device_context.PSSetShaderResources(0, Some(&texture)) }; + + let viewport = [D3D11_VIEWPORT { + TopLeftX: layer.bounds.origin.x as f32, + TopLeftY: layer.bounds.origin.y as f32, + Width: layer.bounds.size.width as f32, + Height: layer.bounds.size.height as f32, + MinDepth: 0.0, + MaxDepth: 1.0, + }]; + unsafe { device_context.RSSetViewports(Some(&viewport)) }; + + unsafe { device_context.Draw(4, 0) }; + } + + unsafe { device_context.CopyResource(&staging_texture, &render_target_texture) }; + + let mapped_data = { + let mut mapped_data = D3D11_MAPPED_SUBRESOURCE::default(); + unsafe { + device_context.Map( + &staging_texture, + 0, + D3D11_MAP_READ, + 0, + Some(&mut mapped_data), + ) + }?; + mapped_data + }; + let mut rasterized = + vec![0u8; (bitmap_size.width.0 as u32 * bitmap_size.height.0 as u32 * 4) as usize]; + + for y in 0..bitmap_size.height.0 as usize { + let width = bitmap_size.width.0 as usize; + unsafe { + std::ptr::copy_nonoverlapping::<u8>( + (mapped_data.pData as *const u8).byte_add(mapped_data.RowPitch as usize * y), + rasterized + .as_mut_ptr() + .byte_add(width * y * std::mem::size_of::<u32>()), + width * std::mem::size_of::<u32>(), + ) + }; + } + + Ok(rasterized) } fn get_typographic_bounds(&self, font_id: FontId, glyph_id: GlyphId) -> Result<Bounds<f32>> { @@ -976,6 +1276,84 @@ impl Drop for DirectWriteState { } } +struct GlyphLayerTexture { + run_color: Rgba, + bounds: Bounds<i32>, + texture_view: ID3D11ShaderResourceView, + // holding on to the texture to not RAII drop it + _texture: ID3D11Texture2D, +} + +impl GlyphLayerTexture { + pub fn new( + gpu_state: &GPUState, + run_color: Rgba, + bounds: Bounds<i32>, + alpha_data: &[u8], + ) -> Result<Self> { + let texture_size = bounds.size; + + let desc = D3D11_TEXTURE2D_DESC { + Width: texture_size.width as u32, + Height: texture_size.height as u32, + MipLevels: 1, + ArraySize: 1, + Format: DXGI_FORMAT_R8G8B8A8_UNORM, + SampleDesc: DXGI_SAMPLE_DESC { + Count: 1, + Quality: 0, + }, + Usage: D3D11_USAGE_DEFAULT, + BindFlags: D3D11_BIND_SHADER_RESOURCE.0 as u32, + CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32, + MiscFlags: 0, + }; + + let texture = { + let mut texture: Option<ID3D11Texture2D> = None; + unsafe { + gpu_state + .device + .CreateTexture2D(&desc, None, Some(&mut texture))? + }; + texture.unwrap() + }; + let texture_view = { + let mut view: Option<ID3D11ShaderResourceView> = None; + unsafe { + gpu_state + .device + .CreateShaderResourceView(&texture, None, Some(&mut view))? + }; + view.unwrap() + }; + + unsafe { + gpu_state.device_context.UpdateSubresource( + &texture, + 0, + None, + alpha_data.as_ptr() as _, + (texture_size.width * 4) as u32, + 0, + ) + }; + + Ok(GlyphLayerTexture { + run_color, + bounds, + texture_view, + _texture: texture, + }) + } +} + +#[repr(C)] +struct GlyphLayerTextureParams { + bounds: Bounds<i32>, + run_color: Rgba, +} + struct TextRendererWrapper(pub IDWriteTextRenderer); impl TextRendererWrapper { @@ -1470,16 +1848,6 @@ fn get_name(string: IDWriteLocalizedStrings, locale: &str) -> Result<String> { Ok(String::from_utf16_lossy(&name_vec[..name_length])) } -#[inline] -fn translate_color(color: &DWRITE_COLOR_F) -> D2D1_COLOR_F { - D2D1_COLOR_F { - r: color.r, - g: color.g, - b: color.b, - a: color.a, - } -} - fn get_system_ui_font_name() -> SharedString { unsafe { let mut info: LOGFONTW = std::mem::zeroed(); @@ -1504,24 +1872,6 @@ fn get_system_ui_font_name() -> SharedString { } } -#[inline] -fn get_render_target_property( - pixel_format: DXGI_FORMAT, - alpha_mode: D2D1_ALPHA_MODE, -) -> D2D1_RENDER_TARGET_PROPERTIES { - D2D1_RENDER_TARGET_PROPERTIES { - r#type: D2D1_RENDER_TARGET_TYPE_DEFAULT, - pixelFormat: D2D1_PIXEL_FORMAT { - format: pixel_format, - alphaMode: alpha_mode, - }, - dpiX: 96.0, - dpiY: 96.0, - usage: D2D1_RENDER_TARGET_USAGE_NONE, - minLevel: D2D1_FEATURE_LEVEL_DEFAULT, - } -} - // One would think that with newer DirectWrite method: IDWriteFontFace4::GetGlyphImageFormats // but that doesn't seem to work for some glyphs, say ❤ fn is_color_glyph( @@ -1561,12 +1911,6 @@ fn is_color_glyph( } const DEFAULT_LOCALE_NAME: PCWSTR = windows::core::w!("en-US"); -const BRUSH_COLOR: D2D1_COLOR_F = D2D1_COLOR_F { - r: 1.0, - g: 1.0, - b: 1.0, - a: 1.0, -}; #[cfg(test)] mod tests { diff --git a/crates/gpui/src/platform/windows/directx_atlas.rs b/crates/gpui/src/platform/windows/directx_atlas.rs new file mode 100644 index 0000000000000000000000000000000000000000..6bced4c11d922ed2c514b9a70fe7e582d7b15a6b --- /dev/null +++ b/crates/gpui/src/platform/windows/directx_atlas.rs @@ -0,0 +1,309 @@ +use collections::FxHashMap; +use etagere::BucketedAtlasAllocator; +use parking_lot::Mutex; +use windows::Win32::Graphics::{ + Direct3D11::{ + D3D11_BIND_SHADER_RESOURCE, D3D11_BOX, D3D11_CPU_ACCESS_WRITE, D3D11_TEXTURE2D_DESC, + D3D11_USAGE_DEFAULT, ID3D11Device, ID3D11DeviceContext, ID3D11ShaderResourceView, + ID3D11Texture2D, + }, + Dxgi::Common::*, +}; + +use crate::{ + AtlasKey, AtlasTextureId, AtlasTextureKind, AtlasTile, Bounds, DevicePixels, PlatformAtlas, + Point, Size, platform::AtlasTextureList, +}; + +pub(crate) struct DirectXAtlas(Mutex<DirectXAtlasState>); + +struct DirectXAtlasState { + device: ID3D11Device, + device_context: ID3D11DeviceContext, + monochrome_textures: AtlasTextureList<DirectXAtlasTexture>, + polychrome_textures: AtlasTextureList<DirectXAtlasTexture>, + tiles_by_key: FxHashMap<AtlasKey, AtlasTile>, +} + +struct DirectXAtlasTexture { + id: AtlasTextureId, + bytes_per_pixel: u32, + allocator: BucketedAtlasAllocator, + texture: ID3D11Texture2D, + view: [Option<ID3D11ShaderResourceView>; 1], + live_atlas_keys: u32, +} + +impl DirectXAtlas { + pub(crate) fn new(device: &ID3D11Device, device_context: &ID3D11DeviceContext) -> Self { + DirectXAtlas(Mutex::new(DirectXAtlasState { + device: device.clone(), + device_context: device_context.clone(), + monochrome_textures: Default::default(), + polychrome_textures: Default::default(), + tiles_by_key: Default::default(), + })) + } + + pub(crate) fn get_texture_view( + &self, + id: AtlasTextureId, + ) -> [Option<ID3D11ShaderResourceView>; 1] { + let lock = self.0.lock(); + let tex = lock.texture(id); + tex.view.clone() + } + + pub(crate) fn handle_device_lost( + &self, + device: &ID3D11Device, + device_context: &ID3D11DeviceContext, + ) { + let mut lock = self.0.lock(); + lock.device = device.clone(); + lock.device_context = device_context.clone(); + lock.monochrome_textures = AtlasTextureList::default(); + lock.polychrome_textures = AtlasTextureList::default(); + lock.tiles_by_key.clear(); + } +} + +impl PlatformAtlas for DirectXAtlas { + fn get_or_insert_with<'a>( + &self, + key: &AtlasKey, + build: &mut dyn FnMut() -> anyhow::Result< + Option<(Size<DevicePixels>, std::borrow::Cow<'a, [u8]>)>, + >, + ) -> anyhow::Result<Option<AtlasTile>> { + let mut lock = self.0.lock(); + if let Some(tile) = lock.tiles_by_key.get(key) { + Ok(Some(tile.clone())) + } else { + let Some((size, bytes)) = build()? else { + return Ok(None); + }; + let tile = lock + .allocate(size, key.texture_kind()) + .ok_or_else(|| anyhow::anyhow!("failed to allocate"))?; + let texture = lock.texture(tile.texture_id); + texture.upload(&lock.device_context, tile.bounds, &bytes); + lock.tiles_by_key.insert(key.clone(), tile.clone()); + Ok(Some(tile)) + } + } + + fn remove(&self, key: &AtlasKey) { + let mut lock = self.0.lock(); + + let Some(id) = lock.tiles_by_key.remove(key).map(|tile| tile.texture_id) else { + return; + }; + + let textures = match id.kind { + AtlasTextureKind::Monochrome => &mut lock.monochrome_textures, + AtlasTextureKind::Polychrome => &mut lock.polychrome_textures, + }; + + let Some(texture_slot) = textures.textures.get_mut(id.index as usize) else { + return; + }; + + if let Some(mut texture) = texture_slot.take() { + texture.decrement_ref_count(); + if texture.is_unreferenced() { + textures.free_list.push(texture.id.index as usize); + lock.tiles_by_key.remove(key); + } else { + *texture_slot = Some(texture); + } + } + } +} + +impl DirectXAtlasState { + fn allocate( + &mut self, + size: Size<DevicePixels>, + texture_kind: AtlasTextureKind, + ) -> Option<AtlasTile> { + { + let textures = match texture_kind { + AtlasTextureKind::Monochrome => &mut self.monochrome_textures, + AtlasTextureKind::Polychrome => &mut self.polychrome_textures, + }; + + if let Some(tile) = textures + .iter_mut() + .rev() + .find_map(|texture| texture.allocate(size)) + { + return Some(tile); + } + } + + let texture = self.push_texture(size, texture_kind)?; + texture.allocate(size) + } + + fn push_texture( + &mut self, + min_size: Size<DevicePixels>, + kind: AtlasTextureKind, + ) -> Option<&mut DirectXAtlasTexture> { + const DEFAULT_ATLAS_SIZE: Size<DevicePixels> = Size { + width: DevicePixels(1024), + height: DevicePixels(1024), + }; + // Max texture size for DirectX. See: + // https://learn.microsoft.com/en-us/windows/win32/direct3d11/overviews-direct3d-11-resources-limits + const MAX_ATLAS_SIZE: Size<DevicePixels> = Size { + width: DevicePixels(16384), + height: DevicePixels(16384), + }; + let size = min_size.min(&MAX_ATLAS_SIZE).max(&DEFAULT_ATLAS_SIZE); + let pixel_format; + let bind_flag; + let bytes_per_pixel; + match kind { + AtlasTextureKind::Monochrome => { + pixel_format = DXGI_FORMAT_R8_UNORM; + bind_flag = D3D11_BIND_SHADER_RESOURCE; + bytes_per_pixel = 1; + } + AtlasTextureKind::Polychrome => { + pixel_format = DXGI_FORMAT_B8G8R8A8_UNORM; + bind_flag = D3D11_BIND_SHADER_RESOURCE; + bytes_per_pixel = 4; + } + } + let texture_desc = D3D11_TEXTURE2D_DESC { + Width: size.width.0 as u32, + Height: size.height.0 as u32, + MipLevels: 1, + ArraySize: 1, + Format: pixel_format, + SampleDesc: DXGI_SAMPLE_DESC { + Count: 1, + Quality: 0, + }, + Usage: D3D11_USAGE_DEFAULT, + BindFlags: bind_flag.0 as u32, + CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32, + MiscFlags: 0, + }; + let mut texture: Option<ID3D11Texture2D> = None; + unsafe { + // This only returns None if the device is lost, which we will recreate later. + // So it's ok to return None here. + self.device + .CreateTexture2D(&texture_desc, None, Some(&mut texture)) + .ok()?; + } + let texture = texture.unwrap(); + + let texture_list = match kind { + AtlasTextureKind::Monochrome => &mut self.monochrome_textures, + AtlasTextureKind::Polychrome => &mut self.polychrome_textures, + }; + let index = texture_list.free_list.pop(); + let view = unsafe { + let mut view = None; + self.device + .CreateShaderResourceView(&texture, None, Some(&mut view)) + .ok()?; + [view] + }; + let atlas_texture = DirectXAtlasTexture { + id: AtlasTextureId { + index: index.unwrap_or(texture_list.textures.len()) as u32, + kind, + }, + bytes_per_pixel, + allocator: etagere::BucketedAtlasAllocator::new(size.into()), + texture, + view, + live_atlas_keys: 0, + }; + if let Some(ix) = index { + texture_list.textures[ix] = Some(atlas_texture); + texture_list.textures.get_mut(ix).unwrap().as_mut() + } else { + texture_list.textures.push(Some(atlas_texture)); + texture_list.textures.last_mut().unwrap().as_mut() + } + } + + fn texture(&self, id: AtlasTextureId) -> &DirectXAtlasTexture { + let textures = match id.kind { + crate::AtlasTextureKind::Monochrome => &self.monochrome_textures, + crate::AtlasTextureKind::Polychrome => &self.polychrome_textures, + }; + textures[id.index as usize].as_ref().unwrap() + } +} + +impl DirectXAtlasTexture { + fn allocate(&mut self, size: Size<DevicePixels>) -> Option<AtlasTile> { + let allocation = self.allocator.allocate(size.into())?; + let tile = AtlasTile { + texture_id: self.id, + tile_id: allocation.id.into(), + bounds: Bounds { + origin: allocation.rectangle.min.into(), + size, + }, + padding: 0, + }; + self.live_atlas_keys += 1; + Some(tile) + } + + fn upload( + &self, + device_context: &ID3D11DeviceContext, + bounds: Bounds<DevicePixels>, + bytes: &[u8], + ) { + unsafe { + device_context.UpdateSubresource( + &self.texture, + 0, + Some(&D3D11_BOX { + left: bounds.left().0 as u32, + top: bounds.top().0 as u32, + front: 0, + right: bounds.right().0 as u32, + bottom: bounds.bottom().0 as u32, + back: 1, + }), + bytes.as_ptr() as _, + bounds.size.width.to_bytes(self.bytes_per_pixel as u8), + 0, + ); + } + } + + fn decrement_ref_count(&mut self) { + self.live_atlas_keys -= 1; + } + + fn is_unreferenced(&mut self) -> bool { + self.live_atlas_keys == 0 + } +} + +impl From<Size<DevicePixels>> for etagere::Size { + fn from(size: Size<DevicePixels>) -> Self { + etagere::Size::new(size.width.into(), size.height.into()) + } +} + +impl From<etagere::Point> for Point<DevicePixels> { + fn from(value: etagere::Point) -> Self { + Point { + x: DevicePixels::from(value.x), + y: DevicePixels::from(value.y), + } + } +} diff --git a/crates/gpui/src/platform/windows/directx_renderer.rs b/crates/gpui/src/platform/windows/directx_renderer.rs new file mode 100644 index 0000000000000000000000000000000000000000..ac285b79acee82571456943d2f660947687b56c3 --- /dev/null +++ b/crates/gpui/src/platform/windows/directx_renderer.rs @@ -0,0 +1,1807 @@ +use std::{mem::ManuallyDrop, sync::Arc}; + +use ::util::ResultExt; +use anyhow::{Context, Result}; +use windows::{ + Win32::{ + Foundation::{HMODULE, HWND}, + Graphics::{ + Direct3D::*, + Direct3D11::*, + DirectComposition::*, + Dxgi::{Common::*, *}, + }, + }, + core::Interface, +}; + +use crate::{ + platform::windows::directx_renderer::shader_resources::{ + RawShaderBytes, ShaderModule, ShaderTarget, + }, + *, +}; + +pub(crate) const DISABLE_DIRECT_COMPOSITION: &str = "GPUI_DISABLE_DIRECT_COMPOSITION"; +const RENDER_TARGET_FORMAT: DXGI_FORMAT = DXGI_FORMAT_B8G8R8A8_UNORM; +// This configuration is used for MSAA rendering on paths only, and it's guaranteed to be supported by DirectX 11. +const PATH_MULTISAMPLE_COUNT: u32 = 4; + +pub(crate) struct DirectXRenderer { + hwnd: HWND, + atlas: Arc<DirectXAtlas>, + devices: ManuallyDrop<DirectXDevices>, + resources: ManuallyDrop<DirectXResources>, + globals: DirectXGlobalElements, + pipelines: DirectXRenderPipelines, + direct_composition: Option<DirectComposition>, +} + +/// Direct3D objects +#[derive(Clone)] +pub(crate) struct DirectXDevices { + adapter: IDXGIAdapter1, + dxgi_factory: IDXGIFactory6, + pub(crate) device: ID3D11Device, + pub(crate) device_context: ID3D11DeviceContext, + dxgi_device: Option<IDXGIDevice>, +} + +struct DirectXResources { + // Direct3D rendering objects + swap_chain: IDXGISwapChain1, + render_target: ManuallyDrop<ID3D11Texture2D>, + render_target_view: [Option<ID3D11RenderTargetView>; 1], + + // Path intermediate textures (with MSAA) + path_intermediate_texture: ID3D11Texture2D, + path_intermediate_srv: [Option<ID3D11ShaderResourceView>; 1], + path_intermediate_msaa_texture: ID3D11Texture2D, + path_intermediate_msaa_view: [Option<ID3D11RenderTargetView>; 1], + + // Cached window size and viewport + width: u32, + height: u32, + viewport: [D3D11_VIEWPORT; 1], +} + +struct DirectXRenderPipelines { + shadow_pipeline: PipelineState<Shadow>, + quad_pipeline: PipelineState<Quad>, + path_rasterization_pipeline: PipelineState<PathRasterizationSprite>, + path_sprite_pipeline: PipelineState<PathSprite>, + underline_pipeline: PipelineState<Underline>, + mono_sprites: PipelineState<MonochromeSprite>, + poly_sprites: PipelineState<PolychromeSprite>, +} + +struct DirectXGlobalElements { + global_params_buffer: [Option<ID3D11Buffer>; 1], + sampler: [Option<ID3D11SamplerState>; 1], +} + +struct DirectComposition { + comp_device: IDCompositionDevice, + comp_target: IDCompositionTarget, + comp_visual: IDCompositionVisual, +} + +impl DirectXDevices { + pub(crate) fn new(disable_direct_composition: bool) -> Result<ManuallyDrop<Self>> { + let debug_layer_available = check_debug_layer_available(); + let dxgi_factory = + get_dxgi_factory(debug_layer_available).context("Creating DXGI factory")?; + let adapter = + get_adapter(&dxgi_factory, debug_layer_available).context("Getting DXGI adapter")?; + let (device, device_context) = { + let mut device: Option<ID3D11Device> = None; + let mut context: Option<ID3D11DeviceContext> = None; + let mut feature_level = D3D_FEATURE_LEVEL::default(); + get_device( + &adapter, + Some(&mut device), + Some(&mut context), + Some(&mut feature_level), + debug_layer_available, + ) + .context("Creating Direct3D device")?; + match feature_level { + D3D_FEATURE_LEVEL_11_1 => { + log::info!("Created device with Direct3D 11.1 feature level.") + } + D3D_FEATURE_LEVEL_11_0 => { + log::info!("Created device with Direct3D 11.0 feature level.") + } + D3D_FEATURE_LEVEL_10_1 => { + log::info!("Created device with Direct3D 10.1 feature level.") + } + _ => unreachable!(), + } + (device.unwrap(), context.unwrap()) + }; + let dxgi_device = if disable_direct_composition { + None + } else { + Some(device.cast().context("Creating DXGI device")?) + }; + + Ok(ManuallyDrop::new(Self { + adapter, + dxgi_factory, + dxgi_device, + device, + device_context, + })) + } +} + +impl DirectXRenderer { + pub(crate) fn new(hwnd: HWND, disable_direct_composition: bool) -> Result<Self> { + if disable_direct_composition { + log::info!("Direct Composition is disabled."); + } + + let devices = + DirectXDevices::new(disable_direct_composition).context("Creating DirectX devices")?; + let atlas = Arc::new(DirectXAtlas::new(&devices.device, &devices.device_context)); + + let resources = DirectXResources::new(&devices, 1, 1, hwnd, disable_direct_composition) + .context("Creating DirectX resources")?; + let globals = DirectXGlobalElements::new(&devices.device) + .context("Creating DirectX global elements")?; + let pipelines = DirectXRenderPipelines::new(&devices.device) + .context("Creating DirectX render pipelines")?; + + let direct_composition = if disable_direct_composition { + None + } else { + let composition = DirectComposition::new(devices.dxgi_device.as_ref().unwrap(), hwnd) + .context("Creating DirectComposition")?; + composition + .set_swap_chain(&resources.swap_chain) + .context("Setting swap chain for DirectComposition")?; + Some(composition) + }; + + Ok(DirectXRenderer { + hwnd, + atlas, + devices, + resources, + globals, + pipelines, + direct_composition, + }) + } + + pub(crate) fn sprite_atlas(&self) -> Arc<dyn PlatformAtlas> { + self.atlas.clone() + } + + fn pre_draw(&self) -> Result<()> { + update_buffer( + &self.devices.device_context, + self.globals.global_params_buffer[0].as_ref().unwrap(), + &[GlobalParams { + viewport_size: [ + self.resources.viewport[0].Width, + self.resources.viewport[0].Height, + ], + _pad: 0, + }], + )?; + unsafe { + self.devices.device_context.ClearRenderTargetView( + self.resources.render_target_view[0].as_ref().unwrap(), + &[0.0; 4], + ); + self.devices + .device_context + .OMSetRenderTargets(Some(&self.resources.render_target_view), None); + self.devices + .device_context + .RSSetViewports(Some(&self.resources.viewport)); + } + Ok(()) + } + + fn present(&mut self) -> Result<()> { + unsafe { + let result = self.resources.swap_chain.Present(1, DXGI_PRESENT(0)); + // Presenting the swap chain can fail if the DirectX device was removed or reset. + if result == DXGI_ERROR_DEVICE_REMOVED || result == DXGI_ERROR_DEVICE_RESET { + let reason = self.devices.device.GetDeviceRemovedReason(); + log::error!( + "DirectX device removed or reset when drawing. Reason: {:?}", + reason + ); + self.handle_device_lost()?; + } else { + result.ok()?; + } + } + Ok(()) + } + + fn handle_device_lost(&mut self) -> Result<()> { + // Here we wait a bit to ensure the the system has time to recover from the device lost state. + // If we don't wait, the final drawing result will be blank. + std::thread::sleep(std::time::Duration::from_millis(300)); + let disable_direct_composition = self.direct_composition.is_none(); + + unsafe { + #[cfg(debug_assertions)] + report_live_objects(&self.devices.device) + .context("Failed to report live objects after device lost") + .log_err(); + + ManuallyDrop::drop(&mut self.resources); + self.devices.device_context.OMSetRenderTargets(None, None); + self.devices.device_context.ClearState(); + self.devices.device_context.Flush(); + + #[cfg(debug_assertions)] + report_live_objects(&self.devices.device) + .context("Failed to report live objects after device lost") + .log_err(); + + drop(self.direct_composition.take()); + ManuallyDrop::drop(&mut self.devices); + } + + let devices = DirectXDevices::new(disable_direct_composition) + .context("Recreating DirectX devices")?; + let resources = DirectXResources::new( + &devices, + self.resources.width, + self.resources.height, + self.hwnd, + disable_direct_composition, + )?; + let globals = DirectXGlobalElements::new(&devices.device)?; + let pipelines = DirectXRenderPipelines::new(&devices.device)?; + + let direct_composition = if disable_direct_composition { + None + } else { + let composition = + DirectComposition::new(devices.dxgi_device.as_ref().unwrap(), self.hwnd)?; + composition.set_swap_chain(&resources.swap_chain)?; + Some(composition) + }; + + self.atlas + .handle_device_lost(&devices.device, &devices.device_context); + self.devices = devices; + self.resources = resources; + self.globals = globals; + self.pipelines = pipelines; + self.direct_composition = direct_composition; + + unsafe { + self.devices + .device_context + .OMSetRenderTargets(Some(&self.resources.render_target_view), None); + } + Ok(()) + } + + pub(crate) fn draw(&mut self, scene: &Scene) -> Result<()> { + self.pre_draw()?; + for batch in scene.batches() { + match batch { + PrimitiveBatch::Shadows(shadows) => self.draw_shadows(shadows), + PrimitiveBatch::Quads(quads) => self.draw_quads(quads), + PrimitiveBatch::Paths(paths) => { + self.draw_paths_to_intermediate(paths)?; + self.draw_paths_from_intermediate(paths) + } + PrimitiveBatch::Underlines(underlines) => self.draw_underlines(underlines), + PrimitiveBatch::MonochromeSprites { + texture_id, + sprites, + } => self.draw_monochrome_sprites(texture_id, sprites), + PrimitiveBatch::PolychromeSprites { + texture_id, + sprites, + } => self.draw_polychrome_sprites(texture_id, sprites), + PrimitiveBatch::Surfaces(surfaces) => self.draw_surfaces(surfaces), + }.context(format!("scene too large: {} paths, {} shadows, {} quads, {} underlines, {} mono, {} poly, {} surfaces", + scene.paths.len(), + scene.shadows.len(), + scene.quads.len(), + scene.underlines.len(), + scene.monochrome_sprites.len(), + scene.polychrome_sprites.len(), + scene.surfaces.len(),))?; + } + self.present() + } + + pub(crate) fn resize(&mut self, new_size: Size<DevicePixels>) -> Result<()> { + let width = new_size.width.0.max(1) as u32; + let height = new_size.height.0.max(1) as u32; + if self.resources.width == width && self.resources.height == height { + return Ok(()); + } + unsafe { + // Clear the render target before resizing + self.devices.device_context.OMSetRenderTargets(None, None); + ManuallyDrop::drop(&mut self.resources.render_target); + drop(self.resources.render_target_view[0].take().unwrap()); + + let result = self.resources.swap_chain.ResizeBuffers( + BUFFER_COUNT as u32, + width, + height, + RENDER_TARGET_FORMAT, + DXGI_SWAP_CHAIN_FLAG(0), + ); + // Resizing the swap chain requires a call to the underlying DXGI adapter, which can return the device removed error. + // The app might have moved to a monitor that's attached to a different graphics device. + // When a graphics device is removed or reset, the desktop resolution often changes, resulting in a window size change. + match result { + Ok(_) => {} + Err(e) => { + if e.code() == DXGI_ERROR_DEVICE_REMOVED || e.code() == DXGI_ERROR_DEVICE_RESET + { + let reason = self.devices.device.GetDeviceRemovedReason(); + log::error!( + "DirectX device removed or reset when resizing. Reason: {:?}", + reason + ); + self.resources.width = width; + self.resources.height = height; + self.handle_device_lost()?; + return Ok(()); + } else { + log::error!("Failed to resize swap chain: {:?}", e); + return Err(e.into()); + } + } + } + + self.resources + .recreate_resources(&self.devices, width, height)?; + self.devices + .device_context + .OMSetRenderTargets(Some(&self.resources.render_target_view), None); + } + Ok(()) + } + + fn draw_shadows(&mut self, shadows: &[Shadow]) -> Result<()> { + if shadows.is_empty() { + return Ok(()); + } + self.pipelines.shadow_pipeline.update_buffer( + &self.devices.device, + &self.devices.device_context, + shadows, + )?; + self.pipelines.shadow_pipeline.draw( + &self.devices.device_context, + &self.resources.viewport, + &self.globals.global_params_buffer, + D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP, + 4, + shadows.len() as u32, + ) + } + + fn draw_quads(&mut self, quads: &[Quad]) -> Result<()> { + if quads.is_empty() { + return Ok(()); + } + self.pipelines.quad_pipeline.update_buffer( + &self.devices.device, + &self.devices.device_context, + quads, + )?; + self.pipelines.quad_pipeline.draw( + &self.devices.device_context, + &self.resources.viewport, + &self.globals.global_params_buffer, + D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP, + 4, + quads.len() as u32, + ) + } + + fn draw_paths_to_intermediate(&mut self, paths: &[Path<ScaledPixels>]) -> Result<()> { + if paths.is_empty() { + return Ok(()); + } + + // Clear intermediate MSAA texture + unsafe { + self.devices.device_context.ClearRenderTargetView( + self.resources.path_intermediate_msaa_view[0] + .as_ref() + .unwrap(), + &[0.0; 4], + ); + // Set intermediate MSAA texture as render target + self.devices + .device_context + .OMSetRenderTargets(Some(&self.resources.path_intermediate_msaa_view), None); + } + + // Collect all vertices and sprites for a single draw call + let mut vertices = Vec::new(); + + for path in paths { + vertices.extend(path.vertices.iter().map(|v| PathRasterizationSprite { + xy_position: v.xy_position, + st_position: v.st_position, + color: path.color, + bounds: path.clipped_bounds(), + })); + } + + self.pipelines.path_rasterization_pipeline.update_buffer( + &self.devices.device, + &self.devices.device_context, + &vertices, + )?; + self.pipelines.path_rasterization_pipeline.draw( + &self.devices.device_context, + &self.resources.viewport, + &self.globals.global_params_buffer, + D3D_PRIMITIVE_TOPOLOGY_TRIANGLELIST, + vertices.len() as u32, + 1, + )?; + + // Resolve MSAA to non-MSAA intermediate texture + unsafe { + self.devices.device_context.ResolveSubresource( + &self.resources.path_intermediate_texture, + 0, + &self.resources.path_intermediate_msaa_texture, + 0, + RENDER_TARGET_FORMAT, + ); + // Restore main render target + self.devices + .device_context + .OMSetRenderTargets(Some(&self.resources.render_target_view), None); + } + + Ok(()) + } + + fn draw_paths_from_intermediate(&mut self, paths: &[Path<ScaledPixels>]) -> Result<()> { + let Some(first_path) = paths.first() else { + return Ok(()); + }; + + // When copying paths from the intermediate texture to the drawable, + // each pixel must only be copied once, in case of transparent paths. + // + // If all paths have the same draw order, then their bounds are all + // disjoint, so we can copy each path's bounds individually. If this + // batch combines different draw orders, we perform a single copy + // for a minimal spanning rect. + let sprites = if paths.last().unwrap().order == first_path.order { + paths + .iter() + .map(|path| PathSprite { + bounds: path.clipped_bounds(), + }) + .collect::<Vec<_>>() + } else { + let mut bounds = first_path.clipped_bounds(); + for path in paths.iter().skip(1) { + bounds = bounds.union(&path.clipped_bounds()); + } + vec![PathSprite { bounds }] + }; + + self.pipelines.path_sprite_pipeline.update_buffer( + &self.devices.device, + &self.devices.device_context, + &sprites, + )?; + + // Draw the sprites with the path texture + self.pipelines.path_sprite_pipeline.draw_with_texture( + &self.devices.device_context, + &self.resources.path_intermediate_srv, + &self.resources.viewport, + &self.globals.global_params_buffer, + &self.globals.sampler, + sprites.len() as u32, + ) + } + + fn draw_underlines(&mut self, underlines: &[Underline]) -> Result<()> { + if underlines.is_empty() { + return Ok(()); + } + self.pipelines.underline_pipeline.update_buffer( + &self.devices.device, + &self.devices.device_context, + underlines, + )?; + self.pipelines.underline_pipeline.draw( + &self.devices.device_context, + &self.resources.viewport, + &self.globals.global_params_buffer, + D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP, + 4, + underlines.len() as u32, + ) + } + + fn draw_monochrome_sprites( + &mut self, + texture_id: AtlasTextureId, + sprites: &[MonochromeSprite], + ) -> Result<()> { + if sprites.is_empty() { + return Ok(()); + } + self.pipelines.mono_sprites.update_buffer( + &self.devices.device, + &self.devices.device_context, + sprites, + )?; + let texture_view = self.atlas.get_texture_view(texture_id); + self.pipelines.mono_sprites.draw_with_texture( + &self.devices.device_context, + &texture_view, + &self.resources.viewport, + &self.globals.global_params_buffer, + &self.globals.sampler, + sprites.len() as u32, + ) + } + + fn draw_polychrome_sprites( + &mut self, + texture_id: AtlasTextureId, + sprites: &[PolychromeSprite], + ) -> Result<()> { + if sprites.is_empty() { + return Ok(()); + } + self.pipelines.poly_sprites.update_buffer( + &self.devices.device, + &self.devices.device_context, + sprites, + )?; + let texture_view = self.atlas.get_texture_view(texture_id); + self.pipelines.poly_sprites.draw_with_texture( + &self.devices.device_context, + &texture_view, + &self.resources.viewport, + &self.globals.global_params_buffer, + &self.globals.sampler, + sprites.len() as u32, + ) + } + + fn draw_surfaces(&mut self, surfaces: &[PaintSurface]) -> Result<()> { + if surfaces.is_empty() { + return Ok(()); + } + Ok(()) + } + + pub(crate) fn gpu_specs(&self) -> Result<GpuSpecs> { + let desc = unsafe { self.devices.adapter.GetDesc1() }?; + let is_software_emulated = (desc.Flags & DXGI_ADAPTER_FLAG_SOFTWARE.0 as u32) != 0; + let device_name = String::from_utf16_lossy(&desc.Description) + .trim_matches(char::from(0)) + .to_string(); + let driver_name = match desc.VendorId { + 0x10DE => "NVIDIA Corporation".to_string(), + 0x1002 => "AMD Corporation".to_string(), + 0x8086 => "Intel Corporation".to_string(), + id => format!("Unknown Vendor (ID: {:#X})", id), + }; + let driver_version = match desc.VendorId { + 0x10DE => nvidia::get_driver_version(), + 0x1002 => amd::get_driver_version(), + // For Intel and other vendors, we use the DXGI API to get the driver version. + _ => dxgi::get_driver_version(&self.devices.adapter), + } + .context("Failed to get gpu driver info") + .log_err() + .unwrap_or("Unknown Driver".to_string()); + Ok(GpuSpecs { + is_software_emulated, + device_name, + driver_name, + driver_info: driver_version, + }) + } +} + +impl DirectXResources { + pub fn new( + devices: &DirectXDevices, + width: u32, + height: u32, + hwnd: HWND, + disable_direct_composition: bool, + ) -> Result<ManuallyDrop<Self>> { + let swap_chain = if disable_direct_composition { + create_swap_chain(&devices.dxgi_factory, &devices.device, hwnd, width, height)? + } else { + create_swap_chain_for_composition( + &devices.dxgi_factory, + &devices.device, + width, + height, + )? + }; + + let ( + render_target, + render_target_view, + path_intermediate_texture, + path_intermediate_srv, + path_intermediate_msaa_texture, + path_intermediate_msaa_view, + viewport, + ) = create_resources(devices, &swap_chain, width, height)?; + set_rasterizer_state(&devices.device, &devices.device_context)?; + + Ok(ManuallyDrop::new(Self { + swap_chain, + render_target, + render_target_view, + path_intermediate_texture, + path_intermediate_msaa_texture, + path_intermediate_msaa_view, + path_intermediate_srv, + viewport, + width, + height, + })) + } + + #[inline] + fn recreate_resources( + &mut self, + devices: &DirectXDevices, + width: u32, + height: u32, + ) -> Result<()> { + let ( + render_target, + render_target_view, + path_intermediate_texture, + path_intermediate_srv, + path_intermediate_msaa_texture, + path_intermediate_msaa_view, + viewport, + ) = create_resources(devices, &self.swap_chain, width, height)?; + self.render_target = render_target; + self.render_target_view = render_target_view; + self.path_intermediate_texture = path_intermediate_texture; + self.path_intermediate_msaa_texture = path_intermediate_msaa_texture; + self.path_intermediate_msaa_view = path_intermediate_msaa_view; + self.path_intermediate_srv = path_intermediate_srv; + self.viewport = viewport; + self.width = width; + self.height = height; + Ok(()) + } +} + +impl DirectXRenderPipelines { + pub fn new(device: &ID3D11Device) -> Result<Self> { + let shadow_pipeline = PipelineState::new( + device, + "shadow_pipeline", + ShaderModule::Shadow, + 4, + create_blend_state(device)?, + )?; + let quad_pipeline = PipelineState::new( + device, + "quad_pipeline", + ShaderModule::Quad, + 64, + create_blend_state(device)?, + )?; + let path_rasterization_pipeline = PipelineState::new( + device, + "path_rasterization_pipeline", + ShaderModule::PathRasterization, + 32, + create_blend_state_for_path_rasterization(device)?, + )?; + let path_sprite_pipeline = PipelineState::new( + device, + "path_sprite_pipeline", + ShaderModule::PathSprite, + 4, + create_blend_state_for_path_sprite(device)?, + )?; + let underline_pipeline = PipelineState::new( + device, + "underline_pipeline", + ShaderModule::Underline, + 4, + create_blend_state(device)?, + )?; + let mono_sprites = PipelineState::new( + device, + "monochrome_sprite_pipeline", + ShaderModule::MonochromeSprite, + 512, + create_blend_state(device)?, + )?; + let poly_sprites = PipelineState::new( + device, + "polychrome_sprite_pipeline", + ShaderModule::PolychromeSprite, + 16, + create_blend_state(device)?, + )?; + + Ok(Self { + shadow_pipeline, + quad_pipeline, + path_rasterization_pipeline, + path_sprite_pipeline, + underline_pipeline, + mono_sprites, + poly_sprites, + }) + } +} + +impl DirectComposition { + pub fn new(dxgi_device: &IDXGIDevice, hwnd: HWND) -> Result<Self> { + let comp_device = get_comp_device(&dxgi_device)?; + let comp_target = unsafe { comp_device.CreateTargetForHwnd(hwnd, true) }?; + let comp_visual = unsafe { comp_device.CreateVisual() }?; + + Ok(Self { + comp_device, + comp_target, + comp_visual, + }) + } + + pub fn set_swap_chain(&self, swap_chain: &IDXGISwapChain1) -> Result<()> { + unsafe { + self.comp_visual.SetContent(swap_chain)?; + self.comp_target.SetRoot(&self.comp_visual)?; + self.comp_device.Commit()?; + } + Ok(()) + } +} + +impl DirectXGlobalElements { + pub fn new(device: &ID3D11Device) -> Result<Self> { + let global_params_buffer = unsafe { + let desc = D3D11_BUFFER_DESC { + ByteWidth: std::mem::size_of::<GlobalParams>() as u32, + Usage: D3D11_USAGE_DYNAMIC, + BindFlags: D3D11_BIND_CONSTANT_BUFFER.0 as u32, + CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32, + ..Default::default() + }; + let mut buffer = None; + device.CreateBuffer(&desc, None, Some(&mut buffer))?; + [buffer] + }; + + let sampler = unsafe { + let desc = D3D11_SAMPLER_DESC { + Filter: D3D11_FILTER_MIN_MAG_MIP_LINEAR, + AddressU: D3D11_TEXTURE_ADDRESS_WRAP, + AddressV: D3D11_TEXTURE_ADDRESS_WRAP, + AddressW: D3D11_TEXTURE_ADDRESS_WRAP, + MipLODBias: 0.0, + MaxAnisotropy: 1, + ComparisonFunc: D3D11_COMPARISON_ALWAYS, + BorderColor: [0.0; 4], + MinLOD: 0.0, + MaxLOD: D3D11_FLOAT32_MAX, + }; + let mut output = None; + device.CreateSamplerState(&desc, Some(&mut output))?; + [output] + }; + + Ok(Self { + global_params_buffer, + sampler, + }) + } +} + +#[derive(Debug, Default)] +#[repr(C)] +struct GlobalParams { + viewport_size: [f32; 2], + _pad: u64, +} + +struct PipelineState<T> { + label: &'static str, + vertex: ID3D11VertexShader, + fragment: ID3D11PixelShader, + buffer: ID3D11Buffer, + buffer_size: usize, + view: [Option<ID3D11ShaderResourceView>; 1], + blend_state: ID3D11BlendState, + _marker: std::marker::PhantomData<T>, +} + +impl<T> PipelineState<T> { + fn new( + device: &ID3D11Device, + label: &'static str, + shader_module: ShaderModule, + buffer_size: usize, + blend_state: ID3D11BlendState, + ) -> Result<Self> { + let vertex = { + let raw_shader = RawShaderBytes::new(shader_module, ShaderTarget::Vertex)?; + create_vertex_shader(device, raw_shader.as_bytes())? + }; + let fragment = { + let raw_shader = RawShaderBytes::new(shader_module, ShaderTarget::Fragment)?; + create_fragment_shader(device, raw_shader.as_bytes())? + }; + let buffer = create_buffer(device, std::mem::size_of::<T>(), buffer_size)?; + let view = create_buffer_view(device, &buffer)?; + + Ok(PipelineState { + label, + vertex, + fragment, + buffer, + buffer_size, + view, + blend_state, + _marker: std::marker::PhantomData, + }) + } + + fn update_buffer( + &mut self, + device: &ID3D11Device, + device_context: &ID3D11DeviceContext, + data: &[T], + ) -> Result<()> { + if self.buffer_size < data.len() { + let new_buffer_size = data.len().next_power_of_two(); + log::info!( + "Updating {} buffer size from {} to {}", + self.label, + self.buffer_size, + new_buffer_size + ); + let buffer = create_buffer(device, std::mem::size_of::<T>(), new_buffer_size)?; + let view = create_buffer_view(device, &buffer)?; + self.buffer = buffer; + self.view = view; + self.buffer_size = new_buffer_size; + } + update_buffer(device_context, &self.buffer, data) + } + + fn draw( + &self, + device_context: &ID3D11DeviceContext, + viewport: &[D3D11_VIEWPORT], + global_params: &[Option<ID3D11Buffer>], + topology: D3D_PRIMITIVE_TOPOLOGY, + vertex_count: u32, + instance_count: u32, + ) -> Result<()> { + set_pipeline_state( + device_context, + &self.view, + topology, + viewport, + &self.vertex, + &self.fragment, + global_params, + &self.blend_state, + ); + unsafe { + device_context.DrawInstanced(vertex_count, instance_count, 0, 0); + } + Ok(()) + } + + fn draw_with_texture( + &self, + device_context: &ID3D11DeviceContext, + texture: &[Option<ID3D11ShaderResourceView>], + viewport: &[D3D11_VIEWPORT], + global_params: &[Option<ID3D11Buffer>], + sampler: &[Option<ID3D11SamplerState>], + instance_count: u32, + ) -> Result<()> { + set_pipeline_state( + device_context, + &self.view, + D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP, + viewport, + &self.vertex, + &self.fragment, + global_params, + &self.blend_state, + ); + unsafe { + device_context.PSSetSamplers(0, Some(sampler)); + device_context.VSSetShaderResources(0, Some(texture)); + device_context.PSSetShaderResources(0, Some(texture)); + + device_context.DrawInstanced(4, instance_count, 0, 0); + } + Ok(()) + } +} + +#[derive(Clone, Copy)] +#[repr(C)] +struct PathRasterizationSprite { + xy_position: Point<ScaledPixels>, + st_position: Point<f32>, + color: Background, + bounds: Bounds<ScaledPixels>, +} + +#[derive(Clone, Copy)] +#[repr(C)] +struct PathSprite { + bounds: Bounds<ScaledPixels>, +} + +impl Drop for DirectXRenderer { + fn drop(&mut self) { + #[cfg(debug_assertions)] + report_live_objects(&self.devices.device).ok(); + unsafe { + ManuallyDrop::drop(&mut self.devices); + ManuallyDrop::drop(&mut self.resources); + } + } +} + +impl Drop for DirectXResources { + fn drop(&mut self) { + unsafe { + ManuallyDrop::drop(&mut self.render_target); + } + } +} + +#[inline] +fn check_debug_layer_available() -> bool { + #[cfg(debug_assertions)] + { + unsafe { DXGIGetDebugInterface1::<IDXGIInfoQueue>(0) } + .log_err() + .is_some() + } + #[cfg(not(debug_assertions))] + { + false + } +} + +#[inline] +fn get_dxgi_factory(debug_layer_available: bool) -> Result<IDXGIFactory6> { + let factory_flag = if debug_layer_available { + DXGI_CREATE_FACTORY_DEBUG + } else { + #[cfg(debug_assertions)] + log::warn!( + "Failed to get DXGI debug interface. DirectX debugging features will be disabled." + ); + DXGI_CREATE_FACTORY_FLAGS::default() + }; + unsafe { Ok(CreateDXGIFactory2(factory_flag)?) } +} + +fn get_adapter(dxgi_factory: &IDXGIFactory6, debug_layer_available: bool) -> Result<IDXGIAdapter1> { + for adapter_index in 0.. { + let adapter: IDXGIAdapter1 = unsafe { + dxgi_factory + .EnumAdapterByGpuPreference(adapter_index, DXGI_GPU_PREFERENCE_MINIMUM_POWER) + }?; + if let Ok(desc) = unsafe { adapter.GetDesc1() } { + let gpu_name = String::from_utf16_lossy(&desc.Description) + .trim_matches(char::from(0)) + .to_string(); + log::info!("Using GPU: {}", gpu_name); + } + // Check to see whether the adapter supports Direct3D 11, but don't + // create the actual device yet. + if get_device(&adapter, None, None, None, debug_layer_available) + .log_err() + .is_some() + { + return Ok(adapter); + } + } + + unreachable!() +} + +fn get_device( + adapter: &IDXGIAdapter1, + device: Option<*mut Option<ID3D11Device>>, + context: Option<*mut Option<ID3D11DeviceContext>>, + feature_level: Option<*mut D3D_FEATURE_LEVEL>, + debug_layer_available: bool, +) -> Result<()> { + let device_flags = if debug_layer_available { + D3D11_CREATE_DEVICE_BGRA_SUPPORT | D3D11_CREATE_DEVICE_DEBUG + } else { + D3D11_CREATE_DEVICE_BGRA_SUPPORT + }; + unsafe { + D3D11CreateDevice( + adapter, + D3D_DRIVER_TYPE_UNKNOWN, + HMODULE::default(), + device_flags, + // 4x MSAA is required for Direct3D Feature Level 10.1 or better + Some(&[ + D3D_FEATURE_LEVEL_11_1, + D3D_FEATURE_LEVEL_11_0, + D3D_FEATURE_LEVEL_10_1, + ]), + D3D11_SDK_VERSION, + device, + feature_level, + context, + )?; + } + Ok(()) +} + +#[inline] +fn get_comp_device(dxgi_device: &IDXGIDevice) -> Result<IDCompositionDevice> { + Ok(unsafe { DCompositionCreateDevice(dxgi_device)? }) +} + +fn create_swap_chain_for_composition( + dxgi_factory: &IDXGIFactory6, + device: &ID3D11Device, + width: u32, + height: u32, +) -> Result<IDXGISwapChain1> { + let desc = DXGI_SWAP_CHAIN_DESC1 { + Width: width, + Height: height, + Format: RENDER_TARGET_FORMAT, + Stereo: false.into(), + SampleDesc: DXGI_SAMPLE_DESC { + Count: 1, + Quality: 0, + }, + BufferUsage: DXGI_USAGE_RENDER_TARGET_OUTPUT, + BufferCount: BUFFER_COUNT as u32, + // Composition SwapChains only support the DXGI_SCALING_STRETCH Scaling. + Scaling: DXGI_SCALING_STRETCH, + SwapEffect: DXGI_SWAP_EFFECT_FLIP_SEQUENTIAL, + AlphaMode: DXGI_ALPHA_MODE_PREMULTIPLIED, + Flags: 0, + }; + Ok(unsafe { dxgi_factory.CreateSwapChainForComposition(device, &desc, None)? }) +} + +fn create_swap_chain( + dxgi_factory: &IDXGIFactory6, + device: &ID3D11Device, + hwnd: HWND, + width: u32, + height: u32, +) -> Result<IDXGISwapChain1> { + use windows::Win32::Graphics::Dxgi::DXGI_MWA_NO_ALT_ENTER; + + let desc = DXGI_SWAP_CHAIN_DESC1 { + Width: width, + Height: height, + Format: RENDER_TARGET_FORMAT, + Stereo: false.into(), + SampleDesc: DXGI_SAMPLE_DESC { + Count: 1, + Quality: 0, + }, + BufferUsage: DXGI_USAGE_RENDER_TARGET_OUTPUT, + BufferCount: BUFFER_COUNT as u32, + Scaling: DXGI_SCALING_NONE, + SwapEffect: DXGI_SWAP_EFFECT_FLIP_SEQUENTIAL, + AlphaMode: DXGI_ALPHA_MODE_IGNORE, + Flags: 0, + }; + let swap_chain = + unsafe { dxgi_factory.CreateSwapChainForHwnd(device, hwnd, &desc, None, None) }?; + unsafe { dxgi_factory.MakeWindowAssociation(hwnd, DXGI_MWA_NO_ALT_ENTER) }?; + Ok(swap_chain) +} + +#[inline] +fn create_resources( + devices: &DirectXDevices, + swap_chain: &IDXGISwapChain1, + width: u32, + height: u32, +) -> Result<( + ManuallyDrop<ID3D11Texture2D>, + [Option<ID3D11RenderTargetView>; 1], + ID3D11Texture2D, + [Option<ID3D11ShaderResourceView>; 1], + ID3D11Texture2D, + [Option<ID3D11RenderTargetView>; 1], + [D3D11_VIEWPORT; 1], +)> { + let (render_target, render_target_view) = + create_render_target_and_its_view(&swap_chain, &devices.device)?; + let (path_intermediate_texture, path_intermediate_srv) = + create_path_intermediate_texture(&devices.device, width, height)?; + let (path_intermediate_msaa_texture, path_intermediate_msaa_view) = + create_path_intermediate_msaa_texture_and_view(&devices.device, width, height)?; + let viewport = set_viewport(&devices.device_context, width as f32, height as f32); + Ok(( + render_target, + render_target_view, + path_intermediate_texture, + path_intermediate_srv, + path_intermediate_msaa_texture, + path_intermediate_msaa_view, + viewport, + )) +} + +#[inline] +fn create_render_target_and_its_view( + swap_chain: &IDXGISwapChain1, + device: &ID3D11Device, +) -> Result<( + ManuallyDrop<ID3D11Texture2D>, + [Option<ID3D11RenderTargetView>; 1], +)> { + let render_target: ID3D11Texture2D = unsafe { swap_chain.GetBuffer(0) }?; + let mut render_target_view = None; + unsafe { device.CreateRenderTargetView(&render_target, None, Some(&mut render_target_view))? }; + Ok(( + ManuallyDrop::new(render_target), + [Some(render_target_view.unwrap())], + )) +} + +#[inline] +fn create_path_intermediate_texture( + device: &ID3D11Device, + width: u32, + height: u32, +) -> Result<(ID3D11Texture2D, [Option<ID3D11ShaderResourceView>; 1])> { + let texture = unsafe { + let mut output = None; + let desc = D3D11_TEXTURE2D_DESC { + Width: width, + Height: height, + MipLevels: 1, + ArraySize: 1, + Format: RENDER_TARGET_FORMAT, + SampleDesc: DXGI_SAMPLE_DESC { + Count: 1, + Quality: 0, + }, + Usage: D3D11_USAGE_DEFAULT, + BindFlags: (D3D11_BIND_RENDER_TARGET.0 | D3D11_BIND_SHADER_RESOURCE.0) as u32, + CPUAccessFlags: 0, + MiscFlags: 0, + }; + device.CreateTexture2D(&desc, None, Some(&mut output))?; + output.unwrap() + }; + + let mut shader_resource_view = None; + unsafe { device.CreateShaderResourceView(&texture, None, Some(&mut shader_resource_view))? }; + + Ok((texture, [Some(shader_resource_view.unwrap())])) +} + +#[inline] +fn create_path_intermediate_msaa_texture_and_view( + device: &ID3D11Device, + width: u32, + height: u32, +) -> Result<(ID3D11Texture2D, [Option<ID3D11RenderTargetView>; 1])> { + let msaa_texture = unsafe { + let mut output = None; + let desc = D3D11_TEXTURE2D_DESC { + Width: width, + Height: height, + MipLevels: 1, + ArraySize: 1, + Format: RENDER_TARGET_FORMAT, + SampleDesc: DXGI_SAMPLE_DESC { + Count: PATH_MULTISAMPLE_COUNT, + Quality: D3D11_STANDARD_MULTISAMPLE_PATTERN.0 as u32, + }, + Usage: D3D11_USAGE_DEFAULT, + BindFlags: D3D11_BIND_RENDER_TARGET.0 as u32, + CPUAccessFlags: 0, + MiscFlags: 0, + }; + device.CreateTexture2D(&desc, None, Some(&mut output))?; + output.unwrap() + }; + let mut msaa_view = None; + unsafe { device.CreateRenderTargetView(&msaa_texture, None, Some(&mut msaa_view))? }; + Ok((msaa_texture, [Some(msaa_view.unwrap())])) +} + +#[inline] +fn set_viewport( + device_context: &ID3D11DeviceContext, + width: f32, + height: f32, +) -> [D3D11_VIEWPORT; 1] { + let viewport = [D3D11_VIEWPORT { + TopLeftX: 0.0, + TopLeftY: 0.0, + Width: width, + Height: height, + MinDepth: 0.0, + MaxDepth: 1.0, + }]; + unsafe { device_context.RSSetViewports(Some(&viewport)) }; + viewport +} + +#[inline] +fn set_rasterizer_state(device: &ID3D11Device, device_context: &ID3D11DeviceContext) -> Result<()> { + let desc = D3D11_RASTERIZER_DESC { + FillMode: D3D11_FILL_SOLID, + CullMode: D3D11_CULL_NONE, + FrontCounterClockwise: false.into(), + DepthBias: 0, + DepthBiasClamp: 0.0, + SlopeScaledDepthBias: 0.0, + DepthClipEnable: true.into(), + ScissorEnable: false.into(), + MultisampleEnable: true.into(), + AntialiasedLineEnable: false.into(), + }; + let rasterizer_state = unsafe { + let mut state = None; + device.CreateRasterizerState(&desc, Some(&mut state))?; + state.unwrap() + }; + unsafe { device_context.RSSetState(&rasterizer_state) }; + Ok(()) +} + +// https://learn.microsoft.com/en-us/windows/win32/api/d3d11/ns-d3d11-d3d11_blend_desc +#[inline] +fn create_blend_state(device: &ID3D11Device) -> Result<ID3D11BlendState> { + // If the feature level is set to greater than D3D_FEATURE_LEVEL_9_3, the display + // device performs the blend in linear space, which is ideal. + let mut desc = D3D11_BLEND_DESC::default(); + desc.RenderTarget[0].BlendEnable = true.into(); + desc.RenderTarget[0].BlendOp = D3D11_BLEND_OP_ADD; + desc.RenderTarget[0].BlendOpAlpha = D3D11_BLEND_OP_ADD; + desc.RenderTarget[0].SrcBlend = D3D11_BLEND_SRC_ALPHA; + desc.RenderTarget[0].SrcBlendAlpha = D3D11_BLEND_ONE; + desc.RenderTarget[0].DestBlend = D3D11_BLEND_INV_SRC_ALPHA; + desc.RenderTarget[0].DestBlendAlpha = D3D11_BLEND_ONE; + desc.RenderTarget[0].RenderTargetWriteMask = D3D11_COLOR_WRITE_ENABLE_ALL.0 as u8; + unsafe { + let mut state = None; + device.CreateBlendState(&desc, Some(&mut state))?; + Ok(state.unwrap()) + } +} + +#[inline] +fn create_blend_state_for_path_rasterization(device: &ID3D11Device) -> Result<ID3D11BlendState> { + // If the feature level is set to greater than D3D_FEATURE_LEVEL_9_3, the display + // device performs the blend in linear space, which is ideal. + let mut desc = D3D11_BLEND_DESC::default(); + desc.RenderTarget[0].BlendEnable = true.into(); + desc.RenderTarget[0].BlendOp = D3D11_BLEND_OP_ADD; + desc.RenderTarget[0].BlendOpAlpha = D3D11_BLEND_OP_ADD; + desc.RenderTarget[0].SrcBlend = D3D11_BLEND_ONE; + desc.RenderTarget[0].SrcBlendAlpha = D3D11_BLEND_ONE; + desc.RenderTarget[0].DestBlend = D3D11_BLEND_INV_SRC_ALPHA; + desc.RenderTarget[0].DestBlendAlpha = D3D11_BLEND_INV_SRC_ALPHA; + desc.RenderTarget[0].RenderTargetWriteMask = D3D11_COLOR_WRITE_ENABLE_ALL.0 as u8; + unsafe { + let mut state = None; + device.CreateBlendState(&desc, Some(&mut state))?; + Ok(state.unwrap()) + } +} + +#[inline] +fn create_blend_state_for_path_sprite(device: &ID3D11Device) -> Result<ID3D11BlendState> { + // If the feature level is set to greater than D3D_FEATURE_LEVEL_9_3, the display + // device performs the blend in linear space, which is ideal. + let mut desc = D3D11_BLEND_DESC::default(); + desc.RenderTarget[0].BlendEnable = true.into(); + desc.RenderTarget[0].BlendOp = D3D11_BLEND_OP_ADD; + desc.RenderTarget[0].BlendOpAlpha = D3D11_BLEND_OP_ADD; + desc.RenderTarget[0].SrcBlend = D3D11_BLEND_ONE; + desc.RenderTarget[0].SrcBlendAlpha = D3D11_BLEND_ONE; + desc.RenderTarget[0].DestBlend = D3D11_BLEND_INV_SRC_ALPHA; + desc.RenderTarget[0].DestBlendAlpha = D3D11_BLEND_ONE; + desc.RenderTarget[0].RenderTargetWriteMask = D3D11_COLOR_WRITE_ENABLE_ALL.0 as u8; + unsafe { + let mut state = None; + device.CreateBlendState(&desc, Some(&mut state))?; + Ok(state.unwrap()) + } +} + +#[inline] +fn create_vertex_shader(device: &ID3D11Device, bytes: &[u8]) -> Result<ID3D11VertexShader> { + unsafe { + let mut shader = None; + device.CreateVertexShader(bytes, None, Some(&mut shader))?; + Ok(shader.unwrap()) + } +} + +#[inline] +fn create_fragment_shader(device: &ID3D11Device, bytes: &[u8]) -> Result<ID3D11PixelShader> { + unsafe { + let mut shader = None; + device.CreatePixelShader(bytes, None, Some(&mut shader))?; + Ok(shader.unwrap()) + } +} + +#[inline] +fn create_buffer( + device: &ID3D11Device, + element_size: usize, + buffer_size: usize, +) -> Result<ID3D11Buffer> { + let desc = D3D11_BUFFER_DESC { + ByteWidth: (element_size * buffer_size) as u32, + Usage: D3D11_USAGE_DYNAMIC, + BindFlags: D3D11_BIND_SHADER_RESOURCE.0 as u32, + CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32, + MiscFlags: D3D11_RESOURCE_MISC_BUFFER_STRUCTURED.0 as u32, + StructureByteStride: element_size as u32, + }; + let mut buffer = None; + unsafe { device.CreateBuffer(&desc, None, Some(&mut buffer)) }?; + Ok(buffer.unwrap()) +} + +#[inline] +fn create_buffer_view( + device: &ID3D11Device, + buffer: &ID3D11Buffer, +) -> Result<[Option<ID3D11ShaderResourceView>; 1]> { + let mut view = None; + unsafe { device.CreateShaderResourceView(buffer, None, Some(&mut view)) }?; + Ok([view]) +} + +#[inline] +fn update_buffer<T>( + device_context: &ID3D11DeviceContext, + buffer: &ID3D11Buffer, + data: &[T], +) -> Result<()> { + unsafe { + let mut dest = std::mem::zeroed(); + device_context.Map(buffer, 0, D3D11_MAP_WRITE_DISCARD, 0, Some(&mut dest))?; + std::ptr::copy_nonoverlapping(data.as_ptr(), dest.pData as _, data.len()); + device_context.Unmap(buffer, 0); + } + Ok(()) +} + +#[inline] +fn set_pipeline_state( + device_context: &ID3D11DeviceContext, + buffer_view: &[Option<ID3D11ShaderResourceView>], + topology: D3D_PRIMITIVE_TOPOLOGY, + viewport: &[D3D11_VIEWPORT], + vertex_shader: &ID3D11VertexShader, + fragment_shader: &ID3D11PixelShader, + global_params: &[Option<ID3D11Buffer>], + blend_state: &ID3D11BlendState, +) { + unsafe { + device_context.VSSetShaderResources(1, Some(buffer_view)); + device_context.PSSetShaderResources(1, Some(buffer_view)); + device_context.IASetPrimitiveTopology(topology); + device_context.RSSetViewports(Some(viewport)); + device_context.VSSetShader(vertex_shader, None); + device_context.PSSetShader(fragment_shader, None); + device_context.VSSetConstantBuffers(0, Some(global_params)); + device_context.PSSetConstantBuffers(0, Some(global_params)); + device_context.OMSetBlendState(blend_state, None, 0xFFFFFFFF); + } +} + +#[cfg(debug_assertions)] +fn report_live_objects(device: &ID3D11Device) -> Result<()> { + let debug_device: ID3D11Debug = device.cast()?; + unsafe { + debug_device.ReportLiveDeviceObjects(D3D11_RLDO_DETAIL)?; + } + Ok(()) +} + +const BUFFER_COUNT: usize = 3; + +pub(crate) mod shader_resources { + use anyhow::Result; + + #[cfg(debug_assertions)] + use windows::{ + Win32::Graphics::Direct3D::{ + Fxc::{D3DCOMPILE_DEBUG, D3DCOMPILE_SKIP_OPTIMIZATION, D3DCompileFromFile}, + ID3DBlob, + }, + core::{HSTRING, PCSTR}, + }; + + #[derive(Copy, Clone, Debug, Eq, PartialEq)] + pub(crate) enum ShaderModule { + Quad, + Shadow, + Underline, + PathRasterization, + PathSprite, + MonochromeSprite, + PolychromeSprite, + EmojiRasterization, + } + + #[derive(Copy, Clone, Debug, Eq, PartialEq)] + pub(crate) enum ShaderTarget { + Vertex, + Fragment, + } + + pub(crate) struct RawShaderBytes<'t> { + inner: &'t [u8], + + #[cfg(debug_assertions)] + _blob: ID3DBlob, + } + + impl<'t> RawShaderBytes<'t> { + pub(crate) fn new(module: ShaderModule, target: ShaderTarget) -> Result<Self> { + #[cfg(not(debug_assertions))] + { + Ok(Self::from_bytes(module, target)) + } + #[cfg(debug_assertions)] + { + let blob = build_shader_blob(module, target)?; + let inner = unsafe { + std::slice::from_raw_parts( + blob.GetBufferPointer() as *const u8, + blob.GetBufferSize(), + ) + }; + Ok(Self { inner, _blob: blob }) + } + } + + pub(crate) fn as_bytes(&'t self) -> &'t [u8] { + self.inner + } + + #[cfg(not(debug_assertions))] + fn from_bytes(module: ShaderModule, target: ShaderTarget) -> Self { + let bytes = match module { + ShaderModule::Quad => match target { + ShaderTarget::Vertex => QUAD_VERTEX_BYTES, + ShaderTarget::Fragment => QUAD_FRAGMENT_BYTES, + }, + ShaderModule::Shadow => match target { + ShaderTarget::Vertex => SHADOW_VERTEX_BYTES, + ShaderTarget::Fragment => SHADOW_FRAGMENT_BYTES, + }, + ShaderModule::Underline => match target { + ShaderTarget::Vertex => UNDERLINE_VERTEX_BYTES, + ShaderTarget::Fragment => UNDERLINE_FRAGMENT_BYTES, + }, + ShaderModule::PathRasterization => match target { + ShaderTarget::Vertex => PATH_RASTERIZATION_VERTEX_BYTES, + ShaderTarget::Fragment => PATH_RASTERIZATION_FRAGMENT_BYTES, + }, + ShaderModule::PathSprite => match target { + ShaderTarget::Vertex => PATH_SPRITE_VERTEX_BYTES, + ShaderTarget::Fragment => PATH_SPRITE_FRAGMENT_BYTES, + }, + ShaderModule::MonochromeSprite => match target { + ShaderTarget::Vertex => MONOCHROME_SPRITE_VERTEX_BYTES, + ShaderTarget::Fragment => MONOCHROME_SPRITE_FRAGMENT_BYTES, + }, + ShaderModule::PolychromeSprite => match target { + ShaderTarget::Vertex => POLYCHROME_SPRITE_VERTEX_BYTES, + ShaderTarget::Fragment => POLYCHROME_SPRITE_FRAGMENT_BYTES, + }, + ShaderModule::EmojiRasterization => match target { + ShaderTarget::Vertex => EMOJI_RASTERIZATION_VERTEX_BYTES, + ShaderTarget::Fragment => EMOJI_RASTERIZATION_FRAGMENT_BYTES, + }, + }; + Self { inner: bytes } + } + } + + #[cfg(debug_assertions)] + pub(super) fn build_shader_blob(entry: ShaderModule, target: ShaderTarget) -> Result<ID3DBlob> { + unsafe { + let shader_name = if matches!(entry, ShaderModule::EmojiRasterization) { + "color_text_raster.hlsl" + } else { + "shaders.hlsl" + }; + + let entry = format!( + "{}_{}\0", + entry.as_str(), + match target { + ShaderTarget::Vertex => "vertex", + ShaderTarget::Fragment => "fragment", + } + ); + let target = match target { + ShaderTarget::Vertex => "vs_4_1\0", + ShaderTarget::Fragment => "ps_4_1\0", + }; + + let mut compile_blob = None; + let mut error_blob = None; + let shader_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join(&format!("src/platform/windows/{}", shader_name)) + .canonicalize()?; + + let entry_point = PCSTR::from_raw(entry.as_ptr()); + let target_cstr = PCSTR::from_raw(target.as_ptr()); + + let ret = D3DCompileFromFile( + &HSTRING::from(shader_path.to_str().unwrap()), + None, + None, + entry_point, + target_cstr, + D3DCOMPILE_DEBUG | D3DCOMPILE_SKIP_OPTIMIZATION, + 0, + &mut compile_blob, + Some(&mut error_blob), + ); + if ret.is_err() { + let Some(error_blob) = error_blob else { + return Err(anyhow::anyhow!("{ret:?}")); + }; + + let error_string = + std::ffi::CStr::from_ptr(error_blob.GetBufferPointer() as *const i8) + .to_string_lossy(); + log::error!("Shader compile error: {}", error_string); + return Err(anyhow::anyhow!("Compile error: {}", error_string)); + } + Ok(compile_blob.unwrap()) + } + } + + #[cfg(not(debug_assertions))] + include!(concat!(env!("OUT_DIR"), "/shaders_bytes.rs")); + + #[cfg(debug_assertions)] + impl ShaderModule { + pub fn as_str(&self) -> &str { + match self { + ShaderModule::Quad => "quad", + ShaderModule::Shadow => "shadow", + ShaderModule::Underline => "underline", + ShaderModule::PathRasterization => "path_rasterization", + ShaderModule::PathSprite => "path_sprite", + ShaderModule::MonochromeSprite => "monochrome_sprite", + ShaderModule::PolychromeSprite => "polychrome_sprite", + ShaderModule::EmojiRasterization => "emoji_rasterization", + } + } + } +} + +mod nvidia { + use std::{ + ffi::CStr, + os::raw::{c_char, c_int, c_uint}, + }; + + use anyhow::{Context, Result}; + use windows::{ + Win32::System::LibraryLoader::{GetProcAddress, LoadLibraryA}, + core::s, + }; + + // https://github.com/NVIDIA/nvapi/blob/7cb76fce2f52de818b3da497af646af1ec16ce27/nvapi_lite_common.h#L180 + const NVAPI_SHORT_STRING_MAX: usize = 64; + + // https://github.com/NVIDIA/nvapi/blob/7cb76fce2f52de818b3da497af646af1ec16ce27/nvapi_lite_common.h#L235 + #[allow(non_camel_case_types)] + type NvAPI_ShortString = [c_char; NVAPI_SHORT_STRING_MAX]; + + // https://github.com/NVIDIA/nvapi/blob/7cb76fce2f52de818b3da497af646af1ec16ce27/nvapi_lite_common.h#L447 + #[allow(non_camel_case_types)] + type NvAPI_SYS_GetDriverAndBranchVersion_t = unsafe extern "C" fn( + driver_version: *mut c_uint, + build_branch_string: *mut NvAPI_ShortString, + ) -> c_int; + + pub(super) fn get_driver_version() -> Result<String> { + unsafe { + // Try to load the NVIDIA driver DLL + #[cfg(target_pointer_width = "64")] + let nvidia_dll = LoadLibraryA(s!("nvapi64.dll")).context("Can't load nvapi64.dll")?; + #[cfg(target_pointer_width = "32")] + let nvidia_dll = LoadLibraryA(s!("nvapi.dll")).context("Can't load nvapi.dll")?; + + let nvapi_query_addr = GetProcAddress(nvidia_dll, s!("nvapi_QueryInterface")) + .ok_or_else(|| anyhow::anyhow!("Failed to get nvapi_QueryInterface address"))?; + let nvapi_query: extern "C" fn(u32) -> *mut () = std::mem::transmute(nvapi_query_addr); + + // https://github.com/NVIDIA/nvapi/blob/7cb76fce2f52de818b3da497af646af1ec16ce27/nvapi_interface.h#L41 + let nvapi_get_driver_version_ptr = nvapi_query(0x2926aaad); + if nvapi_get_driver_version_ptr.is_null() { + anyhow::bail!("Failed to get NVIDIA driver version function pointer"); + } + let nvapi_get_driver_version: NvAPI_SYS_GetDriverAndBranchVersion_t = + std::mem::transmute(nvapi_get_driver_version_ptr); + + let mut driver_version: c_uint = 0; + let mut build_branch_string: NvAPI_ShortString = [0; NVAPI_SHORT_STRING_MAX]; + let result = nvapi_get_driver_version( + &mut driver_version as *mut c_uint, + &mut build_branch_string as *mut NvAPI_ShortString, + ); + + if result != 0 { + anyhow::bail!( + "Failed to get NVIDIA driver version, error code: {}", + result + ); + } + let major = driver_version / 100; + let minor = driver_version % 100; + let branch_string = CStr::from_ptr(build_branch_string.as_ptr()); + Ok(format!( + "{}.{} {}", + major, + minor, + branch_string.to_string_lossy() + )) + } + } +} + +mod amd { + use std::os::raw::{c_char, c_int, c_void}; + + use anyhow::{Context, Result}; + use windows::{ + Win32::System::LibraryLoader::{GetProcAddress, LoadLibraryA}, + core::s, + }; + + // https://github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/blob/5d8812d703d0335741b6f7ffc37838eeb8b967f7/ags_lib/inc/amd_ags.h#L145 + const AGS_CURRENT_VERSION: i32 = (6 << 22) | (3 << 12); + + // https://github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/blob/5d8812d703d0335741b6f7ffc37838eeb8b967f7/ags_lib/inc/amd_ags.h#L204 + // This is an opaque type, using struct to represent it properly for FFI + #[repr(C)] + struct AGSContext { + _private: [u8; 0], + } + + #[repr(C)] + pub struct AGSGPUInfo { + pub driver_version: *const c_char, + pub radeon_software_version: *const c_char, + pub num_devices: c_int, + pub devices: *mut c_void, + } + + // https://github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/blob/5d8812d703d0335741b6f7ffc37838eeb8b967f7/ags_lib/inc/amd_ags.h#L429 + #[allow(non_camel_case_types)] + type agsInitialize_t = unsafe extern "C" fn( + version: c_int, + config: *const c_void, + context: *mut *mut AGSContext, + gpu_info: *mut AGSGPUInfo, + ) -> c_int; + + // https://github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/blob/5d8812d703d0335741b6f7ffc37838eeb8b967f7/ags_lib/inc/amd_ags.h#L436 + #[allow(non_camel_case_types)] + type agsDeInitialize_t = unsafe extern "C" fn(context: *mut AGSContext) -> c_int; + + pub(super) fn get_driver_version() -> Result<String> { + unsafe { + #[cfg(target_pointer_width = "64")] + let amd_dll = + LoadLibraryA(s!("amd_ags_x64.dll")).context("Failed to load AMD AGS library")?; + #[cfg(target_pointer_width = "32")] + let amd_dll = + LoadLibraryA(s!("amd_ags_x86.dll")).context("Failed to load AMD AGS library")?; + + let ags_initialize_addr = GetProcAddress(amd_dll, s!("agsInitialize")) + .ok_or_else(|| anyhow::anyhow!("Failed to get agsInitialize address"))?; + let ags_deinitialize_addr = GetProcAddress(amd_dll, s!("agsDeInitialize")) + .ok_or_else(|| anyhow::anyhow!("Failed to get agsDeInitialize address"))?; + + let ags_initialize: agsInitialize_t = std::mem::transmute(ags_initialize_addr); + let ags_deinitialize: agsDeInitialize_t = std::mem::transmute(ags_deinitialize_addr); + + let mut context: *mut AGSContext = std::ptr::null_mut(); + let mut gpu_info: AGSGPUInfo = AGSGPUInfo { + driver_version: std::ptr::null(), + radeon_software_version: std::ptr::null(), + num_devices: 0, + devices: std::ptr::null_mut(), + }; + + let result = ags_initialize( + AGS_CURRENT_VERSION, + std::ptr::null(), + &mut context, + &mut gpu_info, + ); + if result != 0 { + anyhow::bail!("Failed to initialize AMD AGS, error code: {}", result); + } + + // Vulkan acctually returns this as the driver version + let software_version = if !gpu_info.radeon_software_version.is_null() { + std::ffi::CStr::from_ptr(gpu_info.radeon_software_version) + .to_string_lossy() + .into_owned() + } else { + "Unknown Radeon Software Version".to_string() + }; + + let driver_version = if !gpu_info.driver_version.is_null() { + std::ffi::CStr::from_ptr(gpu_info.driver_version) + .to_string_lossy() + .into_owned() + } else { + "Unknown Radeon Driver Version".to_string() + }; + + ags_deinitialize(context); + Ok(format!("{} ({})", software_version, driver_version)) + } + } +} + +mod dxgi { + use windows::{ + Win32::Graphics::Dxgi::{IDXGIAdapter1, IDXGIDevice}, + core::Interface, + }; + + pub(super) fn get_driver_version(adapter: &IDXGIAdapter1) -> anyhow::Result<String> { + let number = unsafe { adapter.CheckInterfaceSupport(&IDXGIDevice::IID as _) }?; + Ok(format!( + "{}.{}.{}.{}", + number >> 48, + (number >> 32) & 0xFFFF, + (number >> 16) & 0xFFFF, + number & 0xFFFF + )) + } +} diff --git a/crates/gpui/src/platform/windows/events.rs b/crates/gpui/src/platform/windows/events.rs index 839fd10375b04180b5699e0bd6ab5fa3441f8b2b..4ab257d27a69fc5fed458655150e1c09c3ebbba8 100644 --- a/crates/gpui/src/platform/windows/events.rs +++ b/crates/gpui/src/platform/windows/events.rs @@ -23,1027 +23,894 @@ pub(crate) const WM_GPUI_CURSOR_STYLE_CHANGED: u32 = WM_USER + 1; pub(crate) const WM_GPUI_CLOSE_ONE_WINDOW: u32 = WM_USER + 2; pub(crate) const WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD: u32 = WM_USER + 3; pub(crate) const WM_GPUI_DOCK_MENU_ACTION: u32 = WM_USER + 4; +pub(crate) const WM_GPUI_FORCE_UPDATE_WINDOW: u32 = WM_USER + 5; const SIZE_MOVE_LOOP_TIMER_ID: usize = 1; const AUTO_HIDE_TASKBAR_THICKNESS_PX: i32 = 1; -pub(crate) fn handle_msg( - handle: HWND, - msg: u32, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> LRESULT { - let handled = match msg { - WM_ACTIVATE => handle_activate_msg(wparam, state_ptr), - WM_CREATE => handle_create_msg(handle, state_ptr), - WM_MOVE => handle_move_msg(handle, lparam, state_ptr), - WM_SIZE => handle_size_msg(wparam, lparam, state_ptr), - WM_GETMINMAXINFO => handle_get_min_max_info_msg(lparam, state_ptr), - WM_ENTERSIZEMOVE | WM_ENTERMENULOOP => handle_size_move_loop(handle), - WM_EXITSIZEMOVE | WM_EXITMENULOOP => handle_size_move_loop_exit(handle), - WM_TIMER => handle_timer_msg(handle, wparam, state_ptr), - WM_NCCALCSIZE => handle_calc_client_size(handle, wparam, lparam, state_ptr), - WM_DPICHANGED => handle_dpi_changed_msg(handle, wparam, lparam, state_ptr), - WM_DISPLAYCHANGE => handle_display_change_msg(handle, state_ptr), - WM_NCHITTEST => handle_hit_test_msg(handle, msg, wparam, lparam, state_ptr), - WM_PAINT => handle_paint_msg(handle, state_ptr), - WM_CLOSE => handle_close_msg(handle, state_ptr), - WM_DESTROY => handle_destroy_msg(handle, state_ptr), - WM_MOUSEMOVE => handle_mouse_move_msg(handle, lparam, wparam, state_ptr), - WM_MOUSELEAVE | WM_NCMOUSELEAVE => handle_mouse_leave_msg(state_ptr), - WM_NCMOUSEMOVE => handle_nc_mouse_move_msg(handle, lparam, state_ptr), - WM_NCLBUTTONDOWN => { - handle_nc_mouse_down_msg(handle, MouseButton::Left, wparam, lparam, state_ptr) - } - WM_NCRBUTTONDOWN => { - handle_nc_mouse_down_msg(handle, MouseButton::Right, wparam, lparam, state_ptr) - } - WM_NCMBUTTONDOWN => { - handle_nc_mouse_down_msg(handle, MouseButton::Middle, wparam, lparam, state_ptr) - } - WM_NCLBUTTONUP => { - handle_nc_mouse_up_msg(handle, MouseButton::Left, wparam, lparam, state_ptr) - } - WM_NCRBUTTONUP => { - handle_nc_mouse_up_msg(handle, MouseButton::Right, wparam, lparam, state_ptr) - } - WM_NCMBUTTONUP => { - handle_nc_mouse_up_msg(handle, MouseButton::Middle, wparam, lparam, state_ptr) - } - WM_LBUTTONDOWN => handle_mouse_down_msg(handle, MouseButton::Left, lparam, state_ptr), - WM_RBUTTONDOWN => handle_mouse_down_msg(handle, MouseButton::Right, lparam, state_ptr), - WM_MBUTTONDOWN => handle_mouse_down_msg(handle, MouseButton::Middle, lparam, state_ptr), - WM_XBUTTONDOWN => { - handle_xbutton_msg(handle, wparam, lparam, handle_mouse_down_msg, state_ptr) +impl WindowsWindowInner { + pub(crate) fn handle_msg( + self: &Rc<Self>, + handle: HWND, + msg: u32, + wparam: WPARAM, + lparam: LPARAM, + ) -> LRESULT { + let handled = match msg { + WM_ACTIVATE => self.handle_activate_msg(wparam), + WM_CREATE => self.handle_create_msg(handle), + WM_DEVICECHANGE => self.handle_device_change_msg(handle, wparam), + WM_MOVE => self.handle_move_msg(handle, lparam), + WM_SIZE => self.handle_size_msg(wparam, lparam), + WM_GETMINMAXINFO => self.handle_get_min_max_info_msg(lparam), + WM_ENTERSIZEMOVE | WM_ENTERMENULOOP => self.handle_size_move_loop(handle), + WM_EXITSIZEMOVE | WM_EXITMENULOOP => self.handle_size_move_loop_exit(handle), + WM_TIMER => self.handle_timer_msg(handle, wparam), + WM_NCCALCSIZE => self.handle_calc_client_size(handle, wparam, lparam), + WM_DPICHANGED => self.handle_dpi_changed_msg(handle, wparam, lparam), + WM_DISPLAYCHANGE => self.handle_display_change_msg(handle), + WM_NCHITTEST => self.handle_hit_test_msg(handle, msg, wparam, lparam), + WM_PAINT => self.handle_paint_msg(handle), + WM_CLOSE => self.handle_close_msg(), + WM_DESTROY => self.handle_destroy_msg(handle), + WM_MOUSEMOVE => self.handle_mouse_move_msg(handle, lparam, wparam), + WM_MOUSELEAVE | WM_NCMOUSELEAVE => self.handle_mouse_leave_msg(), + WM_NCMOUSEMOVE => self.handle_nc_mouse_move_msg(handle, lparam), + WM_NCLBUTTONDOWN => { + self.handle_nc_mouse_down_msg(handle, MouseButton::Left, wparam, lparam) + } + WM_NCRBUTTONDOWN => { + self.handle_nc_mouse_down_msg(handle, MouseButton::Right, wparam, lparam) + } + WM_NCMBUTTONDOWN => { + self.handle_nc_mouse_down_msg(handle, MouseButton::Middle, wparam, lparam) + } + WM_NCLBUTTONUP => { + self.handle_nc_mouse_up_msg(handle, MouseButton::Left, wparam, lparam) + } + WM_NCRBUTTONUP => { + self.handle_nc_mouse_up_msg(handle, MouseButton::Right, wparam, lparam) + } + WM_NCMBUTTONUP => { + self.handle_nc_mouse_up_msg(handle, MouseButton::Middle, wparam, lparam) + } + WM_LBUTTONDOWN => self.handle_mouse_down_msg(handle, MouseButton::Left, lparam), + WM_RBUTTONDOWN => self.handle_mouse_down_msg(handle, MouseButton::Right, lparam), + WM_MBUTTONDOWN => self.handle_mouse_down_msg(handle, MouseButton::Middle, lparam), + WM_XBUTTONDOWN => { + self.handle_xbutton_msg(handle, wparam, lparam, Self::handle_mouse_down_msg) + } + WM_LBUTTONUP => self.handle_mouse_up_msg(handle, MouseButton::Left, lparam), + WM_RBUTTONUP => self.handle_mouse_up_msg(handle, MouseButton::Right, lparam), + WM_MBUTTONUP => self.handle_mouse_up_msg(handle, MouseButton::Middle, lparam), + WM_XBUTTONUP => { + self.handle_xbutton_msg(handle, wparam, lparam, Self::handle_mouse_up_msg) + } + WM_MOUSEWHEEL => self.handle_mouse_wheel_msg(handle, wparam, lparam), + WM_MOUSEHWHEEL => self.handle_mouse_horizontal_wheel_msg(handle, wparam, lparam), + WM_SYSKEYDOWN => self.handle_syskeydown_msg(handle, wparam, lparam), + WM_SYSKEYUP => self.handle_syskeyup_msg(handle, wparam, lparam), + WM_SYSCOMMAND => self.handle_system_command(wparam), + WM_KEYDOWN => self.handle_keydown_msg(handle, wparam, lparam), + WM_KEYUP => self.handle_keyup_msg(handle, wparam, lparam), + WM_CHAR => self.handle_char_msg(wparam), + WM_DEADCHAR => self.handle_dead_char_msg(wparam), + WM_IME_STARTCOMPOSITION => self.handle_ime_position(handle), + WM_IME_COMPOSITION => self.handle_ime_composition(handle, lparam), + WM_SETCURSOR => self.handle_set_cursor(handle, lparam), + WM_SETTINGCHANGE => self.handle_system_settings_changed(handle, wparam, lparam), + WM_INPUTLANGCHANGE => self.handle_input_language_changed(lparam), + WM_GPUI_CURSOR_STYLE_CHANGED => self.handle_cursor_changed(lparam), + WM_GPUI_FORCE_UPDATE_WINDOW => self.draw_window(handle, true), + _ => None, + }; + if let Some(n) = handled { + LRESULT(n) + } else { + unsafe { DefWindowProcW(handle, msg, wparam, lparam) } } - WM_LBUTTONUP => handle_mouse_up_msg(handle, MouseButton::Left, lparam, state_ptr), - WM_RBUTTONUP => handle_mouse_up_msg(handle, MouseButton::Right, lparam, state_ptr), - WM_MBUTTONUP => handle_mouse_up_msg(handle, MouseButton::Middle, lparam, state_ptr), - WM_XBUTTONUP => handle_xbutton_msg(handle, wparam, lparam, handle_mouse_up_msg, state_ptr), - WM_MOUSEWHEEL => handle_mouse_wheel_msg(handle, wparam, lparam, state_ptr), - WM_MOUSEHWHEEL => handle_mouse_horizontal_wheel_msg(handle, wparam, lparam, state_ptr), - WM_SYSKEYDOWN => handle_syskeydown_msg(handle, wparam, lparam, state_ptr), - WM_SYSKEYUP => handle_syskeyup_msg(handle, wparam, lparam, state_ptr), - WM_SYSCOMMAND => handle_system_command(wparam, state_ptr), - WM_KEYDOWN => handle_keydown_msg(handle, wparam, lparam, state_ptr), - WM_KEYUP => handle_keyup_msg(handle, wparam, lparam, state_ptr), - WM_CHAR => handle_char_msg(wparam, state_ptr), - WM_DEADCHAR => handle_dead_char_msg(wparam, state_ptr), - WM_IME_STARTCOMPOSITION => handle_ime_position(handle, state_ptr), - WM_IME_COMPOSITION => handle_ime_composition(handle, lparam, state_ptr), - WM_SETCURSOR => handle_set_cursor(handle, lparam, state_ptr), - WM_SETTINGCHANGE => handle_system_settings_changed(handle, wparam, lparam, state_ptr), - WM_INPUTLANGCHANGE => handle_input_language_changed(lparam, state_ptr), - WM_GPUI_CURSOR_STYLE_CHANGED => handle_cursor_changed(lparam, state_ptr), - _ => None, - }; - if let Some(n) = handled { - LRESULT(n) - } else { - unsafe { DefWindowProcW(handle, msg, wparam, lparam) } } -} -fn handle_move_msg( - handle: HWND, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - let origin = logical_point( - lparam.signed_loword() as f32, - lparam.signed_hiword() as f32, - lock.scale_factor, - ); - lock.origin = origin; - let size = lock.logical_size; - let center_x = origin.x.0 + size.width.0 / 2.; - let center_y = origin.y.0 + size.height.0 / 2.; - let monitor_bounds = lock.display.bounds(); - if center_x < monitor_bounds.left().0 - || center_x > monitor_bounds.right().0 - || center_y < monitor_bounds.top().0 - || center_y > monitor_bounds.bottom().0 - { - // center of the window may have moved to another monitor - let monitor = unsafe { MonitorFromWindow(handle, MONITOR_DEFAULTTONULL) }; - // minimize the window can trigger this event too, in this case, - // monitor is invalid, we do nothing. - if !monitor.is_invalid() && lock.display.handle != monitor { - // we will get the same monitor if we only have one - lock.display = WindowsDisplay::new_with_handle(monitor); + fn handle_move_msg(&self, handle: HWND, lparam: LPARAM) -> Option<isize> { + let mut lock = self.state.borrow_mut(); + let origin = logical_point( + lparam.signed_loword() as f32, + lparam.signed_hiword() as f32, + lock.scale_factor, + ); + lock.origin = origin; + let size = lock.logical_size; + let center_x = origin.x.0 + size.width.0 / 2.; + let center_y = origin.y.0 + size.height.0 / 2.; + let monitor_bounds = lock.display.bounds(); + if center_x < monitor_bounds.left().0 + || center_x > monitor_bounds.right().0 + || center_y < monitor_bounds.top().0 + || center_y > monitor_bounds.bottom().0 + { + // center of the window may have moved to another monitor + let monitor = unsafe { MonitorFromWindow(handle, MONITOR_DEFAULTTONULL) }; + // minimize the window can trigger this event too, in this case, + // monitor is invalid, we do nothing. + if !monitor.is_invalid() && lock.display.handle != monitor { + // we will get the same monitor if we only have one + lock.display = WindowsDisplay::new_with_handle(monitor); + } + } + if let Some(mut callback) = lock.callbacks.moved.take() { + drop(lock); + callback(); + self.state.borrow_mut().callbacks.moved = Some(callback); } + Some(0) } - if let Some(mut callback) = lock.callbacks.moved.take() { + + fn handle_get_min_max_info_msg(&self, lparam: LPARAM) -> Option<isize> { + let lock = self.state.borrow(); + let min_size = lock.min_size?; + let scale_factor = lock.scale_factor; + let boarder_offset = lock.border_offset; drop(lock); - callback(); - state_ptr.state.borrow_mut().callbacks.moved = Some(callback); + unsafe { + let minmax_info = &mut *(lparam.0 as *mut MINMAXINFO); + minmax_info.ptMinTrackSize.x = + min_size.width.scale(scale_factor).0 as i32 + boarder_offset.width_offset; + minmax_info.ptMinTrackSize.y = + min_size.height.scale(scale_factor).0 as i32 + boarder_offset.height_offset; + } + Some(0) } - Some(0) -} -fn handle_get_min_max_info_msg( - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let lock = state_ptr.state.borrow(); - let min_size = lock.min_size?; - let scale_factor = lock.scale_factor; - let boarder_offset = lock.border_offset; - drop(lock); - unsafe { - let minmax_info = &mut *(lparam.0 as *mut MINMAXINFO); - minmax_info.ptMinTrackSize.x = - min_size.width.scale(scale_factor).0 as i32 + boarder_offset.width_offset; - minmax_info.ptMinTrackSize.y = - min_size.height.scale(scale_factor).0 as i32 + boarder_offset.height_offset; - } - Some(0) -} + fn handle_size_msg(&self, wparam: WPARAM, lparam: LPARAM) -> Option<isize> { + let mut lock = self.state.borrow_mut(); -fn handle_size_msg( - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - - // Don't resize the renderer when the window is minimized, but record that it was minimized so - // that on restore the swap chain can be recreated via `update_drawable_size_even_if_unchanged`. - if wparam.0 == SIZE_MINIMIZED as usize { - lock.restore_from_minimized = lock.callbacks.request_frame.take(); - return Some(0); - } + // Don't resize the renderer when the window is minimized, but record that it was minimized so + // that on restore the swap chain can be recreated via `update_drawable_size_even_if_unchanged`. + if wparam.0 == SIZE_MINIMIZED as usize { + lock.restore_from_minimized = lock.callbacks.request_frame.take(); + return Some(0); + } - let width = lparam.loword().max(1) as i32; - let height = lparam.hiword().max(1) as i32; - let new_size = size(DevicePixels(width), DevicePixels(height)); - let scale_factor = lock.scale_factor; - if lock.restore_from_minimized.is_some() { - lock.renderer - .update_drawable_size_even_if_unchanged(new_size); - lock.callbacks.request_frame = lock.restore_from_minimized.take(); - } else { - lock.renderer.update_drawable_size(new_size); - } - let new_size = new_size.to_pixels(scale_factor); - lock.logical_size = new_size; - if let Some(mut callback) = lock.callbacks.resize.take() { + let width = lparam.loword().max(1) as i32; + let height = lparam.hiword().max(1) as i32; + let new_size = size(DevicePixels(width), DevicePixels(height)); + + let scale_factor = lock.scale_factor; + let mut should_resize_renderer = false; + if lock.restore_from_minimized.is_some() { + lock.callbacks.request_frame = lock.restore_from_minimized.take(); + } else { + should_resize_renderer = true; + } drop(lock); - callback(new_size, scale_factor); - state_ptr.state.borrow_mut().callbacks.resize = Some(callback); + + self.handle_size_change(new_size, scale_factor, should_resize_renderer); + Some(0) } - Some(0) -} -fn handle_size_move_loop(handle: HWND) -> Option<isize> { - unsafe { - let ret = SetTimer( - Some(handle), - SIZE_MOVE_LOOP_TIMER_ID, - USER_TIMER_MINIMUM, - None, - ); - if ret == 0 { - log::error!( - "unable to create timer: {}", - std::io::Error::last_os_error() - ); + fn handle_size_change( + &self, + device_size: Size<DevicePixels>, + scale_factor: f32, + should_resize_renderer: bool, + ) { + let new_logical_size = device_size.to_pixels(scale_factor); + let mut lock = self.state.borrow_mut(); + lock.logical_size = new_logical_size; + if should_resize_renderer { + lock.renderer.resize(device_size).log_err(); + } + if let Some(mut callback) = lock.callbacks.resize.take() { + drop(lock); + callback(new_logical_size, scale_factor); + self.state.borrow_mut().callbacks.resize = Some(callback); } } - None -} -fn handle_size_move_loop_exit(handle: HWND) -> Option<isize> { - unsafe { - KillTimer(Some(handle), SIZE_MOVE_LOOP_TIMER_ID).log_err(); + fn handle_size_move_loop(&self, handle: HWND) -> Option<isize> { + unsafe { + let ret = SetTimer( + Some(handle), + SIZE_MOVE_LOOP_TIMER_ID, + USER_TIMER_MINIMUM, + None, + ); + if ret == 0 { + log::error!( + "unable to create timer: {}", + std::io::Error::last_os_error() + ); + } + } + None } - None -} -fn handle_timer_msg( - handle: HWND, - wparam: WPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - if wparam.0 == SIZE_MOVE_LOOP_TIMER_ID { - for runnable in state_ptr.main_receiver.drain() { - runnable.run(); + fn handle_size_move_loop_exit(&self, handle: HWND) -> Option<isize> { + unsafe { + KillTimer(Some(handle), SIZE_MOVE_LOOP_TIMER_ID).log_err(); } - handle_paint_msg(handle, state_ptr) - } else { None } -} -fn handle_paint_msg(handle: HWND, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - if let Some(mut request_frame) = lock.callbacks.request_frame.take() { - drop(lock); - request_frame(Default::default()); - state_ptr.state.borrow_mut().callbacks.request_frame = Some(request_frame); + fn handle_timer_msg(&self, handle: HWND, wparam: WPARAM) -> Option<isize> { + if wparam.0 == SIZE_MOVE_LOOP_TIMER_ID { + for runnable in self.main_receiver.drain() { + runnable.run(); + } + self.handle_paint_msg(handle) + } else { + None + } } - unsafe { ValidateRect(Some(handle), None).ok().log_err() }; - Some(0) -} -fn handle_close_msg(handle: HWND, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - let output = if let Some(mut callback) = lock.callbacks.should_close.take() { - drop(lock); + fn handle_paint_msg(&self, handle: HWND) -> Option<isize> { + self.draw_window(handle, false) + } + + fn handle_close_msg(&self) -> Option<isize> { + let mut callback = self.state.borrow_mut().callbacks.should_close.take()?; let should_close = callback(); - state_ptr.state.borrow_mut().callbacks.should_close = Some(callback); + self.state.borrow_mut().callbacks.should_close = Some(callback); if should_close { None } else { Some(0) } - } else { - None - }; + } - // Workaround as window close animation is not played with `WS_EX_LAYERED` enabled. - if output.is_none() { + fn handle_destroy_msg(&self, handle: HWND) -> Option<isize> { + let callback = { + let mut lock = self.state.borrow_mut(); + lock.callbacks.close.take() + }; + if let Some(callback) = callback { + callback(); + } unsafe { - let current_style = get_window_long(handle, GWL_EXSTYLE); - set_window_long( - handle, - GWL_EXSTYLE, - current_style & !WS_EX_LAYERED.0 as isize, - ); + PostThreadMessageW( + self.main_thread_id_win32, + WM_GPUI_CLOSE_ONE_WINDOW, + WPARAM(self.validation_number), + LPARAM(handle.0 as isize), + ) + .log_err(); } + Some(0) } - output -} + fn handle_mouse_move_msg(&self, handle: HWND, lparam: LPARAM, wparam: WPARAM) -> Option<isize> { + self.start_tracking_mouse(handle, TME_LEAVE); -fn handle_destroy_msg(handle: HWND, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - let callback = { - let mut lock = state_ptr.state.borrow_mut(); - lock.callbacks.close.take() - }; - if let Some(callback) = callback { - callback(); - } - unsafe { - PostThreadMessageW( - state_ptr.main_thread_id_win32, - WM_GPUI_CLOSE_ONE_WINDOW, - WPARAM(state_ptr.validation_number), - LPARAM(handle.0 as isize), - ) - .log_err(); - } - Some(0) -} + let mut lock = self.state.borrow_mut(); + let Some(mut func) = lock.callbacks.input.take() else { + return Some(1); + }; + let scale_factor = lock.scale_factor; + drop(lock); -fn handle_mouse_move_msg( - handle: HWND, - lparam: LPARAM, - wparam: WPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - start_tracking_mouse(handle, &state_ptr, TME_LEAVE); + let pressed_button = match MODIFIERKEYS_FLAGS(wparam.loword() as u32) { + flags if flags.contains(MK_LBUTTON) => Some(MouseButton::Left), + flags if flags.contains(MK_RBUTTON) => Some(MouseButton::Right), + flags if flags.contains(MK_MBUTTON) => Some(MouseButton::Middle), + flags if flags.contains(MK_XBUTTON1) => { + Some(MouseButton::Navigate(NavigationDirection::Back)) + } + flags if flags.contains(MK_XBUTTON2) => { + Some(MouseButton::Navigate(NavigationDirection::Forward)) + } + _ => None, + }; + let x = lparam.signed_loword() as f32; + let y = lparam.signed_hiword() as f32; + let input = PlatformInput::MouseMove(MouseMoveEvent { + position: logical_point(x, y, scale_factor), + pressed_button, + modifiers: current_modifiers(), + }); + let handled = !func(input).propagate; + self.state.borrow_mut().callbacks.input = Some(func); - let mut lock = state_ptr.state.borrow_mut(); - let Some(mut func) = lock.callbacks.input.take() else { - return Some(1); - }; - let scale_factor = lock.scale_factor; - drop(lock); - - let pressed_button = match MODIFIERKEYS_FLAGS(wparam.loword() as u32) { - flags if flags.contains(MK_LBUTTON) => Some(MouseButton::Left), - flags if flags.contains(MK_RBUTTON) => Some(MouseButton::Right), - flags if flags.contains(MK_MBUTTON) => Some(MouseButton::Middle), - flags if flags.contains(MK_XBUTTON1) => { - Some(MouseButton::Navigate(NavigationDirection::Back)) - } - flags if flags.contains(MK_XBUTTON2) => { - Some(MouseButton::Navigate(NavigationDirection::Forward)) + if handled { Some(0) } else { Some(1) } + } + + fn handle_mouse_leave_msg(&self) -> Option<isize> { + let mut lock = self.state.borrow_mut(); + lock.hovered = false; + if let Some(mut callback) = lock.callbacks.hovered_status_change.take() { + drop(lock); + callback(false); + self.state.borrow_mut().callbacks.hovered_status_change = Some(callback); } - _ => None, - }; - let x = lparam.signed_loword() as f32; - let y = lparam.signed_hiword() as f32; - let input = PlatformInput::MouseMove(MouseMoveEvent { - position: logical_point(x, y, scale_factor), - pressed_button, - modifiers: current_modifiers(), - }); - let handled = !func(input).propagate; - state_ptr.state.borrow_mut().callbacks.input = Some(func); - - if handled { Some(0) } else { Some(1) } -} -fn handle_mouse_leave_msg(state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - lock.hovered = false; - if let Some(mut callback) = lock.callbacks.hovered_status_change.take() { - drop(lock); - callback(false); - state_ptr.state.borrow_mut().callbacks.hovered_status_change = Some(callback); + Some(0) } - Some(0) -} + fn handle_syskeydown_msg(&self, handle: HWND, wparam: WPARAM, lparam: LPARAM) -> Option<isize> { + let mut lock = self.state.borrow_mut(); + let input = handle_key_event(handle, wparam, lparam, &mut lock, |keystroke| { + PlatformInput::KeyDown(KeyDownEvent { + keystroke, + is_held: lparam.0 & (0x1 << 30) > 0, + }) + })?; + let mut func = lock.callbacks.input.take()?; + drop(lock); -fn handle_syskeydown_msg( - handle: HWND, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - let input = handle_key_event(handle, wparam, lparam, &mut lock, |keystroke| { - PlatformInput::KeyDown(KeyDownEvent { - keystroke, - is_held: lparam.0 & (0x1 << 30) > 0, - }) - })?; - let mut func = lock.callbacks.input.take()?; - drop(lock); + let handled = !func(input).propagate; - let handled = !func(input).propagate; + let mut lock = self.state.borrow_mut(); + lock.callbacks.input = Some(func); - let mut lock = state_ptr.state.borrow_mut(); - lock.callbacks.input = Some(func); + if handled { + lock.system_key_handled = true; + Some(0) + } else { + // we need to call `DefWindowProcW`, or we will lose the system-wide `Alt+F4`, `Alt+{other keys}` + // shortcuts. + None + } + } + + fn handle_syskeyup_msg(&self, handle: HWND, wparam: WPARAM, lparam: LPARAM) -> Option<isize> { + let mut lock = self.state.borrow_mut(); + let input = handle_key_event(handle, wparam, lparam, &mut lock, |keystroke| { + PlatformInput::KeyUp(KeyUpEvent { keystroke }) + })?; + let mut func = lock.callbacks.input.take()?; + drop(lock); + func(input); + self.state.borrow_mut().callbacks.input = Some(func); - if handled { - lock.system_key_handled = true; + // Always return 0 to indicate that the message was handled, so we could properly handle `ModifiersChanged` event. Some(0) - } else { - // we need to call `DefWindowProcW`, or we will lose the system-wide `Alt+F4`, `Alt+{other keys}` - // shortcuts. - None } -} -fn handle_syskeyup_msg( - handle: HWND, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - let input = handle_key_event(handle, wparam, lparam, &mut lock, |keystroke| { - PlatformInput::KeyUp(KeyUpEvent { keystroke }) - })?; - let mut func = lock.callbacks.input.take()?; - drop(lock); - func(input); - state_ptr.state.borrow_mut().callbacks.input = Some(func); + // It's a known bug that you can't trigger `ctrl-shift-0`. See: + // https://superuser.com/questions/1455762/ctrl-shift-number-key-combination-has-stopped-working-for-a-few-numbers + fn handle_keydown_msg(&self, handle: HWND, wparam: WPARAM, lparam: LPARAM) -> Option<isize> { + let mut lock = self.state.borrow_mut(); + let Some(input) = handle_key_event(handle, wparam, lparam, &mut lock, |keystroke| { + PlatformInput::KeyDown(KeyDownEvent { + keystroke, + is_held: lparam.0 & (0x1 << 30) > 0, + }) + }) else { + return Some(1); + }; + drop(lock); - // Always return 0 to indicate that the message was handled, so we could properly handle `ModifiersChanged` event. - Some(0) -} + let is_composing = self + .with_input_handler(|input_handler| input_handler.marked_text_range()) + .flatten() + .is_some(); + if is_composing { + translate_message(handle, wparam, lparam); + return Some(0); + } -// It's a known bug that you can't trigger `ctrl-shift-0`. See: -// https://superuser.com/questions/1455762/ctrl-shift-number-key-combination-has-stopped-working-for-a-few-numbers -fn handle_keydown_msg( - handle: HWND, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - let Some(input) = handle_key_event(handle, wparam, lparam, &mut lock, |keystroke| { - PlatformInput::KeyDown(KeyDownEvent { - keystroke, - is_held: lparam.0 & (0x1 << 30) > 0, - }) - }) else { - return Some(1); - }; - drop(lock); + let Some(mut func) = self.state.borrow_mut().callbacks.input.take() else { + return Some(1); + }; - let is_composing = with_input_handler(&state_ptr, |input_handler| { - input_handler.marked_text_range() - }) - .flatten() - .is_some(); - if is_composing { - translate_message(handle, wparam, lparam); - return Some(0); + let handled = !func(input).propagate; + + self.state.borrow_mut().callbacks.input = Some(func); + + if handled { + Some(0) + } else { + translate_message(handle, wparam, lparam); + Some(1) + } } - let Some(mut func) = state_ptr.state.borrow_mut().callbacks.input.take() else { - return Some(1); - }; + fn handle_keyup_msg(&self, handle: HWND, wparam: WPARAM, lparam: LPARAM) -> Option<isize> { + let mut lock = self.state.borrow_mut(); + let Some(input) = handle_key_event(handle, wparam, lparam, &mut lock, |keystroke| { + PlatformInput::KeyUp(KeyUpEvent { keystroke }) + }) else { + return Some(1); + }; - let handled = !func(input).propagate; + let Some(mut func) = lock.callbacks.input.take() else { + return Some(1); + }; + drop(lock); - state_ptr.state.borrow_mut().callbacks.input = Some(func); + let handled = !func(input).propagate; + self.state.borrow_mut().callbacks.input = Some(func); + + if handled { Some(0) } else { Some(1) } + } + + fn handle_char_msg(&self, wparam: WPARAM) -> Option<isize> { + let input = self.parse_char_message(wparam)?; + self.with_input_handler(|input_handler| { + input_handler.replace_text_in_range(None, &input); + }); - if handled { Some(0) - } else { - translate_message(handle, wparam, lparam); - Some(1) } -} -fn handle_keyup_msg( - handle: HWND, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - let Some(input) = handle_key_event(handle, wparam, lparam, &mut lock, |keystroke| { - PlatformInput::KeyUp(KeyUpEvent { keystroke }) - }) else { - return Some(1); - }; + fn handle_dead_char_msg(&self, wparam: WPARAM) -> Option<isize> { + let ch = char::from_u32(wparam.0 as u32)?.to_string(); + self.with_input_handler(|input_handler| { + input_handler.replace_and_mark_text_in_range(None, &ch, None); + }); + None + } - let Some(mut func) = lock.callbacks.input.take() else { - return Some(1); - }; - drop(lock); + fn handle_mouse_down_msg( + &self, + handle: HWND, + button: MouseButton, + lparam: LPARAM, + ) -> Option<isize> { + unsafe { SetCapture(handle) }; + let mut lock = self.state.borrow_mut(); + let Some(mut func) = lock.callbacks.input.take() else { + return Some(1); + }; + let x = lparam.signed_loword(); + let y = lparam.signed_hiword(); + let physical_point = point(DevicePixels(x as i32), DevicePixels(y as i32)); + let click_count = lock.click_state.update(button, physical_point); + let scale_factor = lock.scale_factor; + drop(lock); - let handled = !func(input).propagate; - state_ptr.state.borrow_mut().callbacks.input = Some(func); + let input = PlatformInput::MouseDown(MouseDownEvent { + button, + position: logical_point(x as f32, y as f32, scale_factor), + modifiers: current_modifiers(), + click_count, + first_mouse: false, + }); + let handled = !func(input).propagate; + self.state.borrow_mut().callbacks.input = Some(func); - if handled { Some(0) } else { Some(1) } -} + if handled { Some(0) } else { Some(1) } + } -fn handle_char_msg(wparam: WPARAM, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - let input = parse_char_message(wparam, &state_ptr)?; - with_input_handler(&state_ptr, |input_handler| { - input_handler.replace_text_in_range(None, &input); - }); + fn handle_mouse_up_msg( + &self, + _handle: HWND, + button: MouseButton, + lparam: LPARAM, + ) -> Option<isize> { + unsafe { ReleaseCapture().log_err() }; + let mut lock = self.state.borrow_mut(); + let Some(mut func) = lock.callbacks.input.take() else { + return Some(1); + }; + let x = lparam.signed_loword() as f32; + let y = lparam.signed_hiword() as f32; + let click_count = lock.click_state.current_count; + let scale_factor = lock.scale_factor; + drop(lock); - Some(0) -} + let input = PlatformInput::MouseUp(MouseUpEvent { + button, + position: logical_point(x, y, scale_factor), + modifiers: current_modifiers(), + click_count, + }); + let handled = !func(input).propagate; + self.state.borrow_mut().callbacks.input = Some(func); -fn handle_dead_char_msg(wparam: WPARAM, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - let ch = char::from_u32(wparam.0 as u32)?.to_string(); - with_input_handler(&state_ptr, |input_handler| { - input_handler.replace_and_mark_text_in_range(None, &ch, None); - }); - None -} + if handled { Some(0) } else { Some(1) } + } -fn handle_mouse_down_msg( - handle: HWND, - button: MouseButton, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - unsafe { SetCapture(handle) }; - let mut lock = state_ptr.state.borrow_mut(); - let Some(mut func) = lock.callbacks.input.take() else { - return Some(1); - }; - let x = lparam.signed_loword(); - let y = lparam.signed_hiword(); - let physical_point = point(DevicePixels(x as i32), DevicePixels(y as i32)); - let click_count = lock.click_state.update(button, physical_point); - let scale_factor = lock.scale_factor; - drop(lock); - - let input = PlatformInput::MouseDown(MouseDownEvent { - button, - position: logical_point(x as f32, y as f32, scale_factor), - modifiers: current_modifiers(), - click_count, - first_mouse: false, - }); - let handled = !func(input).propagate; - state_ptr.state.borrow_mut().callbacks.input = Some(func); - - if handled { Some(0) } else { Some(1) } -} + fn handle_xbutton_msg( + &self, + handle: HWND, + wparam: WPARAM, + lparam: LPARAM, + handler: impl Fn(&Self, HWND, MouseButton, LPARAM) -> Option<isize>, + ) -> Option<isize> { + let nav_dir = match wparam.hiword() { + XBUTTON1 => NavigationDirection::Back, + XBUTTON2 => NavigationDirection::Forward, + _ => return Some(1), + }; + handler(self, handle, MouseButton::Navigate(nav_dir), lparam) + } -fn handle_mouse_up_msg( - _handle: HWND, - button: MouseButton, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - unsafe { ReleaseCapture().log_err() }; - let mut lock = state_ptr.state.borrow_mut(); - let Some(mut func) = lock.callbacks.input.take() else { - return Some(1); - }; - let x = lparam.signed_loword() as f32; - let y = lparam.signed_hiword() as f32; - let click_count = lock.click_state.current_count; - let scale_factor = lock.scale_factor; - drop(lock); - - let input = PlatformInput::MouseUp(MouseUpEvent { - button, - position: logical_point(x, y, scale_factor), - modifiers: current_modifiers(), - click_count, - }); - let handled = !func(input).propagate; - state_ptr.state.borrow_mut().callbacks.input = Some(func); - - if handled { Some(0) } else { Some(1) } -} + fn handle_mouse_wheel_msg( + &self, + handle: HWND, + wparam: WPARAM, + lparam: LPARAM, + ) -> Option<isize> { + let modifiers = current_modifiers(); + let mut lock = self.state.borrow_mut(); + let Some(mut func) = lock.callbacks.input.take() else { + return Some(1); + }; + let scale_factor = lock.scale_factor; + let wheel_scroll_amount = match modifiers.shift { + true => lock.system_settings.mouse_wheel_settings.wheel_scroll_chars, + false => lock.system_settings.mouse_wheel_settings.wheel_scroll_lines, + }; + drop(lock); -fn handle_xbutton_msg( - handle: HWND, - wparam: WPARAM, - lparam: LPARAM, - handler: impl Fn(HWND, MouseButton, LPARAM, Rc<WindowsWindowStatePtr>) -> Option<isize>, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let nav_dir = match wparam.hiword() { - XBUTTON1 => NavigationDirection::Back, - XBUTTON2 => NavigationDirection::Forward, - _ => return Some(1), - }; - handler(handle, MouseButton::Navigate(nav_dir), lparam, state_ptr) -} + let wheel_distance = + (wparam.signed_hiword() as f32 / WHEEL_DELTA as f32) * wheel_scroll_amount as f32; + let mut cursor_point = POINT { + x: lparam.signed_loword().into(), + y: lparam.signed_hiword().into(), + }; + unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; + let input = PlatformInput::ScrollWheel(ScrollWheelEvent { + position: logical_point(cursor_point.x as f32, cursor_point.y as f32, scale_factor), + delta: ScrollDelta::Lines(match modifiers.shift { + true => Point { + x: wheel_distance, + y: 0.0, + }, + false => Point { + y: wheel_distance, + x: 0.0, + }, + }), + modifiers, + touch_phase: TouchPhase::Moved, + }); + let handled = !func(input).propagate; + self.state.borrow_mut().callbacks.input = Some(func); -fn handle_mouse_wheel_msg( - handle: HWND, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let modifiers = current_modifiers(); - let mut lock = state_ptr.state.borrow_mut(); - let Some(mut func) = lock.callbacks.input.take() else { - return Some(1); - }; - let scale_factor = lock.scale_factor; - let wheel_scroll_amount = match modifiers.shift { - true => lock.system_settings.mouse_wheel_settings.wheel_scroll_chars, - false => lock.system_settings.mouse_wheel_settings.wheel_scroll_lines, - }; - drop(lock); + if handled { Some(0) } else { Some(1) } + } - let wheel_distance = - (wparam.signed_hiword() as f32 / WHEEL_DELTA as f32) * wheel_scroll_amount as f32; - let mut cursor_point = POINT { - x: lparam.signed_loword().into(), - y: lparam.signed_hiword().into(), - }; - unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; - let input = PlatformInput::ScrollWheel(ScrollWheelEvent { - position: logical_point(cursor_point.x as f32, cursor_point.y as f32, scale_factor), - delta: ScrollDelta::Lines(match modifiers.shift { - true => Point { + fn handle_mouse_horizontal_wheel_msg( + &self, + handle: HWND, + wparam: WPARAM, + lparam: LPARAM, + ) -> Option<isize> { + let mut lock = self.state.borrow_mut(); + let Some(mut func) = lock.callbacks.input.take() else { + return Some(1); + }; + let scale_factor = lock.scale_factor; + let wheel_scroll_chars = lock.system_settings.mouse_wheel_settings.wheel_scroll_chars; + drop(lock); + + let wheel_distance = + (-wparam.signed_hiword() as f32 / WHEEL_DELTA as f32) * wheel_scroll_chars as f32; + let mut cursor_point = POINT { + x: lparam.signed_loword().into(), + y: lparam.signed_hiword().into(), + }; + unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; + let event = PlatformInput::ScrollWheel(ScrollWheelEvent { + position: logical_point(cursor_point.x as f32, cursor_point.y as f32, scale_factor), + delta: ScrollDelta::Lines(Point { x: wheel_distance, y: 0.0, - }, - false => Point { - y: wheel_distance, - x: 0.0, - }, - }), - modifiers, - touch_phase: TouchPhase::Moved, - }); - let handled = !func(input).propagate; - state_ptr.state.borrow_mut().callbacks.input = Some(func); - - if handled { Some(0) } else { Some(1) } -} + }), + modifiers: current_modifiers(), + touch_phase: TouchPhase::Moved, + }); + let handled = !func(event).propagate; + self.state.borrow_mut().callbacks.input = Some(func); -fn handle_mouse_horizontal_wheel_msg( - handle: HWND, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - let Some(mut func) = lock.callbacks.input.take() else { - return Some(1); - }; - let scale_factor = lock.scale_factor; - let wheel_scroll_chars = lock.system_settings.mouse_wheel_settings.wheel_scroll_chars; - drop(lock); - - let wheel_distance = - (-wparam.signed_hiword() as f32 / WHEEL_DELTA as f32) * wheel_scroll_chars as f32; - let mut cursor_point = POINT { - x: lparam.signed_loword().into(), - y: lparam.signed_hiword().into(), - }; - unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; - let event = PlatformInput::ScrollWheel(ScrollWheelEvent { - position: logical_point(cursor_point.x as f32, cursor_point.y as f32, scale_factor), - delta: ScrollDelta::Lines(Point { - x: wheel_distance, - y: 0.0, - }), - modifiers: current_modifiers(), - touch_phase: TouchPhase::Moved, - }); - let handled = !func(event).propagate; - state_ptr.state.borrow_mut().callbacks.input = Some(func); - - if handled { Some(0) } else { Some(1) } -} + if handled { Some(0) } else { Some(1) } + } -fn retrieve_caret_position(state_ptr: &Rc<WindowsWindowStatePtr>) -> Option<POINT> { - with_input_handler_and_scale_factor(state_ptr, |input_handler, scale_factor| { - let caret_range = input_handler.selected_text_range(false)?; - let caret_position = input_handler.bounds_for_range(caret_range.range)?; - Some(POINT { - // logical to physical - x: (caret_position.origin.x.0 * scale_factor) as i32, - y: (caret_position.origin.y.0 * scale_factor) as i32 - + ((caret_position.size.height.0 * scale_factor) as i32 / 2), + fn retrieve_caret_position(&self) -> Option<POINT> { + self.with_input_handler_and_scale_factor(|input_handler, scale_factor| { + let caret_range = input_handler.selected_text_range(false)?; + let caret_position = input_handler.bounds_for_range(caret_range.range)?; + Some(POINT { + // logical to physical + x: (caret_position.origin.x.0 * scale_factor) as i32, + y: (caret_position.origin.y.0 * scale_factor) as i32 + + ((caret_position.size.height.0 * scale_factor) as i32 / 2), + }) }) - }) -} - -fn handle_ime_position(handle: HWND, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - unsafe { - let ctx = ImmGetContext(handle); - - let Some(caret_position) = retrieve_caret_position(&state_ptr) else { - return Some(0); - }; - { - let config = COMPOSITIONFORM { - dwStyle: CFS_POINT, - ptCurrentPos: caret_position, - ..Default::default() - }; - ImmSetCompositionWindow(ctx, &config as _).ok().log_err(); - } - { - let config = CANDIDATEFORM { - dwStyle: CFS_CANDIDATEPOS, - ptCurrentPos: caret_position, - ..Default::default() - }; - ImmSetCandidateWindow(ctx, &config as _).ok().log_err(); - } - ImmReleaseContext(handle, ctx).ok().log_err(); - Some(0) } -} -fn handle_ime_composition( - handle: HWND, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let ctx = unsafe { ImmGetContext(handle) }; - let result = handle_ime_composition_inner(ctx, lparam, state_ptr); - unsafe { ImmReleaseContext(handle, ctx).ok().log_err() }; - result -} + fn handle_ime_position(&self, handle: HWND) -> Option<isize> { + unsafe { + let ctx = ImmGetContext(handle); -fn handle_ime_composition_inner( - ctx: HIMC, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let lparam = lparam.0 as u32; - if lparam == 0 { - // Japanese IME may send this message with lparam = 0, which indicates that - // there is no composition string. - with_input_handler(&state_ptr, |input_handler| { - input_handler.replace_text_in_range(None, ""); - })?; - Some(0) - } else { - if lparam & GCS_COMPSTR.0 > 0 { - let comp_string = parse_ime_composition_string(ctx, GCS_COMPSTR)?; - let caret_pos = (!comp_string.is_empty() && lparam & GCS_CURSORPOS.0 > 0).then(|| { - let pos = retrieve_composition_cursor_position(ctx); - pos..pos - }); - with_input_handler(&state_ptr, |input_handler| { - input_handler.replace_and_mark_text_in_range(None, &comp_string, caret_pos); - })?; - } - if lparam & GCS_RESULTSTR.0 > 0 { - let comp_result = parse_ime_composition_string(ctx, GCS_RESULTSTR)?; - with_input_handler(&state_ptr, |input_handler| { - input_handler.replace_text_in_range(None, &comp_result); - })?; - return Some(0); + let Some(caret_position) = self.retrieve_caret_position() else { + return Some(0); + }; + { + let config = COMPOSITIONFORM { + dwStyle: CFS_POINT, + ptCurrentPos: caret_position, + ..Default::default() + }; + ImmSetCompositionWindow(ctx, &config as _).ok().log_err(); + } + { + let config = CANDIDATEFORM { + dwStyle: CFS_CANDIDATEPOS, + ptCurrentPos: caret_position, + ..Default::default() + }; + ImmSetCandidateWindow(ctx, &config as _).ok().log_err(); + } + ImmReleaseContext(handle, ctx).ok().log_err(); + Some(0) } - - // currently, we don't care other stuff - None } -} -/// SEE: https://learn.microsoft.com/en-us/windows/win32/winmsg/wm-nccalcsize -fn handle_calc_client_size( - handle: HWND, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - if !state_ptr.hide_title_bar || state_ptr.state.borrow().is_fullscreen() || wparam.0 == 0 { - return None; + fn handle_ime_composition(&self, handle: HWND, lparam: LPARAM) -> Option<isize> { + let ctx = unsafe { ImmGetContext(handle) }; + let result = self.handle_ime_composition_inner(ctx, lparam); + unsafe { ImmReleaseContext(handle, ctx).ok().log_err() }; + result } - let is_maximized = state_ptr.state.borrow().is_maximized(); - let insets = get_client_area_insets(handle, is_maximized, state_ptr.windows_version); - // wparam is TRUE so lparam points to an NCCALCSIZE_PARAMS structure - let mut params = lparam.0 as *mut NCCALCSIZE_PARAMS; - let mut requested_client_rect = unsafe { &mut ((*params).rgrc) }; - - requested_client_rect[0].left += insets.left; - requested_client_rect[0].top += insets.top; - requested_client_rect[0].right -= insets.right; - requested_client_rect[0].bottom -= insets.bottom; - - // Fix auto hide taskbar not showing. This solution is based on the approach - // used by Chrome. However, it may result in one row of pixels being obscured - // in our client area. But as Chrome says, "there seems to be no better solution." - if is_maximized { - if let Some(ref taskbar_position) = state_ptr - .state - .borrow() - .system_settings - .auto_hide_taskbar_position - { - // Fot the auto-hide taskbar, adjust in by 1 pixel on taskbar edge, - // so the window isn't treated as a "fullscreen app", which would cause - // the taskbar to disappear. - match taskbar_position { - AutoHideTaskbarPosition::Left => { - requested_client_rect[0].left += AUTO_HIDE_TASKBAR_THICKNESS_PX - } - AutoHideTaskbarPosition::Top => { - requested_client_rect[0].top += AUTO_HIDE_TASKBAR_THICKNESS_PX - } - AutoHideTaskbarPosition::Right => { - requested_client_rect[0].right -= AUTO_HIDE_TASKBAR_THICKNESS_PX - } - AutoHideTaskbarPosition::Bottom => { - requested_client_rect[0].bottom -= AUTO_HIDE_TASKBAR_THICKNESS_PX - } + fn handle_ime_composition_inner(&self, ctx: HIMC, lparam: LPARAM) -> Option<isize> { + let lparam = lparam.0 as u32; + if lparam == 0 { + // Japanese IME may send this message with lparam = 0, which indicates that + // there is no composition string. + self.with_input_handler(|input_handler| { + input_handler.replace_text_in_range(None, ""); + })?; + Some(0) + } else { + if lparam & GCS_COMPSTR.0 > 0 { + let comp_string = parse_ime_composition_string(ctx, GCS_COMPSTR)?; + let caret_pos = + (!comp_string.is_empty() && lparam & GCS_CURSORPOS.0 > 0).then(|| { + let pos = retrieve_composition_cursor_position(ctx); + pos..pos + }); + self.with_input_handler(|input_handler| { + input_handler.replace_and_mark_text_in_range(None, &comp_string, caret_pos); + })?; + } + if lparam & GCS_RESULTSTR.0 > 0 { + let comp_result = parse_ime_composition_string(ctx, GCS_RESULTSTR)?; + self.with_input_handler(|input_handler| { + input_handler.replace_text_in_range(None, &comp_result); + })?; + return Some(0); } + + // currently, we don't care other stuff + None } } - Some(0) -} + /// SEE: https://learn.microsoft.com/en-us/windows/win32/winmsg/wm-nccalcsize + fn handle_calc_client_size( + &self, + handle: HWND, + wparam: WPARAM, + lparam: LPARAM, + ) -> Option<isize> { + if !self.hide_title_bar || self.state.borrow().is_fullscreen() || wparam.0 == 0 { + return None; + } -fn handle_activate_msg(wparam: WPARAM, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - let activated = wparam.loword() > 0; - let this = state_ptr.clone(); - state_ptr - .executor - .spawn(async move { - let mut lock = this.state.borrow_mut(); - if let Some(mut func) = lock.callbacks.active_status_change.take() { - drop(lock); - func(activated); - this.state.borrow_mut().callbacks.active_status_change = Some(func); + let is_maximized = self.state.borrow().is_maximized(); + let insets = get_client_area_insets(handle, is_maximized, self.windows_version); + // wparam is TRUE so lparam points to an NCCALCSIZE_PARAMS structure + let mut params = lparam.0 as *mut NCCALCSIZE_PARAMS; + let mut requested_client_rect = unsafe { &mut ((*params).rgrc) }; + + requested_client_rect[0].left += insets.left; + requested_client_rect[0].top += insets.top; + requested_client_rect[0].right -= insets.right; + requested_client_rect[0].bottom -= insets.bottom; + + // Fix auto hide taskbar not showing. This solution is based on the approach + // used by Chrome. However, it may result in one row of pixels being obscured + // in our client area. But as Chrome says, "there seems to be no better solution." + if is_maximized { + if let Some(ref taskbar_position) = self + .state + .borrow() + .system_settings + .auto_hide_taskbar_position + { + // Fot the auto-hide taskbar, adjust in by 1 pixel on taskbar edge, + // so the window isn't treated as a "fullscreen app", which would cause + // the taskbar to disappear. + match taskbar_position { + AutoHideTaskbarPosition::Left => { + requested_client_rect[0].left += AUTO_HIDE_TASKBAR_THICKNESS_PX + } + AutoHideTaskbarPosition::Top => { + requested_client_rect[0].top += AUTO_HIDE_TASKBAR_THICKNESS_PX + } + AutoHideTaskbarPosition::Right => { + requested_client_rect[0].right -= AUTO_HIDE_TASKBAR_THICKNESS_PX + } + AutoHideTaskbarPosition::Bottom => { + requested_client_rect[0].bottom -= AUTO_HIDE_TASKBAR_THICKNESS_PX + } + } } - }) - .detach(); - - None -} + } -fn handle_create_msg(handle: HWND, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - if state_ptr.hide_title_bar { - notify_frame_changed(handle); Some(0) - } else { - None - } -} - -fn handle_dpi_changed_msg( - handle: HWND, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let new_dpi = wparam.loword() as f32; - let mut lock = state_ptr.state.borrow_mut(); - lock.scale_factor = new_dpi / USER_DEFAULT_SCREEN_DPI as f32; - lock.border_offset.update(handle).log_err(); - drop(lock); - - let rect = unsafe { &*(lparam.0 as *const RECT) }; - let width = rect.right - rect.left; - let height = rect.bottom - rect.top; - // this will emit `WM_SIZE` and `WM_MOVE` right here - // even before this function returns - // the new size is handled in `WM_SIZE` - unsafe { - SetWindowPos( - handle, - None, - rect.left, - rect.top, - width, - height, - SWP_NOZORDER | SWP_NOACTIVATE, - ) - .context("unable to set window position after dpi has changed") - .log_err(); } - Some(0) -} + fn handle_activate_msg(self: &Rc<Self>, wparam: WPARAM) -> Option<isize> { + let activated = wparam.loword() > 0; + let this = self.clone(); + self.executor + .spawn(async move { + let mut lock = this.state.borrow_mut(); + if let Some(mut func) = lock.callbacks.active_status_change.take() { + drop(lock); + func(activated); + this.state.borrow_mut().callbacks.active_status_change = Some(func); + } + }) + .detach(); -/// The following conditions will trigger this event: -/// 1. The monitor on which the window is located goes offline or changes resolution. -/// 2. Another monitor goes offline, is plugged in, or changes resolution. -/// -/// In either case, the window will only receive information from the monitor on which -/// it is located. -/// -/// For example, in the case of condition 2, where the monitor on which the window is -/// located has actually changed nothing, it will still receive this event. -fn handle_display_change_msg(handle: HWND, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - // NOTE: - // Even the `lParam` holds the resolution of the screen, we just ignore it. - // Because WM_DPICHANGED, WM_MOVE, WM_SIZE will come first, window reposition and resize - // are handled there. - // So we only care about if monitor is disconnected. - let previous_monitor = state_ptr.state.borrow().display; - if WindowsDisplay::is_connected(previous_monitor.handle) { - // we are fine, other display changed - return None; - } - // display disconnected - // in this case, the OS will move our window to another monitor, and minimize it. - // we deminimize the window and query the monitor after moving - unsafe { - let _ = ShowWindow(handle, SW_SHOWNORMAL); - }; - let new_monitor = unsafe { MonitorFromWindow(handle, MONITOR_DEFAULTTONULL) }; - // all monitors disconnected - if new_monitor.is_invalid() { - log::error!("No monitor detected!"); - return None; + None } - let new_display = WindowsDisplay::new_with_handle(new_monitor); - state_ptr.state.borrow_mut().display = new_display; - Some(0) -} -fn handle_hit_test_msg( - handle: HWND, - msg: u32, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - if !state_ptr.is_movable || state_ptr.state.borrow().is_fullscreen() { - return None; + fn handle_create_msg(&self, handle: HWND) -> Option<isize> { + if self.hide_title_bar { + notify_frame_changed(handle); + Some(0) + } else { + None + } } - let mut lock = state_ptr.state.borrow_mut(); - if let Some(mut callback) = lock.callbacks.hit_test_window_control.take() { + fn handle_dpi_changed_msg( + &self, + handle: HWND, + wparam: WPARAM, + lparam: LPARAM, + ) -> Option<isize> { + let new_dpi = wparam.loword() as f32; + let mut lock = self.state.borrow_mut(); + let is_maximized = lock.is_maximized(); + let new_scale_factor = new_dpi / USER_DEFAULT_SCREEN_DPI as f32; + lock.scale_factor = new_scale_factor; + lock.border_offset.update(handle).log_err(); drop(lock); - let area = callback(); - state_ptr - .state - .borrow_mut() - .callbacks - .hit_test_window_control = Some(callback); - if let Some(area) = area { - return match area { - WindowControlArea::Drag => Some(HTCAPTION as _), - WindowControlArea::Close => Some(HTCLOSE as _), - WindowControlArea::Max => Some(HTMAXBUTTON as _), - WindowControlArea::Min => Some(HTMINBUTTON as _), - }; + + let rect = unsafe { &*(lparam.0 as *const RECT) }; + let width = rect.right - rect.left; + let height = rect.bottom - rect.top; + // this will emit `WM_SIZE` and `WM_MOVE` right here + // even before this function returns + // the new size is handled in `WM_SIZE` + unsafe { + SetWindowPos( + handle, + None, + rect.left, + rect.top, + width, + height, + SWP_NOZORDER | SWP_NOACTIVATE, + ) + .context("unable to set window position after dpi has changed") + .log_err(); } - } else { - drop(lock); - } - if !state_ptr.hide_title_bar { - // If the OS draws the title bar, we don't need to handle hit test messages. - return None; - } + // When maximized, SetWindowPos doesn't send WM_SIZE, so we need to manually + // update the size and call the resize callback + if is_maximized { + let device_size = size(DevicePixels(width), DevicePixels(height)); + self.handle_size_change(device_size, new_scale_factor, true); + } - // default handler for resize areas - let hit = unsafe { DefWindowProcW(handle, msg, wparam, lparam) }; - if matches!( - hit.0 as u32, - HTNOWHERE - | HTRIGHT - | HTLEFT - | HTTOPLEFT - | HTTOP - | HTTOPRIGHT - | HTBOTTOMRIGHT - | HTBOTTOM - | HTBOTTOMLEFT - ) { - return Some(hit.0); + Some(0) } - if state_ptr.state.borrow().is_fullscreen() { - return Some(HTCLIENT as _); + /// The following conditions will trigger this event: + /// 1. The monitor on which the window is located goes offline or changes resolution. + /// 2. Another monitor goes offline, is plugged in, or changes resolution. + /// + /// In either case, the window will only receive information from the monitor on which + /// it is located. + /// + /// For example, in the case of condition 2, where the monitor on which the window is + /// located has actually changed nothing, it will still receive this event. + fn handle_display_change_msg(&self, handle: HWND) -> Option<isize> { + // NOTE: + // Even the `lParam` holds the resolution of the screen, we just ignore it. + // Because WM_DPICHANGED, WM_MOVE, WM_SIZE will come first, window reposition and resize + // are handled there. + // So we only care about if monitor is disconnected. + let previous_monitor = self.state.borrow().display; + if WindowsDisplay::is_connected(previous_monitor.handle) { + // we are fine, other display changed + return None; + } + // display disconnected + // in this case, the OS will move our window to another monitor, and minimize it. + // we deminimize the window and query the monitor after moving + unsafe { + let _ = ShowWindow(handle, SW_SHOWNORMAL); + }; + let new_monitor = unsafe { MonitorFromWindow(handle, MONITOR_DEFAULTTONULL) }; + // all monitors disconnected + if new_monitor.is_invalid() { + log::error!("No monitor detected!"); + return None; + } + let new_display = WindowsDisplay::new_with_handle(new_monitor); + self.state.borrow_mut().display = new_display; + Some(0) } - let dpi = unsafe { GetDpiForWindow(handle) }; - let frame_y = unsafe { GetSystemMetricsForDpi(SM_CYFRAME, dpi) }; + fn handle_hit_test_msg( + &self, + handle: HWND, + msg: u32, + wparam: WPARAM, + lparam: LPARAM, + ) -> Option<isize> { + if !self.is_movable || self.state.borrow().is_fullscreen() { + return None; + } - let mut cursor_point = POINT { - x: lparam.signed_loword().into(), - y: lparam.signed_hiword().into(), - }; - unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; - if !state_ptr.state.borrow().is_maximized() && cursor_point.y >= 0 && cursor_point.y <= frame_y - { - return Some(HTTOP as _); - } + let mut lock = self.state.borrow_mut(); + if let Some(mut callback) = lock.callbacks.hit_test_window_control.take() { + drop(lock); + let area = callback(); + self.state.borrow_mut().callbacks.hit_test_window_control = Some(callback); + if let Some(area) = area { + return match area { + WindowControlArea::Drag => Some(HTCAPTION as _), + WindowControlArea::Close => Some(HTCLOSE as _), + WindowControlArea::Max => Some(HTMAXBUTTON as _), + WindowControlArea::Min => Some(HTMINBUTTON as _), + }; + } + } else { + drop(lock); + } - Some(HTCLIENT as _) -} + if !self.hide_title_bar { + // If the OS draws the title bar, we don't need to handle hit test messages. + return None; + } -fn handle_nc_mouse_move_msg( - handle: HWND, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - start_tracking_mouse(handle, &state_ptr, TME_LEAVE | TME_NONCLIENT); - - let mut lock = state_ptr.state.borrow_mut(); - let mut func = lock.callbacks.input.take()?; - let scale_factor = lock.scale_factor; - drop(lock); - - let mut cursor_point = POINT { - x: lparam.signed_loword().into(), - y: lparam.signed_hiword().into(), - }; - unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; - let input = PlatformInput::MouseMove(MouseMoveEvent { - position: logical_point(cursor_point.x as f32, cursor_point.y as f32, scale_factor), - pressed_button: None, - modifiers: current_modifiers(), - }); - let handled = !func(input).propagate; - state_ptr.state.borrow_mut().callbacks.input = Some(func); - - if handled { Some(0) } else { None } -} + // default handler for resize areas + let hit = unsafe { DefWindowProcW(handle, msg, wparam, lparam) }; + if matches!( + hit.0 as u32, + HTNOWHERE + | HTRIGHT + | HTLEFT + | HTTOPLEFT + | HTTOP + | HTTOPRIGHT + | HTBOTTOMRIGHT + | HTBOTTOM + | HTBOTTOMLEFT + ) { + return Some(hit.0); + } + + if self.state.borrow().is_fullscreen() { + return Some(HTCLIENT as _); + } + + let dpi = unsafe { GetDpiForWindow(handle) }; + let frame_y = unsafe { GetSystemMetricsForDpi(SM_CYFRAME, dpi) }; -fn handle_nc_mouse_down_msg( - handle: HWND, - button: MouseButton, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - if let Some(mut func) = lock.callbacks.input.take() { - let scale_factor = lock.scale_factor; let mut cursor_point = POINT { x: lparam.signed_loword().into(), y: lparam.signed_hiword().into(), }; unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; - let physical_point = point(DevicePixels(cursor_point.x), DevicePixels(cursor_point.y)); - let click_count = lock.click_state.update(button, physical_point); - drop(lock); - - let input = PlatformInput::MouseDown(MouseDownEvent { - button, - position: logical_point(cursor_point.x as f32, cursor_point.y as f32, scale_factor), - modifiers: current_modifiers(), - click_count, - first_mouse: false, - }); - let result = func(input.clone()); - let handled = !result.propagate || result.default_prevented; - state_ptr.state.borrow_mut().callbacks.input = Some(func); - - if handled { - return Some(0); + if !self.state.borrow().is_maximized() && cursor_point.y >= 0 && cursor_point.y <= frame_y { + return Some(HTTOP as _); } - } else { - drop(lock); - }; - // Since these are handled in handle_nc_mouse_up_msg we must prevent the default window proc - if button == MouseButton::Left { - match wparam.0 as u32 { - HTMINBUTTON => state_ptr.state.borrow_mut().nc_button_pressed = Some(HTMINBUTTON), - HTMAXBUTTON => state_ptr.state.borrow_mut().nc_button_pressed = Some(HTMAXBUTTON), - HTCLOSE => state_ptr.state.borrow_mut().nc_button_pressed = Some(HTCLOSE), - _ => return None, - }; - Some(0) - } else { - None + Some(HTCLIENT as _) } -} -fn handle_nc_mouse_up_msg( - handle: HWND, - button: MouseButton, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - if let Some(mut func) = lock.callbacks.input.take() { + fn handle_nc_mouse_move_msg(&self, handle: HWND, lparam: LPARAM) -> Option<isize> { + self.start_tracking_mouse(handle, TME_LEAVE | TME_NONCLIENT); + + let mut lock = self.state.borrow_mut(); + let mut func = lock.callbacks.input.take()?; let scale_factor = lock.scale_factor; drop(lock); @@ -1052,207 +919,356 @@ fn handle_nc_mouse_up_msg( y: lparam.signed_hiword().into(), }; unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; - let input = PlatformInput::MouseUp(MouseUpEvent { - button, + let input = PlatformInput::MouseMove(MouseMoveEvent { position: logical_point(cursor_point.x as f32, cursor_point.y as f32, scale_factor), + pressed_button: None, modifiers: current_modifiers(), - click_count: 1, }); let handled = !func(input).propagate; - state_ptr.state.borrow_mut().callbacks.input = Some(func); + self.state.borrow_mut().callbacks.input = Some(func); - if handled { - return Some(0); + if handled { Some(0) } else { None } + } + + fn handle_nc_mouse_down_msg( + &self, + handle: HWND, + button: MouseButton, + wparam: WPARAM, + lparam: LPARAM, + ) -> Option<isize> { + let mut lock = self.state.borrow_mut(); + if let Some(mut func) = lock.callbacks.input.take() { + let scale_factor = lock.scale_factor; + let mut cursor_point = POINT { + x: lparam.signed_loword().into(), + y: lparam.signed_hiword().into(), + }; + unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; + let physical_point = point(DevicePixels(cursor_point.x), DevicePixels(cursor_point.y)); + let click_count = lock.click_state.update(button, physical_point); + drop(lock); + + let input = PlatformInput::MouseDown(MouseDownEvent { + button, + position: logical_point(cursor_point.x as f32, cursor_point.y as f32, scale_factor), + modifiers: current_modifiers(), + click_count, + first_mouse: false, + }); + let result = func(input.clone()); + let handled = !result.propagate || result.default_prevented; + self.state.borrow_mut().callbacks.input = Some(func); + + if handled { + return Some(0); + } + } else { + drop(lock); + }; + + // Since these are handled in handle_nc_mouse_up_msg we must prevent the default window proc + if button == MouseButton::Left { + match wparam.0 as u32 { + HTMINBUTTON => self.state.borrow_mut().nc_button_pressed = Some(HTMINBUTTON), + HTMAXBUTTON => self.state.borrow_mut().nc_button_pressed = Some(HTMAXBUTTON), + HTCLOSE => self.state.borrow_mut().nc_button_pressed = Some(HTCLOSE), + _ => return None, + }; + Some(0) + } else { + None } - } else { - drop(lock); } - let last_pressed = state_ptr.state.borrow_mut().nc_button_pressed.take(); - if button == MouseButton::Left - && let Some(last_pressed) = last_pressed - { - let handled = match (wparam.0 as u32, last_pressed) { - (HTMINBUTTON, HTMINBUTTON) => { - unsafe { ShowWindowAsync(handle, SW_MINIMIZE).ok().log_err() }; - true + fn handle_nc_mouse_up_msg( + &self, + handle: HWND, + button: MouseButton, + wparam: WPARAM, + lparam: LPARAM, + ) -> Option<isize> { + let mut lock = self.state.borrow_mut(); + if let Some(mut func) = lock.callbacks.input.take() { + let scale_factor = lock.scale_factor; + drop(lock); + + let mut cursor_point = POINT { + x: lparam.signed_loword().into(), + y: lparam.signed_hiword().into(), + }; + unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; + let input = PlatformInput::MouseUp(MouseUpEvent { + button, + position: logical_point(cursor_point.x as f32, cursor_point.y as f32, scale_factor), + modifiers: current_modifiers(), + click_count: 1, + }); + let handled = !func(input).propagate; + self.state.borrow_mut().callbacks.input = Some(func); + + if handled { + return Some(0); } - (HTMAXBUTTON, HTMAXBUTTON) => { - if state_ptr.state.borrow().is_maximized() { - unsafe { ShowWindowAsync(handle, SW_NORMAL).ok().log_err() }; - } else { - unsafe { ShowWindowAsync(handle, SW_MAXIMIZE).ok().log_err() }; + } else { + drop(lock); + } + + let last_pressed = self.state.borrow_mut().nc_button_pressed.take(); + if button == MouseButton::Left + && let Some(last_pressed) = last_pressed + { + let handled = match (wparam.0 as u32, last_pressed) { + (HTMINBUTTON, HTMINBUTTON) => { + unsafe { ShowWindowAsync(handle, SW_MINIMIZE).ok().log_err() }; + true } - true - } - (HTCLOSE, HTCLOSE) => { - unsafe { - PostMessageW(Some(handle), WM_CLOSE, WPARAM::default(), LPARAM::default()) - .log_err() - }; - true + (HTMAXBUTTON, HTMAXBUTTON) => { + if self.state.borrow().is_maximized() { + unsafe { ShowWindowAsync(handle, SW_NORMAL).ok().log_err() }; + } else { + unsafe { ShowWindowAsync(handle, SW_MAXIMIZE).ok().log_err() }; + } + true + } + (HTCLOSE, HTCLOSE) => { + unsafe { + PostMessageW(Some(handle), WM_CLOSE, WPARAM::default(), LPARAM::default()) + .log_err() + }; + true + } + _ => false, + }; + if handled { + return Some(0); } - _ => false, - }; - if handled { - return Some(0); } + + None } - None -} + fn handle_cursor_changed(&self, lparam: LPARAM) -> Option<isize> { + let mut state = self.state.borrow_mut(); + let had_cursor = state.current_cursor.is_some(); -fn handle_cursor_changed(lparam: LPARAM, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - let mut state = state_ptr.state.borrow_mut(); - let had_cursor = state.current_cursor.is_some(); + state.current_cursor = if lparam.0 == 0 { + None + } else { + Some(HCURSOR(lparam.0 as _)) + }; - state.current_cursor = if lparam.0 == 0 { - None - } else { - Some(HCURSOR(lparam.0 as _)) - }; + if had_cursor != state.current_cursor.is_some() { + unsafe { SetCursor(state.current_cursor) }; + } - if had_cursor != state.current_cursor.is_some() { - unsafe { SetCursor(state.current_cursor) }; + Some(0) } - Some(0) -} - -fn handle_set_cursor( - handle: HWND, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - if unsafe { !IsWindowEnabled(handle).as_bool() } - || matches!( - lparam.loword() as u32, - HTLEFT - | HTRIGHT - | HTTOP - | HTTOPLEFT - | HTTOPRIGHT - | HTBOTTOM - | HTBOTTOMLEFT - | HTBOTTOMRIGHT - ) - { - return None; + fn handle_set_cursor(&self, handle: HWND, lparam: LPARAM) -> Option<isize> { + if unsafe { !IsWindowEnabled(handle).as_bool() } + || matches!( + lparam.loword() as u32, + HTLEFT + | HTRIGHT + | HTTOP + | HTTOPLEFT + | HTTOPRIGHT + | HTBOTTOM + | HTBOTTOMLEFT + | HTBOTTOMRIGHT + ) + { + return None; + } + unsafe { + SetCursor(self.state.borrow().current_cursor); + }; + Some(1) } - unsafe { - SetCursor(state_ptr.state.borrow().current_cursor); - }; - Some(1) -} -fn handle_system_settings_changed( - handle: HWND, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - if wparam.0 != 0 { - let mut lock = state_ptr.state.borrow_mut(); - let display = lock.display; - lock.system_settings.update(display, wparam.0); - lock.click_state.system_update(wparam.0); - lock.border_offset.update(handle).log_err(); - } else { - handle_system_theme_changed(handle, lparam, state_ptr)?; - }; - // Force to trigger WM_NCCALCSIZE event to ensure that we handle auto hide - // taskbar correctly. - notify_frame_changed(handle); + fn handle_system_settings_changed( + &self, + handle: HWND, + wparam: WPARAM, + lparam: LPARAM, + ) -> Option<isize> { + if wparam.0 != 0 { + let mut lock = self.state.borrow_mut(); + let display = lock.display; + lock.system_settings.update(display, wparam.0); + lock.click_state.system_update(wparam.0); + lock.border_offset.update(handle).log_err(); + } else { + self.handle_system_theme_changed(handle, lparam)?; + }; + // Force to trigger WM_NCCALCSIZE event to ensure that we handle auto hide + // taskbar correctly. + notify_frame_changed(handle); - Some(0) -} + Some(0) + } -fn handle_system_command(wparam: WPARAM, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - if wparam.0 == SC_KEYMENU as usize { - let mut lock = state_ptr.state.borrow_mut(); - if lock.system_key_handled { - lock.system_key_handled = false; - return Some(0); + fn handle_system_command(&self, wparam: WPARAM) -> Option<isize> { + if wparam.0 == SC_KEYMENU as usize { + let mut lock = self.state.borrow_mut(); + if lock.system_key_handled { + lock.system_key_handled = false; + return Some(0); + } } + None } - None -} -fn handle_system_theme_changed( - handle: HWND, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - // lParam is a pointer to a string that indicates the area containing the system parameter - // that was changed. - let parameter = PCWSTR::from_raw(lparam.0 as _); - if unsafe { !parameter.is_null() && !parameter.is_empty() } { - if let Some(parameter_string) = unsafe { parameter.to_string() }.log_err() { - log::info!("System settings changed: {}", parameter_string); - match parameter_string.as_str() { - "ImmersiveColorSet" => { - let new_appearance = system_appearance() - .context("unable to get system appearance when handling ImmersiveColorSet") - .log_err()?; - let mut lock = state_ptr.state.borrow_mut(); - if new_appearance != lock.appearance { - lock.appearance = new_appearance; - let mut callback = lock.callbacks.appearance_changed.take()?; - drop(lock); - callback(); - state_ptr.state.borrow_mut().callbacks.appearance_changed = Some(callback); - configure_dwm_dark_mode(handle, new_appearance); + fn handle_system_theme_changed(&self, handle: HWND, lparam: LPARAM) -> Option<isize> { + // lParam is a pointer to a string that indicates the area containing the system parameter + // that was changed. + let parameter = PCWSTR::from_raw(lparam.0 as _); + if unsafe { !parameter.is_null() && !parameter.is_empty() } { + if let Some(parameter_string) = unsafe { parameter.to_string() }.log_err() { + log::info!("System settings changed: {}", parameter_string); + match parameter_string.as_str() { + "ImmersiveColorSet" => { + let new_appearance = system_appearance() + .context( + "unable to get system appearance when handling ImmersiveColorSet", + ) + .log_err()?; + let mut lock = self.state.borrow_mut(); + if new_appearance != lock.appearance { + lock.appearance = new_appearance; + let mut callback = lock.callbacks.appearance_changed.take()?; + drop(lock); + callback(); + self.state.borrow_mut().callbacks.appearance_changed = Some(callback); + configure_dwm_dark_mode(handle, new_appearance); + } } + _ => {} } - _ => {} } } + Some(0) } - Some(0) -} -fn handle_input_language_changed( - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let thread = state_ptr.main_thread_id_win32; - let validation = state_ptr.validation_number; - unsafe { - PostThreadMessageW(thread, WM_INPUTLANGCHANGE, WPARAM(validation), lparam).log_err(); + fn handle_input_language_changed(&self, lparam: LPARAM) -> Option<isize> { + let thread = self.main_thread_id_win32; + let validation = self.validation_number; + unsafe { + PostThreadMessageW(thread, WM_INPUTLANGCHANGE, WPARAM(validation), lparam).log_err(); + } + Some(0) } - Some(0) -} -#[inline] -fn parse_char_message(wparam: WPARAM, state_ptr: &Rc<WindowsWindowStatePtr>) -> Option<String> { - let code_point = wparam.loword(); - let mut lock = state_ptr.state.borrow_mut(); - // https://www.unicode.org/versions/Unicode16.0.0/core-spec/chapter-3/#G2630 - match code_point { - 0xD800..=0xDBFF => { - // High surrogate, wait for low surrogate - lock.pending_surrogate = Some(code_point); + fn handle_device_change_msg(&self, handle: HWND, wparam: WPARAM) -> Option<isize> { + if wparam.0 == DBT_DEVNODES_CHANGED as usize { + // The reason for sending this message is to actually trigger a redraw of the window. + unsafe { + PostMessageW( + Some(handle), + WM_GPUI_FORCE_UPDATE_WINDOW, + WPARAM(0), + LPARAM(0), + ) + .log_err(); + } + // If the GPU device is lost, this redraw will take care of recreating the device context. + // The WM_GPUI_FORCE_UPDATE_WINDOW message will take care of redrawing the window, after + // the device context has been recreated. + self.draw_window(handle, true) + } else { + // Other device change messages are not handled. None } - 0xDC00..=0xDFFF => { - if let Some(high_surrogate) = lock.pending_surrogate.take() { - // Low surrogate, combine with pending high surrogate - String::from_utf16(&[high_surrogate, code_point]).ok() - } else { - // Invalid low surrogate without a preceding high surrogate - log::warn!( - "Received low surrogate without a preceding high surrogate: {code_point:x}" - ); + } + + #[inline] + fn draw_window(&self, handle: HWND, force_render: bool) -> Option<isize> { + let mut request_frame = self.state.borrow_mut().callbacks.request_frame.take()?; + request_frame(RequestFrameOptions { + require_presentation: false, + force_render, + }); + self.state.borrow_mut().callbacks.request_frame = Some(request_frame); + unsafe { ValidateRect(Some(handle), None).ok().log_err() }; + Some(0) + } + + #[inline] + fn parse_char_message(&self, wparam: WPARAM) -> Option<String> { + let code_point = wparam.loword(); + let mut lock = self.state.borrow_mut(); + // https://www.unicode.org/versions/Unicode16.0.0/core-spec/chapter-3/#G2630 + match code_point { + 0xD800..=0xDBFF => { + // High surrogate, wait for low surrogate + lock.pending_surrogate = Some(code_point); None } + 0xDC00..=0xDFFF => { + if let Some(high_surrogate) = lock.pending_surrogate.take() { + // Low surrogate, combine with pending high surrogate + String::from_utf16(&[high_surrogate, code_point]).ok() + } else { + // Invalid low surrogate without a preceding high surrogate + log::warn!( + "Received low surrogate without a preceding high surrogate: {code_point:x}" + ); + None + } + } + _ => { + lock.pending_surrogate = None; + char::from_u32(code_point as u32) + .filter(|c| !c.is_control()) + .map(|c| c.to_string()) + } } - _ => { - lock.pending_surrogate = None; - char::from_u32(code_point as u32) - .filter(|c| !c.is_control()) - .map(|c| c.to_string()) + } + + fn start_tracking_mouse(&self, handle: HWND, flags: TRACKMOUSEEVENT_FLAGS) { + let mut lock = self.state.borrow_mut(); + if !lock.hovered { + lock.hovered = true; + unsafe { + TrackMouseEvent(&mut TRACKMOUSEEVENT { + cbSize: std::mem::size_of::<TRACKMOUSEEVENT>() as u32, + dwFlags: flags, + hwndTrack: handle, + dwHoverTime: HOVER_DEFAULT, + }) + .log_err() + }; + if let Some(mut callback) = lock.callbacks.hovered_status_change.take() { + drop(lock); + callback(true); + self.state.borrow_mut().callbacks.hovered_status_change = Some(callback); + } } } + + fn with_input_handler<F, R>(&self, f: F) -> Option<R> + where + F: FnOnce(&mut PlatformInputHandler) -> R, + { + let mut input_handler = self.state.borrow_mut().input_handler.take()?; + let result = f(&mut input_handler); + self.state.borrow_mut().input_handler = Some(input_handler); + Some(result) + } + + fn with_input_handler_and_scale_factor<F, R>(&self, f: F) -> Option<R> + where + F: FnOnce(&mut PlatformInputHandler, f32) -> Option<R>, + { + let mut lock = self.state.borrow_mut(); + let mut input_handler = lock.input_handler.take()?; + let scale_factor = lock.scale_factor; + drop(lock); + let result = f(&mut input_handler, scale_factor); + self.state.borrow_mut().input_handler = Some(input_handler); + result + } } #[inline] @@ -1521,54 +1537,3 @@ fn notify_frame_changed(handle: HWND) { .log_err(); } } - -fn start_tracking_mouse( - handle: HWND, - state_ptr: &Rc<WindowsWindowStatePtr>, - flags: TRACKMOUSEEVENT_FLAGS, -) { - let mut lock = state_ptr.state.borrow_mut(); - if !lock.hovered { - lock.hovered = true; - unsafe { - TrackMouseEvent(&mut TRACKMOUSEEVENT { - cbSize: std::mem::size_of::<TRACKMOUSEEVENT>() as u32, - dwFlags: flags, - hwndTrack: handle, - dwHoverTime: HOVER_DEFAULT, - }) - .log_err() - }; - if let Some(mut callback) = lock.callbacks.hovered_status_change.take() { - drop(lock); - callback(true); - state_ptr.state.borrow_mut().callbacks.hovered_status_change = Some(callback); - } - } -} - -fn with_input_handler<F, R>(state_ptr: &Rc<WindowsWindowStatePtr>, f: F) -> Option<R> -where - F: FnOnce(&mut PlatformInputHandler) -> R, -{ - let mut input_handler = state_ptr.state.borrow_mut().input_handler.take()?; - let result = f(&mut input_handler); - state_ptr.state.borrow_mut().input_handler = Some(input_handler); - Some(result) -} - -fn with_input_handler_and_scale_factor<F, R>( - state_ptr: &Rc<WindowsWindowStatePtr>, - f: F, -) -> Option<R> -where - F: FnOnce(&mut PlatformInputHandler, f32) -> Option<R>, -{ - let mut lock = state_ptr.state.borrow_mut(); - let mut input_handler = lock.input_handler.take()?; - let scale_factor = lock.scale_factor; - drop(lock); - let result = f(&mut input_handler, scale_factor); - state_ptr.state.borrow_mut().input_handler = Some(input_handler); - result -} diff --git a/crates/gpui/src/platform/windows/keyboard.rs b/crates/gpui/src/platform/windows/keyboard.rs index f5a148a97e986c8c33d30cb516474d8103d0c5f7..371feb70c25ab593ce612c7a90381a4cffdeff7d 100644 --- a/crates/gpui/src/platform/windows/keyboard.rs +++ b/crates/gpui/src/platform/windows/keyboard.rs @@ -130,11 +130,13 @@ pub(crate) fn generate_key_char( let mut buffer = [0; 8]; let len = unsafe { ToUnicode(vkey.0 as u32, scan_code, Some(&state), &mut buffer, 1 << 2) }; - if len > 0 { - let candidate = String::from_utf16_lossy(&buffer[..len as usize]); - if !candidate.is_empty() && !candidate.chars().next().unwrap().is_control() { - return Some(candidate); - } + match len { + len if len > 0 => String::from_utf16(&buffer[..len as usize]) + .ok() + .filter(|candidate| { + !candidate.is_empty() && !candidate.chars().next().unwrap().is_control() + }), + len if len < 0 => String::from_utf16(&buffer[..(-len as usize)]).ok(), + _ => None, } - None } diff --git a/crates/gpui/src/platform/windows/platform.rs b/crates/gpui/src/platform/windows/platform.rs index f69a802da07fab1636404e3aae0dfd8487d69479..01b043a755240868699ad7872b4e46202b5bd90d 100644 --- a/crates/gpui/src/platform/windows/platform.rs +++ b/crates/gpui/src/platform/windows/platform.rs @@ -28,13 +28,12 @@ use windows::{ core::*, }; -use crate::{platform::blade::BladeContext, *}; +use crate::*; pub(crate) struct WindowsPlatform { state: RefCell<WindowsPlatformState>, raw_window_handles: RwLock<SmallVec<[HWND; 4]>>, // The below members will never change throughout the entire lifecycle of the app. - gpu_context: BladeContext, icon: HICON, main_receiver: flume::Receiver<Runnable>, background_executor: BackgroundExecutor, @@ -45,6 +44,7 @@ pub(crate) struct WindowsPlatform { drop_target_helper: IDropTargetHelper, validation_number: usize, main_thread_id_win32: u32, + disable_direct_composition: bool, } pub(crate) struct WindowsPlatformState { @@ -94,14 +94,18 @@ impl WindowsPlatform { main_thread_id_win32, validation_number, )); + let disable_direct_composition = std::env::var(DISABLE_DIRECT_COMPOSITION) + .is_ok_and(|value| value == "true" || value == "1"); let background_executor = BackgroundExecutor::new(dispatcher.clone()); let foreground_executor = ForegroundExecutor::new(dispatcher); + let directx_devices = DirectXDevices::new(disable_direct_composition) + .context("Unable to init directx devices.")?; let bitmap_factory = ManuallyDrop::new(unsafe { CoCreateInstance(&CLSID_WICImagingFactory, None, CLSCTX_INPROC_SERVER) .context("Error creating bitmap factory.")? }); let text_system = Arc::new( - DirectWriteTextSystem::new(&bitmap_factory) + DirectWriteTextSystem::new(&directx_devices, &bitmap_factory) .context("Error creating DirectWriteTextSystem")?, ); let drop_target_helper: IDropTargetHelper = unsafe { @@ -111,18 +115,17 @@ impl WindowsPlatform { let icon = load_icon().unwrap_or_default(); let state = RefCell::new(WindowsPlatformState::new()); let raw_window_handles = RwLock::new(SmallVec::new()); - let gpu_context = BladeContext::new().context("Unable to init GPU context")?; let windows_version = WindowsVersion::new().context("Error retrieve windows version")?; Ok(Self { state, raw_window_handles, - gpu_context, icon, main_receiver, background_executor, foreground_executor, text_system, + disable_direct_composition, windows_version, bitmap_factory, drop_target_helper, @@ -141,12 +144,12 @@ impl WindowsPlatform { } } - pub fn try_get_windows_inner_from_hwnd(&self, hwnd: HWND) -> Option<Rc<WindowsWindowStatePtr>> { + pub fn window_from_hwnd(&self, hwnd: HWND) -> Option<Rc<WindowsWindowInner>> { self.raw_window_handles .read() .iter() .find(|entry| *entry == &hwnd) - .and_then(|hwnd| try_get_window_inner(*hwnd)) + .and_then(|hwnd| window_from_hwnd(*hwnd)) } #[inline] @@ -187,6 +190,7 @@ impl WindowsPlatform { validation_number: self.validation_number, main_receiver: self.main_receiver.clone(), main_thread_id_win32: self.main_thread_id_win32, + disable_direct_composition: self.disable_direct_composition, } } @@ -343,27 +347,11 @@ impl Platform for WindowsPlatform { fn run(&self, on_finish_launching: Box<dyn 'static + FnOnce()>) { on_finish_launching(); - let vsync_event = unsafe { Owned::new(CreateEventW(None, false, false, None).unwrap()) }; - begin_vsync(*vsync_event); - 'a: loop { - let wait_result = unsafe { - MsgWaitForMultipleObjects(Some(&[*vsync_event]), false, INFINITE, QS_ALLINPUT) - }; - - match wait_result { - // compositor clock ticked so we should draw a frame - WAIT_EVENT(0) => self.redraw_all(), - // Windows thread messages are posted - WAIT_EVENT(1) => { - if self.handle_events() { - break 'a; - } - } - _ => { - log::error!("Something went wrong while waiting {:?}", wait_result); - break; - } + loop { + if self.handle_events() { + break; } + self.redraw_all(); } if let Some(ref mut callback) = self.state.borrow_mut().callbacks.quit { @@ -440,13 +428,13 @@ impl Platform for WindowsPlatform { #[cfg(feature = "screen-capture")] fn screen_capture_sources( &self, - ) -> oneshot::Receiver<Result<Vec<Box<dyn ScreenCaptureSource>>>> { + ) -> oneshot::Receiver<Result<Vec<Rc<dyn ScreenCaptureSource>>>> { crate::platform::scap_screen_capture::scap_screen_sources(&self.foreground_executor) } fn active_window(&self) -> Option<AnyWindowHandle> { let active_window_hwnd = unsafe { GetActiveWindow() }; - self.try_get_windows_inner_from_hwnd(active_window_hwnd) + self.window_from_hwnd(active_window_hwnd) .map(|inner| inner.handle) } @@ -455,12 +443,7 @@ impl Platform for WindowsPlatform { handle: AnyWindowHandle, options: WindowParams, ) -> Result<Box<dyn PlatformWindow>> { - let window = WindowsWindow::new( - handle, - options, - self.generate_creation_info(), - &self.gpu_context, - )?; + let window = WindowsWindow::new(handle, options, self.generate_creation_info())?; let handle = window.get_raw_handle(); self.raw_window_handles.write().push(handle); @@ -739,6 +722,7 @@ pub(crate) struct WindowCreationInfo { pub(crate) validation_number: usize, pub(crate) main_receiver: flume::Receiver<Runnable>, pub(crate) main_thread_id_win32: u32, + pub(crate) disable_direct_composition: bool, } fn open_target(target: &str) { @@ -846,16 +830,6 @@ fn file_save_dialog(directory: PathBuf, window: Option<HWND>) -> Result<Option<P Ok(Some(PathBuf::from(file_path_string))) } -fn begin_vsync(vsync_event: HANDLE) { - let event: SafeHandle = vsync_event.into(); - std::thread::spawn(move || unsafe { - loop { - windows::Win32::Graphics::Dwm::DwmFlush().log_err(); - SetEvent(*event).log_err(); - } - }); -} - fn load_icon() -> Result<HICON> { let module = unsafe { GetModuleHandleW(None).context("unable to get module handle")? }; let handle = unsafe { diff --git a/crates/gpui/src/platform/windows/shaders.hlsl b/crates/gpui/src/platform/windows/shaders.hlsl new file mode 100644 index 0000000000000000000000000000000000000000..25830e4b6c3183772aaa3c0a73bc1cef33b908a9 --- /dev/null +++ b/crates/gpui/src/platform/windows/shaders.hlsl @@ -0,0 +1,1159 @@ +cbuffer GlobalParams: register(b0) { + float2 global_viewport_size; + uint2 _pad; +}; + +Texture2D<float4> t_sprite: register(t0); +SamplerState s_sprite: register(s0); + +struct Bounds { + float2 origin; + float2 size; +}; + +struct Corners { + float top_left; + float top_right; + float bottom_right; + float bottom_left; +}; + +struct Edges { + float top; + float right; + float bottom; + float left; +}; + +struct Hsla { + float h; + float s; + float l; + float a; +}; + +struct LinearColorStop { + Hsla color; + float percentage; +}; + +struct Background { + // 0u is Solid + // 1u is LinearGradient + // 2u is PatternSlash + uint tag; + // 0u is sRGB linear color + // 1u is Oklab color + uint color_space; + Hsla solid; + float gradient_angle_or_pattern_height; + LinearColorStop colors[2]; + uint pad; +}; + +struct GradientColor { + float4 solid; + float4 color0; + float4 color1; +}; + +struct AtlasTextureId { + uint index; + uint kind; +}; + +struct AtlasBounds { + int2 origin; + int2 size; +}; + +struct AtlasTile { + AtlasTextureId texture_id; + uint tile_id; + uint padding; + AtlasBounds bounds; +}; + +struct TransformationMatrix { + float2x2 rotation_scale; + float2 translation; +}; + +static const float M_PI_F = 3.141592653f; +static const float3 GRAYSCALE_FACTORS = float3(0.2126f, 0.7152f, 0.0722f); + +float4 to_device_position_impl(float2 position) { + float2 device_position = position / global_viewport_size * float2(2.0, -2.0) + float2(-1.0, 1.0); + return float4(device_position, 0., 1.); +} + +float4 to_device_position(float2 unit_vertex, Bounds bounds) { + float2 position = unit_vertex * bounds.size + bounds.origin; + return to_device_position_impl(position); +} + +float4 distance_from_clip_rect_impl(float2 position, Bounds clip_bounds) { + float2 tl = position - clip_bounds.origin; + float2 br = clip_bounds.origin + clip_bounds.size - position; + return float4(tl.x, br.x, tl.y, br.y); +} + +float4 distance_from_clip_rect(float2 unit_vertex, Bounds bounds, Bounds clip_bounds) { + float2 position = unit_vertex * bounds.size + bounds.origin; + return distance_from_clip_rect_impl(position, clip_bounds); +} + +// Convert linear RGB to sRGB +float3 linear_to_srgb(float3 color) { + return pow(color, float3(2.2, 2.2, 2.2)); +} + +// Convert sRGB to linear RGB +float3 srgb_to_linear(float3 color) { + return pow(color, float3(1.0 / 2.2, 1.0 / 2.2, 1.0 / 2.2)); +} + +/// Hsla to linear RGBA conversion. +float4 hsla_to_rgba(Hsla hsla) { + float h = hsla.h * 6.0; // Now, it's an angle but scaled in [0, 6) range + float s = hsla.s; + float l = hsla.l; + float a = hsla.a; + + float c = (1.0 - abs(2.0 * l - 1.0)) * s; + float x = c * (1.0 - abs(fmod(h, 2.0) - 1.0)); + float m = l - c / 2.0; + + float r = 0.0; + float g = 0.0; + float b = 0.0; + + if (h >= 0.0 && h < 1.0) { + r = c; + g = x; + b = 0.0; + } else if (h >= 1.0 && h < 2.0) { + r = x; + g = c; + b = 0.0; + } else if (h >= 2.0 && h < 3.0) { + r = 0.0; + g = c; + b = x; + } else if (h >= 3.0 && h < 4.0) { + r = 0.0; + g = x; + b = c; + } else if (h >= 4.0 && h < 5.0) { + r = x; + g = 0.0; + b = c; + } else { + r = c; + g = 0.0; + b = x; + } + + float4 rgba; + rgba.x = (r + m); + rgba.y = (g + m); + rgba.z = (b + m); + rgba.w = a; + return rgba; +} + +// Converts a sRGB color to the Oklab color space. +// Reference: https://bottosson.github.io/posts/oklab/#converting-from-linear-srgb-to-oklab +float4 srgb_to_oklab(float4 color) { + // Convert non-linear sRGB to linear sRGB + color = float4(srgb_to_linear(color.rgb), color.a); + + float l = 0.4122214708 * color.r + 0.5363325363 * color.g + 0.0514459929 * color.b; + float m = 0.2119034982 * color.r + 0.6806995451 * color.g + 0.1073969566 * color.b; + float s = 0.0883024619 * color.r + 0.2817188376 * color.g + 0.6299787005 * color.b; + + float l_ = pow(l, 1.0/3.0); + float m_ = pow(m, 1.0/3.0); + float s_ = pow(s, 1.0/3.0); + + return float4( + 0.2104542553 * l_ + 0.7936177850 * m_ - 0.0040720468 * s_, + 1.9779984951 * l_ - 2.4285922050 * m_ + 0.4505937099 * s_, + 0.0259040371 * l_ + 0.7827717662 * m_ - 0.8086757660 * s_, + color.a + ); +} + +// Converts an Oklab color to the sRGB color space. +float4 oklab_to_srgb(float4 color) { + float l_ = color.r + 0.3963377774 * color.g + 0.2158037573 * color.b; + float m_ = color.r - 0.1055613458 * color.g - 0.0638541728 * color.b; + float s_ = color.r - 0.0894841775 * color.g - 1.2914855480 * color.b; + + float l = l_ * l_ * l_; + float m = m_ * m_ * m_; + float s = s_ * s_ * s_; + + float3 linear_rgb = float3( + 4.0767416621 * l - 3.3077115913 * m + 0.2309699292 * s, + -1.2684380046 * l + 2.6097574011 * m - 0.3413193965 * s, + -0.0041960863 * l - 0.7034186147 * m + 1.7076147010 * s + ); + + // Convert linear sRGB to non-linear sRGB + return float4(linear_to_srgb(linear_rgb), color.a); +} + +// This approximates the error function, needed for the gaussian integral +float2 erf(float2 x) { + float2 s = sign(x); + float2 a = abs(x); + x = 1. + (0.278393 + (0.230389 + 0.078108 * (a * a)) * a) * a; + x *= x; + return s - s / (x * x); +} + +float blur_along_x(float x, float y, float sigma, float corner, float2 half_size) { + float delta = min(half_size.y - corner - abs(y), 0.); + float curved = half_size.x - corner + sqrt(max(0., corner * corner - delta * delta)); + float2 integral = 0.5 + 0.5 * erf((x + float2(-curved, curved)) * (sqrt(0.5) / sigma)); + return integral.y - integral.x; +} + +// A standard gaussian function, used for weighting samples +float gaussian(float x, float sigma) { + return exp(-(x * x) / (2. * sigma * sigma)) / (sqrt(2. * M_PI_F) * sigma); +} + +float4 over(float4 below, float4 above) { + float4 result; + float alpha = above.a + below.a * (1.0 - above.a); + result.rgb = (above.rgb * above.a + below.rgb * below.a * (1.0 - above.a)) / alpha; + result.a = alpha; + return result; +} + +float2 to_tile_position(float2 unit_vertex, AtlasTile tile) { + float2 atlas_size; + t_sprite.GetDimensions(atlas_size.x, atlas_size.y); + return (float2(tile.bounds.origin) + unit_vertex * float2(tile.bounds.size)) / atlas_size; +} + +// Selects corner radius based on quadrant. +float pick_corner_radius(float2 center_to_point, Corners corner_radii) { + if (center_to_point.x < 0.) { + if (center_to_point.y < 0.) { + return corner_radii.top_left; + } else { + return corner_radii.bottom_left; + } + } else { + if (center_to_point.y < 0.) { + return corner_radii.top_right; + } else { + return corner_radii.bottom_right; + } + } +} + +float4 to_device_position_transformed(float2 unit_vertex, Bounds bounds, + TransformationMatrix transformation) { + float2 position = unit_vertex * bounds.size + bounds.origin; + float2 transformed = mul(position, transformation.rotation_scale) + transformation.translation; + float2 device_position = transformed / global_viewport_size * float2(2.0, -2.0) + float2(-1.0, 1.0); + return float4(device_position, 0.0, 1.0); +} + +// Implementation of quad signed distance field +float quad_sdf_impl(float2 corner_center_to_point, float corner_radius) { + if (corner_radius == 0.0) { + // Fast path for unrounded corners + return max(corner_center_to_point.x, corner_center_to_point.y); + } else { + // Signed distance of the point from a quad that is inset by corner_radius + // It is negative inside this quad, and positive outside + float signed_distance_to_inset_quad = + // 0 inside the inset quad, and positive outside + length(max(float2(0.0, 0.0), corner_center_to_point)) + + // 0 outside the inset quad, and negative inside + min(0.0, max(corner_center_to_point.x, corner_center_to_point.y)); + + return signed_distance_to_inset_quad - corner_radius; + } +} + +float quad_sdf(float2 pt, Bounds bounds, Corners corner_radii) { + float2 half_size = bounds.size / 2.; + float2 center = bounds.origin + half_size; + float2 center_to_point = pt - center; + float corner_radius = pick_corner_radius(center_to_point, corner_radii); + float2 corner_to_point = abs(center_to_point) - half_size; + float2 corner_center_to_point = corner_to_point + corner_radius; + return quad_sdf_impl(corner_center_to_point, corner_radius); +} + +GradientColor prepare_gradient_color(uint tag, uint color_space, Hsla solid, LinearColorStop colors[2]) { + GradientColor output; + if (tag == 0 || tag == 2) { + output.solid = hsla_to_rgba(solid); + } else if (tag == 1) { + output.color0 = hsla_to_rgba(colors[0].color); + output.color1 = hsla_to_rgba(colors[1].color); + + // Prepare color space in vertex for avoid conversion + // in fragment shader for performance reasons + if (color_space == 1) { + // Oklab + output.color0 = srgb_to_oklab(output.color0); + output.color1 = srgb_to_oklab(output.color1); + } + } + + return output; +} + +float2x2 rotate2d(float angle) { + float s = sin(angle); + float c = cos(angle); + return float2x2(c, -s, s, c); +} + +float4 gradient_color(Background background, + float2 position, + Bounds bounds, + float4 solid_color, float4 color0, float4 color1) { + float4 color; + + switch (background.tag) { + case 0: + color = solid_color; + break; + case 1: { + // -90 degrees to match the CSS gradient angle. + float gradient_angle = background.gradient_angle_or_pattern_height; + float radians = (fmod(gradient_angle, 360.0) - 90.0) * (M_PI_F / 180.0); + float2 direction = float2(cos(radians), sin(radians)); + + // Expand the short side to be the same as the long side + if (bounds.size.x > bounds.size.y) { + direction.y *= bounds.size.y / bounds.size.x; + } else { + direction.x *= bounds.size.x / bounds.size.y; + } + + // Get the t value for the linear gradient with the color stop percentages. + float2 half_size = bounds.size * 0.5; + float2 center = bounds.origin + half_size; + float2 center_to_point = position - center; + float t = dot(center_to_point, direction) / length(direction); + // Check the direct to determine the use x or y + if (abs(direction.x) > abs(direction.y)) { + t = (t + half_size.x) / bounds.size.x; + } else { + t = (t + half_size.y) / bounds.size.y; + } + + // Adjust t based on the stop percentages + t = (t - background.colors[0].percentage) + / (background.colors[1].percentage + - background.colors[0].percentage); + t = clamp(t, 0.0, 1.0); + + switch (background.color_space) { + case 0: + color = lerp(color0, color1, t); + break; + case 1: { + float4 oklab_color = lerp(color0, color1, t); + color = oklab_to_srgb(oklab_color); + break; + } + } + break; + } + case 2: { + float gradient_angle_or_pattern_height = background.gradient_angle_or_pattern_height; + float pattern_width = (gradient_angle_or_pattern_height / 65535.0f) / 255.0f; + float pattern_interval = fmod(gradient_angle_or_pattern_height, 65535.0f) / 255.0f; + float pattern_height = pattern_width + pattern_interval; + float stripe_angle = M_PI_F / 4.0; + float pattern_period = pattern_height * sin(stripe_angle); + float2x2 rotation = rotate2d(stripe_angle); + float2 relative_position = position - bounds.origin; + float2 rotated_point = mul(rotation, relative_position); + float pattern = fmod(rotated_point.x, pattern_period); + float distance = min(pattern, pattern_period - pattern) - pattern_period * (pattern_width / pattern_height) / 2.0f; + color = solid_color; + color.a *= saturate(0.5 - distance); + break; + } + } + + return color; +} + +// Returns the dash velocity of a corner given the dash velocity of the two +// sides, by returning the slower velocity (larger dashes). +// +// Since 0 is used for dash velocity when the border width is 0 (instead of +// +inf), this returns the other dash velocity in that case. +// +// An alternative to this might be to appropriately interpolate the dash +// velocity around the corner, but that seems overcomplicated. +float corner_dash_velocity(float dv1, float dv2) { + if (dv1 == 0.0) { + return dv2; + } else if (dv2 == 0.0) { + return dv1; + } else { + return min(dv1, dv2); + } +} + +// Returns alpha used to render antialiased dashes. +// `t` is within the dash when `fmod(t, period) < length`. +float dash_alpha( + float t, float period, float length, float dash_velocity, + float antialias_threshold +) { + float half_period = period / 2.0; + float half_length = length / 2.0; + // Value in [-half_period, half_period] + // The dash is in [-half_length, half_length] + float centered = fmod(t + half_period - half_length, period) - half_period; + // Signed distance for the dash, negative values are inside the dash + float signed_distance = abs(centered) - half_length; + // Antialiased alpha based on the signed distance + return saturate(antialias_threshold - signed_distance / dash_velocity); +} + +// This approximates distance to the nearest point to a quarter ellipse in a way +// that is sufficient for anti-aliasing when the ellipse is not very eccentric. +// The components of `point` are expected to be positive. +// +// Negative on the outside and positive on the inside. +float quarter_ellipse_sdf(float2 pt, float2 radii) { + // Scale the space to treat the ellipse like a unit circle + float2 circle_vec = pt / radii; + float unit_circle_sdf = length(circle_vec) - 1.0; + // Approximate up-scaling of the length by using the average of the radii. + // + // TODO: A better solution would be to use the gradient of the implicit + // function for an ellipse to approximate a scaling factor. + return unit_circle_sdf * (radii.x + radii.y) * -0.5; +} + +/* +** +** Quads +** +*/ + +struct Quad { + uint order; + uint border_style; + Bounds bounds; + Bounds content_mask; + Background background; + Hsla border_color; + Corners corner_radii; + Edges border_widths; +}; + +struct QuadVertexOutput { + nointerpolation uint quad_id: TEXCOORD0; + float4 position: SV_Position; + nointerpolation float4 border_color: COLOR0; + nointerpolation float4 background_solid: COLOR1; + nointerpolation float4 background_color0: COLOR2; + nointerpolation float4 background_color1: COLOR3; + float4 clip_distance: SV_ClipDistance; +}; + +struct QuadFragmentInput { + nointerpolation uint quad_id: TEXCOORD0; + float4 position: SV_Position; + nointerpolation float4 border_color: COLOR0; + nointerpolation float4 background_solid: COLOR1; + nointerpolation float4 background_color0: COLOR2; + nointerpolation float4 background_color1: COLOR3; +}; + +StructuredBuffer<Quad> quads: register(t1); + +QuadVertexOutput quad_vertex(uint vertex_id: SV_VertexID, uint quad_id: SV_InstanceID) { + float2 unit_vertex = float2(float(vertex_id & 1u), 0.5 * float(vertex_id & 2u)); + Quad quad = quads[quad_id]; + float4 device_position = to_device_position(unit_vertex, quad.bounds); + + GradientColor gradient = prepare_gradient_color( + quad.background.tag, + quad.background.color_space, + quad.background.solid, + quad.background.colors + ); + float4 clip_distance = distance_from_clip_rect(unit_vertex, quad.bounds, quad.content_mask); + float4 border_color = hsla_to_rgba(quad.border_color); + + QuadVertexOutput output; + output.position = device_position; + output.border_color = border_color; + output.quad_id = quad_id; + output.background_solid = gradient.solid; + output.background_color0 = gradient.color0; + output.background_color1 = gradient.color1; + output.clip_distance = clip_distance; + return output; +} + +float4 quad_fragment(QuadFragmentInput input): SV_Target { + Quad quad = quads[input.quad_id]; + float4 background_color = gradient_color(quad.background, input.position.xy, quad.bounds, + input.background_solid, input.background_color0, input.background_color1); + + bool unrounded = quad.corner_radii.top_left == 0.0 && + quad.corner_radii.top_right == 0.0 && + quad.corner_radii.bottom_left == 0.0 && + quad.corner_radii.bottom_right == 0.0; + + // Fast path when the quad is not rounded and doesn't have any border + if (quad.border_widths.top == 0.0 && + quad.border_widths.left == 0.0 && + quad.border_widths.right == 0.0 && + quad.border_widths.bottom == 0.0 && + unrounded) { + return background_color; + } + + float2 size = quad.bounds.size; + float2 half_size = size / 2.; + float2 the_point = input.position.xy - quad.bounds.origin; + float2 center_to_point = the_point - half_size; + + // Signed distance field threshold for inclusion of pixels. 0.5 is the + // minimum distance between the center of the pixel and the edge. + const float antialias_threshold = 0.5; + + // Radius of the nearest corner + float corner_radius = pick_corner_radius(center_to_point, quad.corner_radii); + + float2 border = float2( + center_to_point.x < 0.0 ? quad.border_widths.left : quad.border_widths.right, + center_to_point.y < 0.0 ? quad.border_widths.top : quad.border_widths.bottom + ); + + // 0-width borders are reduced so that `inner_sdf >= antialias_threshold`. + // The purpose of this is to not draw antialiasing pixels in this case. + float2 reduced_border = float2( + border.x == 0.0 ? -antialias_threshold : border.x, + border.y == 0.0 ? -antialias_threshold : border.y + ); + + // Vector from the corner of the quad bounds to the point, after mirroring + // the point into the bottom right quadrant. Both components are <= 0. + float2 corner_to_point = abs(center_to_point) - half_size; + + // Vector from the point to the center of the rounded corner's circle, also + // mirrored into bottom right quadrant. + float2 corner_center_to_point = corner_to_point + corner_radius; + + // Whether the nearest point on the border is rounded + bool is_near_rounded_corner = + corner_center_to_point.x >= 0.0 && + corner_center_to_point.y >= 0.0; + + // Vector from straight border inner corner to point. + // + // 0-width borders are turned into width -1 so that inner_sdf is > 1.0 near + // the border. Without this, antialiasing pixels would be drawn. + float2 straight_border_inner_corner_to_point = corner_to_point + reduced_border; + + // Whether the point is beyond the inner edge of the straight border + bool is_beyond_inner_straight_border = + straight_border_inner_corner_to_point.x > 0.0 || + straight_border_inner_corner_to_point.y > 0.0; + + // Whether the point is far enough inside the quad, such that the pixels are + // not affected by the straight border. + bool is_within_inner_straight_border = + straight_border_inner_corner_to_point.x < -antialias_threshold && + straight_border_inner_corner_to_point.y < -antialias_threshold; + + // Fast path for points that must be part of the background + if (is_within_inner_straight_border && !is_near_rounded_corner) { + return background_color; + } + + // Signed distance of the point to the outside edge of the quad's border + float outer_sdf = quad_sdf_impl(corner_center_to_point, corner_radius); + + // Approximate signed distance of the point to the inside edge of the quad's + // border. It is negative outside this edge (within the border), and + // positive inside. + // + // This is not always an accurate signed distance: + // * The rounded portions with varying border width use an approximation of + // nearest-point-on-ellipse. + // * When it is quickly known to be outside the edge, -1.0 is used. + float inner_sdf = 0.0; + if (corner_center_to_point.x <= 0.0 || corner_center_to_point.y <= 0.0) { + // Fast paths for straight borders + inner_sdf = -max(straight_border_inner_corner_to_point.x, + straight_border_inner_corner_to_point.y); + } else if (is_beyond_inner_straight_border) { + // Fast path for points that must be outside the inner edge + inner_sdf = -1.0; + } else if (reduced_border.x == reduced_border.y) { + // Fast path for circular inner edge. + inner_sdf = -(outer_sdf + reduced_border.x); + } else { + float2 ellipse_radii = max(float2(0.0, 0.0), float2(corner_radius, corner_radius) - reduced_border); + inner_sdf = quarter_ellipse_sdf(corner_center_to_point, ellipse_radii); + } + + // Negative when inside the border + float border_sdf = max(inner_sdf, outer_sdf); + + float4 color = background_color; + if (border_sdf < antialias_threshold) { + float4 border_color = input.border_color; + // Dashed border logic when border_style == 1 + if (quad.border_style == 1) { + // Position along the perimeter in "dash space", where each dash + // period has length 1 + float t = 0.0; + + // Total number of dash periods, so that the dash spacing can be + // adjusted to evenly divide it + float max_t = 0.0; + + // Border width is proportional to dash size. This is the behavior + // used by browsers, but also avoids dashes from different segments + // overlapping when dash size is smaller than the border width. + // + // Dash pattern: (2 * border width) dash, (1 * border width) gap + const float dash_length_per_width = 2.0; + const float dash_gap_per_width = 1.0; + const float dash_period_per_width = dash_length_per_width + dash_gap_per_width; + + // Since the dash size is determined by border width, the density of + // dashes varies. Multiplying a pixel distance by this returns a + // position in dash space - it has units (dash period / pixels). So + // a dash velocity of (1 / 10) is 1 dash every 10 pixels. + float dash_velocity = 0.0; + + // Dividing this by the border width gives the dash velocity + const float dv_numerator = 1.0 / dash_period_per_width; + + if (unrounded) { + // When corners aren't rounded, the dashes are separately laid + // out on each straight line, rather than around the whole + // perimeter. This way each line starts and ends with a dash. + bool is_horizontal = corner_center_to_point.x < corner_center_to_point.y; + float border_width = is_horizontal ? border.x : border.y; + dash_velocity = dv_numerator / border_width; + t = is_horizontal ? the_point.x : the_point.y; + t *= dash_velocity; + max_t = is_horizontal ? size.x : size.y; + max_t *= dash_velocity; + } else { + // When corners are rounded, the dashes are laid out clockwise + // around the whole perimeter. + + float r_tr = quad.corner_radii.top_right; + float r_br = quad.corner_radii.bottom_right; + float r_bl = quad.corner_radii.bottom_left; + float r_tl = quad.corner_radii.top_left; + + float w_t = quad.border_widths.top; + float w_r = quad.border_widths.right; + float w_b = quad.border_widths.bottom; + float w_l = quad.border_widths.left; + + // Straight side dash velocities + float dv_t = w_t <= 0.0 ? 0.0 : dv_numerator / w_t; + float dv_r = w_r <= 0.0 ? 0.0 : dv_numerator / w_r; + float dv_b = w_b <= 0.0 ? 0.0 : dv_numerator / w_b; + float dv_l = w_l <= 0.0 ? 0.0 : dv_numerator / w_l; + + // Straight side lengths in dash space + float s_t = (size.x - r_tl - r_tr) * dv_t; + float s_r = (size.y - r_tr - r_br) * dv_r; + float s_b = (size.x - r_br - r_bl) * dv_b; + float s_l = (size.y - r_bl - r_tl) * dv_l; + + float corner_dash_velocity_tr = corner_dash_velocity(dv_t, dv_r); + float corner_dash_velocity_br = corner_dash_velocity(dv_b, dv_r); + float corner_dash_velocity_bl = corner_dash_velocity(dv_b, dv_l); + float corner_dash_velocity_tl = corner_dash_velocity(dv_t, dv_l); + + // Corner lengths in dash space + float c_tr = r_tr * (M_PI_F / 2.0) * corner_dash_velocity_tr; + float c_br = r_br * (M_PI_F / 2.0) * corner_dash_velocity_br; + float c_bl = r_bl * (M_PI_F / 2.0) * corner_dash_velocity_bl; + float c_tl = r_tl * (M_PI_F / 2.0) * corner_dash_velocity_tl; + + // Cumulative dash space upto each segment + float upto_tr = s_t; + float upto_r = upto_tr + c_tr; + float upto_br = upto_r + s_r; + float upto_b = upto_br + c_br; + float upto_bl = upto_b + s_b; + float upto_l = upto_bl + c_bl; + float upto_tl = upto_l + s_l; + max_t = upto_tl + c_tl; + + if (is_near_rounded_corner) { + float radians = atan2(corner_center_to_point.y, corner_center_to_point.x); + float corner_t = radians * corner_radius; + + if (center_to_point.x >= 0.0) { + if (center_to_point.y < 0.0) { + dash_velocity = corner_dash_velocity_tr; + // Subtracted because radians is pi/2 to 0 when + // going clockwise around the top right corner, + // since the y axis has been flipped + t = upto_r - corner_t * dash_velocity; + } else { + dash_velocity = corner_dash_velocity_br; + // Added because radians is 0 to pi/2 when going + // clockwise around the bottom-right corner + t = upto_br + corner_t * dash_velocity; + } + } else { + if (center_to_point.y >= 0.0) { + dash_velocity = corner_dash_velocity_bl; + // Subtracted because radians is pi/1 to 0 when + // going clockwise around the bottom-left corner, + // since the x axis has been flipped + t = upto_l - corner_t * dash_velocity; + } else { + dash_velocity = corner_dash_velocity_tl; + // Added because radians is 0 to pi/2 when going + // clockwise around the top-left corner, since both + // axis were flipped + t = upto_tl + corner_t * dash_velocity; + } + } + } else { + // Straight borders + bool is_horizontal = corner_center_to_point.x < corner_center_to_point.y; + if (is_horizontal) { + if (center_to_point.y < 0.0) { + dash_velocity = dv_t; + t = (the_point.x - r_tl) * dash_velocity; + } else { + dash_velocity = dv_b; + t = upto_bl - (the_point.x - r_bl) * dash_velocity; + } + } else { + if (center_to_point.x < 0.0) { + dash_velocity = dv_l; + t = upto_tl - (the_point.y - r_tl) * dash_velocity; + } else { + dash_velocity = dv_r; + t = upto_r + (the_point.y - r_tr) * dash_velocity; + } + } + } + } + float dash_length = dash_length_per_width / dash_period_per_width; + float desired_dash_gap = dash_gap_per_width / dash_period_per_width; + + // Straight borders should start and end with a dash, so max_t is + // reduced to cause this. + max_t -= unrounded ? dash_length : 0.0; + if (max_t >= 1.0) { + // Adjust dash gap to evenly divide max_t + float dash_count = floor(max_t); + float dash_period = max_t / dash_count; + border_color.a *= dash_alpha(t, dash_period, dash_length, dash_velocity, antialias_threshold); + } else if (unrounded) { + // When there isn't enough space for the full gap between the + // two start / end dashes of a straight border, reduce gap to + // make them fit. + float dash_gap = max_t - dash_length; + if (dash_gap > 0.0) { + float dash_period = dash_length + dash_gap; + border_color.a *= dash_alpha(t, dash_period, dash_length, dash_velocity, antialias_threshold); + } + } + } + + // Blend the border on top of the background and then linearly interpolate + // between the two as we slide inside the background. + float4 blended_border = over(background_color, border_color); + color = lerp(background_color, blended_border, + saturate(antialias_threshold - inner_sdf)); + } + + return color * float4(1.0, 1.0, 1.0, saturate(antialias_threshold - outer_sdf)); +} + +/* +** +** Shadows +** +*/ + +struct Shadow { + uint order; + float blur_radius; + Bounds bounds; + Corners corner_radii; + Bounds content_mask; + Hsla color; +}; + +struct ShadowVertexOutput { + nointerpolation uint shadow_id: TEXCOORD0; + float4 position: SV_Position; + nointerpolation float4 color: COLOR; + float4 clip_distance: SV_ClipDistance; +}; + +struct ShadowFragmentInput { + nointerpolation uint shadow_id: TEXCOORD0; + float4 position: SV_Position; + nointerpolation float4 color: COLOR; +}; + +StructuredBuffer<Shadow> shadows: register(t1); + +ShadowVertexOutput shadow_vertex(uint vertex_id: SV_VertexID, uint shadow_id: SV_InstanceID) { + float2 unit_vertex = float2(float(vertex_id & 1u), 0.5 * float(vertex_id & 2u)); + Shadow shadow = shadows[shadow_id]; + + float margin = 3.0 * shadow.blur_radius; + Bounds bounds = shadow.bounds; + bounds.origin -= margin; + bounds.size += 2.0 * margin; + + float4 device_position = to_device_position(unit_vertex, bounds); + float4 clip_distance = distance_from_clip_rect(unit_vertex, bounds, shadow.content_mask); + float4 color = hsla_to_rgba(shadow.color); + + ShadowVertexOutput output; + output.position = device_position; + output.color = color; + output.shadow_id = shadow_id; + output.clip_distance = clip_distance; + + return output; +} + +float4 shadow_fragment(ShadowFragmentInput input): SV_TARGET { + Shadow shadow = shadows[input.shadow_id]; + + float2 half_size = shadow.bounds.size / 2.; + float2 center = shadow.bounds.origin + half_size; + float2 point0 = input.position.xy - center; + float corner_radius = pick_corner_radius(point0, shadow.corner_radii); + + // The signal is only non-zero in a limited range, so don't waste samples + float low = point0.y - half_size.y; + float high = point0.y + half_size.y; + float start = clamp(-3. * shadow.blur_radius, low, high); + float end = clamp(3. * shadow.blur_radius, low, high); + + // Accumulate samples (we can get away with surprisingly few samples) + float step = (end - start) / 4.; + float y = start + step * 0.5; + float alpha = 0.; + for (int i = 0; i < 4; i++) { + alpha += blur_along_x(point0.x, point0.y - y, shadow.blur_radius, + corner_radius, half_size) * + gaussian(y, shadow.blur_radius) * step; + y += step; + } + + return input.color * float4(1., 1., 1., alpha); +} + +/* +** +** Path Rasterization +** +*/ + +struct PathRasterizationSprite { + float2 xy_position; + float2 st_position; + Background color; + Bounds bounds; +}; + +StructuredBuffer<PathRasterizationSprite> path_rasterization_sprites: register(t1); + +struct PathVertexOutput { + float4 position: SV_Position; + float2 st_position: TEXCOORD0; + nointerpolation uint vertex_id: TEXCOORD1; + float4 clip_distance: SV_ClipDistance; +}; + +struct PathFragmentInput { + float4 position: SV_Position; + float2 st_position: TEXCOORD0; + nointerpolation uint vertex_id: TEXCOORD1; +}; + +PathVertexOutput path_rasterization_vertex(uint vertex_id: SV_VertexID) { + PathRasterizationSprite sprite = path_rasterization_sprites[vertex_id]; + + PathVertexOutput output; + output.position = to_device_position_impl(sprite.xy_position); + output.st_position = sprite.st_position; + output.vertex_id = vertex_id; + output.clip_distance = distance_from_clip_rect_impl(sprite.xy_position, sprite.bounds); + + return output; +} + +float4 path_rasterization_fragment(PathFragmentInput input): SV_Target { + float2 dx = ddx(input.st_position); + float2 dy = ddy(input.st_position); + PathRasterizationSprite sprite = path_rasterization_sprites[input.vertex_id]; + + Background background = sprite.color; + Bounds bounds = sprite.bounds; + + float alpha; + if (length(float2(dx.x, dy.x))) { + alpha = 1.0; + } else { + float2 gradient = 2.0 * input.st_position.xx * float2(dx.x, dy.x) - float2(dx.y, dy.y); + float f = input.st_position.x * input.st_position.x - input.st_position.y; + float distance = f / length(gradient); + alpha = saturate(0.5 - distance); + } + + GradientColor gradient = prepare_gradient_color( + background.tag, background.color_space, background.solid, background.colors); + + float4 color = gradient_color(background, input.position.xy, bounds, + gradient.solid, gradient.color0, gradient.color1); + return float4(color.rgb * color.a * alpha, alpha * color.a); +} + +/* +** +** Path Sprites +** +*/ + +struct PathSprite { + Bounds bounds; +}; + +struct PathSpriteVertexOutput { + float4 position: SV_Position; + float2 texture_coords: TEXCOORD0; +}; + +StructuredBuffer<PathSprite> path_sprites: register(t1); + +PathSpriteVertexOutput path_sprite_vertex(uint vertex_id: SV_VertexID, uint sprite_id: SV_InstanceID) { + float2 unit_vertex = float2(float(vertex_id & 1u), 0.5 * float(vertex_id & 2u)); + PathSprite sprite = path_sprites[sprite_id]; + + // Don't apply content mask because it was already accounted for when rasterizing the path + float4 device_position = to_device_position(unit_vertex, sprite.bounds); + + float2 screen_position = sprite.bounds.origin + unit_vertex * sprite.bounds.size; + float2 texture_coords = screen_position / global_viewport_size; + + PathSpriteVertexOutput output; + output.position = device_position; + output.texture_coords = texture_coords; + return output; +} + +float4 path_sprite_fragment(PathSpriteVertexOutput input): SV_Target { + return t_sprite.Sample(s_sprite, input.texture_coords); +} + +/* +** +** Underlines +** +*/ + +struct Underline { + uint order; + uint pad; + Bounds bounds; + Bounds content_mask; + Hsla color; + float thickness; + uint wavy; +}; + +struct UnderlineVertexOutput { + nointerpolation uint underline_id: TEXCOORD0; + float4 position: SV_Position; + nointerpolation float4 color: COLOR; + float4 clip_distance: SV_ClipDistance; +}; + +struct UnderlineFragmentInput { + nointerpolation uint underline_id: TEXCOORD0; + float4 position: SV_Position; + nointerpolation float4 color: COLOR; +}; + +StructuredBuffer<Underline> underlines: register(t1); + +UnderlineVertexOutput underline_vertex(uint vertex_id: SV_VertexID, uint underline_id: SV_InstanceID) { + float2 unit_vertex = float2(float(vertex_id & 1u), 0.5 * float(vertex_id & 2u)); + Underline underline = underlines[underline_id]; + float4 device_position = to_device_position(unit_vertex, underline.bounds); + float4 clip_distance = distance_from_clip_rect(unit_vertex, underline.bounds, + underline.content_mask); + float4 color = hsla_to_rgba(underline.color); + + UnderlineVertexOutput output; + output.position = device_position; + output.color = color; + output.underline_id = underline_id; + output.clip_distance = clip_distance; + return output; +} + +float4 underline_fragment(UnderlineFragmentInput input): SV_Target { + Underline underline = underlines[input.underline_id]; + if (underline.wavy) { + float half_thickness = underline.thickness * 0.5; + float2 origin = underline.bounds.origin; + float2 st = ((input.position.xy - origin) / underline.bounds.size.y) - float2(0., 0.5); + float frequency = (M_PI_F * (3. * underline.thickness)) / 8.; + float amplitude = 1. / (2. * underline.thickness); + float sine = sin(st.x * frequency) * amplitude; + float dSine = cos(st.x * frequency) * amplitude * frequency; + float distance = (st.y - sine) / sqrt(1. + dSine * dSine); + float distance_in_pixels = distance * underline.bounds.size.y; + float distance_from_top_border = distance_in_pixels - half_thickness; + float distance_from_bottom_border = distance_in_pixels + half_thickness; + float alpha = saturate( + 0.5 - max(-distance_from_bottom_border, distance_from_top_border)); + return input.color * float4(1., 1., 1., alpha); + } else { + return input.color; + } +} + +/* +** +** Monochrome sprites +** +*/ + +struct MonochromeSprite { + uint order; + uint pad; + Bounds bounds; + Bounds content_mask; + Hsla color; + AtlasTile tile; + TransformationMatrix transformation; +}; + +struct MonochromeSpriteVertexOutput { + float4 position: SV_Position; + float2 tile_position: POSITION; + nointerpolation float4 color: COLOR; + float4 clip_distance: SV_ClipDistance; +}; + +struct MonochromeSpriteFragmentInput { + float4 position: SV_Position; + float2 tile_position: POSITION; + nointerpolation float4 color: COLOR; + float4 clip_distance: SV_ClipDistance; +}; + +StructuredBuffer<MonochromeSprite> mono_sprites: register(t1); + +MonochromeSpriteVertexOutput monochrome_sprite_vertex(uint vertex_id: SV_VertexID, uint sprite_id: SV_InstanceID) { + float2 unit_vertex = float2(float(vertex_id & 1u), 0.5 * float(vertex_id & 2u)); + MonochromeSprite sprite = mono_sprites[sprite_id]; + float4 device_position = + to_device_position_transformed(unit_vertex, sprite.bounds, sprite.transformation); + float4 clip_distance = distance_from_clip_rect(unit_vertex, sprite.bounds, sprite.content_mask); + float2 tile_position = to_tile_position(unit_vertex, sprite.tile); + float4 color = hsla_to_rgba(sprite.color); + + MonochromeSpriteVertexOutput output; + output.position = device_position; + output.tile_position = tile_position; + output.color = color; + output.clip_distance = clip_distance; + return output; +} + +float4 monochrome_sprite_fragment(MonochromeSpriteFragmentInput input): SV_Target { + float sample = t_sprite.Sample(s_sprite, input.tile_position).r; + return float4(input.color.rgb, input.color.a * sample); +} + +/* +** +** Polychrome sprites +** +*/ + +struct PolychromeSprite { + uint order; + uint pad; + uint grayscale; + float opacity; + Bounds bounds; + Bounds content_mask; + Corners corner_radii; + AtlasTile tile; +}; + +struct PolychromeSpriteVertexOutput { + nointerpolation uint sprite_id: TEXCOORD0; + float4 position: SV_Position; + float2 tile_position: POSITION; + float4 clip_distance: SV_ClipDistance; +}; + +struct PolychromeSpriteFragmentInput { + nointerpolation uint sprite_id: TEXCOORD0; + float4 position: SV_Position; + float2 tile_position: POSITION; +}; + +StructuredBuffer<PolychromeSprite> poly_sprites: register(t1); + +PolychromeSpriteVertexOutput polychrome_sprite_vertex(uint vertex_id: SV_VertexID, uint sprite_id: SV_InstanceID) { + float2 unit_vertex = float2(float(vertex_id & 1u), 0.5 * float(vertex_id & 2u)); + PolychromeSprite sprite = poly_sprites[sprite_id]; + float4 device_position = to_device_position(unit_vertex, sprite.bounds); + float4 clip_distance = distance_from_clip_rect(unit_vertex, sprite.bounds, + sprite.content_mask); + float2 tile_position = to_tile_position(unit_vertex, sprite.tile); + + PolychromeSpriteVertexOutput output; + output.position = device_position; + output.tile_position = tile_position; + output.sprite_id = sprite_id; + output.clip_distance = clip_distance; + return output; +} + +float4 polychrome_sprite_fragment(PolychromeSpriteFragmentInput input): SV_Target { + PolychromeSprite sprite = poly_sprites[input.sprite_id]; + float4 sample = t_sprite.Sample(s_sprite, input.tile_position); + float distance = quad_sdf(input.position.xy, sprite.bounds, sprite.corner_radii); + + float4 color = sample; + if ((sprite.grayscale & 0xFFu) != 0u) { + float3 grayscale = dot(color.rgb, GRAYSCALE_FACTORS); + color = float4(grayscale, sample.a); + } + color.a *= sprite.opacity * saturate(0.5 - distance); + return color; +} diff --git a/crates/gpui/src/platform/windows/window.rs b/crates/gpui/src/platform/windows/window.rs index 5703a82815eb0679ca3668a13c08f3e9affa3696..32a6da23915d1e2bdf61c662e364119e9c6a8c64 100644 --- a/crates/gpui/src/platform/windows/window.rs +++ b/crates/gpui/src/platform/windows/window.rs @@ -26,10 +26,9 @@ use windows::{ core::*, }; -use crate::platform::blade::{BladeContext, BladeRenderer}; use crate::*; -pub(crate) struct WindowsWindow(pub Rc<WindowsWindowStatePtr>); +pub(crate) struct WindowsWindow(pub Rc<WindowsWindowInner>); pub struct WindowsWindowState { pub origin: Point<Pixels>, @@ -49,7 +48,7 @@ pub struct WindowsWindowState { pub system_key_handled: bool, pub hovered: bool, - pub renderer: BladeRenderer, + pub renderer: DirectXRenderer, pub click_state: ClickState, pub system_settings: WindowsSystemSettings, @@ -62,9 +61,9 @@ pub struct WindowsWindowState { hwnd: HWND, } -pub(crate) struct WindowsWindowStatePtr { +pub(crate) struct WindowsWindowInner { hwnd: HWND, - this: Weak<Self>, + pub(super) this: Weak<Self>, drop_target_helper: IDropTargetHelper, pub(crate) state: RefCell<WindowsWindowState>, pub(crate) handle: AnyWindowHandle, @@ -80,21 +79,23 @@ pub(crate) struct WindowsWindowStatePtr { impl WindowsWindowState { fn new( hwnd: HWND, - transparent: bool, - cs: &CREATESTRUCTW, + window_params: &CREATESTRUCTW, current_cursor: Option<HCURSOR>, display: WindowsDisplay, - gpu_context: &BladeContext, min_size: Option<Size<Pixels>>, appearance: WindowAppearance, + disable_direct_composition: bool, ) -> Result<Self> { let scale_factor = { let monitor_dpi = unsafe { GetDpiForWindow(hwnd) } as f32; monitor_dpi / USER_DEFAULT_SCREEN_DPI as f32 }; - let origin = logical_point(cs.x as f32, cs.y as f32, scale_factor); + let origin = logical_point(window_params.x as f32, window_params.y as f32, scale_factor); let logical_size = { - let physical_size = size(DevicePixels(cs.cx), DevicePixels(cs.cy)); + let physical_size = size( + DevicePixels(window_params.cx), + DevicePixels(window_params.cy), + ); physical_size.to_pixels(scale_factor) }; let fullscreen_restore_bounds = Bounds { @@ -103,7 +104,8 @@ impl WindowsWindowState { }; let border_offset = WindowBorderOffset::default(); let restore_from_minimized = None; - let renderer = windows_renderer::init(gpu_context, hwnd, transparent)?; + let renderer = DirectXRenderer::new(hwnd, disable_direct_composition) + .context("Creating DirectX renderer")?; let callbacks = Callbacks::default(); let input_handler = None; let pending_surrogate = None; @@ -202,17 +204,16 @@ impl WindowsWindowState { } } -impl WindowsWindowStatePtr { +impl WindowsWindowInner { fn new(context: &WindowCreateContext, hwnd: HWND, cs: &CREATESTRUCTW) -> Result<Rc<Self>> { let state = RefCell::new(WindowsWindowState::new( hwnd, - context.transparent, cs, context.current_cursor, context.display, - context.gpu_context, context.min_size, context.appearance, + context.disable_direct_composition, )?); Ok(Rc::new_cyclic(|this| Self { @@ -232,13 +233,13 @@ impl WindowsWindowStatePtr { } fn toggle_fullscreen(&self) { - let Some(state_ptr) = self.this.upgrade() else { + let Some(this) = self.this.upgrade() else { log::error!("Unable to toggle fullscreen: window has been dropped"); return; }; self.executor .spawn(async move { - let mut lock = state_ptr.state.borrow_mut(); + let mut lock = this.state.borrow_mut(); let StyleAndBounds { style, x, @@ -250,10 +251,9 @@ impl WindowsWindowStatePtr { } else { let (window_bounds, _) = lock.calculate_window_bounds(); lock.fullscreen_restore_bounds = window_bounds; - let style = - WINDOW_STYLE(unsafe { get_window_long(state_ptr.hwnd, GWL_STYLE) } as _); + let style = WINDOW_STYLE(unsafe { get_window_long(this.hwnd, GWL_STYLE) } as _); let mut rc = RECT::default(); - unsafe { GetWindowRect(state_ptr.hwnd, &mut rc) }.log_err(); + unsafe { GetWindowRect(this.hwnd, &mut rc) }.log_err(); let _ = lock.fullscreen.insert(StyleAndBounds { style, x: rc.left, @@ -277,10 +277,10 @@ impl WindowsWindowStatePtr { } }; drop(lock); - unsafe { set_window_long(state_ptr.hwnd, GWL_STYLE, style.0 as isize) }; + unsafe { set_window_long(this.hwnd, GWL_STYLE, style.0 as isize) }; unsafe { SetWindowPos( - state_ptr.hwnd, + this.hwnd, None, x, y, @@ -329,12 +329,11 @@ pub(crate) struct Callbacks { pub(crate) appearance_changed: Option<Box<dyn FnMut()>>, } -struct WindowCreateContext<'a> { - inner: Option<Result<Rc<WindowsWindowStatePtr>>>, +struct WindowCreateContext { + inner: Option<Result<Rc<WindowsWindowInner>>>, handle: AnyWindowHandle, hide_title_bar: bool, display: WindowsDisplay, - transparent: bool, is_movable: bool, min_size: Option<Size<Pixels>>, executor: ForegroundExecutor, @@ -343,9 +342,9 @@ struct WindowCreateContext<'a> { drop_target_helper: IDropTargetHelper, validation_number: usize, main_receiver: flume::Receiver<Runnable>, - gpu_context: &'a BladeContext, main_thread_id_win32: u32, appearance: WindowAppearance, + disable_direct_composition: bool, } impl WindowsWindow { @@ -353,7 +352,6 @@ impl WindowsWindow { handle: AnyWindowHandle, params: WindowParams, creation_info: WindowCreationInfo, - gpu_context: &BladeContext, ) -> Result<Self> { let WindowCreationInfo { icon, @@ -364,14 +362,15 @@ impl WindowsWindow { validation_number, main_receiver, main_thread_id_win32, + disable_direct_composition, } = creation_info; - let classname = register_wnd_class(icon); + register_window_class(icon); let hide_title_bar = params .titlebar .as_ref() .map(|titlebar| titlebar.appears_transparent) .unwrap_or(true); - let windowname = HSTRING::from( + let window_name = HSTRING::from( params .titlebar .as_ref() @@ -379,14 +378,18 @@ impl WindowsWindow { .map(|title| title.as_ref()) .unwrap_or(""), ); - let (dwexstyle, mut dwstyle) = if params.kind == WindowKind::PopUp { - (WS_EX_TOOLWINDOW | WS_EX_LAYERED, WINDOW_STYLE(0x0)) + + let (mut dwexstyle, dwstyle) = if params.kind == WindowKind::PopUp { + (WS_EX_TOOLWINDOW, WINDOW_STYLE(0x0)) } else { ( - WS_EX_APPWINDOW | WS_EX_LAYERED, + WS_EX_APPWINDOW, WS_THICKFRAME | WS_SYSMENU | WS_MAXIMIZEBOX | WS_MINIMIZEBOX, ) }; + if !disable_direct_composition { + dwexstyle |= WS_EX_NOREDIRECTIONBITMAP; + } let hinstance = get_module_handle(); let display = if let Some(display_id) = params.display_id { @@ -401,7 +404,6 @@ impl WindowsWindow { handle, hide_title_bar, display, - transparent: true, is_movable: params.is_movable, min_size: params.window_min_size, executor, @@ -410,16 +412,15 @@ impl WindowsWindow { drop_target_helper, validation_number, main_receiver, - gpu_context, main_thread_id_win32, appearance, + disable_direct_composition, }; - let lpparam = Some(&context as *const _ as *const _); let creation_result = unsafe { CreateWindowExW( dwexstyle, - classname, - &windowname, + WINDOW_CLASS_NAME, + &window_name, dwstyle, CW_USEDEFAULT, CW_USEDEFAULT, @@ -428,41 +429,35 @@ impl WindowsWindow { None, None, Some(hinstance.into()), - lpparam, + Some(&context as *const _ as *const _), ) }; - // We should call `?` on state_ptr first, then call `?` on hwnd. - // Or, we will lose the error info reported by `WindowsWindowState::new` - let state_ptr = context.inner.take().unwrap()?; + + // Failure to create a `WindowsWindowState` can cause window creation to fail, + // so check the inner result first. + let this = context.inner.take().unwrap()?; let hwnd = creation_result?; - register_drag_drop(state_ptr.clone())?; + + register_drag_drop(&this)?; configure_dwm_dark_mode(hwnd, appearance); - state_ptr.state.borrow_mut().border_offset.update(hwnd)?; + this.state.borrow_mut().border_offset.update(hwnd)?; let placement = retrieve_window_placement( hwnd, display, params.bounds, - state_ptr.state.borrow().scale_factor, - state_ptr.state.borrow().border_offset, + this.state.borrow().scale_factor, + this.state.borrow().border_offset, )?; if params.show { unsafe { SetWindowPlacement(hwnd, &placement)? }; } else { - state_ptr.state.borrow_mut().initial_placement = Some(WindowOpenStatus { + this.state.borrow_mut().initial_placement = Some(WindowOpenStatus { placement, state: WindowOpenState::Windowed, }); } - // The render pipeline will perform compositing on the GPU when the - // swapchain is configured correctly (see downstream of - // update_transparency). - // The following configuration is a one-time setup to ensure that the - // window is going to be composited with per-pixel alpha, but the render - // pipeline is responsible for effectively calling UpdateLayeredWindow - // at the appropriate time. - unsafe { SetLayeredWindowAttributes(hwnd, COLORREF(0), 255, LWA_ALPHA)? }; - Ok(Self(state_ptr)) + Ok(Self(this)) } } @@ -485,7 +480,6 @@ impl rwh::HasDisplayHandle for WindowsWindow { impl Drop for WindowsWindow { fn drop(&mut self) { - self.0.state.borrow_mut().renderer.destroy(); // clone this `Rc` to prevent early release of the pointer let this = self.0.clone(); self.0 @@ -683,6 +677,36 @@ impl PlatformWindow for WindowsWindow { this.set_window_placement().log_err(); unsafe { SetActiveWindow(hwnd).log_err() }; unsafe { SetFocus(Some(hwnd)).log_err() }; + + // premium ragebait by windows, this is needed because the window + // must have received an input event to be able to set itself to foreground + // so let's just simulate user input as that seems to be the most reliable way + // some more info: https://gist.github.com/Aetopia/1581b40f00cc0cadc93a0e8ccb65dc8c + // bonus: this bug also doesn't manifest if you have vs attached to the process + let inputs = [ + INPUT { + r#type: INPUT_KEYBOARD, + Anonymous: INPUT_0 { + ki: KEYBDINPUT { + wVk: VK_MENU, + dwFlags: KEYBD_EVENT_FLAGS(0), + ..Default::default() + }, + }, + }, + INPUT { + r#type: INPUT_KEYBOARD, + Anonymous: INPUT_0 { + ki: KEYBDINPUT { + wVk: VK_MENU, + dwFlags: KEYEVENTF_KEYUP, + ..Default::default() + }, + }, + }, + ]; + unsafe { SendInput(&inputs, std::mem::size_of::<INPUT>() as i32) }; + // todo(windows) // crate `windows 0.56` reports true as Err unsafe { SetForegroundWindow(hwnd).as_bool() }; @@ -705,24 +729,21 @@ impl PlatformWindow for WindowsWindow { } fn set_background_appearance(&self, background_appearance: WindowBackgroundAppearance) { - let mut window_state = self.0.state.borrow_mut(); - window_state - .renderer - .update_transparency(background_appearance != WindowBackgroundAppearance::Opaque); + let hwnd = self.0.hwnd; match background_appearance { WindowBackgroundAppearance::Opaque => { // ACCENT_DISABLED - set_window_composition_attribute(window_state.hwnd, None, 0); + set_window_composition_attribute(hwnd, None, 0); } WindowBackgroundAppearance::Transparent => { // Use ACCENT_ENABLE_TRANSPARENTGRADIENT for transparent background - set_window_composition_attribute(window_state.hwnd, None, 2); + set_window_composition_attribute(hwnd, None, 2); } WindowBackgroundAppearance::Blurred => { // Enable acrylic blur // ACCENT_ENABLE_ACRYLICBLURBEHIND - set_window_composition_attribute(window_state.hwnd, Some((0, 0, 0, 0)), 4); + set_window_composition_attribute(hwnd, Some((0, 0, 0, 0)), 4); } } } @@ -794,11 +815,11 @@ impl PlatformWindow for WindowsWindow { } fn draw(&self, scene: &Scene) { - self.0.state.borrow_mut().renderer.draw(scene) + self.0.state.borrow_mut().renderer.draw(scene).log_err(); } fn sprite_atlas(&self) -> Arc<dyn PlatformAtlas> { - self.0.state.borrow().renderer.sprite_atlas().clone() + self.0.state.borrow().renderer.sprite_atlas() } fn get_raw_handle(&self) -> HWND { @@ -806,16 +827,16 @@ impl PlatformWindow for WindowsWindow { } fn gpu_specs(&self) -> Option<GpuSpecs> { - Some(self.0.state.borrow().renderer.gpu_specs()) + self.0.state.borrow().renderer.gpu_specs().log_err() } fn update_ime_position(&self, _bounds: Bounds<ScaledPixels>) { - // todo(windows) + // There is no such thing on Windows. } } #[implement(IDropTarget)] -struct WindowsDragDropHandler(pub Rc<WindowsWindowStatePtr>); +struct WindowsDragDropHandler(pub Rc<WindowsWindowInner>); impl WindowsDragDropHandler { fn handle_drag_drop(&self, input: PlatformInput) { @@ -1096,15 +1117,15 @@ enum WindowOpenState { Windowed, } -fn register_wnd_class(icon_handle: HICON) -> PCWSTR { - const CLASS_NAME: PCWSTR = w!("Zed::Window"); +const WINDOW_CLASS_NAME: PCWSTR = w!("Zed::Window"); +fn register_window_class(icon_handle: HICON) { static ONCE: Once = Once::new(); ONCE.call_once(|| { let wc = WNDCLASSW { - lpfnWndProc: Some(wnd_proc), + lpfnWndProc: Some(window_procedure), hIcon: icon_handle, - lpszClassName: PCWSTR(CLASS_NAME.as_ptr()), + lpszClassName: PCWSTR(WINDOW_CLASS_NAME.as_ptr()), style: CS_HREDRAW | CS_VREDRAW, hInstance: get_module_handle().into(), hbrBackground: unsafe { CreateSolidBrush(COLORREF(0x00000000)) }, @@ -1112,54 +1133,58 @@ fn register_wnd_class(icon_handle: HICON) -> PCWSTR { }; unsafe { RegisterClassW(&wc) }; }); - - CLASS_NAME } -unsafe extern "system" fn wnd_proc( +unsafe extern "system" fn window_procedure( hwnd: HWND, msg: u32, wparam: WPARAM, lparam: LPARAM, ) -> LRESULT { if msg == WM_NCCREATE { - let cs = lparam.0 as *const CREATESTRUCTW; - let cs = unsafe { &*cs }; - let ctx = cs.lpCreateParams as *mut WindowCreateContext; - let ctx = unsafe { &mut *ctx }; - let creation_result = WindowsWindowStatePtr::new(ctx, hwnd, cs); - if creation_result.is_err() { - ctx.inner = Some(creation_result); - return LRESULT(0); - } - let weak = Box::new(Rc::downgrade(creation_result.as_ref().unwrap())); - unsafe { set_window_long(hwnd, GWLP_USERDATA, Box::into_raw(weak) as isize) }; - ctx.inner = Some(creation_result); - return unsafe { DefWindowProcW(hwnd, msg, wparam, lparam) }; + let window_params = lparam.0 as *const CREATESTRUCTW; + let window_params = unsafe { &*window_params }; + let window_creation_context = window_params.lpCreateParams as *mut WindowCreateContext; + let window_creation_context = unsafe { &mut *window_creation_context }; + return match WindowsWindowInner::new(window_creation_context, hwnd, window_params) { + Ok(window_state) => { + let weak = Box::new(Rc::downgrade(&window_state)); + unsafe { set_window_long(hwnd, GWLP_USERDATA, Box::into_raw(weak) as isize) }; + window_creation_context.inner = Some(Ok(window_state)); + unsafe { DefWindowProcW(hwnd, msg, wparam, lparam) } + } + Err(error) => { + window_creation_context.inner = Some(Err(error)); + LRESULT(0) + } + }; } - let ptr = unsafe { get_window_long(hwnd, GWLP_USERDATA) } as *mut Weak<WindowsWindowStatePtr>; + + let ptr = unsafe { get_window_long(hwnd, GWLP_USERDATA) } as *mut Weak<WindowsWindowInner>; if ptr.is_null() { return unsafe { DefWindowProcW(hwnd, msg, wparam, lparam) }; } let inner = unsafe { &*ptr }; - let r = if let Some(state) = inner.upgrade() { - handle_msg(hwnd, msg, wparam, lparam, state) + let result = if let Some(inner) = inner.upgrade() { + inner.handle_msg(hwnd, msg, wparam, lparam) } else { unsafe { DefWindowProcW(hwnd, msg, wparam, lparam) } }; + if msg == WM_NCDESTROY { unsafe { set_window_long(hwnd, GWLP_USERDATA, 0) }; unsafe { drop(Box::from_raw(ptr)) }; } - r + + result } -pub(crate) fn try_get_window_inner(hwnd: HWND) -> Option<Rc<WindowsWindowStatePtr>> { +pub(crate) fn window_from_hwnd(hwnd: HWND) -> Option<Rc<WindowsWindowInner>> { if hwnd.is_invalid() { return None; } - let ptr = unsafe { get_window_long(hwnd, GWLP_USERDATA) } as *mut Weak<WindowsWindowStatePtr>; + let ptr = unsafe { get_window_long(hwnd, GWLP_USERDATA) } as *mut Weak<WindowsWindowInner>; if !ptr.is_null() { let inner = unsafe { &*ptr }; inner.upgrade() @@ -1182,9 +1207,9 @@ fn get_module_handle() -> HMODULE { } } -fn register_drag_drop(state_ptr: Rc<WindowsWindowStatePtr>) -> Result<()> { - let window_handle = state_ptr.hwnd; - let handler = WindowsDragDropHandler(state_ptr); +fn register_drag_drop(window: &Rc<WindowsWindowInner>) -> Result<()> { + let window_handle = window.hwnd; + let handler = WindowsDragDropHandler(window.clone()); // The lifetime of `IDropTarget` is handled by Windows, it won't release until // we call `RevokeDragDrop`. // So, it's safe to drop it here. @@ -1306,52 +1331,6 @@ fn set_window_composition_attribute(hwnd: HWND, color: Option<Color>, state: u32 } } -mod windows_renderer { - use crate::platform::blade::{BladeContext, BladeRenderer, BladeSurfaceConfig}; - use raw_window_handle as rwh; - use std::num::NonZeroIsize; - use windows::Win32::{Foundation::HWND, UI::WindowsAndMessaging::GWLP_HINSTANCE}; - - use crate::{get_window_long, show_error}; - - pub(super) fn init( - context: &BladeContext, - hwnd: HWND, - transparent: bool, - ) -> anyhow::Result<BladeRenderer> { - let raw = RawWindow { hwnd }; - let config = BladeSurfaceConfig { - size: Default::default(), - transparent, - }; - BladeRenderer::new(context, &raw, config) - .inspect_err(|err| show_error("Failed to initialize BladeRenderer", err.to_string())) - } - - struct RawWindow { - hwnd: HWND, - } - - impl rwh::HasWindowHandle for RawWindow { - fn window_handle(&self) -> Result<rwh::WindowHandle<'_>, rwh::HandleError> { - Ok(unsafe { - let hwnd = NonZeroIsize::new_unchecked(self.hwnd.0 as isize); - let mut handle = rwh::Win32WindowHandle::new(hwnd); - let hinstance = get_window_long(self.hwnd, GWLP_HINSTANCE); - handle.hinstance = NonZeroIsize::new(hinstance); - rwh::WindowHandle::borrow_raw(handle.into()) - }) - } - } - - impl rwh::HasDisplayHandle for RawWindow { - fn display_handle(&self) -> Result<rwh::DisplayHandle<'_>, rwh::HandleError> { - let handle = rwh::WindowsDisplayHandle::new(); - Ok(unsafe { rwh::DisplayHandle::borrow_raw(handle.into()) }) - } - } -} - #[cfg(test)] mod tests { use super::ClickState; diff --git a/crates/gpui/src/platform/windows/wrapper.rs b/crates/gpui/src/platform/windows/wrapper.rs index 6015dffdab299754d9469684ead08a9d5c95d4c6..a1fe98a392fcb4e776a7f0603d080da2ea8d2136 100644 --- a/crates/gpui/src/platform/windows/wrapper.rs +++ b/crates/gpui/src/platform/windows/wrapper.rs @@ -1,28 +1,6 @@ use std::ops::Deref; -use windows::Win32::{Foundation::HANDLE, UI::WindowsAndMessaging::HCURSOR}; - -#[derive(Debug, Clone, Copy)] -pub(crate) struct SafeHandle { - raw: HANDLE, -} - -unsafe impl Send for SafeHandle {} -unsafe impl Sync for SafeHandle {} - -impl From<HANDLE> for SafeHandle { - fn from(value: HANDLE) -> Self { - SafeHandle { raw: value } - } -} - -impl Deref for SafeHandle { - type Target = HANDLE; - - fn deref(&self) -> &Self::Target { - &self.raw - } -} +use windows::Win32::UI::WindowsAndMessaging::HCURSOR; #[derive(Debug, Clone, Copy)] pub(crate) struct SafeCursor { diff --git a/crates/gpui/src/scene.rs b/crates/gpui/src/scene.rs index 681444a4737867bb78ac8081958a4cf1af4f6771..c527dfe750beb2d19ed6750e48f71208ec2720bf 100644 --- a/crates/gpui/src/scene.rs +++ b/crates/gpui/src/scene.rs @@ -6,9 +6,14 @@ use serde::{Deserialize, Serialize}; use crate::{ AtlasTextureId, AtlasTile, Background, Bounds, ContentMask, Corners, Edges, Hsla, Pixels, - Point, Radians, ScaledPixels, Size, bounds_tree::BoundsTree, + Point, Radians, ScaledPixels, Size, bounds_tree::BoundsTree, point, +}; +use std::{ + fmt::Debug, + iter::Peekable, + ops::{Add, Range, Sub}, + slice, }; -use std::{fmt::Debug, iter::Peekable, ops::Range, slice}; #[allow(non_camel_case_types, unused)] pub(crate) type PathVertex_ScaledPixels = PathVertex<ScaledPixels>; @@ -43,11 +48,6 @@ impl Scene { self.surfaces.clear(); } - #[allow(dead_code)] - pub fn paths(&self) -> &[Path<ScaledPixels>] { - &self.paths - } - pub fn len(&self) -> usize { self.paint_operations.len() } @@ -675,7 +675,7 @@ pub(crate) struct PathId(pub(crate) usize); #[derive(Clone, Debug)] pub struct Path<P: Clone + Debug + Default + PartialEq> { pub(crate) id: PathId, - order: DrawOrder, + pub(crate) order: DrawOrder, pub(crate) bounds: Bounds<P>, pub(crate) content_mask: ContentMask<P>, pub(crate) vertices: Vec<PathVertex<P>>, @@ -683,7 +683,6 @@ pub struct Path<P: Clone + Debug + Default + PartialEq> { start: Point<P>, current: Point<P>, contour_count: usize, - base_scale: f32, } impl Path<Pixels> { @@ -702,35 +701,25 @@ impl Path<Pixels> { content_mask: Default::default(), color: Default::default(), contour_count: 0, - base_scale: 1.0, } } - /// Set the base scale of the path. - pub fn scale(mut self, factor: f32) -> Self { - self.base_scale = factor; - self - } - - /// Apply a scale to the path. - pub(crate) fn apply_scale(&self, factor: f32) -> Path<ScaledPixels> { + /// Scale this path by the given factor. + pub fn scale(&self, factor: f32) -> Path<ScaledPixels> { Path { id: self.id, order: self.order, - bounds: self.bounds.scale(self.base_scale * factor), - content_mask: self.content_mask.scale(self.base_scale * factor), + bounds: self.bounds.scale(factor), + content_mask: self.content_mask.scale(factor), vertices: self .vertices .iter() - .map(|vertex| vertex.scale(self.base_scale * factor)) + .map(|vertex| vertex.scale(factor)) .collect(), - start: self - .start - .map(|start| start.scale(self.base_scale * factor)), - current: self.current.scale(self.base_scale * factor), + start: self.start.map(|start| start.scale(factor)), + current: self.current.scale(factor), contour_count: self.contour_count, color: self.color, - base_scale: 1.0, } } @@ -745,7 +734,10 @@ impl Path<Pixels> { pub fn line_to(&mut self, to: Point<Pixels>) { self.contour_count += 1; if self.contour_count > 1 { - self.push_triangle((self.start, self.current, to)); + self.push_triangle( + (self.start, self.current, to), + (point(0., 1.), point(0., 1.), point(0., 1.)), + ); } self.current = to; } @@ -754,15 +746,25 @@ impl Path<Pixels> { pub fn curve_to(&mut self, to: Point<Pixels>, ctrl: Point<Pixels>) { self.contour_count += 1; if self.contour_count > 1 { - self.push_triangle((self.start, self.current, to)); + self.push_triangle( + (self.start, self.current, to), + (point(0., 1.), point(0., 1.), point(0., 1.)), + ); } - self.push_triangle((self.current, ctrl, to)); + self.push_triangle( + (self.current, ctrl, to), + (point(0., 0.), point(0.5, 0.), point(1., 1.)), + ); self.current = to; } /// Push a triangle to the Path. - pub fn push_triangle(&mut self, xy: (Point<Pixels>, Point<Pixels>, Point<Pixels>)) { + pub fn push_triangle( + &mut self, + xy: (Point<Pixels>, Point<Pixels>, Point<Pixels>), + st: (Point<f32>, Point<f32>, Point<f32>), + ) { self.bounds = self .bounds .union(&Bounds { @@ -780,19 +782,32 @@ impl Path<Pixels> { self.vertices.push(PathVertex { xy_position: xy.0, + st_position: st.0, content_mask: Default::default(), }); self.vertices.push(PathVertex { xy_position: xy.1, + st_position: st.1, content_mask: Default::default(), }); self.vertices.push(PathVertex { xy_position: xy.2, + st_position: st.2, content_mask: Default::default(), }); } } +impl<T> Path<T> +where + T: Clone + Debug + Default + PartialEq + PartialOrd + Add<T, Output = T> + Sub<Output = T>, +{ + #[allow(unused)] + pub(crate) fn clipped_bounds(&self) -> Bounds<T> { + self.bounds.intersect(&self.content_mask.bounds) + } +} + impl From<Path<ScaledPixels>> for Primitive { fn from(path: Path<ScaledPixels>) -> Self { Primitive::Path(path) @@ -803,6 +818,7 @@ impl From<Path<ScaledPixels>> for Primitive { #[repr(C)] pub(crate) struct PathVertex<P: Clone + Debug + Default + PartialEq> { pub(crate) xy_position: Point<P>, + pub(crate) st_position: Point<f32>, pub(crate) content_mask: ContentMask<P>, } @@ -810,6 +826,7 @@ impl PathVertex<Pixels> { pub fn scale(&self, factor: f32) -> PathVertex<ScaledPixels> { PathVertex { xy_position: self.xy_position.scale(factor), + st_position: self.st_position, content_mask: self.content_mask.scale(factor), } } diff --git a/crates/gpui/src/svg_renderer.rs b/crates/gpui/src/svg_renderer.rs index 08d281b850ca80a370130e9f364d6ecb5334a1ce..0107624bc8d0e6a26c6acc4a085cbddc7e14c4c5 100644 --- a/crates/gpui/src/svg_renderer.rs +++ b/crates/gpui/src/svg_renderer.rs @@ -27,7 +27,7 @@ pub enum SvgSize { impl SvgRenderer { pub fn new(asset_source: Arc<dyn AssetSource>) -> Self { - let font_db = LazyLock::new(|| { + static FONT_DB: LazyLock<Arc<usvg::fontdb::Database>> = LazyLock::new(|| { let mut db = usvg::fontdb::Database::new(); db.load_system_fonts(); Arc::new(db) @@ -36,7 +36,7 @@ impl SvgRenderer { let font_resolver = Box::new( move |font: &usvg::Font, db: &mut Arc<usvg::fontdb::Database>| { if db.is_empty() { - *db = font_db.clone(); + *db = FONT_DB.clone(); } default_font_resolver(font, db) }, diff --git a/crates/gpui/src/tab_stop.rs b/crates/gpui/src/tab_stop.rs new file mode 100644 index 0000000000000000000000000000000000000000..7dde42efed8a138de3a29657683d95c60e27dda0 --- /dev/null +++ b/crates/gpui/src/tab_stop.rs @@ -0,0 +1,161 @@ +use crate::{FocusHandle, FocusId}; + +/// Represents a collection of tab handles. +/// +/// Used to manage the `Tab` event to switch between focus handles. +#[derive(Default)] +pub(crate) struct TabHandles { + pub(crate) handles: Vec<FocusHandle>, +} + +impl TabHandles { + pub(crate) fn insert(&mut self, focus_handle: &FocusHandle) { + if !focus_handle.tab_stop { + return; + } + + let focus_handle = focus_handle.clone(); + + // Insert handle with same tab_index last + if let Some(ix) = self + .handles + .iter() + .position(|tab| tab.tab_index > focus_handle.tab_index) + { + self.handles.insert(ix, focus_handle); + } else { + self.handles.push(focus_handle); + } + } + + pub(crate) fn clear(&mut self) { + self.handles.clear(); + } + + fn current_index(&self, focused_id: Option<&FocusId>) -> Option<usize> { + self.handles.iter().position(|h| Some(&h.id) == focused_id) + } + + pub(crate) fn next(&self, focused_id: Option<&FocusId>) -> Option<FocusHandle> { + let next_ix = self + .current_index(focused_id) + .and_then(|ix| { + let next_ix = ix + 1; + (next_ix < self.handles.len()).then_some(next_ix) + }) + .unwrap_or_default(); + + if let Some(next_handle) = self.handles.get(next_ix) { + Some(next_handle.clone()) + } else { + None + } + } + + pub(crate) fn prev(&self, focused_id: Option<&FocusId>) -> Option<FocusHandle> { + let ix = self.current_index(focused_id).unwrap_or_default(); + let prev_ix; + if ix == 0 { + prev_ix = self.handles.len().saturating_sub(1); + } else { + prev_ix = ix.saturating_sub(1); + } + + if let Some(prev_handle) = self.handles.get(prev_ix) { + Some(prev_handle.clone()) + } else { + None + } + } +} + +#[cfg(test)] +mod tests { + use crate::{FocusHandle, FocusMap, TabHandles}; + use std::sync::Arc; + + #[test] + fn test_tab_handles() { + let focus_map = Arc::new(FocusMap::default()); + let mut tab = TabHandles::default(); + + let focus_handles = vec![ + FocusHandle::new(&focus_map).tab_stop(true).tab_index(0), + FocusHandle::new(&focus_map).tab_stop(true).tab_index(1), + FocusHandle::new(&focus_map).tab_stop(true).tab_index(1), + FocusHandle::new(&focus_map), + FocusHandle::new(&focus_map).tab_index(2), + FocusHandle::new(&focus_map).tab_stop(true).tab_index(0), + FocusHandle::new(&focus_map).tab_stop(true).tab_index(2), + ]; + + for handle in focus_handles.iter() { + tab.insert(&handle); + } + assert_eq!( + tab.handles + .iter() + .map(|handle| handle.id) + .collect::<Vec<_>>(), + vec![ + focus_handles[0].id, + focus_handles[5].id, + focus_handles[1].id, + focus_handles[2].id, + focus_handles[6].id, + ] + ); + + // Select first tab index if no handle is currently focused. + assert_eq!(tab.next(None), Some(tab.handles[0].clone())); + // Select last tab index if no handle is currently focused. + assert_eq!( + tab.prev(None), + Some(tab.handles[tab.handles.len() - 1].clone()) + ); + + assert_eq!( + tab.next(Some(&tab.handles[0].id)), + Some(tab.handles[1].clone()) + ); + assert_eq!( + tab.next(Some(&tab.handles[1].id)), + Some(tab.handles[2].clone()) + ); + assert_eq!( + tab.next(Some(&tab.handles[2].id)), + Some(tab.handles[3].clone()) + ); + assert_eq!( + tab.next(Some(&tab.handles[3].id)), + Some(tab.handles[4].clone()) + ); + assert_eq!( + tab.next(Some(&tab.handles[4].id)), + Some(tab.handles[0].clone()) + ); + + // prev + assert_eq!(tab.prev(None), Some(tab.handles[4].clone())); + assert_eq!( + tab.prev(Some(&tab.handles[0].id)), + Some(tab.handles[4].clone()) + ); + assert_eq!( + tab.prev(Some(&tab.handles[1].id)), + Some(tab.handles[0].clone()) + ); + assert_eq!( + tab.prev(Some(&tab.handles[2].id)), + Some(tab.handles[1].clone()) + ); + assert_eq!( + tab.prev(Some(&tab.handles[3].id)), + Some(tab.handles[2].clone()) + ); + assert_eq!( + tab.prev(Some(&tab.handles[4].id)), + Some(tab.handles[3].clone()) + ); + } +} diff --git a/crates/gpui/src/taffy.rs b/crates/gpui/src/taffy.rs index f12c62d504395a2afbf698685a4eb3cc5f0e4e1f..f7fa54256df20b38170ecb4d3e48c22913e44ae6 100644 --- a/crates/gpui/src/taffy.rs +++ b/crates/gpui/src/taffy.rs @@ -182,7 +182,7 @@ impl TaffyLayoutEngine { .compute_layout_with_measure( id.into(), available_space.into(), - |known_dimensions, available_space, _id, node_context| { + |known_dimensions, available_space, _id, node_context, _style| { let Some(node_context) = node_context else { return taffy::geometry::Size::default(); }; @@ -283,7 +283,7 @@ impl ToTaffy<taffy::style::LengthPercentageAuto> for Length { fn to_taffy(&self, rem_size: Pixels) -> taffy::prelude::LengthPercentageAuto { match self { Length::Definite(length) => length.to_taffy(rem_size), - Length::Auto => taffy::prelude::LengthPercentageAuto::Auto, + Length::Auto => taffy::prelude::LengthPercentageAuto::auto(), } } } @@ -292,7 +292,7 @@ impl ToTaffy<taffy::style::Dimension> for Length { fn to_taffy(&self, rem_size: Pixels) -> taffy::prelude::Dimension { match self { Length::Definite(length) => length.to_taffy(rem_size), - Length::Auto => taffy::prelude::Dimension::Auto, + Length::Auto => taffy::prelude::Dimension::auto(), } } } @@ -302,14 +302,14 @@ impl ToTaffy<taffy::style::LengthPercentage> for DefiniteLength { match self { DefiniteLength::Absolute(length) => match length { AbsoluteLength::Pixels(pixels) => { - taffy::style::LengthPercentage::Length(pixels.into()) + taffy::style::LengthPercentage::length(pixels.into()) } AbsoluteLength::Rems(rems) => { - taffy::style::LengthPercentage::Length((*rems * rem_size).into()) + taffy::style::LengthPercentage::length((*rems * rem_size).into()) } }, DefiniteLength::Fraction(fraction) => { - taffy::style::LengthPercentage::Percent(*fraction) + taffy::style::LengthPercentage::percent(*fraction) } } } @@ -320,14 +320,14 @@ impl ToTaffy<taffy::style::LengthPercentageAuto> for DefiniteLength { match self { DefiniteLength::Absolute(length) => match length { AbsoluteLength::Pixels(pixels) => { - taffy::style::LengthPercentageAuto::Length(pixels.into()) + taffy::style::LengthPercentageAuto::length(pixels.into()) } AbsoluteLength::Rems(rems) => { - taffy::style::LengthPercentageAuto::Length((*rems * rem_size).into()) + taffy::style::LengthPercentageAuto::length((*rems * rem_size).into()) } }, DefiniteLength::Fraction(fraction) => { - taffy::style::LengthPercentageAuto::Percent(*fraction) + taffy::style::LengthPercentageAuto::percent(*fraction) } } } @@ -337,12 +337,12 @@ impl ToTaffy<taffy::style::Dimension> for DefiniteLength { fn to_taffy(&self, rem_size: Pixels) -> taffy::style::Dimension { match self { DefiniteLength::Absolute(length) => match length { - AbsoluteLength::Pixels(pixels) => taffy::style::Dimension::Length(pixels.into()), + AbsoluteLength::Pixels(pixels) => taffy::style::Dimension::length(pixels.into()), AbsoluteLength::Rems(rems) => { - taffy::style::Dimension::Length((*rems * rem_size).into()) + taffy::style::Dimension::length((*rems * rem_size).into()) } }, - DefiniteLength::Fraction(fraction) => taffy::style::Dimension::Percent(*fraction), + DefiniteLength::Fraction(fraction) => taffy::style::Dimension::percent(*fraction), } } } @@ -350,9 +350,9 @@ impl ToTaffy<taffy::style::Dimension> for DefiniteLength { impl ToTaffy<taffy::style::LengthPercentage> for AbsoluteLength { fn to_taffy(&self, rem_size: Pixels) -> taffy::style::LengthPercentage { match self { - AbsoluteLength::Pixels(pixels) => taffy::style::LengthPercentage::Length(pixels.into()), + AbsoluteLength::Pixels(pixels) => taffy::style::LengthPercentage::length(pixels.into()), AbsoluteLength::Rems(rems) => { - taffy::style::LengthPercentage::Length((*rems * rem_size).into()) + taffy::style::LengthPercentage::length((*rems * rem_size).into()) } } } diff --git a/crates/gpui/src/window.rs b/crates/gpui/src/window.rs index e9145bd9f5181da662a882f0f12e340e34d4822f..40d3845ff9600f9dcd16fe55b0b27d5e3eddce09 100644 --- a/crates/gpui/src/window.rs +++ b/crates/gpui/src/window.rs @@ -12,10 +12,11 @@ use crate::{ PlatformInputHandler, PlatformWindow, Point, PolychromeSprite, PromptButton, PromptLevel, Quad, Render, RenderGlyphParams, RenderImage, RenderImageParams, RenderSvgParams, Replay, ResizeEdge, SMOOTH_SVG_SCALE_FACTOR, SUBPIXEL_VARIANTS, ScaledPixels, Scene, Shadow, SharedString, Size, - StrikethroughStyle, Style, SubscriberSet, Subscription, TaffyLayoutEngine, Task, TextStyle, - TextStyleRefinement, TransformationMatrix, Underline, UnderlineStyle, WindowAppearance, - WindowBackgroundAppearance, WindowBounds, WindowControls, WindowDecorations, WindowOptions, - WindowParams, WindowTextSystem, point, prelude::*, px, rems, size, transparent_black, + StrikethroughStyle, Style, SubscriberSet, Subscription, TabHandles, TaffyLayoutEngine, Task, + TextStyle, TextStyleRefinement, TransformationMatrix, Underline, UnderlineStyle, + WindowAppearance, WindowBackgroundAppearance, WindowBounds, WindowControls, WindowDecorations, + WindowOptions, WindowParams, WindowTextSystem, point, prelude::*, px, rems, size, + transparent_black, }; use anyhow::{Context as _, Result, anyhow}; use collections::{FxHashMap, FxHashSet}; @@ -78,11 +79,13 @@ pub enum DispatchPhase { impl DispatchPhase { /// Returns true if this represents the "bubble" phase. + #[inline] pub fn bubble(self) -> bool { self == DispatchPhase::Bubble } /// Returns true if this represents the "capture" phase. + #[inline] pub fn capture(self) -> bool { self == DispatchPhase::Capture } @@ -222,7 +225,12 @@ impl ArenaClearNeeded { } } -pub(crate) type FocusMap = RwLock<SlotMap<FocusId, AtomicUsize>>; +pub(crate) type FocusMap = RwLock<SlotMap<FocusId, FocusRef>>; +pub(crate) struct FocusRef { + pub(crate) ref_count: AtomicUsize, + pub(crate) tab_index: isize, + pub(crate) tab_stop: bool, +} impl FocusId { /// Obtains whether the element associated with this handle is currently focused. @@ -258,6 +266,10 @@ impl FocusId { pub struct FocusHandle { pub(crate) id: FocusId, handles: Arc<FocusMap>, + /// The index of this element in the tab order. + pub tab_index: isize, + /// Whether this element can be focused by tab navigation. + pub tab_stop: bool, } impl std::fmt::Debug for FocusHandle { @@ -268,25 +280,54 @@ impl std::fmt::Debug for FocusHandle { impl FocusHandle { pub(crate) fn new(handles: &Arc<FocusMap>) -> Self { - let id = handles.write().insert(AtomicUsize::new(1)); + let id = handles.write().insert(FocusRef { + ref_count: AtomicUsize::new(1), + tab_index: 0, + tab_stop: false, + }); + Self { id, + tab_index: 0, + tab_stop: false, handles: handles.clone(), } } pub(crate) fn for_id(id: FocusId, handles: &Arc<FocusMap>) -> Option<Self> { let lock = handles.read(); - let ref_count = lock.get(id)?; - if atomic_incr_if_not_zero(ref_count) == 0 { + let focus = lock.get(id)?; + if atomic_incr_if_not_zero(&focus.ref_count) == 0 { return None; } Some(Self { id, + tab_index: focus.tab_index, + tab_stop: focus.tab_stop, handles: handles.clone(), }) } + /// Sets the tab index of the element associated with this handle. + pub fn tab_index(mut self, index: isize) -> Self { + self.tab_index = index; + if let Some(focus) = self.handles.write().get_mut(self.id) { + focus.tab_index = index; + } + self + } + + /// Sets whether the element associated with this handle is a tab stop. + /// + /// When `false`, the element will not be included in the tab order. + pub fn tab_stop(mut self, tab_stop: bool) -> Self { + self.tab_stop = tab_stop; + if let Some(focus) = self.handles.write().get_mut(self.id) { + focus.tab_stop = tab_stop; + } + self + } + /// Converts this focus handle into a weak variant, which does not prevent it from being released. pub fn downgrade(&self) -> WeakFocusHandle { WeakFocusHandle { @@ -354,6 +395,7 @@ impl Drop for FocusHandle { .read() .get(self.id) .unwrap() + .ref_count .fetch_sub(1, SeqCst); } } @@ -642,6 +684,7 @@ pub(crate) struct Frame { pub(crate) next_inspector_instance_ids: FxHashMap<Rc<crate::InspectorElementPath>, usize>, #[cfg(any(feature = "inspector", debug_assertions))] pub(crate) inspector_hitboxes: FxHashMap<HitboxId, crate::InspectorElementId>, + pub(crate) tab_handles: TabHandles, } #[derive(Clone, Default)] @@ -661,6 +704,7 @@ pub(crate) struct PaintIndex { input_handlers_index: usize, cursor_styles_index: usize, accessed_element_states_index: usize, + tab_handle_index: usize, line_layout_index: LineLayoutIndex, } @@ -689,6 +733,7 @@ impl Frame { #[cfg(any(feature = "inspector", debug_assertions))] inspector_hitboxes: FxHashMap::default(), + tab_handles: TabHandles::default(), } } @@ -704,6 +749,7 @@ impl Frame { self.hitboxes.clear(); self.window_control_hitboxes.clear(); self.deferred_draws.clear(); + self.tab_handles.clear(); self.focus = None; #[cfg(any(feature = "inspector", debug_assertions))] @@ -976,7 +1022,7 @@ impl Window { || (active.get() && last_input_timestamp.get().elapsed() < Duration::from_secs(1)); - if invalidator.is_dirty() { + if invalidator.is_dirty() || request_frame_options.force_render { measure("frame duration", || { handle .update(&mut cx, |_, window, cx| { @@ -1289,6 +1335,28 @@ impl Window { self.focus_enabled = false; } + /// Move focus to next tab stop. + pub fn focus_next(&mut self) { + if !self.focus_enabled { + return; + } + + if let Some(handle) = self.rendered_frame.tab_handles.next(self.focus.as_ref()) { + self.focus(&handle) + } + } + + /// Move focus to previous tab stop. + pub fn focus_prev(&mut self) { + if !self.focus_enabled { + return; + } + + if let Some(handle) = self.rendered_frame.tab_handles.prev(self.focus.as_ref()) { + self.focus(&handle) + } + } + /// Accessor for the text system. pub fn text_system(&self) -> &Arc<WindowTextSystem> { &self.text_system @@ -2143,6 +2211,7 @@ impl Window { input_handlers_index: self.next_frame.input_handlers.len(), cursor_styles_index: self.next_frame.cursor_styles.len(), accessed_element_states_index: self.next_frame.accessed_element_states.len(), + tab_handle_index: self.next_frame.tab_handles.handles.len(), line_layout_index: self.text_system.layout_index(), } } @@ -2172,6 +2241,12 @@ impl Window { .iter() .map(|(id, type_id)| (GlobalElementId(id.0.clone()), *type_id)), ); + self.next_frame.tab_handles.handles.extend( + self.rendered_frame.tab_handles.handles + [range.start.tab_handle_index..range.end.tab_handle_index] + .iter() + .cloned(), + ); self.text_system .reuse_layouts(range.start.line_layout_index..range.end.line_layout_index); @@ -2424,6 +2499,53 @@ impl Window { result } + /// Use a piece of state that exists as long this element is being rendered in consecutive frames. + pub fn use_keyed_state<S: 'static>( + &mut self, + key: impl Into<ElementId>, + cx: &mut App, + init: impl FnOnce(&mut Self, &mut App) -> S, + ) -> Entity<S> { + let current_view = self.current_view(); + self.with_global_id(key.into(), |global_id, window| { + window.with_element_state(global_id, |state: Option<Entity<S>>, window| { + if let Some(state) = state { + (state.clone(), state) + } else { + let new_state = cx.new(|cx| init(window, cx)); + cx.observe(&new_state, move |_, cx| { + cx.notify(current_view); + }) + .detach(); + (new_state.clone(), new_state) + } + }) + }) + } + + /// Immediately push an element ID onto the stack. Useful for simplifying IDs in lists + pub fn with_id<R>(&mut self, id: impl Into<ElementId>, f: impl FnOnce(&mut Self) -> R) -> R { + self.with_global_id(id.into(), |_, window| f(window)) + } + + /// Use a piece of state that exists as long this element is being rendered in consecutive frames, without needing to specify a key + /// + /// NOTE: This method uses the location of the caller to generate an ID for this state. + /// If this is not sufficient to identify your state (e.g. you're rendering a list item), + /// you can provide a custom ElementID using the `use_keyed_state` method. + #[track_caller] + pub fn use_state<S: 'static>( + &mut self, + cx: &mut App, + init: impl FnOnce(&mut Self, &mut App) -> S, + ) -> Entity<S> { + self.use_keyed_state( + ElementId::CodeLocation(*core::panic::Location::caller()), + cx, + init, + ) + } + /// Updates or initializes state for an element with the given id that lives across multiple /// frames. If an element with this ID existed in the rendered frame, its state will be passed /// to the given closure. The state returned by the closure will be stored so it can be referenced @@ -2658,7 +2780,7 @@ impl Window { path.color = color.opacity(opacity); self.next_frame .scene - .insert_primitive(path.apply_scale(scale_factor)); + .insert_primitive(path.scale(scale_factor)); } /// Paint an underline into the scene for the next frame at the current z-index. @@ -4126,6 +4248,25 @@ impl Window { .on_action(action_type, Rc::new(listener)); } + /// Register an action listener on the window for the next frame if the condition is true. + /// The type of action is determined by the first parameter of the given listener. + /// When the next frame is rendered the listener will be cleared. + /// + /// This is a fairly low-level method, so prefer using action handlers on elements unless you have + /// a specific need to register a global listener. + pub fn on_action_when( + &mut self, + condition: bool, + action_type: TypeId, + listener: impl Fn(&dyn Any, DispatchPhase, &mut Window, &mut App) + 'static, + ) { + if condition { + self.next_frame + .dispatch_tree + .on_action(action_type, Rc::new(listener)); + } + } + /// Read information about the GPU backing this window. /// Currently returns None on Mac and Windows. pub fn gpu_specs(&self) -> Option<GpuSpecs> { @@ -4577,6 +4718,10 @@ pub enum ElementId { NamedInteger(SharedString, u64), /// A path. Path(Arc<std::path::Path>), + /// A code location. + CodeLocation(core::panic::Location<'static>), + /// A labeled child of an element. + NamedChild(Box<ElementId>, SharedString), } impl ElementId { @@ -4596,6 +4741,8 @@ impl Display for ElementId { ElementId::NamedInteger(s, i) => write!(f, "{}-{}", s, i)?, ElementId::Uuid(uuid) => write!(f, "{}", uuid)?, ElementId::Path(path) => write!(f, "{}", path.display())?, + ElementId::CodeLocation(location) => write!(f, "{}", location)?, + ElementId::NamedChild(id, name) => write!(f, "{}-{}", id, name)?, } Ok(()) @@ -4686,6 +4833,12 @@ impl From<(&'static str, u32)> for ElementId { } } +impl<T: Into<SharedString>> From<(ElementId, T)> for ElementId { + fn from((id, name): (ElementId, T)) -> Self { + ElementId::NamedChild(Box::new(id), name.into()) + } +} + /// A rectangle to be rendered in the window at the given position and size. /// Passed as an argument [`Window::paint_quad`]. #[derive(Clone)] diff --git a/crates/gpui_macros/src/derive_app_context.rs b/crates/gpui_macros/src/derive_app_context.rs index bca015b8dc5ab43c0a6873f03251d68e2d0b592b..d2dc250d0239769f6834860a128c2653546a926e 100644 --- a/crates/gpui_macros/src/derive_app_context.rs +++ b/crates/gpui_macros/src/derive_app_context.rs @@ -53,6 +53,16 @@ pub fn derive_app_context(input: TokenStream) -> TokenStream { self.#app_variable.update_entity(handle, update) } + fn as_mut<'y, 'z, T>( + &'y mut self, + handle: &'z gpui::Entity<T>, + ) -> Self::Result<gpui::GpuiBorrow<'y, T>> + where + T: 'static, + { + self.#app_variable.as_mut(handle) + } + fn read_entity<T, R>( &self, handle: &gpui::Entity<T>, diff --git a/crates/html_to_markdown/src/markdown_writer.rs b/crates/html_to_markdown/src/markdown_writer.rs index a9caf7afa7a5a275ff4ecabaf38150f4fd2127a3..c32205ae7b349239d0a41a5a63c0793d958b4eea 100644 --- a/crates/html_to_markdown/src/markdown_writer.rs +++ b/crates/html_to_markdown/src/markdown_writer.rs @@ -119,8 +119,10 @@ impl MarkdownWriter { .push_back(current_element.clone()); } - for child in node.children.borrow().iter() { - self.visit_node(child, handlers)?; + if self.current_element_stack.len() < 200 { + for child in node.children.borrow().iter() { + self.visit_node(child, handlers)?; + } } if let Some(current_element) = current_element { diff --git a/crates/http_client/Cargo.toml b/crates/http_client/Cargo.toml index 2b114f240acc32f146f5d0fc4b70bbe150f655da..f63bff295e22c36512dbc6285e68d4686714f411 100644 --- a/crates/http_client/Cargo.toml +++ b/crates/http_client/Cargo.toml @@ -21,7 +21,10 @@ anyhow.workspace = true derive_more.workspace = true futures.workspace = true http.workspace = true +http-body.workspace = true log.workspace = true +parking_lot.workspace = true +reqwest.workspace = true serde.workspace = true serde_json.workspace = true url.workspace = true diff --git a/crates/http_client/src/async_body.rs b/crates/http_client/src/async_body.rs index caf8089d0f15d0ce818839bd0672e1a2d1419fc7..473849f3cdca785a802590a60cce922c9ee0b5f9 100644 --- a/crates/http_client/src/async_body.rs +++ b/crates/http_client/src/async_body.rs @@ -6,6 +6,7 @@ use std::{ use bytes::Bytes; use futures::AsyncRead; +use http_body::{Body, Frame}; /// Based on the implementation of AsyncBody in /// <https://github.com/sagebind/isahc/blob/5c533f1ef4d6bdf1fd291b5103c22110f41d0bf0/src/body/mod.rs>. @@ -87,6 +88,17 @@ impl From<&'static str> for AsyncBody { } } +impl TryFrom<reqwest::Body> for AsyncBody { + type Error = anyhow::Error; + + fn try_from(value: reqwest::Body) -> Result<Self, Self::Error> { + value + .as_bytes() + .ok_or_else(|| anyhow::anyhow!("Underlying data is a stream")) + .map(|bytes| Self::from_bytes(Bytes::copy_from_slice(bytes))) + } +} + impl<T: Into<Self>> From<Option<T>> for AsyncBody { fn from(body: Option<T>) -> Self { match body { @@ -114,3 +126,24 @@ impl futures::AsyncRead for AsyncBody { } } } + +impl Body for AsyncBody { + type Data = Bytes; + type Error = std::io::Error; + + fn poll_frame( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> { + let mut buffer = vec![0; 8192]; + match AsyncRead::poll_read(self.as_mut(), cx, &mut buffer) { + Poll::Ready(Ok(0)) => Poll::Ready(None), + Poll::Ready(Ok(n)) => { + let data = Bytes::copy_from_slice(&buffer[..n]); + Poll::Ready(Some(Ok(Frame::data(data)))) + } + Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))), + Poll::Pending => Poll::Pending, + } + } +} diff --git a/crates/http_client/src/github.rs b/crates/http_client/src/github.rs index a038915e2f4865f16308add594763c33d2230c42..a19c13b0ff5ef9fec3c1e6223b63c62ed84f03dc 100644 --- a/crates/http_client/src/github.rs +++ b/crates/http_client/src/github.rs @@ -8,6 +8,7 @@ use url::Url; pub struct GitHubLspBinaryVersion { pub name: String, pub url: String, + pub digest: Option<String>, } #[derive(Deserialize, Debug)] @@ -24,6 +25,7 @@ pub struct GithubRelease { pub struct GithubReleaseAsset { pub name: String, pub browser_download_url: String, + pub digest: Option<String>, } pub async fn latest_github_release( diff --git a/crates/http_client/src/http_client.rs b/crates/http_client/src/http_client.rs index c60a56002f5234f1085c65e25cfc0f1549fdbdb1..a7f75b0962561ac713e57f9ad26cb64ed82f8003 100644 --- a/crates/http_client/src/http_client.rs +++ b/crates/http_client/src/http_client.rs @@ -4,16 +4,18 @@ pub mod github; pub use anyhow::{Result, anyhow}; pub use async_body::{AsyncBody, Inner}; use derive_more::Deref; +use http::HeaderValue; pub use http::{self, Method, Request, Response, StatusCode, Uri}; -use futures::future::BoxFuture; +use futures::{ + FutureExt as _, + future::{self, BoxFuture}, +}; use http::request::Builder; +use parking_lot::Mutex; #[cfg(feature = "test-support")] use std::fmt; -use std::{ - any::type_name, - sync::{Arc, Mutex}, -}; +use std::{any::type_name, sync::Arc}; pub use url::Url; #[derive(Default, Debug, Clone, PartialEq, Eq, Hash)] @@ -39,6 +41,8 @@ impl HttpRequestExt for http::request::Builder { pub trait HttpClient: 'static + Send + Sync { fn type_name(&self) -> &'static str; + fn user_agent(&self) -> Option<&HeaderValue>; + fn send( &self, req: http::Request<AsyncBody>, @@ -83,6 +87,19 @@ pub trait HttpClient: 'static + Send + Sync { } fn proxy(&self) -> Option<&Url>; + + #[cfg(feature = "test-support")] + fn as_fake(&self) -> &FakeHttpClient { + panic!("called as_fake on {}", type_name::<Self>()) + } + + fn send_multipart_form<'a>( + &'a self, + _url: &str, + _request: reqwest::multipart::Form, + ) -> BoxFuture<'a, anyhow::Result<Response<AsyncBody>>> { + future::ready(Err(anyhow!("not implemented"))).boxed() + } } /// An [`HttpClient`] that may have a proxy. @@ -118,6 +135,10 @@ impl HttpClient for HttpClientWithProxy { self.client.send(req) } + fn user_agent(&self) -> Option<&HeaderValue> { + self.client.user_agent() + } + fn proxy(&self) -> Option<&Url> { self.proxy.as_ref() } @@ -125,22 +146,18 @@ impl HttpClient for HttpClientWithProxy { fn type_name(&self) -> &'static str { self.client.type_name() } -} - -impl HttpClient for Arc<HttpClientWithProxy> { - fn send( - &self, - req: Request<AsyncBody>, - ) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>> { - self.client.send(req) - } - fn proxy(&self) -> Option<&Url> { - self.proxy.as_ref() + #[cfg(feature = "test-support")] + fn as_fake(&self) -> &FakeHttpClient { + self.client.as_fake() } - fn type_name(&self) -> &'static str { - self.client.type_name() + fn send_multipart_form<'a>( + &'a self, + url: &str, + form: reqwest::multipart::Form, + ) -> BoxFuture<'a, anyhow::Result<Response<AsyncBody>>> { + self.client.send_multipart_form(url, form) } } @@ -188,20 +205,13 @@ impl HttpClientWithUrl { /// Returns the base URL. pub fn base_url(&self) -> String { - self.base_url - .lock() - .map_or_else(|_| Default::default(), |url| url.clone()) + self.base_url.lock().clone() } /// Sets the base URL. pub fn set_base_url(&self, base_url: impl Into<String>) { let base_url = base_url.into(); - self.base_url - .lock() - .map(|mut url| { - *url = base_url; - }) - .ok(); + *self.base_url.lock() = base_url; } /// Builds a URL using the given path. @@ -225,22 +235,27 @@ impl HttpClientWithUrl { )?) } + /// Builds a Zed Cloud URL using the given path. + pub fn build_zed_cloud_url(&self, path: &str, query: &[(&str, &str)]) -> Result<Url> { + let base_url = self.base_url(); + let base_api_url = match base_url.as_ref() { + "https://zed.dev" => "https://cloud.zed.dev", + "https://staging.zed.dev" => "https://cloud.zed.dev", + "http://localhost:3000" => "http://localhost:8787", + other => other, + }; + + Ok(Url::parse_with_params( + &format!("{}{}", base_api_url, path), + query, + )?) + } + /// Builds a Zed LLM URL using the given path. - pub fn build_zed_llm_url( - &self, - path: &str, - query: &[(&str, &str)], - use_cloud: bool, - ) -> Result<Url> { + pub fn build_zed_llm_url(&self, path: &str, query: &[(&str, &str)]) -> Result<Url> { let base_url = self.base_url(); let base_api_url = match base_url.as_ref() { - "https://zed.dev" => { - if use_cloud { - "https://cloud.zed.dev" - } else { - "https://llm.zed.dev" - } - } + "https://zed.dev" => "https://cloud.zed.dev", "https://staging.zed.dev" => "https://llm-staging.zed.dev", "http://localhost:3000" => "http://localhost:8787", other => other, @@ -253,7 +268,7 @@ impl HttpClientWithUrl { } } -impl HttpClient for Arc<HttpClientWithUrl> { +impl HttpClient for HttpClientWithUrl { fn send( &self, req: Request<AsyncBody>, @@ -261,6 +276,10 @@ impl HttpClient for Arc<HttpClientWithUrl> { self.client.send(req) } + fn user_agent(&self) -> Option<&HeaderValue> { + self.client.user_agent() + } + fn proxy(&self) -> Option<&Url> { self.client.proxy.as_ref() } @@ -268,22 +287,18 @@ impl HttpClient for Arc<HttpClientWithUrl> { fn type_name(&self) -> &'static str { self.client.type_name() } -} - -impl HttpClient for HttpClientWithUrl { - fn send( - &self, - req: Request<AsyncBody>, - ) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>> { - self.client.send(req) - } - fn proxy(&self) -> Option<&Url> { - self.client.proxy.as_ref() + #[cfg(feature = "test-support")] + fn as_fake(&self) -> &FakeHttpClient { + self.client.as_fake() } - fn type_name(&self) -> &'static str { - self.client.type_name() + fn send_multipart_form<'a>( + &'a self, + url: &str, + request: reqwest::multipart::Form, + ) -> BoxFuture<'a, anyhow::Result<Response<AsyncBody>>> { + self.client.send_multipart_form(url, request) } } @@ -325,6 +340,10 @@ impl HttpClient for BlockedHttpClient { }) } + fn user_agent(&self) -> Option<&HeaderValue> { + None + } + fn proxy(&self) -> Option<&Url> { None } @@ -332,10 +351,15 @@ impl HttpClient for BlockedHttpClient { fn type_name(&self) -> &'static str { type_name::<Self>() } + + #[cfg(feature = "test-support")] + fn as_fake(&self) -> &FakeHttpClient { + panic!("called as_fake on {}", type_name::<Self>()) + } } #[cfg(feature = "test-support")] -type FakeHttpHandler = Box< +type FakeHttpHandler = Arc< dyn Fn(Request<AsyncBody>) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>> + Send + Sync @@ -344,7 +368,8 @@ type FakeHttpHandler = Box< #[cfg(feature = "test-support")] pub struct FakeHttpClient { - handler: FakeHttpHandler, + handler: Mutex<Option<FakeHttpHandler>>, + user_agent: HeaderValue, } #[cfg(feature = "test-support")] @@ -358,7 +383,8 @@ impl FakeHttpClient { base_url: Mutex::new("http://test.example".into()), client: HttpClientWithProxy { client: Arc::new(Self { - handler: Box::new(move |req| Box::pin(handler(req))), + handler: Mutex::new(Some(Arc::new(move |req| Box::pin(handler(req))))), + user_agent: HeaderValue::from_static(type_name::<Self>()), }), proxy: None, }, @@ -382,6 +408,18 @@ impl FakeHttpClient { .unwrap()) }) } + + pub fn replace_handler<Fut, F>(&self, new_handler: F) + where + Fut: futures::Future<Output = anyhow::Result<Response<AsyncBody>>> + Send + 'static, + F: Fn(FakeHttpHandler, Request<AsyncBody>) -> Fut + Send + Sync + 'static, + { + let mut handler = self.handler.lock(); + let old_handler = handler.take().unwrap(); + *handler = Some(Arc::new(move |req| { + Box::pin(new_handler(old_handler.clone(), req)) + })); + } } #[cfg(feature = "test-support")] @@ -397,10 +435,14 @@ impl HttpClient for FakeHttpClient { &self, req: Request<AsyncBody>, ) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>> { - let future = (self.handler)(req); + let future = (self.handler.lock().as_ref().unwrap())(req); future } + fn user_agent(&self) -> Option<&HeaderValue> { + Some(&self.user_agent) + } + fn proxy(&self) -> Option<&Url> { None } @@ -408,4 +450,8 @@ impl HttpClient for FakeHttpClient { fn type_name(&self) -> &'static str { type_name::<Self>() } + + fn as_fake(&self) -> &FakeHttpClient { + self + } } diff --git a/crates/icons/src/icons.rs b/crates/icons/src/icons.rs index 76a04e036dc92d9ebbf848fc34bb21195862b99a..12805e62e061ccd7e91dc0210e68b87752f7ed6c 100644 --- a/crates/icons/src/icons.rs +++ b/crates/icons/src/icons.rs @@ -11,6 +11,7 @@ pub enum IconName { Ai, AiAnthropic, AiBedrock, + AiClaude, AiDeepSeek, AiEdit, AiGemini, @@ -19,8 +20,10 @@ pub enum IconName { AiMistral, AiOllama, AiOpenAi, + AiOpenAiCompat, AiOpenRouter, AiVZero, + AiXAi, AiZed, ArrowCircle, ArrowDown, @@ -35,7 +38,6 @@ pub enum IconName { ArrowUpFromLine, ArrowUpRight, ArrowUpRightAlt, - AtSign, AudioOff, AudioOn, Backspace, @@ -45,15 +47,13 @@ pub enum IconName { BellRing, Binary, Blocks, - Bolt, + BoltOutlined, BoltFilled, - BoltFilledAlt, Book, BookCopy, - BookPlus, - Brain, BugOff, CaseSensitive, + Chat, Check, CheckDouble, ChevronDown, @@ -68,6 +68,7 @@ pub enum IconName { CircleHelp, Close, Cloud, + CloudDownload, Code, Cog, Command, @@ -103,9 +104,16 @@ pub enum IconName { Disconnected, DocumentText, Download, + EditorAtom, + EditorCursor, + EditorEmacs, + EditorJetBrains, + EditorSublime, + EditorVsCode, Ellipsis, EllipsisVertical, Envelope, + Equal, Eraser, Escape, Exit, @@ -163,6 +171,7 @@ pub enum IconName { ListTree, ListX, LoadCircle, + LocationEdit, LockOutlined, LspDebug, LspRestart, @@ -172,10 +181,8 @@ pub enum IconName { Maximize, Menu, MenuAlt, - MessageBubbles, Mic, MicMute, - Microscope, Minimize, Option, PageDown, @@ -187,9 +194,8 @@ pub enum IconName { PersonCircle, PhoneIncoming, Pin, - Play, - PlayAlt, - PlayBug, + PlayOutlined, + PlayFilled, Plus, PocketKnife, Power, @@ -205,7 +211,6 @@ pub enum IconName { ReplyArrowRight, Rerun, Return, - Reveal, RotateCcw, RotateCw, Route, @@ -219,6 +224,7 @@ pub enum IconName { Server, Settings, SettingsAlt, + ShieldCheck, Shift, Slash, SlashSquare, @@ -229,7 +235,6 @@ pub enum IconName { Sparkle, SparkleAlt, SparkleFilled, - Spinner, Split, SplitAlt, SquareDot, @@ -239,7 +244,6 @@ pub enum IconName { StarFilled, Stop, StopFilled, - Strikethrough, Supermaven, SupermavenDisabled, SupermavenError, @@ -247,10 +251,16 @@ pub enum IconName { SwatchBook, Tab, Terminal, + TerminalAlt, TextSnippet, + TextThread, + Thread, + ThreadFromSummary, ThumbsDown, ThumbsUp, - ToolBulb, + TodoComplete, + TodoPending, + TodoProgress, ToolCopy, ToolDeleteFile, ToolDiagnostics, @@ -262,9 +272,9 @@ pub enum IconName { ToolRegex, ToolSearch, ToolTerminal, + ToolThink, ToolWeb, Trash, - TrashAlt, Triangle, TriangleRight, Undo, diff --git a/crates/inspector_ui/src/div_inspector.rs b/crates/inspector_ui/src/div_inspector.rs index 7d162bcc355b1c29f55a6cb001638809a707599b..bd395aa01bca42ce923073ee6f80472abc7820eb 100644 --- a/crates/inspector_ui/src/div_inspector.rs +++ b/crates/inspector_ui/src/div_inspector.rs @@ -1,5 +1,8 @@ use anyhow::{Result, anyhow}; -use editor::{Bias, CompletionProvider, Editor, EditorEvent, EditorMode, ExcerptId, MultiBuffer}; +use editor::{ + Bias, CompletionProvider, Editor, EditorEvent, EditorMode, ExcerptId, MinimapVisibility, + MultiBuffer, +}; use fuzzy::StringMatch; use gpui::{ AsyncWindowContext, DivInspectorState, Entity, InspectorElementId, IntoElement, @@ -499,6 +502,7 @@ impl DivInspector { editor.set_show_git_diff_gutter(false, cx); editor.set_show_runnables(false, cx); editor.set_show_edit_predictions(Some(false), window, cx); + editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx); editor }) } diff --git a/crates/language/Cargo.toml b/crates/language/Cargo.toml index 477b978517d56d0f70270a4bf413b285b455ca94..4ab56d6647db5246bf0af7343c8485d946c8b156 100644 --- a/crates/language/Cargo.toml +++ b/crates/language/Cargo.toml @@ -92,6 +92,7 @@ tree-sitter-python.workspace = true tree-sitter-ruby.workspace = true tree-sitter-rust.workspace = true tree-sitter-typescript.workspace = true +toml.workspace = true unindent.workspace = true util = { workspace = true, features = ["test-support"] } zlog.workspace = true diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index ae0184b22a97acfb2adf1080a352479fca2ab82e..83517accc239ecf9d2196f124fc5695a8545ef17 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -2072,6 +2072,21 @@ impl Buffer { self.text.push_transaction(transaction, now); } + /// Differs from `push_transaction` in that it does not clear the redo + /// stack. Intended to be used to create a parent transaction to merge + /// potential child transactions into. + /// + /// The caller is responsible for removing it from the undo history using + /// `forget_transaction` if no edits are merged into it. Otherwise, if edits + /// are merged into this transaction, the caller is responsible for ensuring + /// the redo stack is cleared. The easiest way to ensure the redo stack is + /// cleared is to create transactions with the usual `start_transaction` and + /// `end_transaction` methods and merging the resulting transactions into + /// the transaction created by this method + pub fn push_empty_transaction(&mut self, now: Instant) -> TransactionId { + self.text.push_empty_transaction(now) + } + /// Prevent the last transaction from being grouped with any subsequent transactions, /// even if they occur with the buffer's undo grouping duration. pub fn finalize_last_transaction(&mut self) -> Option<&Transaction> { @@ -3364,13 +3379,19 @@ impl BufferSnapshot { /// Returns a tuple of the range and character kind of the word /// surrounding the given position. - pub fn surrounding_word<T: ToOffset>(&self, start: T) -> (Range<usize>, Option<CharKind>) { + pub fn surrounding_word<T: ToOffset>( + &self, + start: T, + for_completion: bool, + ) -> (Range<usize>, Option<CharKind>) { let mut start = start.to_offset(self); let mut end = start; let mut next_chars = self.chars_at(start).take(128).peekable(); let mut prev_chars = self.reversed_chars_at(start).take(128).peekable(); - let classifier = self.char_classifier_at(start); + let classifier = self + .char_classifier_at(start) + .for_completion(for_completion); let word_kind = cmp::max( prev_chars.peek().copied().map(|c| classifier.kind(c)), next_chars.peek().copied().map(|c| classifier.kind(c)), diff --git a/crates/language/src/buffer_tests.rs b/crates/language/src/buffer_tests.rs index 6955cd054925076f8d2678eff58c44e0b82351d0..2e2df7e658596daaca3b338ef830794fd0d3bef8 100644 --- a/crates/language/src/buffer_tests.rs +++ b/crates/language/src/buffer_tests.rs @@ -2273,7 +2273,12 @@ fn test_language_scope_at_with_javascript(cx: &mut App) { LanguageConfig { name: "JavaScript".into(), line_comments: vec!["// ".into()], - block_comment: Some(("/*".into(), "*/".into())), + block_comment: Some(BlockCommentConfig { + start: "/*".into(), + end: "*/".into(), + prefix: "* ".into(), + tab_size: 1, + }), brackets: BracketPairConfig { pairs: vec![ BracketPair { @@ -2300,7 +2305,12 @@ fn test_language_scope_at_with_javascript(cx: &mut App) { "element".into(), LanguageConfigOverride { line_comments: Override::Remove { remove: true }, - block_comment: Override::Set(("{/*".into(), "*/}".into())), + block_comment: Override::Set(BlockCommentConfig { + start: "{/*".into(), + prefix: "".into(), + end: "*/}".into(), + tab_size: 0, + }), ..Default::default() }, )] @@ -2338,9 +2348,15 @@ fn test_language_scope_at_with_javascript(cx: &mut App) { let config = snapshot.language_scope_at(0).unwrap(); assert_eq!(config.line_comment_prefixes(), &[Arc::from("// ")]); assert_eq!( - config.block_comment_delimiters(), - Some((&"/*".into(), &"*/".into())) + config.block_comment(), + Some(&BlockCommentConfig { + start: "/*".into(), + prefix: "* ".into(), + end: "*/".into(), + tab_size: 1, + }) ); + // Both bracket pairs are enabled assert_eq!( config.brackets().map(|e| e.1).collect::<Vec<_>>(), @@ -2360,8 +2376,13 @@ fn test_language_scope_at_with_javascript(cx: &mut App) { .unwrap(); assert_eq!(string_config.line_comment_prefixes(), &[Arc::from("// ")]); assert_eq!( - string_config.block_comment_delimiters(), - Some((&"/*".into(), &"*/".into())) + string_config.block_comment(), + Some(&BlockCommentConfig { + start: "/*".into(), + prefix: "* ".into(), + end: "*/".into(), + tab_size: 1, + }) ); // Second bracket pair is disabled assert_eq!( @@ -2391,8 +2412,13 @@ fn test_language_scope_at_with_javascript(cx: &mut App) { .unwrap(); assert_eq!(tag_config.line_comment_prefixes(), &[Arc::from("// ")]); assert_eq!( - tag_config.block_comment_delimiters(), - Some((&"/*".into(), &"*/".into())) + tag_config.block_comment(), + Some(&BlockCommentConfig { + start: "/*".into(), + prefix: "* ".into(), + end: "*/".into(), + tab_size: 1, + }) ); assert_eq!( tag_config.brackets().map(|e| e.1).collect::<Vec<_>>(), @@ -2408,8 +2434,13 @@ fn test_language_scope_at_with_javascript(cx: &mut App) { &[Arc::from("// ")] ); assert_eq!( - expression_in_element_config.block_comment_delimiters(), - Some((&"/*".into(), &"*/".into())) + expression_in_element_config.block_comment(), + Some(&BlockCommentConfig { + start: "/*".into(), + prefix: "* ".into(), + end: "*/".into(), + tab_size: 1, + }) ); assert_eq!( expression_in_element_config @@ -2528,13 +2559,18 @@ fn test_language_scope_at_with_combined_injections(cx: &mut App) { let html_config = snapshot.language_scope_at(Point::new(2, 4)).unwrap(); assert_eq!(html_config.line_comment_prefixes(), &[]); assert_eq!( - html_config.block_comment_delimiters(), - Some((&"<!--".into(), &"-->".into())) + html_config.block_comment(), + Some(&BlockCommentConfig { + start: "<!--".into(), + end: "-->".into(), + prefix: "".into(), + tab_size: 0, + }) ); let ruby_config = snapshot.language_scope_at(Point::new(3, 12)).unwrap(); assert_eq!(ruby_config.line_comment_prefixes(), &[Arc::from("# ")]); - assert_eq!(ruby_config.block_comment_delimiters(), None); + assert_eq!(ruby_config.block_comment(), None); buffer }); @@ -3490,7 +3526,12 @@ fn html_lang() -> Language { Language::new( LanguageConfig { name: LanguageName::new("HTML"), - block_comment: Some(("<!--".into(), "-->".into())), + block_comment: Some(BlockCommentConfig { + start: "<!--".into(), + prefix: "".into(), + end: "-->".into(), + tab_size: 0, + }), ..Default::default() }, Some(tree_sitter_html::LANGUAGE.into()), @@ -3521,7 +3562,12 @@ fn erb_lang() -> Language { path_suffixes: vec!["erb".to_string()], ..Default::default() }, - block_comment: Some(("<%#".into(), "%>".into())), + block_comment: Some(BlockCommentConfig { + start: "<%#".into(), + prefix: "".into(), + end: "%>".into(), + tab_size: 0, + }), ..Default::default() }, Some(tree_sitter_embedded_template::LANGUAGE.into()), diff --git a/crates/language/src/diagnostic_set.rs b/crates/language/src/diagnostic_set.rs index 661e3ef217a65a24ca59e37f93b137c89ed31dc6..613c445652fbcfe87232afec559480ce943b15e3 100644 --- a/crates/language/src/diagnostic_set.rs +++ b/crates/language/src/diagnostic_set.rs @@ -158,17 +158,17 @@ impl DiagnosticSet { }); if reversed { - cursor.prev(buffer); + cursor.prev(); } else { - cursor.next(buffer); + cursor.next(); } iter::from_fn({ move || { if let Some(diagnostic) = cursor.item() { if reversed { - cursor.prev(buffer); + cursor.prev(); } else { - cursor.next(buffer); + cursor.next(); } Some(diagnostic.resolve(buffer)) } else { diff --git a/crates/language/src/language.rs b/crates/language/src/language.rs index 1ad057ff41eb3eef961d687d9e7ee097c0364c43..b9933dfcec36f1e8c5cb31271668a25b60020c8a 100644 --- a/crates/language/src/language.rs +++ b/crates/language/src/language.rs @@ -161,12 +161,11 @@ pub struct CachedLspAdapter { pub name: LanguageServerName, pub disk_based_diagnostic_sources: Vec<String>, pub disk_based_diagnostics_progress_token: Option<String>, - language_ids: HashMap<String, String>, + language_ids: HashMap<LanguageName, String>, pub adapter: Arc<dyn LspAdapter>, pub reinstall_attempt_count: AtomicU64, cached_binary: futures::lock::Mutex<Option<LanguageServerBinary>>, manifest_name: OnceLock<Option<ManifestName>>, - attach_kind: OnceLock<Attach>, } impl Debug for CachedLspAdapter { @@ -202,7 +201,6 @@ impl CachedLspAdapter { adapter, cached_binary: Default::default(), reinstall_attempt_count: AtomicU64::new(0), - attach_kind: Default::default(), manifest_name: Default::default(), }) } @@ -279,38 +277,25 @@ impl CachedLspAdapter { pub fn language_id(&self, language_name: &LanguageName) -> String { self.language_ids - .get(language_name.as_ref()) + .get(language_name) .cloned() .unwrap_or_else(|| language_name.lsp_id()) } + pub fn manifest_name(&self) -> Option<ManifestName> { self.manifest_name .get_or_init(|| self.adapter.manifest_name()) .clone() } - pub fn attach_kind(&self) -> Attach { - *self.attach_kind.get_or_init(|| self.adapter.attach_kind()) - } } +/// Determines what gets sent out as a workspace folders content #[derive(Clone, Copy, Debug, PartialEq)] -pub enum Attach { - /// Create a single language server instance per subproject root. - InstancePerRoot, - /// Use one shared language server instance for all subprojects within a project. - Shared, -} - -impl Attach { - pub fn root_path( - &self, - root_subproject_path: (WorktreeId, Arc<Path>), - ) -> (WorktreeId, Arc<Path>) { - match self { - Attach::InstancePerRoot => root_subproject_path, - Attach::Shared => (root_subproject_path.0, Arc::from(Path::new(""))), - } - } +pub enum WorkspaceFoldersContent { + /// Send out a single entry with the root of the workspace. + WorktreeRoot, + /// Send out a list of subproject roots. + SubprojectRoots, } /// [`LspAdapterDelegate`] allows [`LspAdapter]` implementations to interface with the application @@ -589,8 +574,8 @@ pub trait LspAdapter: 'static + Send + Sync { None } - fn language_ids(&self) -> HashMap<String, String> { - Default::default() + fn language_ids(&self) -> HashMap<LanguageName, String> { + HashMap::default() } /// Support custom initialize params. @@ -602,8 +587,11 @@ pub trait LspAdapter: 'static + Send + Sync { Ok(original) } - fn attach_kind(&self) -> Attach { - Attach::Shared + /// Determines whether a language server supports workspace folders. + /// + /// And does not trip over itself in the process. + fn workspace_folders_content(&self) -> WorkspaceFoldersContent { + WorkspaceFoldersContent::SubprojectRoots } fn manifest_name(&self) -> Option<ManifestName> { @@ -727,9 +715,12 @@ pub struct LanguageConfig { /// used for comment continuations on the next line, but only the first one is used for Editor::ToggleComments. #[serde(default)] pub line_comments: Vec<Arc<str>>, - /// Starting and closing characters of a block comment. + /// Delimiters and configuration for recognizing and formatting block comments. #[serde(default)] - pub block_comment: Option<(Arc<str>, Arc<str>)>, + pub block_comment: Option<BlockCommentConfig>, + /// Delimiters and configuration for recognizing and formatting documentation comments. + #[serde(default, alias = "documentation")] + pub documentation_comment: Option<BlockCommentConfig>, /// A list of additional regex patterns that should be treated as prefixes /// for creating boundaries during rewrapping, ensuring content from one /// prefixed section doesn't merge with another (e.g., markdown list items). @@ -774,10 +765,6 @@ pub struct LanguageConfig { /// A list of preferred debuggers for this language. #[serde(default)] pub debuggers: IndexSet<SharedString>, - /// Whether to treat documentation comment of this language differently by - /// auto adding prefix on new line, adjusting the indenting , etc. - #[serde(default)] - pub documentation: Option<DocumentationConfig>, } #[derive(Clone, Debug, Deserialize, Default, JsonSchema)] @@ -837,17 +824,56 @@ pub struct JsxTagAutoCloseConfig { pub erroneous_close_tag_name_node_name: Option<String>, } -/// The configuration for documentation block for this language. -#[derive(Clone, Deserialize, JsonSchema)] -pub struct DocumentationConfig { - /// A start tag of documentation block. +/// The configuration for block comments for this language. +#[derive(Clone, Debug, JsonSchema, PartialEq)] +pub struct BlockCommentConfig { + /// A start tag of block comment. pub start: Arc<str>, - /// A end tag of documentation block. + /// A end tag of block comment. pub end: Arc<str>, - /// A character to add as a prefix when a new line is added to a documentation block. + /// A character to add as a prefix when a new line is added to a block comment. pub prefix: Arc<str>, /// A indent to add for prefix and end line upon new line. - pub tab_size: NonZeroU32, + pub tab_size: u32, +} + +impl<'de> Deserialize<'de> for BlockCommentConfig { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + #[serde(untagged)] + enum BlockCommentConfigHelper { + New { + start: Arc<str>, + end: Arc<str>, + prefix: Arc<str>, + tab_size: u32, + }, + Old([Arc<str>; 2]), + } + + match BlockCommentConfigHelper::deserialize(deserializer)? { + BlockCommentConfigHelper::New { + start, + end, + prefix, + tab_size, + } => Ok(BlockCommentConfig { + start, + end, + prefix, + tab_size, + }), + BlockCommentConfigHelper::Old([start, end]) => Ok(BlockCommentConfig { + start, + end, + prefix: "".into(), + tab_size: 0, + }), + } + } } /// Represents a language for the given range. Some languages (e.g. HTML) @@ -864,7 +890,7 @@ pub struct LanguageConfigOverride { #[serde(default)] pub line_comments: Override<Vec<Arc<str>>>, #[serde(default)] - pub block_comment: Override<(Arc<str>, Arc<str>)>, + pub block_comment: Override<BlockCommentConfig>, #[serde(skip)] pub disabled_bracket_ixs: Vec<u16>, #[serde(default)] @@ -916,6 +942,7 @@ impl Default for LanguageConfig { autoclose_before: Default::default(), line_comments: Default::default(), block_comment: Default::default(), + documentation_comment: Default::default(), rewrap_prefixes: Default::default(), scope_opt_in_language_servers: Default::default(), overrides: Default::default(), @@ -929,7 +956,6 @@ impl Default for LanguageConfig { jsx_tag_auto_close: None, completion_query_characters: Default::default(), debuggers: Default::default(), - documentation: None, } } } @@ -1847,12 +1873,17 @@ impl LanguageScope { .map_or([].as_slice(), |e| e.as_slice()) } - pub fn block_comment_delimiters(&self) -> Option<(&Arc<str>, &Arc<str>)> { + /// Config for block comments for this language. + pub fn block_comment(&self) -> Option<&BlockCommentConfig> { Override::as_option( self.config_override().map(|o| &o.block_comment), self.language.config.block_comment.as_ref(), ) - .map(|e| (&e.0, &e.1)) + } + + /// Config for documentation-style block comments for this language. + pub fn documentation_comment(&self) -> Option<&BlockCommentConfig> { + self.language.config.documentation_comment.as_ref() } /// Returns additional regex patterns that act as prefix markers for creating @@ -1897,14 +1928,6 @@ impl LanguageScope { .unwrap_or(false) } - /// Returns config to documentation block for this language. - /// - /// Used for documentation styles that require a leading character on each line, - /// such as the asterisk in JSDoc, Javadoc, etc. - pub fn documentation(&self) -> Option<&DocumentationConfig> { - self.language.config.documentation.as_ref() - } - /// Returns a list of bracket pairs for a given language with an additional /// piece of information about whether the particular bracket pair is currently active for a given language. pub fn brackets(&self) -> impl Iterator<Item = (&BracketPair, bool)> { @@ -2299,6 +2322,7 @@ pub fn range_from_lsp(range: lsp::Range) -> Range<Unclipped<PointUtf16>> { mod tests { use super::*; use gpui::TestAppContext; + use pretty_assertions::assert_matches; #[gpui::test(iterations = 10)] async fn test_language_loading(cx: &mut TestAppContext) { @@ -2329,9 +2353,9 @@ mod tests { assert_eq!( languages.language_names(), &[ - "JSON".to_string(), - "Plain Text".to_string(), - "Rust".to_string(), + LanguageName::new("JSON"), + LanguageName::new("Plain Text"), + LanguageName::new("Rust"), ] ); @@ -2342,9 +2366,9 @@ mod tests { assert_eq!( languages.language_names(), &[ - "JSON".to_string(), - "Plain Text".to_string(), - "Rust".to_string(), + LanguageName::new("JSON"), + LanguageName::new("Plain Text"), + LanguageName::new("Rust"), ] ); @@ -2355,9 +2379,9 @@ mod tests { assert_eq!( languages.language_names(), &[ - "JSON".to_string(), - "Plain Text".to_string(), - "Rust".to_string(), + LanguageName::new("JSON"), + LanguageName::new("Plain Text"), + LanguageName::new("Rust"), ] ); @@ -2460,4 +2484,75 @@ mod tests { "LSP completion items with duplicate label and detail, should omit the detail" ); } + + #[test] + fn test_deserializing_comments_backwards_compat() { + // current version of `block_comment` and `documentation_comment` work + { + let config: LanguageConfig = ::toml::from_str( + r#" + name = "Foo" + block_comment = { start = "a", end = "b", prefix = "c", tab_size = 1 } + documentation_comment = { start = "d", end = "e", prefix = "f", tab_size = 2 } + "#, + ) + .unwrap(); + assert_matches!(config.block_comment, Some(BlockCommentConfig { .. })); + assert_matches!( + config.documentation_comment, + Some(BlockCommentConfig { .. }) + ); + + let block_config = config.block_comment.unwrap(); + assert_eq!(block_config.start.as_ref(), "a"); + assert_eq!(block_config.end.as_ref(), "b"); + assert_eq!(block_config.prefix.as_ref(), "c"); + assert_eq!(block_config.tab_size, 1); + + let doc_config = config.documentation_comment.unwrap(); + assert_eq!(doc_config.start.as_ref(), "d"); + assert_eq!(doc_config.end.as_ref(), "e"); + assert_eq!(doc_config.prefix.as_ref(), "f"); + assert_eq!(doc_config.tab_size, 2); + } + + // former `documentation` setting is read into `documentation_comment` + { + let config: LanguageConfig = ::toml::from_str( + r#" + name = "Foo" + documentation = { start = "a", end = "b", prefix = "c", tab_size = 1} + "#, + ) + .unwrap(); + assert_matches!( + config.documentation_comment, + Some(BlockCommentConfig { .. }) + ); + + let config = config.documentation_comment.unwrap(); + assert_eq!(config.start.as_ref(), "a"); + assert_eq!(config.end.as_ref(), "b"); + assert_eq!(config.prefix.as_ref(), "c"); + assert_eq!(config.tab_size, 1); + } + + // old block_comment format is read into BlockCommentConfig + { + let config: LanguageConfig = ::toml::from_str( + r#" + name = "Foo" + block_comment = ["a", "b"] + "#, + ) + .unwrap(); + assert_matches!(config.block_comment, Some(BlockCommentConfig { .. })); + + let config = config.block_comment.unwrap(); + assert_eq!(config.start.as_ref(), "a"); + assert_eq!(config.end.as_ref(), "b"); + assert_eq!(config.prefix.as_ref(), ""); + assert_eq!(config.tab_size, 0); + } + } } diff --git a/crates/language/src/language_registry.rs b/crates/language/src/language_registry.rs index ff17d6dd9a9d7bb250f15d358d11eb23ef8f188f..ea988e8098ec2a795e8c0a386b4e162ecd5c89ca 100644 --- a/crates/language/src/language_registry.rs +++ b/crates/language/src/language_registry.rs @@ -334,6 +334,9 @@ impl LanguageRegistry { if let Some(adapters) = state.lsp_adapters.get_mut(language_name) { adapters.retain(|adapter| &adapter.name != name) } + state.all_lsp_adapters.remove(name); + state.available_lsp_adapters.remove(name); + state.version += 1; state.reload_count += 1; *state.subscription.0.borrow_mut() = (); @@ -408,30 +411,6 @@ impl LanguageRegistry { cached } - pub fn get_or_register_lsp_adapter( - &self, - language_name: LanguageName, - server_name: LanguageServerName, - build_adapter: impl FnOnce() -> Arc<dyn LspAdapter> + 'static, - ) -> Arc<CachedLspAdapter> { - let registered = self - .state - .write() - .lsp_adapters - .entry(language_name.clone()) - .or_default() - .iter() - .find(|cached_adapter| cached_adapter.name == server_name) - .cloned(); - - if let Some(found) = registered { - found - } else { - let adapter = build_adapter(); - self.register_lsp_adapter(language_name, adapter) - } - } - /// Register a fake language server and adapter /// The returned channel receives a new instance of the language server every time it is started #[cfg(any(feature = "test-support", test))] @@ -568,15 +547,15 @@ impl LanguageRegistry { self.state.read().language_settings.clone() } - pub fn language_names(&self) -> Vec<String> { + pub fn language_names(&self) -> Vec<LanguageName> { let state = self.state.read(); let mut result = state .available_languages .iter() - .filter_map(|l| l.loaded.not().then_some(l.name.to_string())) - .chain(state.languages.iter().map(|l| l.config.name.to_string())) + .filter_map(|l| l.loaded.not().then_some(l.name.clone())) + .chain(state.languages.iter().map(|l| l.config.name.clone())) .collect::<Vec<_>>(); - result.sort_unstable_by_key(|language_name| language_name.to_lowercase()); + result.sort_unstable_by_key(|language_name| language_name.as_ref().to_lowercase()); result } diff --git a/crates/language/src/syntax_map.rs b/crates/language/src/syntax_map.rs index da05416e894e6a713121affe767f71d953408684..c56ffed0663a9419419201f902f3db8311acb9bd 100644 --- a/crates/language/src/syntax_map.rs +++ b/crates/language/src/syntax_map.rs @@ -17,7 +17,7 @@ use std::{ sync::Arc, }; use streaming_iterator::StreamingIterator; -use sum_tree::{Bias, SeekTarget, SumTree}; +use sum_tree::{Bias, Dimensions, SeekTarget, SumTree}; use text::{Anchor, BufferSnapshot, OffsetRangeExt, Point, Rope, ToOffset, ToPoint}; use tree_sitter::{Node, Query, QueryCapture, QueryCaptures, QueryCursor, QueryMatches, Tree}; @@ -285,7 +285,7 @@ impl SyntaxSnapshot { pub fn interpolate(&mut self, text: &BufferSnapshot) { let edits = text - .anchored_edits_since::<(usize, Point)>(&self.interpolated_version) + .anchored_edits_since::<Dimensions<usize, Point>>(&self.interpolated_version) .collect::<Vec<_>>(); self.interpolated_version = text.version().clone(); @@ -297,10 +297,10 @@ impl SyntaxSnapshot { let mut first_edit_ix_for_depth = 0; let mut prev_depth = 0; let mut cursor = self.layers.cursor::<SyntaxLayerSummary>(text); - cursor.next(text); + cursor.next(); 'outer: loop { - let depth = cursor.end(text).max_depth; + let depth = cursor.end().max_depth; if depth > prev_depth { first_edit_ix_for_depth = 0; prev_depth = depth; @@ -313,7 +313,7 @@ impl SyntaxSnapshot { position: edit_range.start, }; if target.cmp(cursor.start(), text).is_gt() { - let slice = cursor.slice(&target, Bias::Left, text); + let slice = cursor.slice(&target, Bias::Left); layers.append(slice, text); } } @@ -327,14 +327,14 @@ impl SyntaxSnapshot { language: None, }, Bias::Left, - text, ); layers.append(slice, text); continue; }; let Some(layer) = cursor.item() else { break }; - let (start_byte, start_point) = layer.range.start.summary::<(usize, Point)>(text); + let Dimensions(start_byte, start_point, _) = + layer.range.start.summary::<Dimensions<usize, Point>>(text); // Ignore edits that end before the start of this layer, and don't consider them // for any subsequent layers at this same depth. @@ -394,10 +394,10 @@ impl SyntaxSnapshot { } layers.push(layer, text); - cursor.next(text); + cursor.next(); } - layers.append(cursor.suffix(text), text); + layers.append(cursor.suffix(), text); drop(cursor); self.layers = layers; } @@ -420,7 +420,7 @@ impl SyntaxSnapshot { let mut cursor = self .layers .filter::<_, ()>(text, |summary| summary.contains_unknown_injections); - cursor.next(text); + cursor.next(); while let Some(layer) = cursor.item() { let SyntaxLayerContent::Pending { language_name } = &layer.content else { unreachable!() @@ -436,7 +436,7 @@ impl SyntaxSnapshot { resolved_injection_ranges.push(range); } - cursor.next(text); + cursor.next(); } drop(cursor); @@ -469,7 +469,7 @@ impl SyntaxSnapshot { let max_depth = self.layers.summary().max_depth; let mut cursor = self.layers.cursor::<SyntaxLayerSummary>(text); - cursor.next(text); + cursor.next(); let mut layers = SumTree::new(text); let mut changed_regions = ChangeRegionSet::default(); @@ -514,7 +514,7 @@ impl SyntaxSnapshot { }; let mut done = cursor.item().is_none(); - while !done && position.cmp(&cursor.end(text), text).is_gt() { + while !done && position.cmp(&cursor.end(), text).is_gt() { done = true; let bounded_position = SyntaxLayerPositionBeforeChange { @@ -522,16 +522,16 @@ impl SyntaxSnapshot { change: changed_regions.start_position(), }; if bounded_position.cmp(cursor.start(), text).is_gt() { - let slice = cursor.slice(&bounded_position, Bias::Left, text); + let slice = cursor.slice(&bounded_position, Bias::Left); if !slice.is_empty() { layers.append(slice, text); - if changed_regions.prune(cursor.end(text), text) { + if changed_regions.prune(cursor.end(), text) { done = false; } } } - while position.cmp(&cursor.end(text), text).is_gt() { + while position.cmp(&cursor.end(), text).is_gt() { let Some(layer) = cursor.item() else { break }; if changed_regions.intersects(layer, text) { @@ -555,16 +555,16 @@ impl SyntaxSnapshot { layers.push(layer.clone(), text); } - cursor.next(text); - if changed_regions.prune(cursor.end(text), text) { + cursor.next(); + if changed_regions.prune(cursor.end(), text) { done = false; } } } let Some(step) = step else { break }; - let (step_start_byte, step_start_point) = - step.range.start.summary::<(usize, Point)>(text); + let Dimensions(step_start_byte, step_start_point, _) = + step.range.start.summary::<Dimensions<usize, Point>>(text); let step_end_byte = step.range.end.to_offset(text); let mut old_layer = cursor.item(); @@ -572,7 +572,7 @@ impl SyntaxSnapshot { if layer.range.to_offset(text) == (step_start_byte..step_end_byte) && layer.content.language_id() == step.language.id() { - cursor.next(text); + cursor.next(); } else { old_layer = None; } @@ -918,7 +918,7 @@ impl SyntaxSnapshot { } }); - cursor.next(buffer); + cursor.next(); iter::from_fn(move || { while let Some(layer) = cursor.item() { let mut info = None; @@ -940,7 +940,7 @@ impl SyntaxSnapshot { }); } } - cursor.next(buffer); + cursor.next(); if info.is_some() { return info; } diff --git a/crates/language_extension/src/extension_lsp_adapter.rs b/crates/language_extension/src/extension_lsp_adapter.rs index 58fbe6cda269c768abdbcb90a1e098804eb1d869..98b6fd4b5a2ef6e7f1b5adbc54dcecd0707b60ff 100644 --- a/crates/language_extension/src/extension_lsp_adapter.rs +++ b/crates/language_extension/src/extension_lsp_adapter.rs @@ -242,7 +242,7 @@ impl LspAdapter for ExtensionLspAdapter { ])) } - fn language_ids(&self) -> HashMap<String, String> { + fn language_ids(&self) -> HashMap<LanguageName, String> { // TODO: The language IDs can be provided via the language server options // in `extension.toml now but we're leaving these existing usages in place temporarily // to avoid any compatibility issues between Zed and the extension versions. @@ -250,7 +250,7 @@ impl LspAdapter for ExtensionLspAdapter { // We can remove once the following extension versions no longer see any use: // - php@0.0.1 if self.extension.manifest().id.as_ref() == "php" { - return HashMap::from_iter([("PHP".into(), "php".into())]); + return HashMap::from_iter([(LanguageName::new("PHP"), "php".into())]); } self.extension diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index b718c530f5cd59f25593fb3c5261bdc706839223..f9920623b5ea3bff79535f92753fae0b723f850f 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -20,6 +20,8 @@ anthropic = { workspace = true, features = ["schemars"] } anyhow.workspace = true base64.workspace = true client.workspace = true +cloud_api_types.workspace = true +cloud_llm_client.workspace = true collections.workspace = true futures.workspace = true gpui.workspace = true @@ -37,7 +39,6 @@ telemetry_events.workspace = true thiserror.workspace = true util.workspace = true workspace-hack.workspace = true -zed_llm_client.workspace = true [dev-dependencies] gpui = { workspace = true, features = ["test-support"] } diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs index f5191016d89d1418922865bc2eaddada945d072a..a9c7d5c0343295ff02d9d693f2cdbe3d92f1e07d 100644 --- a/crates/language_model/src/fake_provider.rs +++ b/crates/language_model/src/fake_provider.rs @@ -10,25 +10,21 @@ use http_client::Result; use parking_lot::Mutex; use std::sync::Arc; -pub fn language_model_id() -> LanguageModelId { - LanguageModelId::from("fake".to_string()) +#[derive(Clone)] +pub struct FakeLanguageModelProvider { + id: LanguageModelProviderId, + name: LanguageModelProviderName, } -pub fn language_model_name() -> LanguageModelName { - LanguageModelName::from("Fake".to_string()) -} - -pub fn provider_id() -> LanguageModelProviderId { - LanguageModelProviderId::from("fake".to_string()) -} - -pub fn provider_name() -> LanguageModelProviderName { - LanguageModelProviderName::from("Fake".to_string()) +impl Default for FakeLanguageModelProvider { + fn default() -> Self { + Self { + id: LanguageModelProviderId::from("fake".to_string()), + name: LanguageModelProviderName::from("Fake".to_string()), + } + } } -#[derive(Clone, Default)] -pub struct FakeLanguageModelProvider; - impl LanguageModelProviderState for FakeLanguageModelProvider { type ObservableEntity = (); @@ -39,11 +35,11 @@ impl LanguageModelProviderState for FakeLanguageModelProvider { impl LanguageModelProvider for FakeLanguageModelProvider { fn id(&self) -> LanguageModelProviderId { - provider_id() + self.id.clone() } fn name(&self) -> LanguageModelProviderName { - provider_name() + self.name.clone() } fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> { @@ -76,6 +72,10 @@ impl LanguageModelProvider for FakeLanguageModelProvider { } impl FakeLanguageModelProvider { + pub fn new(id: LanguageModelProviderId, name: LanguageModelProviderName) -> Self { + Self { id, name } + } + pub fn test_model(&self) -> FakeLanguageModel { FakeLanguageModel::default() } @@ -89,9 +89,25 @@ pub struct ToolUseRequest { pub schema: serde_json::Value, } -#[derive(Default)] pub struct FakeLanguageModel { - current_completion_txs: Mutex<Vec<(LanguageModelRequest, mpsc::UnboundedSender<String>)>>, + provider_id: LanguageModelProviderId, + provider_name: LanguageModelProviderName, + current_completion_txs: Mutex< + Vec<( + LanguageModelRequest, + mpsc::UnboundedSender<LanguageModelCompletionEvent>, + )>, + >, +} + +impl Default for FakeLanguageModel { + fn default() -> Self { + Self { + provider_id: LanguageModelProviderId::from("fake".to_string()), + provider_name: LanguageModelProviderName::from("Fake".to_string()), + current_completion_txs: Mutex::new(Vec::new()), + } + } } impl FakeLanguageModel { @@ -107,10 +123,21 @@ impl FakeLanguageModel { self.current_completion_txs.lock().len() } - pub fn stream_completion_response( + pub fn send_completion_stream_text_chunk( &self, request: &LanguageModelRequest, chunk: impl Into<String>, + ) { + self.send_completion_stream_event( + request, + LanguageModelCompletionEvent::Text(chunk.into()), + ); + } + + pub fn send_completion_stream_event( + &self, + request: &LanguageModelRequest, + event: impl Into<LanguageModelCompletionEvent>, ) { let current_completion_txs = self.current_completion_txs.lock(); let tx = current_completion_txs @@ -118,7 +145,7 @@ impl FakeLanguageModel { .find(|(req, _)| req == request) .map(|(_, tx)| tx) .unwrap(); - tx.unbounded_send(chunk.into()).unwrap(); + tx.unbounded_send(event.into()).unwrap(); } pub fn end_completion_stream(&self, request: &LanguageModelRequest) { @@ -127,8 +154,15 @@ impl FakeLanguageModel { .retain(|(req, _)| req != request); } - pub fn stream_last_completion_response(&self, chunk: impl Into<String>) { - self.stream_completion_response(self.pending_completions().last().unwrap(), chunk); + pub fn send_last_completion_stream_text_chunk(&self, chunk: impl Into<String>) { + self.send_completion_stream_text_chunk(self.pending_completions().last().unwrap(), chunk); + } + + pub fn send_last_completion_stream_event( + &self, + event: impl Into<LanguageModelCompletionEvent>, + ) { + self.send_completion_stream_event(self.pending_completions().last().unwrap(), event); } pub fn end_last_completion_stream(&self) { @@ -138,19 +172,19 @@ impl FakeLanguageModel { impl LanguageModel for FakeLanguageModel { fn id(&self) -> LanguageModelId { - language_model_id() + LanguageModelId::from("fake".to_string()) } fn name(&self) -> LanguageModelName { - language_model_name() + LanguageModelName::from("Fake".to_string()) } fn provider_id(&self) -> LanguageModelProviderId { - provider_id() + self.provider_id.clone() } fn provider_name(&self) -> LanguageModelProviderName { - provider_name() + self.provider_name.clone() } fn supports_tools(&self) -> bool { @@ -190,12 +224,7 @@ impl LanguageModel for FakeLanguageModel { > { let (tx, rx) = mpsc::unbounded(); self.current_completion_txs.lock().push((request, tx)); - async move { - Ok(rx - .map(|text| Ok(LanguageModelCompletionEvent::Text(text))) - .boxed()) - } - .boxed() + async move { Ok(rx.map(Ok).boxed()) }.boxed() } fn as_fake(&self) -> &Self { diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 81a0f7d8a1c5096989e8f7bf7ce140575950f281..1637d2de8a3c14b910ea345c03a4eb5db13df28d 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -11,6 +11,7 @@ pub mod fake_provider; use anthropic::{AnthropicError, parse_prompt_too_long}; use anyhow::{Result, anyhow}; use client::Client; +use cloud_llm_client::{CompletionMode, CompletionRequestStatus}; use futures::FutureExt; use futures::{StreamExt, future::BoxFuture, stream::BoxStream}; use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window}; @@ -26,7 +27,6 @@ use std::time::Duration; use std::{fmt, io}; use thiserror::Error; use util::serde::is_default; -use zed_llm_client::{CompletionMode, CompletionRequestStatus}; pub use crate::model::*; pub use crate::rate_limiter::*; @@ -116,6 +116,12 @@ pub enum LanguageModelCompletionError { provider: LanguageModelProviderName, message: String, }, + #[error("{message}")] + UpstreamProviderError { + message: String, + status: StatusCode, + retry_after: Option<Duration>, + }, #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")] HttpResponseError { provider: LanguageModelProviderName, @@ -178,6 +184,21 @@ pub enum LanguageModelCompletionError { } impl LanguageModelCompletionError { + fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> { + let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?; + let upstream_status = error_json + .get("upstream_status") + .and_then(|v| v.as_u64()) + .and_then(|status| u16::try_from(status).ok()) + .and_then(|status| StatusCode::from_u16(status).ok())?; + let inner_message = error_json + .get("message") + .and_then(|v| v.as_str()) + .unwrap_or(message) + .to_string(); + Some((upstream_status, inner_message)) + } + pub fn from_cloud_failure( upstream_provider: LanguageModelProviderName, code: String, @@ -191,6 +212,18 @@ impl LanguageModelCompletionError { Self::PromptTooLarge { tokens: Some(tokens), } + } else if code == "upstream_http_error" { + if let Some((upstream_status, inner_message)) = + Self::parse_upstream_error_json(&message) + { + return Self::from_http_status( + upstream_provider, + upstream_status, + inner_message, + retry_after, + ); + } + anyhow!("completion request failed, code: {code}, message: {message}").into() } else if let Some(status_code) = code .strip_prefix("upstream_http_") .and_then(|code| StatusCode::from_str(code).ok()) @@ -621,7 +654,7 @@ pub enum LanguageModelProviderTosView { ThreadEmptyState, /// When there are no past interactions in the Agent Panel. ThreadFreshStart, - PromptEditorPopup, + TextThreadPopup, Configuration, } @@ -701,3 +734,116 @@ impl From<String> for LanguageModelProviderName { Self(SharedString::from(value)) } } + +impl From<Arc<str>> for LanguageModelProviderId { + fn from(value: Arc<str>) -> Self { + Self(SharedString::from(value)) + } +} + +impl From<Arc<str>> for LanguageModelProviderName { + fn from(value: Arc<str>) -> Self { + Self(SharedString::from(value)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_from_cloud_failure_with_upstream_http_error() { + let error = LanguageModelCompletionError::from_cloud_failure( + String::from("anthropic").into(), + "upstream_http_error".to_string(), + r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(), + None, + ); + + match error { + LanguageModelCompletionError::ServerOverloaded { provider, .. } => { + assert_eq!(provider.0, "anthropic"); + } + _ => panic!( + "Expected ServerOverloaded error for 503 status, got: {:?}", + error + ), + } + + let error = LanguageModelCompletionError::from_cloud_failure( + String::from("anthropic").into(), + "upstream_http_error".to_string(), + r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(), + None, + ); + + match error { + LanguageModelCompletionError::ApiInternalServerError { provider, message } => { + assert_eq!(provider.0, "anthropic"); + assert_eq!(message, "Internal server error"); + } + _ => panic!( + "Expected ApiInternalServerError for 500 status, got: {:?}", + error + ), + } + } + + #[test] + fn test_from_cloud_failure_with_standard_format() { + let error = LanguageModelCompletionError::from_cloud_failure( + String::from("anthropic").into(), + "upstream_http_503".to_string(), + "Service unavailable".to_string(), + None, + ); + + match error { + LanguageModelCompletionError::ServerOverloaded { provider, .. } => { + assert_eq!(provider.0, "anthropic"); + } + _ => panic!("Expected ServerOverloaded error for upstream_http_503"), + } + } + + #[test] + fn test_upstream_http_error_connection_timeout() { + let error = LanguageModelCompletionError::from_cloud_failure( + String::from("anthropic").into(), + "upstream_http_error".to_string(), + r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(), + None, + ); + + match error { + LanguageModelCompletionError::ServerOverloaded { provider, .. } => { + assert_eq!(provider.0, "anthropic"); + } + _ => panic!( + "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}", + error + ), + } + + let error = LanguageModelCompletionError::from_cloud_failure( + String::from("anthropic").into(), + "upstream_http_error".to_string(), + r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(), + None, + ); + + match error { + LanguageModelCompletionError::ApiInternalServerError { provider, message } => { + assert_eq!(provider.0, "anthropic"); + assert_eq!( + message, + "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout" + ); + } + _ => panic!( + "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}", + error + ), + } + } +} diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 72b7132c60c5536883107c3f186f0a0f54d46ea1..3b4c1fa269020d1bf17d98cbb67251902536dafc 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -3,10 +3,9 @@ use std::sync::Arc; use anyhow::Result; use client::Client; -use gpui::{ - App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _, -}; -use proto::{Plan, TypedEnvelope}; +use cloud_api_types::websocket_protocol::MessageToClient; +use cloud_llm_client::Plan; +use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _}; use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}; use thiserror::Error; @@ -30,7 +29,7 @@ pub struct ModelRequestLimitReachedError { impl fmt::Display for ModelRequestLimitReachedError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let message = match self.plan { - Plan::Free => "Model request limit reached. Upgrade to Zed Pro for more requests.", + Plan::ZedFree => "Model request limit reached. Upgrade to Zed Pro for more requests.", Plan::ZedPro => { "Model request limit reached. Upgrade to usage-based billing for more requests." } @@ -64,9 +63,14 @@ impl LlmApiToken { mut lock: RwLockWriteGuard<'_, Option<String>>, client: &Arc<Client>, ) -> Result<String> { - let response = client.request(proto::GetLlmToken {}).await?; - *lock = Some(response.token.clone()); - Ok(response.token.clone()) + let system_id = client + .telemetry() + .system_id() + .map(|system_id| system_id.to_string()); + + let response = client.cloud_client().create_llm_token(system_id).await?; + *lock = Some(response.token.0.clone()); + Ok(response.token.0.clone()) } } @@ -76,9 +80,7 @@ impl Global for GlobalRefreshLlmTokenListener {} pub struct RefreshLlmTokenEvent; -pub struct RefreshLlmTokenListener { - _llm_token_subscription: client::Subscription, -} +pub struct RefreshLlmTokenListener; impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {} @@ -93,17 +95,21 @@ impl RefreshLlmTokenListener { } fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self { - Self { - _llm_token_subscription: client - .add_message_handler(cx.weak_entity(), Self::handle_refresh_llm_token), - } + client.add_message_to_client_handler({ + let this = cx.entity(); + move |message, cx| { + Self::handle_refresh_llm_token(this.clone(), message, cx); + } + }); + + Self } - async fn handle_refresh_llm_token( - this: Entity<Self>, - _: TypedEnvelope<proto::RefreshLlmToken>, - mut cx: AsyncApp, - ) -> Result<()> { - this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent)) + fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) { + match message { + MessageToClient::UserUpdated => { + this.update(cx, |_this, cx| cx.emit(RefreshLlmTokenEvent)); + } + } } } diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index 840fda38dec4714a32f3397a28dd2d116bb67f5d..7cf071808a2c0d95bf9aa5a41eaa260cff533d57 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -125,7 +125,7 @@ impl LanguageModelRegistry { #[cfg(any(test, feature = "test-support"))] pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider { - let fake_provider = crate::fake_provider::FakeLanguageModelProvider; + let fake_provider = crate::fake_provider::FakeLanguageModelProvider::default(); let registry = cx.new(|cx| { let mut registry = Self::default(); registry.register_provider(fake_provider.clone(), cx); @@ -206,8 +206,8 @@ impl LanguageModelRegistry { None } - /// Check that we have at least one provider that is authenticated. - fn has_authenticated_provider(&self, cx: &App) -> bool { + /// Returns `true` if at least one provider that is authenticated. + pub fn has_authenticated_provider(&self, cx: &App) -> bool { self.providers.values().any(|p| p.is_authenticated(cx)) } @@ -403,16 +403,17 @@ mod tests { fn test_register_providers(cx: &mut App) { let registry = cx.new(|_| LanguageModelRegistry::default()); + let provider = FakeLanguageModelProvider::default(); registry.update(cx, |registry, cx| { - registry.register_provider(FakeLanguageModelProvider, cx); + registry.register_provider(provider.clone(), cx); }); let providers = registry.read(cx).providers(); assert_eq!(providers.len(), 1); - assert_eq!(providers[0].id(), crate::fake_provider::provider_id()); + assert_eq!(providers[0].id(), provider.id()); registry.update(cx, |registry, cx| { - registry.unregister_provider(crate::fake_provider::provider_id(), cx); + registry.unregister_provider(provider.id(), cx); }); let providers = registry.read(cx).providers(); diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index 6f3d420ad5ac1304daf1f3341b2fb05da8662a18..dc485e9937f7e6579f99212609e6f2383c11ae6c 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -1,10 +1,9 @@ use std::io::{Cursor, Write}; use std::sync::Arc; -use crate::role::Role; -use crate::{LanguageModelToolUse, LanguageModelToolUseId}; use anyhow::Result; use base64::write::EncoderWriter; +use cloud_llm_client::{CompletionIntent, CompletionMode}; use gpui::{ App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, SharedString, Size, Task, point, px, size, @@ -12,7 +11,9 @@ use gpui::{ use image::codecs::png::PngEncoder; use serde::{Deserialize, Serialize}; use util::ResultExt; -use zed_llm_client::{CompletionIntent, CompletionMode}; + +use crate::role::Role; +use crate::{LanguageModelToolUse, LanguageModelToolUseId}; #[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] pub struct LanguageModelImage { diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index 514443ddec57cff2efa94261f7c94dce25f609cd..b5bfb870f643452bd5be248c9910d99f16a8101e 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -12,29 +12,29 @@ workspace = true path = "src/language_models.rs" [dependencies] +ai_onboarding.workspace = true anthropic = { workspace = true, features = ["schemars"] } anyhow.workspace = true aws-config = { workspace = true, features = ["behavior-version-latest"] } -aws-credential-types = { workspace = true, features = [ - "hardcoded-credentials", -] } +aws-credential-types = { workspace = true, features = ["hardcoded-credentials"] } aws_http_client.workspace = true bedrock.workspace = true chrono.workspace = true client.workspace = true +cloud_llm_client.workspace = true collections.workspace = true component.workspace = true -credentials_provider.workspace = true +convert_case.workspace = true copilot.workspace = true +credentials_provider.workspace = true deepseek = { workspace = true, features = ["schemars"] } editor.workspace = true -feature_flags.workspace = true -fs.workspace = true futures.workspace = true google_ai = { workspace = true, features = ["schemars"] } gpui.workspace = true gpui_tokio.workspace = true http_client.workspace = true +language.workspace = true language_model.workspace = true lmstudio = { workspace = true, features = ["schemars"] } log.workspace = true @@ -43,9 +43,7 @@ mistral = { workspace = true, features = ["schemars"] } ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } open_router = { workspace = true, features = ["schemars"] } -vercel = { workspace = true, features = ["schemars"] } partial-json-fixer.workspace = true -proto.workspace = true release_channel.workspace = true schemars.workspace = true serde.workspace = true @@ -60,9 +58,9 @@ tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } ui.workspace = true ui_input.workspace = true util.workspace = true +vercel = { workspace = true, features = ["schemars"] } workspace-hack.workspace = true -zed_llm_client.workspace = true -language.workspace = true +x_ai = { workspace = true, features = ["schemars"] } [dev-dependencies] editor = { workspace = true, features = ["test-support"] } diff --git a/crates/language_models/src/language_models.rs b/crates/language_models/src/language_models.rs index c7324732c9bbf88698a1a7280ff80cea077a1d2f..18e6f47ed0591256591df578f98dcaf988ed6444 100644 --- a/crates/language_models/src/language_models.rs +++ b/crates/language_models/src/language_models.rs @@ -1,8 +1,10 @@ use std::sync::Arc; +use ::settings::{Settings, SettingsStore}; use client::{Client, UserStore}; +use collections::HashSet; use gpui::{App, Context, Entity}; -use language_model::LanguageModelRegistry; +use language_model::{LanguageModelProviderId, LanguageModelRegistry}; use provider::deepseek::DeepSeekLanguageModelProvider; pub mod provider; @@ -18,16 +20,81 @@ use crate::provider::lmstudio::LmStudioLanguageModelProvider; use crate::provider::mistral::MistralLanguageModelProvider; use crate::provider::ollama::OllamaLanguageModelProvider; use crate::provider::open_ai::OpenAiLanguageModelProvider; +use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider; use crate::provider::open_router::OpenRouterLanguageModelProvider; use crate::provider::vercel::VercelLanguageModelProvider; +use crate::provider::x_ai::XAiLanguageModelProvider; pub use crate::settings::*; pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) { - crate::settings::init(cx); + crate::settings::init_settings(cx); let registry = LanguageModelRegistry::global(cx); registry.update(cx, |registry, cx| { - register_language_model_providers(registry, user_store, client, cx); + register_language_model_providers(registry, user_store, client.clone(), cx); }); + + let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx) + .openai_compatible + .keys() + .cloned() + .collect::<HashSet<_>>(); + + registry.update(cx, |registry, cx| { + register_openai_compatible_providers( + registry, + &HashSet::default(), + &openai_compatible_providers, + client.clone(), + cx, + ); + }); + cx.observe_global::<SettingsStore>(move |cx| { + let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx) + .openai_compatible + .keys() + .cloned() + .collect::<HashSet<_>>(); + if openai_compatible_providers_new != openai_compatible_providers { + registry.update(cx, |registry, cx| { + register_openai_compatible_providers( + registry, + &openai_compatible_providers, + &openai_compatible_providers_new, + client.clone(), + cx, + ); + }); + openai_compatible_providers = openai_compatible_providers_new; + } + }) + .detach(); +} + +fn register_openai_compatible_providers( + registry: &mut LanguageModelRegistry, + old: &HashSet<Arc<str>>, + new: &HashSet<Arc<str>>, + client: Arc<Client>, + cx: &mut Context<LanguageModelRegistry>, +) { + for provider_id in old { + if !new.contains(provider_id) { + registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx); + } + } + + for provider_id in new { + if !old.contains(provider_id) { + registry.register_provider( + OpenAiCompatibleLanguageModelProvider::new( + provider_id.clone(), + client.http_client(), + cx, + ), + cx, + ); + } + } } fn register_language_model_providers( @@ -81,5 +148,6 @@ fn register_language_model_providers( VercelLanguageModelProvider::new(client.http_client(), cx), cx, ); + registry.register_provider(XAiLanguageModelProvider::new(client.http_client(), cx), cx); registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx); } diff --git a/crates/language_models/src/provider.rs b/crates/language_models/src/provider.rs index 6bc93bd3661e86fc2c8f9bacafaf2d4121e0f7a6..d780195c66ec0d19c2b7d53e62b5e3629baa8a43 100644 --- a/crates/language_models/src/provider.rs +++ b/crates/language_models/src/provider.rs @@ -8,5 +8,7 @@ pub mod lmstudio; pub mod mistral; pub mod ollama; pub mod open_ai; +pub mod open_ai_compatible; pub mod open_router; pub mod vercel; +pub mod x_ai; diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index 959cbccf39bcd4660d4336325cc9e5268c8e99c8..ef21e85f711e41722d4ac421ba1d0a89b422b6a6 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -1012,7 +1012,7 @@ impl Render for ConfigurationView { v_flex() .size_full() .on_action(cx.listener(Self::save_api_key)) - .child(Label::new("To use Zed's assistant with Anthropic, you need to add an API key. Follow these steps:")) + .child(Label::new("To use Zed's agent with Anthropic, you need to add an API key. Follow these steps:")) .child( List::new() .child( diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index 65ce1dbc4b61cb1d6432fa6e6011aadc4479613f..6df96c5c566aac6f23af837491292cc89a56c74a 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -243,7 +243,7 @@ impl State { pub struct BedrockLanguageModelProvider { http_client: AwsHttpClient, - handler: tokio::runtime::Handle, + handle: tokio::runtime::Handle, state: gpui::Entity<State>, } @@ -258,13 +258,9 @@ impl BedrockLanguageModelProvider { }), }); - let tokio_handle = Tokio::handle(cx); - - let coerced_client = AwsHttpClient::new(http_client.clone(), tokio_handle.clone()); - Self { - http_client: coerced_client, - handler: tokio_handle.clone(), + http_client: AwsHttpClient::new(http_client.clone()), + handle: Tokio::handle(cx), state, } } @@ -274,7 +270,7 @@ impl BedrockLanguageModelProvider { id: LanguageModelId::from(model.id().to_string()), model, http_client: self.http_client.clone(), - handler: self.handler.clone(), + handle: self.handle.clone(), state: self.state.clone(), client: OnceCell::new(), request_limiter: RateLimiter::new(4), @@ -375,7 +371,7 @@ struct BedrockModel { id: LanguageModelId, model: Model, http_client: AwsHttpClient, - handler: tokio::runtime::Handle, + handle: tokio::runtime::Handle, client: OnceCell<BedrockClient>, state: gpui::Entity<State>, request_limiter: RateLimiter, @@ -447,7 +443,7 @@ impl BedrockModel { } } - let config = self.handler.block_on(config_builder.load()); + let config = self.handle.block_on(config_builder.load()); anyhow::Ok(BedrockClient::new(&config)) }) .context("initializing Bedrock client")?; @@ -1255,7 +1251,7 @@ impl Render for ConfigurationView { v_flex() .size_full() .on_action(cx.listener(ConfigurationView::save_credentials)) - .child(Label::new("To use Zed's assistant with Bedrock, you can set a custom authentication strategy through the settings.json, or use static credentials.")) + .child(Label::new("To use Zed's agent with Bedrock, you can set a custom authentication strategy through the settings.json, or use static credentials.")) .child(Label::new("But, to access models on AWS, you need to:").mt_1()) .child( List::new() diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index ede84713f17fca49ded8f7c927763d7c907a90c6..40dd12076113ade6cff3563795472482573faf16 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -1,8 +1,15 @@ +use ai_onboarding::YoungAccountBanner; use anthropic::AnthropicModelMode; use anyhow::{Context as _, Result, anyhow}; use chrono::{DateTime, Utc}; use client::{Client, ModelRequestUsage, UserStore, zed_urls}; -use feature_flags::{FeatureFlagAppExt as _, ZedCloudFeatureFlag}; +use cloud_llm_client::{ + CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody, + CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse, + EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE, Plan, + SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME, + TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME, +}; use futures::{ AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream, }; @@ -20,7 +27,6 @@ use language_model::{ LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener, }; -use proto::Plan; use release_channel::AppVersion; use schemars::JsonSchema; use serde::{Deserialize, Serialize, de::DeserializeOwned}; @@ -33,13 +39,6 @@ use std::time::Duration; use thiserror::Error; use ui::{TintColor, prelude::*}; use util::{ResultExt as _, maybe}; -use zed_llm_client::{ - CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody, - CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME, - ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE, - SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME, - TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME, -}; use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic}; use crate::provider::google::{GoogleEventMapper, into_google}; @@ -120,10 +119,10 @@ pub struct State { user_store: Entity<UserStore>, status: client::Status, accept_terms_of_service_task: Option<Task<Result<()>>>, - models: Vec<Arc<zed_llm_client::LanguageModel>>, - default_model: Option<Arc<zed_llm_client::LanguageModel>>, - default_fast_model: Option<Arc<zed_llm_client::LanguageModel>>, - recommended_models: Vec<Arc<zed_llm_client::LanguageModel>>, + models: Vec<Arc<cloud_llm_client::LanguageModel>>, + default_model: Option<Arc<cloud_llm_client::LanguageModel>>, + default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>, + recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>, _fetch_models_task: Task<()>, _settings_subscription: Subscription, _llm_token_subscription: Subscription, @@ -137,12 +136,11 @@ impl State { cx: &mut Context<Self>, ) -> Self { let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); - let use_cloud = cx.has_flag::<ZedCloudFeatureFlag>(); - + let mut current_user = user_store.read(cx).watch_current_user(); Self { client: client.clone(), llm_api_token: LlmApiToken::default(), - user_store, + user_store: user_store.clone(), status, accept_terms_of_service_task: None, models: Vec::new(), @@ -154,21 +152,14 @@ impl State { let (client, llm_api_token) = this .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?; - loop { - let status = this.read_with(cx, |this, _cx| this.status)?; - if matches!(status, client::Status::Connected { .. }) { - break; - } - - cx.background_executor() - .timer(Duration::from_millis(100)) - .await; + while current_user.borrow().is_none() { + current_user.next().await; } - let response = Self::fetch_models(client, llm_api_token, use_cloud).await?; - this.update(cx, |this, cx| { - this.update_models(response, cx); - }) + let response = + Self::fetch_models(client.clone(), llm_api_token.clone()).await?; + this.update(cx, |this, cx| this.update_models(response, cx))?; + anyhow::Ok(()) }) .await .context("failed to fetch Zed models") @@ -184,7 +175,7 @@ impl State { let llm_api_token = this.llm_api_token.clone(); cx.spawn(async move |this, cx| { llm_api_token.refresh(&client).await?; - let response = Self::fetch_models(client, llm_api_token, use_cloud).await?; + let response = Self::fetch_models(client, llm_api_token).await?; this.update(cx, |this, cx| { this.update_models(response, cx); }) @@ -195,26 +186,20 @@ impl State { } } - fn is_signed_out(&self) -> bool { - self.status.is_signed_out() + fn is_signed_out(&self, cx: &App) -> bool { + self.user_store.read(cx).current_user().is_none() } fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> { let client = self.client.clone(); cx.spawn(async move |state, cx| { - client - .authenticate_and_connect(true, &cx) - .await - .into_response()?; + client.sign_in_with_optional_connect(true, &cx).await?; state.update(cx, |_, cx| cx.notify()) }) } fn has_accepted_terms_of_service(&self, cx: &App) -> bool { - self.user_store - .read(cx) - .current_user_has_accepted_terms() - .unwrap_or(false) + self.user_store.read(cx).has_accepted_terms_of_service() } fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) { @@ -239,8 +224,8 @@ impl State { // Right now we represent thinking variants of models as separate models on the client, // so we need to insert variants for any model that supports thinking. if model.supports_thinking { - models.push(Arc::new(zed_llm_client::LanguageModel { - id: zed_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()), + models.push(Arc::new(cloud_llm_client::LanguageModel { + id: cloud_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()), display_name: format!("{} Thinking", model.display_name), ..model })); @@ -268,18 +253,13 @@ impl State { async fn fetch_models( client: Arc<Client>, llm_api_token: LlmApiToken, - use_cloud: bool, ) -> Result<ListModelsResponse> { let http_client = &client.http_client(); let token = llm_api_token.acquire(&client).await?; let request = http_client::Request::builder() .method(Method::GET) - .uri( - http_client - .build_zed_llm_url("/models", &[], use_cloud)? - .as_ref(), - ) + .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref()) .header("Authorization", format!("Bearer {token}")) .body(AsyncBody::empty())?; let mut response = http_client @@ -334,7 +314,7 @@ impl CloudLanguageModelProvider { fn create_language_model( &self, - model: Arc<zed_llm_client::LanguageModel>, + model: Arc<cloud_llm_client::LanguageModel>, llm_api_token: LlmApiToken, ) -> Arc<dyn LanguageModel> { Arc::new(CloudLanguageModel { @@ -404,7 +384,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider { fn is_authenticated(&self, cx: &App) -> bool { let state = self.state.read(cx); - !state.is_signed_out() && state.has_accepted_terms_of_service(cx) + !state.is_signed_out(cx) && state.has_accepted_terms_of_service(cx) } fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> { @@ -507,7 +487,7 @@ fn render_accept_terms( ) .child({ match view_kind { - LanguageModelProviderTosView::PromptEditorPopup => { + LanguageModelProviderTosView::TextThreadPopup => { button_container.w_full().justify_end() } LanguageModelProviderTosView::Configuration => { @@ -524,7 +504,7 @@ fn render_accept_terms( pub struct CloudLanguageModel { id: LanguageModelId, - model: Arc<zed_llm_client::LanguageModel>, + model: Arc<cloud_llm_client::LanguageModel>, llm_api_token: LlmApiToken, client: Arc<Client>, request_limiter: RateLimiter, @@ -543,7 +523,6 @@ impl CloudLanguageModel { llm_api_token: LlmApiToken, app_version: Option<SemanticVersion>, body: CompletionBody, - use_cloud: bool, ) -> Result<PerformLlmCompletionResponse> { let http_client = &client.http_client(); @@ -551,11 +530,9 @@ impl CloudLanguageModel { let mut refreshed_token = false; loop { - let request_builder = http_client::Request::builder().method(Method::POST).uri( - http_client - .build_zed_llm_url("/completions", &[], use_cloud)? - .as_ref(), - ); + let request_builder = http_client::Request::builder() + .method(Method::POST) + .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref()); let request_builder = if let Some(app_version) = app_version { request_builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string()) } else { @@ -620,13 +597,8 @@ impl CloudLanguageModel { .headers() .get(CURRENT_PLAN_HEADER_NAME) .and_then(|plan| plan.to_str().ok()) - .and_then(|plan| zed_llm_client::Plan::from_str(plan).ok()) + .and_then(|plan| cloud_llm_client::Plan::from_str(plan).ok()) { - let plan = match plan { - zed_llm_client::Plan::ZedFree => Plan::Free, - zed_llm_client::Plan::ZedPro => Plan::ZedPro, - zed_llm_client::Plan::ZedProTrial => Plan::ZedProTrial, - }; return Err(anyhow!(ModelRequestLimitReachedError { plan })); } } @@ -654,8 +626,62 @@ struct ApiError { headers: HeaderMap<HeaderValue>, } +/// Represents error responses from Zed's cloud API. +/// +/// Example JSON for an upstream HTTP error: +/// ```json +/// { +/// "code": "upstream_http_error", +/// "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout", +/// "upstream_status": 503 +/// } +/// ``` +#[derive(Debug, serde::Deserialize)] +struct CloudApiError { + code: String, + message: String, + #[serde(default)] + #[serde(deserialize_with = "deserialize_optional_status_code")] + upstream_status: Option<StatusCode>, + #[serde(default)] + retry_after: Option<f64>, +} + +fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error> +where + D: serde::Deserializer<'de>, +{ + let opt: Option<u16> = Option::deserialize(deserializer)?; + Ok(opt.and_then(|code| StatusCode::from_u16(code).ok())) +} + impl From<ApiError> for LanguageModelCompletionError { fn from(error: ApiError) -> Self { + if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) { + if cloud_error.code.starts_with("upstream_http_") { + let status = if let Some(status) = cloud_error.upstream_status { + status + } else if cloud_error.code.ends_with("_error") { + error.status + } else { + // If there's a status code in the code string (e.g. "upstream_http_429") + // then use that; otherwise, see if the JSON contains a status code. + cloud_error + .code + .strip_prefix("upstream_http_") + .and_then(|code_str| code_str.parse::<u16>().ok()) + .and_then(|code| StatusCode::from_u16(code).ok()) + .unwrap_or(error.status) + }; + + return LanguageModelCompletionError::UpstreamProviderError { + message: cloud_error.message, + status, + retry_after: cloud_error.retry_after.map(Duration::from_secs_f64), + }; + } + } + let retry_after = None; LanguageModelCompletionError::from_http_status( PROVIDER_NAME, @@ -684,7 +710,7 @@ impl LanguageModel for CloudLanguageModel { } fn upstream_provider_id(&self) -> LanguageModelProviderId { - use zed_llm_client::LanguageModelProvider::*; + use cloud_llm_client::LanguageModelProvider::*; match self.model.provider { Anthropic => language_model::ANTHROPIC_PROVIDER_ID, OpenAi => language_model::OPEN_AI_PROVIDER_ID, @@ -693,7 +719,7 @@ impl LanguageModel for CloudLanguageModel { } fn upstream_provider_name(&self) -> LanguageModelProviderName { - use zed_llm_client::LanguageModelProvider::*; + use cloud_llm_client::LanguageModelProvider::*; match self.model.provider { Anthropic => language_model::ANTHROPIC_PROVIDER_NAME, OpenAi => language_model::OPEN_AI_PROVIDER_NAME, @@ -727,11 +753,11 @@ impl LanguageModel for CloudLanguageModel { fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { match self.model.provider { - zed_llm_client::LanguageModelProvider::Anthropic - | zed_llm_client::LanguageModelProvider::OpenAi => { + cloud_llm_client::LanguageModelProvider::Anthropic + | cloud_llm_client::LanguageModelProvider::OpenAi => { LanguageModelToolSchemaFormat::JsonSchema } - zed_llm_client::LanguageModelProvider::Google => { + cloud_llm_client::LanguageModelProvider::Google => { LanguageModelToolSchemaFormat::JsonSchemaSubset } } @@ -750,15 +776,15 @@ impl LanguageModel for CloudLanguageModel { fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> { match &self.model.provider { - zed_llm_client::LanguageModelProvider::Anthropic => { + cloud_llm_client::LanguageModelProvider::Anthropic => { Some(LanguageModelCacheConfiguration { min_total_token: 2_048, should_speculate: true, max_cache_anchors: 4, }) } - zed_llm_client::LanguageModelProvider::OpenAi - | zed_llm_client::LanguageModelProvider::Google => None, + cloud_llm_client::LanguageModelProvider::OpenAi + | cloud_llm_client::LanguageModelProvider::Google => None, } } @@ -768,27 +794,28 @@ impl LanguageModel for CloudLanguageModel { cx: &App, ) -> BoxFuture<'static, Result<u64>> { match self.model.provider { - zed_llm_client::LanguageModelProvider::Anthropic => count_anthropic_tokens(request, cx), - zed_llm_client::LanguageModelProvider::OpenAi => { + cloud_llm_client::LanguageModelProvider::Anthropic => { + count_anthropic_tokens(request, cx) + } + cloud_llm_client::LanguageModelProvider::OpenAi => { let model = match open_ai::Model::from_id(&self.model.id.0) { Ok(model) => model, Err(err) => return async move { Err(anyhow!(err)) }.boxed(), }; count_open_ai_tokens(request, model, cx) } - zed_llm_client::LanguageModelProvider::Google => { + cloud_llm_client::LanguageModelProvider::Google => { let client = self.client.clone(); let llm_api_token = self.llm_api_token.clone(); let model_id = self.model.id.to_string(); let generate_content_request = into_google(request, model_id.clone(), GoogleModelMode::Default); - let use_cloud = cx.has_flag::<ZedCloudFeatureFlag>(); async move { let http_client = &client.http_client(); let token = llm_api_token.acquire(&client).await?; let request_body = CountTokensBody { - provider: zed_llm_client::LanguageModelProvider::Google, + provider: cloud_llm_client::LanguageModelProvider::Google, model: model_id, provider_request: serde_json::to_value(&google_ai::CountTokensRequest { generate_content_request, @@ -798,7 +825,7 @@ impl LanguageModel for CloudLanguageModel { .method(Method::POST) .uri( http_client - .build_zed_llm_url("/count_tokens", &[], use_cloud)? + .build_zed_llm_url("/count_tokens", &[])? .as_ref(), ) .header("Content-Type", "application/json") @@ -847,12 +874,9 @@ impl LanguageModel for CloudLanguageModel { let intent = request.intent; let mode = request.mode; let app_version = cx.update(|cx| AppVersion::global(cx)).ok(); - let use_cloud = cx - .update(|cx| cx.has_flag::<ZedCloudFeatureFlag>()) - .unwrap_or(false); let thinking_allowed = request.thinking_allowed; match self.model.provider { - zed_llm_client::LanguageModelProvider::Anthropic => { + cloud_llm_client::LanguageModelProvider::Anthropic => { let request = into_anthropic( request, self.model.id.to_string(), @@ -883,12 +907,11 @@ impl LanguageModel for CloudLanguageModel { prompt_id, intent, mode, - provider: zed_llm_client::LanguageModelProvider::Anthropic, + provider: cloud_llm_client::LanguageModelProvider::Anthropic, model: request.model.clone(), provider_request: serde_json::to_value(&request) .map_err(|e| anyhow!(e))?, }, - use_cloud, ) .await .map_err(|err| match err.downcast::<ApiError>() { @@ -908,7 +931,7 @@ impl LanguageModel for CloudLanguageModel { }); async move { Ok(future.await?.boxed()) }.boxed() } - zed_llm_client::LanguageModelProvider::OpenAi => { + cloud_llm_client::LanguageModelProvider::OpenAi => { let client = self.client.clone(); let model = match open_ai::Model::from_id(&self.model.id.0) { Ok(model) => model, @@ -936,12 +959,11 @@ impl LanguageModel for CloudLanguageModel { prompt_id, intent, mode, - provider: zed_llm_client::LanguageModelProvider::OpenAi, + provider: cloud_llm_client::LanguageModelProvider::OpenAi, model: request.model.clone(), provider_request: serde_json::to_value(&request) .map_err(|e| anyhow!(e))?, }, - use_cloud, ) .await?; @@ -957,7 +979,7 @@ impl LanguageModel for CloudLanguageModel { }); async move { Ok(future.await?.boxed()) }.boxed() } - zed_llm_client::LanguageModelProvider::Google => { + cloud_llm_client::LanguageModelProvider::Google => { let client = self.client.clone(); let request = into_google(request, self.model.id.to_string(), GoogleModelMode::Default); @@ -977,12 +999,11 @@ impl LanguageModel for CloudLanguageModel { prompt_id, intent, mode, - provider: zed_llm_client::LanguageModelProvider::Google, + provider: cloud_llm_client::LanguageModelProvider::Google, model: request.model.model_id.clone(), provider_request: serde_json::to_value(&request) .map_err(|e| anyhow!(e))?, }, - use_cloud, ) .await?; @@ -1002,15 +1023,8 @@ impl LanguageModel for CloudLanguageModel { } } -#[derive(Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum CloudCompletionEvent<T> { - Status(CompletionRequestStatus), - Event(T), -} - fn map_cloud_completion_events<T, F>( - stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>, + stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>, mut map_callback: F, ) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> where @@ -1025,10 +1039,10 @@ where Err(error) => { vec![Err(LanguageModelCompletionError::from(error))] } - Ok(CloudCompletionEvent::Status(event)) => { + Ok(CompletionEvent::Status(event)) => { vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))] } - Ok(CloudCompletionEvent::Event(event)) => map_callback(event), + Ok(CompletionEvent::Event(event)) => map_callback(event), }) }) .boxed() @@ -1036,9 +1050,9 @@ where fn usage_updated_event<T>( usage: Option<ModelRequestUsage>, -) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> { +) -> impl Stream<Item = Result<CompletionEvent<T>>> { futures::stream::iter(usage.map(|usage| { - Ok(CloudCompletionEvent::Status( + Ok(CompletionEvent::Status( CompletionRequestStatus::UsageUpdated { amount: usage.amount as usize, limit: usage.limit, @@ -1049,9 +1063,9 @@ fn usage_updated_event<T>( fn tool_use_limit_reached_event<T>( tool_use_limit_reached: bool, -) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> { +) -> impl Stream<Item = Result<CompletionEvent<T>>> { futures::stream::iter(tool_use_limit_reached.then(|| { - Ok(CloudCompletionEvent::Status( + Ok(CompletionEvent::Status( CompletionRequestStatus::ToolUseLimitReached, )) })) @@ -1060,7 +1074,7 @@ fn tool_use_limit_reached_event<T>( fn response_lines<T: DeserializeOwned>( response: Response<AsyncBody>, includes_status_messages: bool, -) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> { +) -> impl Stream<Item = Result<CompletionEvent<T>>> { futures::stream::try_unfold( (String::new(), BufReader::new(response.into_body())), move |(mut line, mut body)| async move { @@ -1068,9 +1082,9 @@ fn response_lines<T: DeserializeOwned>( Ok(0) => Ok(None), Ok(_) => { let event = if includes_status_messages { - serde_json::from_str::<CloudCompletionEvent<T>>(&line)? + serde_json::from_str::<CompletionEvent<T>>(&line)? } else { - CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?) + CompletionEvent::Event(serde_json::from_str::<T>(&line)?) }; line.clear(); @@ -1085,10 +1099,11 @@ fn response_lines<T: DeserializeOwned>( #[derive(IntoElement, RegisterComponent)] struct ZedAiConfiguration { is_connected: bool, - plan: Option<proto::Plan>, + plan: Option<Plan>, subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>, eligible_for_trial: bool, has_accepted_terms_of_service: bool, + account_too_young: bool, accept_terms_of_service_in_progress: bool, accept_terms_of_service_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>, sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>, @@ -1096,89 +1111,98 @@ struct ZedAiConfiguration { impl RenderOnce for ZedAiConfiguration { fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement { - const ZED_PRICING_URL: &str = "https://zed.dev/pricing"; + let young_account_banner = YoungAccountBanner; - let is_pro = self.plan == Some(proto::Plan::ZedPro); + let is_pro = self.plan == Some(Plan::ZedPro); let subscription_text = match (self.plan, self.subscription_period) { - (Some(proto::Plan::ZedPro), Some(_)) => { - "You have access to Zed's hosted LLMs through your Zed Pro subscription." + (Some(Plan::ZedPro), Some(_)) => { + "You have access to Zed's hosted models through your Pro subscription." } - (Some(proto::Plan::ZedProTrial), Some(_)) => { - "You have access to Zed's hosted LLMs through your Zed Pro trial." + (Some(Plan::ZedProTrial), Some(_)) => { + "You have access to Zed's hosted models through your Pro trial." } - (Some(proto::Plan::Free), Some(_)) => { - "You have basic access to Zed's hosted LLMs through your Zed Free subscription." + (Some(Plan::ZedFree), Some(_)) => { + "You have basic access to Zed's hosted models through the Free plan." } _ => { if self.eligible_for_trial { - "Subscribe for access to Zed's hosted LLMs. Start with a 14 day free trial." + "Subscribe for access to Zed's hosted models. Start with a 14 day free trial." } else { - "Subscribe for access to Zed's hosted LLMs." + "Subscribe for access to Zed's hosted models." } } }; + let manage_subscription_buttons = if is_pro { - h_flex().child( - Button::new("manage_settings", "Manage Subscription") - .style(ButtonStyle::Tinted(TintColor::Accent)) - .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx))), - ) + Button::new("manage_settings", "Manage Subscription") + .full_width() + .style(ButtonStyle::Tinted(TintColor::Accent)) + .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx))) + .into_any_element() + } else if self.plan.is_none() || self.eligible_for_trial { + Button::new("start_trial", "Start 14-day Free Pro Trial") + .full_width() + .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent)) + .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx))) + .into_any_element() } else { - h_flex() - .gap_2() - .child( - Button::new("learn_more", "Learn more") - .style(ButtonStyle::Subtle) - .on_click(|_, _, cx| cx.open_url(ZED_PRICING_URL)), - ) - .child( - Button::new( - "upgrade", - if self.plan.is_none() && self.eligible_for_trial { - "Start Trial" - } else { - "Upgrade" - }, - ) - .style(ButtonStyle::Subtle) - .color(Color::Accent) - .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx))), - ) + Button::new("upgrade", "Upgrade to Pro") + .full_width() + .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent)) + .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))) + .into_any_element() }; - if self.is_connected { - v_flex() - .gap_3() - .w_full() - .when(!self.has_accepted_terms_of_service, |this| { - this.child(render_accept_terms( - LanguageModelProviderTosView::Configuration, - self.accept_terms_of_service_in_progress, - { - let callback = self.accept_terms_of_service_callback.clone(); - move |window, cx| (callback)(window, cx) - }, - )) - }) - .when(self.has_accepted_terms_of_service, |this| { - this.child(subscription_text) - .child(manage_subscription_buttons) - }) - } else { - v_flex() + if !self.is_connected { + return v_flex() .gap_2() - .child(Label::new("Use Zed AI to access hosted language models.")) + .child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models.")) .child( - Button::new("sign_in", "Sign In") + Button::new("sign_in", "Sign In to use Zed AI") .icon_color(Color::Muted) .icon(IconName::Github) + .icon_size(IconSize::Small) .icon_position(IconPosition::Start) + .full_width() .on_click({ let callback = self.sign_in_callback.clone(); move |_, window, cx| (callback)(window, cx) }), - ) + ); } + + v_flex() + .gap_2() + .w_full() + .when(!self.has_accepted_terms_of_service, |this| { + this.child(render_accept_terms( + LanguageModelProviderTosView::Configuration, + self.accept_terms_of_service_in_progress, + { + let callback = self.accept_terms_of_service_callback.clone(); + move |window, cx| (callback)(window, cx) + }, + )) + }) + .map(|this| { + if self.has_accepted_terms_of_service && self.account_too_young { + this.child(young_account_banner).child( + Button::new("upgrade", "Upgrade to Pro") + .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent)) + .full_width() + .on_click(|_, _, cx| { + cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)) + }), + ) + } else if self.has_accepted_terms_of_service { + this.text_sm() + .child(subscription_text) + .child(manage_subscription_buttons) + } else { + this + } + }) + .when(self.has_accepted_terms_of_service, |this| this) } } @@ -1222,11 +1246,12 @@ impl Render for ConfigurationView { let user_store = state.user_store.read(cx); ZedAiConfiguration { - is_connected: !state.is_signed_out(), - plan: user_store.current_plan(), + is_connected: !state.is_signed_out(cx), + plan: user_store.plan(), subscription_period: user_store.subscription_period(), eligible_for_trial: user_store.trial_started_at().is_none(), has_accepted_terms_of_service: state.has_accepted_terms_of_service(cx), + account_too_young: user_store.account_too_young(), accept_terms_of_service_in_progress: state.accept_terms_of_service_task.is_some(), accept_terms_of_service_callback: self.accept_terms_of_service_callback.clone(), sign_in_callback: self.sign_in_callback.clone(), @@ -1235,15 +1260,24 @@ impl Render for ConfigurationView { } impl Component for ZedAiConfiguration { + fn name() -> &'static str { + "AI Configuration Content" + } + + fn sort_name() -> &'static str { + "AI Configuration Content" + } + fn scope() -> ComponentScope { - ComponentScope::Agent + ComponentScope::Onboarding } fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> { fn configuration( is_connected: bool, - plan: Option<proto::Plan>, + plan: Option<Plan>, eligible_for_trial: bool, + account_too_young: bool, has_accepted_terms_of_service: bool, ) -> AnyElement { ZedAiConfiguration { @@ -1254,6 +1288,7 @@ impl Component for ZedAiConfiguration { .then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))), eligible_for_trial, has_accepted_terms_of_service, + account_too_young, accept_terms_of_service_in_progress: false, accept_terms_of_service_callback: Arc::new(|_, _| {}), sign_in_callback: Arc::new(|_, _| {}), @@ -1266,33 +1301,188 @@ impl Component for ZedAiConfiguration { .p_4() .gap_4() .children(vec![ - single_example("Not connected", configuration(false, None, false, true)), + single_example( + "Not connected", + configuration(false, None, false, false, true), + ), single_example( "Accept Terms of Service", - configuration(true, None, true, false), + configuration(true, None, true, false, false), ), single_example( "No Plan - Not eligible for trial", - configuration(true, None, false, true), + configuration(true, None, false, false, true), ), single_example( "No Plan - Eligible for trial", - configuration(true, None, true, true), + configuration(true, None, true, false, true), ), single_example( "Free Plan", - configuration(true, Some(proto::Plan::Free), true, true), + configuration(true, Some(Plan::ZedFree), true, false, true), ), single_example( "Zed Pro Trial Plan", - configuration(true, Some(proto::Plan::ZedProTrial), true, true), + configuration(true, Some(Plan::ZedProTrial), true, false, true), ), single_example( "Zed Pro Plan", - configuration(true, Some(proto::Plan::ZedPro), true, true), + configuration(true, Some(Plan::ZedPro), true, false, true), ), ]) .into_any_element(), ) } } + +#[cfg(test)] +mod tests { + use super::*; + use http_client::http::{HeaderMap, StatusCode}; + use language_model::LanguageModelCompletionError; + + #[test] + fn test_api_error_conversion_with_upstream_http_error() { + // upstream_http_error with 503 status should become ServerOverloaded + let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#; + + let api_error = ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: error_body.to_string(), + headers: HeaderMap::new(), + }; + + let completion_error: LanguageModelCompletionError = api_error.into(); + + match completion_error { + LanguageModelCompletionError::UpstreamProviderError { message, .. } => { + assert_eq!( + message, + "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout" + ); + } + _ => panic!( + "Expected UpstreamProviderError for upstream 503, got: {:?}", + completion_error + ), + } + + // upstream_http_error with 500 status should become ApiInternalServerError + let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#; + + let api_error = ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: error_body.to_string(), + headers: HeaderMap::new(), + }; + + let completion_error: LanguageModelCompletionError = api_error.into(); + + match completion_error { + LanguageModelCompletionError::UpstreamProviderError { message, .. } => { + assert_eq!( + message, + "Received an error from the OpenAI API: internal server error" + ); + } + _ => panic!( + "Expected UpstreamProviderError for upstream 500, got: {:?}", + completion_error + ), + } + + // upstream_http_error with 429 status should become RateLimitExceeded + let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#; + + let api_error = ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: error_body.to_string(), + headers: HeaderMap::new(), + }; + + let completion_error: LanguageModelCompletionError = api_error.into(); + + match completion_error { + LanguageModelCompletionError::UpstreamProviderError { message, .. } => { + assert_eq!( + message, + "Received an error from the Google API: rate limit exceeded" + ); + } + _ => panic!( + "Expected UpstreamProviderError for upstream 429, got: {:?}", + completion_error + ), + } + + // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed + let error_body = "Regular internal server error"; + + let api_error = ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: error_body.to_string(), + headers: HeaderMap::new(), + }; + + let completion_error: LanguageModelCompletionError = api_error.into(); + + match completion_error { + LanguageModelCompletionError::ApiInternalServerError { provider, message } => { + assert_eq!(provider, PROVIDER_NAME); + assert_eq!(message, "Regular internal server error"); + } + _ => panic!( + "Expected ApiInternalServerError for regular 500, got: {:?}", + completion_error + ), + } + + // upstream_http_429 format should be converted to UpstreamProviderError + let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#; + + let api_error = ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: error_body.to_string(), + headers: HeaderMap::new(), + }; + + let completion_error: LanguageModelCompletionError = api_error.into(); + + match completion_error { + LanguageModelCompletionError::UpstreamProviderError { + message, + status, + retry_after, + } => { + assert_eq!(message, "Upstream Anthropic rate limit exceeded."); + assert_eq!(status, StatusCode::TOO_MANY_REQUESTS); + assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5))); + } + _ => panic!( + "Expected UpstreamProviderError for upstream_http_429, got: {:?}", + completion_error + ), + } + + // Invalid JSON in error body should fall back to regular error handling + let error_body = "Not JSON at all"; + + let api_error = ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: error_body.to_string(), + headers: HeaderMap::new(), + }; + + let completion_error: LanguageModelCompletionError = api_error.into(); + + match completion_error { + LanguageModelCompletionError::ApiInternalServerError { provider, .. } => { + assert_eq!(provider, PROVIDER_NAME); + } + _ => panic!( + "Expected ApiInternalServerError for invalid JSON, got: {:?}", + completion_error + ), + } + } +} diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index d9a84f1eb74465a0d5e72591d450802d5708cb20..73f73a9a313c764d45adfd14910efd801a472f1c 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -3,6 +3,7 @@ use std::str::FromStr as _; use std::sync::Arc; use anyhow::{Result, anyhow}; +use cloud_llm_client::CompletionIntent; use collections::HashMap; use copilot::copilot_chat::{ ChatMessage, ChatMessageContent, ChatMessagePart, CopilotChat, ImageUrl, @@ -30,7 +31,6 @@ use settings::SettingsStore; use std::time::Duration; use ui::prelude::*; use util::debug_panic; -use zed_llm_client::CompletionIntent; use super::anthropic::count_anthropic_tokens; use super::google::count_google_tokens; @@ -706,7 +706,8 @@ impl Render for ConfigurationView { .child(svg().size_8().path(IconName::CopilotError.path())) } _ => { - const LABEL: &str = "To use Zed's assistant with GitHub Copilot, you need to be logged in to GitHub. Note that your GitHub account must have an active Copilot Chat subscription."; + const LABEL: &str = "To use Zed's agent with GitHub Copilot, you need to be logged in to GitHub. Note that your GitHub account must have an active Copilot Chat subscription."; + v_flex().gap_2().child(Label::new(LABEL)).child( Button::new("sign_in", "Sign in to use GitHub Copilot") .icon_color(Color::Muted) diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index d1539dd22cfb64b4ed194830f3f9c5babc2a6cea..b287e8181a2ac5d04650d799a0cd9b23d51749c2 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -94,6 +94,7 @@ pub struct State { _subscription: Subscription, } +const GEMINI_API_KEY_VAR: &str = "GEMINI_API_KEY"; const GOOGLE_AI_API_KEY_VAR: &str = "GOOGLE_AI_API_KEY"; impl State { @@ -151,6 +152,8 @@ impl State { cx.spawn(async move |this, cx| { let (api_key, from_env) = if let Ok(api_key) = std::env::var(GOOGLE_AI_API_KEY_VAR) { (api_key, true) + } else if let Ok(api_key) = std::env::var(GEMINI_API_KEY_VAR) { + (api_key, true) } else { let (_, api_key) = credentials_provider .read_credentials(&api_url, &cx) @@ -877,7 +880,7 @@ impl Render for ConfigurationView { v_flex() .size_full() .on_action(cx.listener(Self::save_api_key)) - .child(Label::new("To use Zed's assistant with Google AI, you need to add an API key. Follow these steps:")) + .child(Label::new("To use Zed's agent with Google AI, you need to add an API key. Follow these steps:")) .child( List::new() .child(InstructionListItem::new( @@ -903,7 +906,7 @@ impl Render for ConfigurationView { ) .child( Label::new( - format!("You can also assign the {GOOGLE_AI_API_KEY_VAR} environment variable and restart Zed."), + format!("You can also assign the {GEMINI_API_KEY_VAR} environment variable and restart Zed."), ) .size(LabelSize::Small).color(Color::Muted), ) @@ -922,7 +925,7 @@ impl Render for ConfigurationView { .gap_1() .child(Icon::new(IconName::Check).color(Color::Success)) .child(Label::new(if env_var_set { - format!("API key set in {GOOGLE_AI_API_KEY_VAR} environment variable.") + format!("API key set in {GEMINI_API_KEY_VAR} environment variable.") } else { "API key configured.".to_string() })), @@ -935,7 +938,7 @@ impl Render for ConfigurationView { .icon_position(IconPosition::Start) .disabled(env_var_set) .when(env_var_set, |this| { - this.tooltip(Tooltip::text(format!("To reset your API key, unset the {GOOGLE_AI_API_KEY_VAR} environment variable."))) + this.tooltip(Tooltip::text(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR} and {GOOGLE_AI_API_KEY_VAR} environment variables are unset."))) }) .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), ) diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index 01600f3646da5091796adacd90db2d18f0042b1e..9792b4f27b9990c1c9c64fbda971f7e45490fc49 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -744,7 +744,7 @@ impl Render for ConfigurationView { Button::new("retry_lmstudio_models", "Connect") .icon_position(IconPosition::Start) .icon_size(IconSize::XSmall) - .icon(IconName::Play) + .icon(IconName::PlayOutlined) .on_click(cx.listener(move |this, _, _window, cx| { this.retry_connection(cx) })), diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index 11497fda350a02ec9433cb2311a28e1901dfeb4f..02e53cb99a846ab24ffdcd2f6c087b436d4f3789 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -410,8 +410,20 @@ pub fn into_mistral( .push_part(mistral::MessagePart::Text { text: text.clone() }); } MessageContent::RedactedThinking(_) => {} - MessageContent::ToolUse(_) | MessageContent::ToolResult(_) => { - // Tool content is not supported in User messages for Mistral + MessageContent::ToolUse(_) => { + // Tool use is not supported in User messages for Mistral + } + MessageContent::ToolResult(tool_result) => { + let tool_content = match &tool_result.content { + LanguageModelToolResultContent::Text(text) => text.to_string(), + LanguageModelToolResultContent::Image(_) => { + "[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string() + } + }; + messages.push(mistral::RequestMessage::Tool { + content: tool_content, + tool_call_id: tool_result.tool_use_id.to_string(), + }); } } } @@ -482,24 +494,6 @@ pub fn into_mistral( } } - for message in &request.messages { - for content in &message.content { - if let MessageContent::ToolResult(tool_result) = content { - let content = match &tool_result.content { - LanguageModelToolResultContent::Text(text) => text.to_string(), - LanguageModelToolResultContent::Image(_) => { - "[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string() - } - }; - - messages.push(mistral::RequestMessage::Tool { - content, - tool_call_id: tool_result.tool_use_id.to_string(), - }); - } - } - } - // The Mistral API requires that tool messages be followed by assistant messages, // not user messages. When we have a tool->user sequence in the conversation, // we need to insert a placeholder assistant message to maintain proper conversation @@ -813,7 +807,7 @@ impl Render for ConfigurationView { v_flex() .size_full() .on_action(cx.listener(Self::save_api_key)) - .child(Label::new("To use Zed's assistant with Mistral, you need to add an API key. Follow these steps:")) + .child(Label::new("To use Zed's agent with Mistral, you need to add an API key. Follow these steps:")) .child( List::new() .child(InstructionListItem::new( diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index dc81e8be1897aa3ae51b8d2cb26b7cdec0e55cbf..c845c97b09c5f7ca7205a43b889065de06c39550 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -192,12 +192,16 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { IconName::AiOllama } - fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> { - self.provided_models(cx).into_iter().next() + fn default_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> { + // We shouldn't try to select default model, because it might lead to a load call for an unloaded model. + // In a constrained environment where user might not have enough resources it'll be a bad UX to select something + // to load by default. + None } - fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> { - self.default_model(cx) + fn default_fast_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> { + // See explanation for default_model. + None } fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> { @@ -627,7 +631,7 @@ impl Render for ConfigurationView { } }) .child( - Button::new("view-models", "All Models") + Button::new("view-models", "View All Models") .style(ButtonStyle::Subtle) .icon(IconName::ArrowUpRight) .icon_size(IconSize::XSmall) @@ -654,7 +658,7 @@ impl Render for ConfigurationView { Button::new("retry_ollama_models", "Connect") .icon_position(IconPosition::Start) .icon_size(IconSize::XSmall) - .icon(IconName::Play) + .icon(IconName::PlayOutlined) .on_click(cx.listener(move |this, _, _, cx| { this.retry_connection(cx) })), diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 76f2fbe303c4bed0cfeefbfca6358667420aed51..ee74562687b5de4258328a1f7ffae082d5dc4931 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -2,7 +2,6 @@ use anyhow::{Context as _, Result, anyhow}; use collections::{BTreeMap, HashMap}; use credentials_provider::CredentialsProvider; -use fs::Fs; use futures::Stream; use futures::{FutureExt, StreamExt, future::BoxFuture}; use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window}; @@ -18,7 +17,7 @@ use menu; use open_ai::{ImageUrl, Model, ResponseStreamEvent, stream_completion}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use settings::{Settings, SettingsStore, update_settings_file}; +use settings::{Settings, SettingsStore}; use std::pin::Pin; use std::str::FromStr as _; use std::sync::Arc; @@ -28,7 +27,6 @@ use ui::{ElevationIndex, List, Tooltip, prelude::*}; use ui_input::SingleLineInput; use util::ResultExt; -use crate::OpenAiSettingsContent; use crate::{AllLanguageModelSettings, ui::InstructionListItem}; const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID; @@ -621,26 +619,32 @@ struct RawToolCall { arguments: String, } +pub(crate) fn collect_tiktoken_messages( + request: LanguageModelRequest, +) -> Vec<tiktoken_rs::ChatCompletionRequestMessage> { + request + .messages + .into_iter() + .map(|message| tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(message.string_contents()), + name: None, + function_call: None, + }) + .collect::<Vec<_>>() +} + pub fn count_open_ai_tokens( request: LanguageModelRequest, model: Model, cx: &App, ) -> BoxFuture<'static, Result<u64>> { cx.background_spawn(async move { - let messages = request - .messages - .into_iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(message.string_contents()), - name: None, - function_call: None, - }) - .collect::<Vec<_>>(); + let messages = collect_tiktoken_messages(request); match model { Model::Custom { max_tokens, .. } => { @@ -670,6 +674,10 @@ pub fn count_open_ai_tokens( | Model::O3 | Model::O3Mini | Model::O4Mini => tiktoken_rs::num_tokens_from_messages(model.id(), &messages), + // GPT-5 models don't have tiktoken support yet; fall back on gpt-4o tokenizer + Model::Five | Model::FiveMini | Model::FiveNano => { + tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages) + } } .map(|tokens| tokens as u64) }) @@ -678,7 +686,6 @@ pub fn count_open_ai_tokens( struct ConfigurationView { api_key_editor: Entity<SingleLineInput>, - api_url_editor: Entity<SingleLineInput>, state: gpui::Entity<State>, load_credentials_task: Option<Task<()>>, } @@ -691,23 +698,6 @@ impl ConfigurationView { cx, "sk-000000000000000000000000000000000000000000000000", ) - .label("API key") - }); - - let api_url = AllLanguageModelSettings::get_global(cx) - .openai - .api_url - .clone(); - - let api_url_editor = cx.new(|cx| { - let input = SingleLineInput::new(window, cx, open_ai::OPEN_AI_API_URL).label("API URL"); - - if !api_url.is_empty() { - input.editor.update(cx, |editor, cx| { - editor.set_text(&*api_url, window, cx); - }); - } - input }); cx.observe(&state, |_, _, cx| { @@ -735,7 +725,6 @@ impl ConfigurationView { Self { api_key_editor, - api_url_editor, state, load_credentials_task, } @@ -783,57 +772,6 @@ impl ConfigurationView { cx.notify(); } - fn save_api_url(&mut self, cx: &mut Context<Self>) { - let api_url = self - .api_url_editor - .read(cx) - .editor() - .read(cx) - .text(cx) - .trim() - .to_string(); - - let current_url = AllLanguageModelSettings::get_global(cx) - .openai - .api_url - .clone(); - - let effective_current_url = if current_url.is_empty() { - open_ai::OPEN_AI_API_URL - } else { - ¤t_url - }; - - if !api_url.is_empty() && api_url != effective_current_url { - let fs = <dyn Fs>::global(cx); - update_settings_file::<AllLanguageModelSettings>(fs, cx, move |settings, _| { - if let Some(settings) = settings.openai.as_mut() { - settings.api_url = Some(api_url.clone()); - } else { - settings.openai = Some(OpenAiSettingsContent { - api_url: Some(api_url.clone()), - available_models: None, - }); - } - }); - } - } - - fn reset_api_url(&mut self, window: &mut Window, cx: &mut Context<Self>) { - self.api_url_editor.update(cx, |input, cx| { - input.editor.update(cx, |editor, cx| { - editor.set_text("", window, cx); - }); - }); - let fs = <dyn Fs>::global(cx); - update_settings_file::<AllLanguageModelSettings>(fs, cx, |settings, _cx| { - if let Some(settings) = settings.openai.as_mut() { - settings.api_url = None; - } - }); - cx.notify(); - } - fn should_render_editor(&self, cx: &mut Context<Self>) -> bool { !self.state.read(cx).is_authenticated() } @@ -846,8 +784,7 @@ impl Render for ConfigurationView { let api_key_section = if self.should_render_editor(cx) { v_flex() .on_action(cx.listener(Self::save_api_key)) - - .child(Label::new("To use Zed's assistant with OpenAI, you need to add an API key. Follow these steps:")) + .child(Label::new("To use Zed's agent with OpenAI, you need to add an API key. Follow these steps:")) .child( List::new() .child(InstructionListItem::new( @@ -910,59 +847,34 @@ impl Render for ConfigurationView { .into_any() }; - let custom_api_url_set = - AllLanguageModelSettings::get_global(cx).openai.api_url != open_ai::OPEN_AI_API_URL; - - let api_url_section = if custom_api_url_set { - h_flex() - .mt_1() - .p_1() - .justify_between() - .rounded_md() - .border_1() - .border_color(cx.theme().colors().border) - .bg(cx.theme().colors().background) - .child( - h_flex() - .gap_1() - .child(Icon::new(IconName::Check).color(Color::Success)) - .child(Label::new("Custom API URL configured.")), - ) - .child( - Button::new("reset-api-url", "Reset API URL") - .label_size(LabelSize::Small) - .icon(IconName::Undo) - .icon_size(IconSize::Small) - .icon_position(IconPosition::Start) - .layer(ElevationIndex::ModalSurface) - .on_click( - cx.listener(|this, _, window, cx| this.reset_api_url(window, cx)), - ), - ) - .into_any() - } else { - v_flex() - .on_action(cx.listener(|this, _: &menu::Confirm, _window, cx| { - this.save_api_url(cx); - cx.notify(); - })) - .mt_2() - .pt_2() - .border_t_1() - .border_color(cx.theme().colors().border_variant) - .gap_1() - .child( - List::new() - .child(InstructionListItem::text_only( - "Optionally, you can change the base URL for the OpenAI API request.", - )) - .child(InstructionListItem::text_only( - "Paste the new API endpoint below and hit enter", - )), - ) - .child(self.api_url_editor.clone()) - .into_any() - }; + let compatible_api_section = h_flex() + .mt_1p5() + .gap_0p5() + .flex_wrap() + .when(self.should_render_editor(cx), |this| { + this.pt_1p5() + .border_t_1() + .border_color(cx.theme().colors().border_variant) + }) + .child( + h_flex() + .gap_2() + .child( + Icon::new(IconName::Info) + .size(IconSize::XSmall) + .color(Color::Muted), + ) + .child(Label::new("Zed also supports OpenAI-compatible models.")), + ) + .child( + Button::new("docs", "Learn More") + .icon(IconName::ArrowUpRight) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .on_click(move |_, _window, cx| { + cx.open_url("https://zed.dev/docs/ai/llm-providers#openai-api-compatible") + }), + ); if self.load_credentials_task.is_some() { div().child(Label::new("Loading credentials…")).into_any() @@ -970,7 +882,7 @@ impl Render for ConfigurationView { v_flex() .size_full() .child(api_key_section) - .child(api_url_section) + .child(compatible_api_section) .into_any() } } diff --git a/crates/language_models/src/provider/open_ai_compatible.rs b/crates/language_models/src/provider/open_ai_compatible.rs new file mode 100644 index 0000000000000000000000000000000000000000..38bd7cee06db915cc3d6cc4d8a39309fab879b44 --- /dev/null +++ b/crates/language_models/src/provider/open_ai_compatible.rs @@ -0,0 +1,522 @@ +use anyhow::{Context as _, Result, anyhow}; +use credentials_provider::CredentialsProvider; + +use convert_case::{Case, Casing}; +use futures::{FutureExt, StreamExt, future::BoxFuture}; +use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window}; +use http_client::HttpClient; +use language_model::{ + AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, + LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, + LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, + LanguageModelToolChoice, RateLimiter, +}; +use menu; +use open_ai::{ResponseStreamEvent, stream_completion}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings::{Settings, SettingsStore}; +use std::sync::Arc; + +use ui::{ElevationIndex, Tooltip, prelude::*}; +use ui_input::SingleLineInput; +use util::ResultExt; + +use crate::AllLanguageModelSettings; +use crate::provider::open_ai::{OpenAiEventMapper, into_open_ai}; + +#[derive(Default, Clone, Debug, PartialEq)] +pub struct OpenAiCompatibleSettings { + pub api_url: String, + pub available_models: Vec<AvailableModel>, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +pub struct AvailableModel { + pub name: String, + pub display_name: Option<String>, + pub max_tokens: u64, + pub max_output_tokens: Option<u64>, + pub max_completion_tokens: Option<u64>, +} + +pub struct OpenAiCompatibleLanguageModelProvider { + id: LanguageModelProviderId, + name: LanguageModelProviderName, + http_client: Arc<dyn HttpClient>, + state: gpui::Entity<State>, +} + +pub struct State { + id: Arc<str>, + env_var_name: Arc<str>, + api_key: Option<String>, + api_key_from_env: bool, + settings: OpenAiCompatibleSettings, + _subscription: Subscription, +} + +impl State { + fn is_authenticated(&self) -> bool { + self.api_key.is_some() + } + + fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> { + let credentials_provider = <dyn CredentialsProvider>::global(cx); + let api_url = self.settings.api_url.clone(); + cx.spawn(async move |this, cx| { + credentials_provider + .delete_credentials(&api_url, &cx) + .await + .log_err(); + this.update(cx, |this, cx| { + this.api_key = None; + this.api_key_from_env = false; + cx.notify(); + }) + }) + } + + fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> { + let credentials_provider = <dyn CredentialsProvider>::global(cx); + let api_url = self.settings.api_url.clone(); + cx.spawn(async move |this, cx| { + credentials_provider + .write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx) + .await + .log_err(); + this.update(cx, |this, cx| { + this.api_key = Some(api_key); + cx.notify(); + }) + }) + } + + fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> { + if self.is_authenticated() { + return Task::ready(Ok(())); + } + + let credentials_provider = <dyn CredentialsProvider>::global(cx); + let env_var_name = self.env_var_name.clone(); + let api_url = self.settings.api_url.clone(); + cx.spawn(async move |this, cx| { + let (api_key, from_env) = if let Ok(api_key) = std::env::var(env_var_name.as_ref()) { + (api_key, true) + } else { + let (_, api_key) = credentials_provider + .read_credentials(&api_url, &cx) + .await? + .ok_or(AuthenticateError::CredentialsNotFound)?; + ( + String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, + false, + ) + }; + this.update(cx, |this, cx| { + this.api_key = Some(api_key); + this.api_key_from_env = from_env; + cx.notify(); + })?; + + Ok(()) + }) + } +} + +impl OpenAiCompatibleLanguageModelProvider { + pub fn new(id: Arc<str>, http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self { + fn resolve_settings<'a>(id: &'a str, cx: &'a App) -> Option<&'a OpenAiCompatibleSettings> { + AllLanguageModelSettings::get_global(cx) + .openai_compatible + .get(id) + } + + let state = cx.new(|cx| State { + id: id.clone(), + env_var_name: format!("{}_API_KEY", id).to_case(Case::Constant).into(), + settings: resolve_settings(&id, cx).cloned().unwrap_or_default(), + api_key: None, + api_key_from_env: false, + _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| { + let Some(settings) = resolve_settings(&this.id, cx) else { + return; + }; + if &this.settings != settings { + this.settings = settings.clone(); + cx.notify(); + } + }), + }); + + Self { + id: id.clone().into(), + name: id.into(), + http_client, + state, + } + } + + fn create_language_model(&self, model: AvailableModel) -> Arc<dyn LanguageModel> { + Arc::new(OpenAiCompatibleLanguageModel { + id: LanguageModelId::from(model.name.clone()), + provider_id: self.id.clone(), + provider_name: self.name.clone(), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) + } +} + +impl LanguageModelProviderState for OpenAiCompatibleLanguageModelProvider { + type ObservableEntity = State; + + fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> { + Some(self.state.clone()) + } +} + +impl LanguageModelProvider for OpenAiCompatibleLanguageModelProvider { + fn id(&self) -> LanguageModelProviderId { + self.id.clone() + } + + fn name(&self) -> LanguageModelProviderName { + self.name.clone() + } + + fn icon(&self) -> IconName { + IconName::AiOpenAiCompat + } + + fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> { + self.state + .read(cx) + .settings + .available_models + .first() + .map(|model| self.create_language_model(model.clone())) + } + + fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> { + None + } + + fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> { + self.state + .read(cx) + .settings + .available_models + .iter() + .map(|model| self.create_language_model(model.clone())) + .collect() + } + + fn is_authenticated(&self, cx: &App) -> bool { + self.state.read(cx).is_authenticated() + } + + fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> { + self.state.update(cx, |state, cx| state.authenticate(cx)) + } + + fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) + .into() + } + + fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> { + self.state.update(cx, |state, cx| state.reset_api_key(cx)) + } +} + +pub struct OpenAiCompatibleLanguageModel { + id: LanguageModelId, + provider_id: LanguageModelProviderId, + provider_name: LanguageModelProviderName, + model: AvailableModel, + state: gpui::Entity<State>, + http_client: Arc<dyn HttpClient>, + request_limiter: RateLimiter, +} + +impl OpenAiCompatibleLanguageModel { + fn stream_completion( + &self, + request: open_ai::Request, + cx: &AsyncApp, + ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>> + { + let http_client = self.http_client.clone(); + let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, _| { + (state.api_key.clone(), state.settings.api_url.clone()) + }) else { + return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + }; + + let provider = self.provider_name.clone(); + let future = self.request_limiter.stream(async move { + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { provider }); + }; + let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request); + let response = request.await?; + Ok(response) + }); + + async move { Ok(future.await?.boxed()) }.boxed() + } +} + +impl LanguageModel for OpenAiCompatibleLanguageModel { + fn id(&self) -> LanguageModelId { + self.id.clone() + } + + fn name(&self) -> LanguageModelName { + LanguageModelName::from( + self.model + .display_name + .clone() + .unwrap_or_else(|| self.model.name.clone()), + ) + } + + fn provider_id(&self) -> LanguageModelProviderId { + self.provider_id.clone() + } + + fn provider_name(&self) -> LanguageModelProviderName { + self.provider_name.clone() + } + + fn supports_tools(&self) -> bool { + true + } + + fn supports_images(&self) -> bool { + false + } + + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { + match choice { + LanguageModelToolChoice::Auto => true, + LanguageModelToolChoice::Any => true, + LanguageModelToolChoice::None => true, + } + } + + fn telemetry_id(&self) -> String { + format!("openai/{}", self.model.name) + } + + fn max_token_count(&self) -> u64 { + self.model.max_tokens + } + + fn max_output_tokens(&self) -> Option<u64> { + self.model.max_output_tokens + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &App, + ) -> BoxFuture<'static, Result<u64>> { + let max_token_count = self.max_token_count(); + cx.background_spawn(async move { + let messages = super::open_ai::collect_tiktoken_messages(request); + let model = if max_token_count >= 100_000 { + // If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o + "gpt-4o" + } else { + // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are + // supported with this tiktoken method + "gpt-4" + }; + tiktoken_rs::num_tokens_from_messages(model, &messages).map(|tokens| tokens as u64) + }) + .boxed() + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result< + futures::stream::BoxStream< + 'static, + Result<LanguageModelCompletionEvent, LanguageModelCompletionError>, + >, + LanguageModelCompletionError, + >, + > { + let request = into_open_ai(request, &self.model.name, true, self.max_output_tokens()); + let completions = self.stream_completion(request, cx); + async move { + let mapper = OpenAiEventMapper::new(); + Ok(mapper.map_stream(completions.await?).boxed()) + } + .boxed() + } +} + +struct ConfigurationView { + api_key_editor: Entity<SingleLineInput>, + state: gpui::Entity<State>, + load_credentials_task: Option<Task<()>>, +} + +impl ConfigurationView { + fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self { + let api_key_editor = cx.new(|cx| { + SingleLineInput::new( + window, + cx, + "000000000000000000000000000000000000000000000000000", + ) + }); + + cx.observe(&state, |_, _, cx| { + cx.notify(); + }) + .detach(); + + let load_credentials_task = Some(cx.spawn_in(window, { + let state = state.clone(); + async move |this, cx| { + if let Some(task) = state + .update(cx, |state, cx| state.authenticate(cx)) + .log_err() + { + // We don't log an error, because "not signed in" is also an error. + let _ = task.await; + } + this.update(cx, |this, cx| { + this.load_credentials_task = None; + cx.notify(); + }) + .log_err(); + } + })); + + Self { + api_key_editor, + state, + load_credentials_task, + } + } + + fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) { + let api_key = self + .api_key_editor + .read(cx) + .editor() + .read(cx) + .text(cx) + .trim() + .to_string(); + + // Don't proceed if no API key is provided and we're not authenticated + if api_key.is_empty() && !self.state.read(cx).is_authenticated() { + return; + } + + let state = self.state.clone(); + cx.spawn_in(window, async move |_, cx| { + state + .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .await + }) + .detach_and_log_err(cx); + + cx.notify(); + } + + fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) { + self.api_key_editor.update(cx, |input, cx| { + input.editor.update(cx, |editor, cx| { + editor.set_text("", window, cx); + }); + }); + + let state = self.state.clone(); + cx.spawn_in(window, async move |_, cx| { + state.update(cx, |state, cx| state.reset_api_key(cx))?.await + }) + .detach_and_log_err(cx); + + cx.notify(); + } + + fn should_render_editor(&self, cx: &mut Context<Self>) -> bool { + !self.state.read(cx).is_authenticated() + } +} + +impl Render for ConfigurationView { + fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { + let env_var_set = self.state.read(cx).api_key_from_env; + let env_var_name = self.state.read(cx).env_var_name.clone(); + + let api_key_section = if self.should_render_editor(cx) { + v_flex() + .on_action(cx.listener(Self::save_api_key)) + .child(Label::new("To use Zed's agent with an OpenAI-compatible provider, you need to add an API key.")) + .child( + div() + .pt(DynamicSpacing::Base04.rems(cx)) + .child(self.api_key_editor.clone()) + ) + .child( + Label::new( + format!("You can also assign the {env_var_name} environment variable and restart Zed."), + ) + .size(LabelSize::Small).color(Color::Muted), + ) + .into_any() + } else { + h_flex() + .mt_1() + .p_1() + .justify_between() + .rounded_md() + .border_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().background) + .child( + h_flex() + .gap_1() + .child(Icon::new(IconName::Check).color(Color::Success)) + .child(Label::new(if env_var_set { + format!("API key set in {env_var_name} environment variable.") + } else { + "API key configured.".to_string() + })), + ) + .child( + Button::new("reset-api-key", "Reset API Key") + .label_size(LabelSize::Small) + .icon(IconName::Undo) + .icon_size(IconSize::Small) + .icon_position(IconPosition::Start) + .layer(ElevationIndex::ModalSurface) + .when(env_var_set, |this| { + this.tooltip(Tooltip::text(format!("To reset your API key, unset the {env_var_name} environment variable."))) + }) + .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), + ) + .into_any() + }; + + if self.load_credentials_task.is_some() { + div().child(Label::new("Loading credentials…")).into_any() + } else { + v_flex().size_full().child(api_key_section).into_any() + } + } +} diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index c46135ff3eae704f5d54027457d8f86fbef4820a..3a492086f16e1f9b53a196b7bb2e9817a3cac0e7 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -376,7 +376,7 @@ impl LanguageModel for OpenRouterLanguageModel { fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { let model_id = self.model.id().trim().to_lowercase(); - if model_id.contains("gemini") { + if model_id.contains("gemini") || model_id.contains("grok-4") { LanguageModelToolSchemaFormat::JsonSchemaSubset } else { LanguageModelToolSchemaFormat::JsonSchema @@ -855,7 +855,7 @@ impl Render for ConfigurationView { v_flex() .size_full() .on_action(cx.listener(Self::save_api_key)) - .child(Label::new("To use Zed's assistant with OpenRouter, you need to add an API key. Follow these steps:")) + .child(Label::new("To use Zed's agent with OpenRouter, you need to add an API key. Follow these steps:")) .child( List::new() .child(InstructionListItem::new( diff --git a/crates/language_models/src/provider/x_ai.rs b/crates/language_models/src/provider/x_ai.rs new file mode 100644 index 0000000000000000000000000000000000000000..5f6034571b54fd30baa6769881f5d27bbcaf162f --- /dev/null +++ b/crates/language_models/src/provider/x_ai.rs @@ -0,0 +1,571 @@ +use anyhow::{Context as _, Result, anyhow}; +use collections::BTreeMap; +use credentials_provider::CredentialsProvider; +use futures::{FutureExt, StreamExt, future::BoxFuture}; +use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window}; +use http_client::HttpClient; +use language_model::{ + AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, + LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, + LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, + LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter, Role, +}; +use menu; +use open_ai::ResponseStreamEvent; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings::{Settings, SettingsStore}; +use std::sync::Arc; +use strum::IntoEnumIterator; +use x_ai::Model; + +use ui::{ElevationIndex, List, Tooltip, prelude::*}; +use ui_input::SingleLineInput; +use util::ResultExt; + +use crate::{AllLanguageModelSettings, ui::InstructionListItem}; + +const PROVIDER_ID: &str = "x_ai"; +const PROVIDER_NAME: &str = "xAI"; + +#[derive(Default, Clone, Debug, PartialEq)] +pub struct XAiSettings { + pub api_url: String, + pub available_models: Vec<AvailableModel>, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +pub struct AvailableModel { + pub name: String, + pub display_name: Option<String>, + pub max_tokens: u64, + pub max_output_tokens: Option<u64>, + pub max_completion_tokens: Option<u64>, +} + +pub struct XAiLanguageModelProvider { + http_client: Arc<dyn HttpClient>, + state: gpui::Entity<State>, +} + +pub struct State { + api_key: Option<String>, + api_key_from_env: bool, + _subscription: Subscription, +} + +const XAI_API_KEY_VAR: &str = "XAI_API_KEY"; + +impl State { + fn is_authenticated(&self) -> bool { + self.api_key.is_some() + } + + fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> { + let credentials_provider = <dyn CredentialsProvider>::global(cx); + let settings = &AllLanguageModelSettings::get_global(cx).x_ai; + let api_url = if settings.api_url.is_empty() { + x_ai::XAI_API_URL.to_string() + } else { + settings.api_url.clone() + }; + cx.spawn(async move |this, cx| { + credentials_provider + .delete_credentials(&api_url, &cx) + .await + .log_err(); + this.update(cx, |this, cx| { + this.api_key = None; + this.api_key_from_env = false; + cx.notify(); + }) + }) + } + + fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> { + let credentials_provider = <dyn CredentialsProvider>::global(cx); + let settings = &AllLanguageModelSettings::get_global(cx).x_ai; + let api_url = if settings.api_url.is_empty() { + x_ai::XAI_API_URL.to_string() + } else { + settings.api_url.clone() + }; + cx.spawn(async move |this, cx| { + credentials_provider + .write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx) + .await + .log_err(); + this.update(cx, |this, cx| { + this.api_key = Some(api_key); + cx.notify(); + }) + }) + } + + fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> { + if self.is_authenticated() { + return Task::ready(Ok(())); + } + + let credentials_provider = <dyn CredentialsProvider>::global(cx); + let settings = &AllLanguageModelSettings::get_global(cx).x_ai; + let api_url = if settings.api_url.is_empty() { + x_ai::XAI_API_URL.to_string() + } else { + settings.api_url.clone() + }; + cx.spawn(async move |this, cx| { + let (api_key, from_env) = if let Ok(api_key) = std::env::var(XAI_API_KEY_VAR) { + (api_key, true) + } else { + let (_, api_key) = credentials_provider + .read_credentials(&api_url, &cx) + .await? + .ok_or(AuthenticateError::CredentialsNotFound)?; + ( + String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, + false, + ) + }; + this.update(cx, |this, cx| { + this.api_key = Some(api_key); + this.api_key_from_env = from_env; + cx.notify(); + })?; + + Ok(()) + }) + } +} + +impl XAiLanguageModelProvider { + pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self { + let state = cx.new(|cx| State { + api_key: None, + api_key_from_env: false, + _subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| { + cx.notify(); + }), + }); + + Self { http_client, state } + } + + fn create_language_model(&self, model: x_ai::Model) -> Arc<dyn LanguageModel> { + Arc::new(XAiLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) + } +} + +impl LanguageModelProviderState for XAiLanguageModelProvider { + type ObservableEntity = State; + + fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> { + Some(self.state.clone()) + } +} + +impl LanguageModelProvider for XAiLanguageModelProvider { + fn id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn icon(&self) -> IconName { + IconName::AiXAi + } + + fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> { + Some(self.create_language_model(x_ai::Model::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> { + Some(self.create_language_model(x_ai::Model::default_fast())) + } + + fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> { + let mut models = BTreeMap::default(); + + for model in x_ai::Model::iter() { + if !matches!(model, x_ai::Model::Custom { .. }) { + models.insert(model.id().to_string(), model); + } + } + + for model in &AllLanguageModelSettings::get_global(cx) + .x_ai + .available_models + { + models.insert( + model.name.clone(), + x_ai::Model::Custom { + name: model.name.clone(), + display_name: model.display_name.clone(), + max_tokens: model.max_tokens, + max_output_tokens: model.max_output_tokens, + max_completion_tokens: model.max_completion_tokens, + }, + ); + } + + models + .into_values() + .map(|model| self.create_language_model(model)) + .collect() + } + + fn is_authenticated(&self, cx: &App) -> bool { + self.state.read(cx).is_authenticated() + } + + fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> { + self.state.update(cx, |state, cx| state.authenticate(cx)) + } + + fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) + .into() + } + + fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> { + self.state.update(cx, |state, cx| state.reset_api_key(cx)) + } +} + +pub struct XAiLanguageModel { + id: LanguageModelId, + model: x_ai::Model, + state: gpui::Entity<State>, + http_client: Arc<dyn HttpClient>, + request_limiter: RateLimiter, +} + +impl XAiLanguageModel { + fn stream_completion( + &self, + request: open_ai::Request, + cx: &AsyncApp, + ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>> + { + let http_client = self.http_client.clone(); + let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { + let settings = &AllLanguageModelSettings::get_global(cx).x_ai; + let api_url = if settings.api_url.is_empty() { + x_ai::XAI_API_URL.to_string() + } else { + settings.api_url.clone() + }; + (state.api_key.clone(), api_url) + }) else { + return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + }; + + let future = self.request_limiter.stream(async move { + let api_key = api_key.context("Missing xAI API Key")?; + let request = + open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request); + let response = request.await?; + Ok(response) + }); + + async move { Ok(future.await?.boxed()) }.boxed() + } +} + +impl LanguageModel for XAiLanguageModel { + fn id(&self) -> LanguageModelId { + self.id.clone() + } + + fn name(&self) -> LanguageModelName { + LanguageModelName::from(self.model.display_name().to_string()) + } + + fn provider_id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn provider_name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn supports_tools(&self) -> bool { + self.model.supports_tool() + } + + fn supports_images(&self) -> bool { + self.model.supports_images() + } + + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { + match choice { + LanguageModelToolChoice::Auto + | LanguageModelToolChoice::Any + | LanguageModelToolChoice::None => true, + } + } + fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { + let model_id = self.model.id().trim().to_lowercase(); + if model_id.eq(x_ai::Model::Grok4.id()) { + LanguageModelToolSchemaFormat::JsonSchemaSubset + } else { + LanguageModelToolSchemaFormat::JsonSchema + } + } + + fn telemetry_id(&self) -> String { + format!("x_ai/{}", self.model.id()) + } + + fn max_token_count(&self) -> u64 { + self.model.max_token_count() + } + + fn max_output_tokens(&self) -> Option<u64> { + self.model.max_output_tokens() + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &App, + ) -> BoxFuture<'static, Result<u64>> { + count_xai_tokens(request, self.model.clone(), cx) + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result< + futures::stream::BoxStream< + 'static, + Result<LanguageModelCompletionEvent, LanguageModelCompletionError>, + >, + LanguageModelCompletionError, + >, + > { + let request = crate::provider::open_ai::into_open_ai( + request, + self.model.id(), + self.model.supports_parallel_tool_calls(), + self.max_output_tokens(), + ); + let completions = self.stream_completion(request, cx); + async move { + let mapper = crate::provider::open_ai::OpenAiEventMapper::new(); + Ok(mapper.map_stream(completions.await?).boxed()) + } + .boxed() + } +} + +pub fn count_xai_tokens( + request: LanguageModelRequest, + model: Model, + cx: &App, +) -> BoxFuture<'static, Result<u64>> { + cx.background_spawn(async move { + let messages = request + .messages + .into_iter() + .map(|message| tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(message.string_contents()), + name: None, + function_call: None, + }) + .collect::<Vec<_>>(); + + let model_name = if model.max_token_count() >= 100_000 { + "gpt-4o" + } else { + "gpt-4" + }; + tiktoken_rs::num_tokens_from_messages(model_name, &messages).map(|tokens| tokens as u64) + }) + .boxed() +} + +struct ConfigurationView { + api_key_editor: Entity<SingleLineInput>, + state: gpui::Entity<State>, + load_credentials_task: Option<Task<()>>, +} + +impl ConfigurationView { + fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self { + let api_key_editor = cx.new(|cx| { + SingleLineInput::new( + window, + cx, + "xai-0000000000000000000000000000000000000000000000000", + ) + .label("API key") + }); + + cx.observe(&state, |_, _, cx| { + cx.notify(); + }) + .detach(); + + let load_credentials_task = Some(cx.spawn_in(window, { + let state = state.clone(); + async move |this, cx| { + if let Some(task) = state + .update(cx, |state, cx| state.authenticate(cx)) + .log_err() + { + // We don't log an error, because "not signed in" is also an error. + let _ = task.await; + } + this.update(cx, |this, cx| { + this.load_credentials_task = None; + cx.notify(); + }) + .log_err(); + } + })); + + Self { + api_key_editor, + state, + load_credentials_task, + } + } + + fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) { + let api_key = self + .api_key_editor + .read(cx) + .editor() + .read(cx) + .text(cx) + .trim() + .to_string(); + + // Don't proceed if no API key is provided and we're not authenticated + if api_key.is_empty() && !self.state.read(cx).is_authenticated() { + return; + } + + let state = self.state.clone(); + cx.spawn_in(window, async move |_, cx| { + state + .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .await + }) + .detach_and_log_err(cx); + + cx.notify(); + } + + fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) { + self.api_key_editor.update(cx, |input, cx| { + input.editor.update(cx, |editor, cx| { + editor.set_text("", window, cx); + }); + }); + + let state = self.state.clone(); + cx.spawn_in(window, async move |_, cx| { + state.update(cx, |state, cx| state.reset_api_key(cx))?.await + }) + .detach_and_log_err(cx); + + cx.notify(); + } + + fn should_render_editor(&self, cx: &mut Context<Self>) -> bool { + !self.state.read(cx).is_authenticated() + } +} + +impl Render for ConfigurationView { + fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { + let env_var_set = self.state.read(cx).api_key_from_env; + + let api_key_section = if self.should_render_editor(cx) { + v_flex() + .on_action(cx.listener(Self::save_api_key)) + .child(Label::new("To use Zed's agent with xAI, you need to add an API key. Follow these steps:")) + .child( + List::new() + .child(InstructionListItem::new( + "Create one by visiting", + Some("xAI console"), + Some("https://console.x.ai/team/default/api-keys"), + )) + .child(InstructionListItem::text_only( + "Paste your API key below and hit enter to start using the agent", + )), + ) + .child(self.api_key_editor.clone()) + .child( + Label::new(format!( + "You can also assign the {XAI_API_KEY_VAR} environment variable and restart Zed." + )) + .size(LabelSize::Small) + .color(Color::Muted), + ) + .child( + Label::new("Note that xAI is a custom OpenAI-compatible provider.") + .size(LabelSize::Small) + .color(Color::Muted), + ) + .into_any() + } else { + h_flex() + .mt_1() + .p_1() + .justify_between() + .rounded_md() + .border_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().background) + .child( + h_flex() + .gap_1() + .child(Icon::new(IconName::Check).color(Color::Success)) + .child(Label::new(if env_var_set { + format!("API key set in {XAI_API_KEY_VAR} environment variable.") + } else { + "API key configured.".to_string() + })), + ) + .child( + Button::new("reset-api-key", "Reset API Key") + .label_size(LabelSize::Small) + .icon(IconName::Undo) + .icon_size(IconSize::Small) + .icon_position(IconPosition::Start) + .layer(ElevationIndex::ModalSurface) + .when(env_var_set, |this| { + this.tooltip(Tooltip::text(format!("To reset your API key, unset the {XAI_API_KEY_VAR} environment variable."))) + }) + .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), + ) + .into_any() + }; + + if self.load_credentials_task.is_some() { + div().child(Label::new("Loading credentials…")).into_any() + } else { + v_flex().size_full().child(api_key_section).into_any() + } + } +} diff --git a/crates/language_models/src/settings.rs b/crates/language_models/src/settings.rs index f96a2c0a66cfe698738deec177b5f82cde274df7..b163585aa7b745447381aa62f710e8c5dbdf469c 100644 --- a/crates/language_models/src/settings.rs +++ b/crates/language_models/src/settings.rs @@ -1,4 +1,7 @@ +use std::sync::Arc; + use anyhow::Result; +use collections::HashMap; use gpui::App; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -15,12 +18,14 @@ use crate::provider::{ mistral::MistralSettings, ollama::OllamaSettings, open_ai::OpenAiSettings, + open_ai_compatible::OpenAiCompatibleSettings, open_router::OpenRouterSettings, vercel::VercelSettings, + x_ai::XAiSettings, }; /// Initializes the language model settings. -pub fn init(cx: &mut App) { +pub fn init_settings(cx: &mut App) { AllLanguageModelSettings::register(cx); } @@ -28,33 +33,35 @@ pub fn init(cx: &mut App) { pub struct AllLanguageModelSettings { pub anthropic: AnthropicSettings, pub bedrock: AmazonBedrockSettings, - pub ollama: OllamaSettings, - pub openai: OpenAiSettings, - pub open_router: OpenRouterSettings, - pub zed_dot_dev: ZedDotDevSettings, + pub deepseek: DeepSeekSettings, pub google: GoogleSettings, - pub vercel: VercelSettings, - pub lmstudio: LmStudioSettings, - pub deepseek: DeepSeekSettings, pub mistral: MistralSettings, + pub ollama: OllamaSettings, + pub open_router: OpenRouterSettings, + pub openai: OpenAiSettings, + pub openai_compatible: HashMap<Arc<str>, OpenAiCompatibleSettings>, + pub vercel: VercelSettings, + pub x_ai: XAiSettings, + pub zed_dot_dev: ZedDotDevSettings, } #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] pub struct AllLanguageModelSettingsContent { pub anthropic: Option<AnthropicSettingsContent>, pub bedrock: Option<AmazonBedrockSettingsContent>, - pub ollama: Option<OllamaSettingsContent>, + pub deepseek: Option<DeepseekSettingsContent>, + pub google: Option<GoogleSettingsContent>, pub lmstudio: Option<LmStudioSettingsContent>, - pub openai: Option<OpenAiSettingsContent>, + pub mistral: Option<MistralSettingsContent>, + pub ollama: Option<OllamaSettingsContent>, pub open_router: Option<OpenRouterSettingsContent>, + pub openai: Option<OpenAiSettingsContent>, + pub openai_compatible: Option<HashMap<Arc<str>, OpenAiCompatibleSettingsContent>>, + pub vercel: Option<VercelSettingsContent>, + pub x_ai: Option<XAiSettingsContent>, #[serde(rename = "zed.dev")] pub zed_dot_dev: Option<ZedDotDevSettingsContent>, - pub google: Option<GoogleSettingsContent>, - pub deepseek: Option<DeepseekSettingsContent>, - pub vercel: Option<VercelSettingsContent>, - - pub mistral: Option<MistralSettingsContent>, } #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] @@ -102,6 +109,12 @@ pub struct OpenAiSettingsContent { pub available_models: Option<Vec<provider::open_ai::AvailableModel>>, } +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +pub struct OpenAiCompatibleSettingsContent { + pub api_url: String, + pub available_models: Vec<provider::open_ai_compatible::AvailableModel>, +} + #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] pub struct VercelSettingsContent { pub api_url: Option<String>, @@ -114,6 +127,12 @@ pub struct GoogleSettingsContent { pub available_models: Option<Vec<provider::google::AvailableModel>>, } +#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +pub struct XAiSettingsContent { + pub api_url: Option<String>, + pub available_models: Option<Vec<provider::x_ai::AvailableModel>>, +} + #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] pub struct ZedDotDevSettingsContent { available_models: Option<Vec<cloud::AvailableModel>>, @@ -219,6 +238,19 @@ impl settings::Settings for AllLanguageModelSettings { openai.as_ref().and_then(|s| s.available_models.clone()), ); + // OpenAI Compatible + if let Some(openai_compatible) = value.openai_compatible.clone() { + for (id, openai_compatible_settings) in openai_compatible { + settings.openai_compatible.insert( + id, + OpenAiCompatibleSettings { + api_url: openai_compatible_settings.api_url, + available_models: openai_compatible_settings.available_models, + }, + ); + } + } + // Vercel let vercel = value.vercel.clone(); merge( @@ -230,6 +262,18 @@ impl settings::Settings for AllLanguageModelSettings { vercel.as_ref().and_then(|s| s.available_models.clone()), ); + // XAI + let x_ai = value.x_ai.clone(); + merge( + &mut settings.x_ai.api_url, + x_ai.as_ref().and_then(|s| s.api_url.clone()), + ); + merge( + &mut settings.x_ai.available_models, + x_ai.as_ref().and_then(|s| s.available_models.clone()), + ); + + // ZedDotDev merge( &mut settings.zed_dot_dev.available_models, value diff --git a/crates/language_selector/src/language_selector.rs b/crates/language_selector/src/language_selector.rs index 4c034305537e51e752fcef90eeeb7668f1bb50b7..f6e2d75015560582b30453767b1a3b30f7cce82e 100644 --- a/crates/language_selector/src/language_selector.rs +++ b/crates/language_selector/src/language_selector.rs @@ -86,7 +86,10 @@ impl LanguageSelector { impl Render for LanguageSelector { fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement { - v_flex().w(rems(34.)).child(self.picker.clone()) + v_flex() + .key_context("LanguageSelector") + .w(rems(34.)) + .child(self.picker.clone()) } } @@ -121,13 +124,13 @@ impl LanguageSelectorDelegate { .into_iter() .filter_map(|name| { language_registry - .available_language_for_name(&name)? + .available_language_for_name(name.as_ref())? .hidden() .not() .then_some(name) }) .enumerate() - .map(|(candidate_id, name)| StringMatchCandidate::new(candidate_id, &name)) + .map(|(candidate_id, name)| StringMatchCandidate::new(candidate_id, name.as_ref())) .collect::<Vec<_>>(); Self { diff --git a/crates/language_tools/Cargo.toml b/crates/language_tools/Cargo.toml index 45af7518d589166e26788203c919d2267b544756..5aa914311a6eccc1cb68efa37e878ad12249d6fd 100644 --- a/crates/language_tools/Cargo.toml +++ b/crates/language_tools/Cargo.toml @@ -18,7 +18,6 @@ client.workspace = true collections.workspace = true copilot.workspace = true editor.workspace = true -feature_flags.workspace = true futures.workspace = true gpui.workspace = true itertools.workspace = true diff --git a/crates/language_tools/src/key_context_view.rs b/crates/language_tools/src/key_context_view.rs index c933872d8c513c21c2095b6b32d7a316fcb7f92f..88131781ec3af336d3ae793cf1820e5bcf731605 100644 --- a/crates/language_tools/src/key_context_view.rs +++ b/crates/language_tools/src/key_context_view.rs @@ -132,14 +132,7 @@ impl KeyContextView { } fn matches(&self, predicate: &KeyBindingContextPredicate) -> bool { - let mut stack = self.context_stack.clone(); - while !stack.is_empty() { - if predicate.eval(&stack) { - return true; - } - stack.pop(); - } - false + predicate.depth_of(&self.context_stack).is_some() } fn action_matches(&self, a: &Option<Box<dyn Action>>, b: &dyn Action) -> bool { diff --git a/crates/language_tools/src/lsp_log.rs b/crates/language_tools/src/lsp_log.rs index d1a90d7dbb64bbb8916f6abd8e44ed1141f56076..606f3a3f0e5f91b5fb8856cabce240d094f3cf49 100644 --- a/crates/language_tools/src/lsp_log.rs +++ b/crates/language_tools/src/lsp_log.rs @@ -253,8 +253,8 @@ impl LogStore { let copilot_subscription = Copilot::global(cx).map(|copilot| { let copilot = &copilot; - cx.subscribe(copilot, |this, copilot, inline_completion_event, cx| { - if let copilot::Event::CopilotLanguageServerStarted = inline_completion_event { + cx.subscribe(copilot, |this, copilot, edit_prediction_event, cx| { + if let copilot::Event::CopilotLanguageServerStarted = edit_prediction_event { if let Some(server) = copilot.read(cx).language_server() { let server_id = server.server_id(); let weak_this = cx.weak_entity(); @@ -867,7 +867,7 @@ impl LspLogView { BINARY = server.binary(), WORKSPACE_FOLDERS = server .workspace_folders() - .iter() + .into_iter() .filter_map(|path| path .to_file_path() .ok() diff --git a/crates/language_tools/src/lsp_tool.rs b/crates/language_tools/src/lsp_tool.rs index d14a6fb78187c0769459c8f51b47f12d944bf69c..50547253a92b8c23d0530326faf916e56363dcd9 100644 --- a/crates/language_tools/src/lsp_tool.rs +++ b/crates/language_tools/src/lsp_tool.rs @@ -1,13 +1,17 @@ -use std::{collections::hash_map, path::PathBuf, rc::Rc, time::Duration}; +use std::{ + collections::{BTreeMap, HashMap}, + path::{Path, PathBuf}, + rc::Rc, + time::Duration, +}; use client::proto; -use collections::{HashMap, HashSet}; +use collections::HashSet; use editor::{Editor, EditorEvent}; -use feature_flags::FeatureFlagAppExt as _; use gpui::{Corner, Entity, Subscription, Task, WeakEntity, actions}; -use language::{BinaryStatus, BufferId, LocalFile, ServerHealth}; +use language::{BinaryStatus, BufferId, ServerHealth}; use lsp::{LanguageServerId, LanguageServerName, LanguageServerSelector}; -use project::{LspStore, LspStoreEvent, project_settings::ProjectSettings}; +use project::{LspStore, LspStoreEvent, Worktree, project_settings::ProjectSettings}; use settings::{Settings as _, SettingsStore}; use ui::{ Context, ContextMenu, ContextMenuEntry, ContextMenuItem, DocumentationAside, DocumentationSide, @@ -36,8 +40,7 @@ pub struct LspTool { #[derive(Debug)] struct LanguageServerState { - items: Vec<LspItem>, - other_servers_start_index: Option<usize>, + items: Vec<LspMenuItem>, workspace: WeakEntity<Workspace>, lsp_store: WeakEntity<LspStore>, active_editor: Option<ActiveEditor>, @@ -63,8 +66,13 @@ impl std::fmt::Debug for ActiveEditor { struct LanguageServers { health_statuses: HashMap<LanguageServerId, LanguageServerHealthStatus>, binary_statuses: HashMap<LanguageServerName, LanguageServerBinaryStatus>, - servers_per_buffer_abs_path: - HashMap<PathBuf, HashMap<LanguageServerId, Option<LanguageServerName>>>, + servers_per_buffer_abs_path: HashMap<PathBuf, ServersForPath>, +} + +#[derive(Debug, Clone)] +struct ServersForPath { + servers: HashMap<LanguageServerId, Option<LanguageServerName>>, + worktree: Option<WeakEntity<Worktree>>, } #[derive(Debug, Clone)] @@ -119,8 +127,9 @@ impl LanguageServerState { return menu; }; - for (i, item) in self.items.iter().enumerate() { - if let LspItem::ToggleServersButton { restart } = item { + let mut first_button_encountered = false; + for item in &self.items { + if let LspMenuItem::ToggleServersButton { restart } = item { let label = if *restart { "Restart All Servers" } else { @@ -139,22 +148,19 @@ impl LanguageServerState { }; let project = workspace.read(cx).project().clone(); let buffer_store = project.read(cx).buffer_store().clone(); - let worktree_store = project.read(cx).worktree_store(); - let buffers = state .read(cx) .language_servers .servers_per_buffer_abs_path - .keys() - .filter_map(|abs_path| { - worktree_store.read(cx).find_worktree(abs_path, cx) - }) - .filter_map(|(worktree, relative_path)| { - let entry = - worktree.read(cx).entry_for_path(&relative_path)?; - project.read(cx).path_for_entry(entry.id, cx) - }) - .filter_map(|project_path| { + .iter() + .filter_map(|(abs_path, servers)| { + let worktree = + servers.worktree.as_ref()?.upgrade()?.read(cx); + let relative_path = + abs_path.strip_prefix(&worktree.abs_path()).ok()?; + let entry = worktree.entry_for_path(&relative_path)?; + let project_path = + project.read(cx).path_for_entry(entry.id, cx)?; buffer_store.read(cx).get_by_path(&project_path) }) .collect(); @@ -164,13 +170,16 @@ impl LanguageServerState { .iter() // Do not try to use IDs as we have stopped all servers already, when allowing to restart them all .flat_map(|item| match item { - LspItem::ToggleServersButton { .. } => None, - LspItem::WithHealthCheck(_, status, ..) => Some( - LanguageServerSelector::Name(status.name.clone()), - ), - LspItem::WithBinaryStatus(_, server_name, ..) => Some( - LanguageServerSelector::Name(server_name.clone()), + LspMenuItem::Header { .. } => None, + LspMenuItem::ToggleServersButton { .. } => None, + LspMenuItem::WithHealthCheck { health, .. } => Some( + LanguageServerSelector::Name(health.name.clone()), ), + LspMenuItem::WithBinaryStatus { + server_name, .. + } => Some(LanguageServerSelector::Name( + server_name.clone(), + )), }) .collect(); lsp_store.restart_language_servers_for_buffers( @@ -183,15 +192,23 @@ impl LanguageServerState { .ok(); } }); - menu = menu.separator().item(button); + if !first_button_encountered { + menu = menu.separator(); + first_button_encountered = true; + } + menu = menu.item(button); continue; - }; + } else if let LspMenuItem::Header { header, separator } = item { + menu = menu + .when(*separator, |menu| menu.separator()) + .when_some(header.as_ref(), |menu, header| menu.header(header)); + continue; + } let Some(server_info) = item.server_info() else { continue; }; - let workspace = self.workspace.clone(); let server_selector = server_info.server_selector(); // TODO currently, Zed remote does not work well with the LSP logs // https://github.com/zed-industries/zed/issues/28557 @@ -200,6 +217,7 @@ impl LanguageServerState { let status_color = server_info .binary_status + .as_ref() .and_then(|binary_status| match binary_status.status { BinaryStatus::None => None, BinaryStatus::CheckingForUpdate @@ -218,17 +236,20 @@ impl LanguageServerState { }) .unwrap_or(Color::Success); - if self - .other_servers_start_index - .is_some_and(|index| index == i) - { - menu = menu.separator().header("Other Buffers"); - } - - if i == 0 && self.other_servers_start_index.is_some() { - menu = menu.header("Current Buffer"); - } + let message = server_info + .message + .as_ref() + .or_else(|| server_info.binary_status.as_ref()?.message.as_ref()) + .cloned(); + let hover_label = if has_logs { + Some("View Logs") + } else if message.is_some() { + Some("View Message") + } else { + None + }; + let server_name = server_info.name.clone(); menu = menu.item(ContextMenuItem::custom_entry( move |_, _| { h_flex() @@ -240,42 +261,99 @@ impl LanguageServerState { h_flex() .gap_2() .child(Indicator::dot().color(status_color)) - .child(Label::new(server_info.name.0.clone())), - ) - .child( - h_flex() - .visible_on_hover("menu_item") - .child( - Label::new("View Logs") - .size(LabelSize::Small) - .color(Color::Muted), - ) - .child( - Icon::new(IconName::ChevronRight) - .size(IconSize::Small) - .color(Color::Muted), - ), + .child(Label::new(server_name.0.clone())), ) + .when_some(hover_label, |div, hover_label| { + div.child( + h_flex() + .visible_on_hover("menu_item") + .child( + Label::new(hover_label) + .size(LabelSize::Small) + .color(Color::Muted), + ) + .child( + Icon::new(IconName::ChevronRight) + .size(IconSize::Small) + .color(Color::Muted), + ), + ) + }) .into_any_element() }, { let lsp_logs = lsp_logs.clone(); + let message = message.clone(); + let server_selector = server_selector.clone(); + let server_name = server_info.name.clone(); + let workspace = self.workspace.clone(); move |window, cx| { - if !has_logs { + if has_logs { + lsp_logs.update(cx, |lsp_logs, cx| { + lsp_logs.open_server_trace( + workspace.clone(), + server_selector.clone(), + window, + cx, + ); + }); + } else if let Some(message) = &message { + let Some(create_buffer) = workspace + .update(cx, |workspace, cx| { + workspace + .project() + .update(cx, |project, cx| project.create_buffer(cx)) + }) + .ok() + else { + return; + }; + + let window = window.window_handle(); + let workspace = workspace.clone(); + let message = message.clone(); + let server_name = server_name.clone(); + cx.spawn(async move |cx| { + let buffer = create_buffer.await?; + buffer.update(cx, |buffer, cx| { + buffer.edit( + [( + 0..0, + format!("Language server {server_name}:\n\n{message}"), + )], + None, + cx, + ); + buffer.set_capability(language::Capability::ReadOnly, cx); + })?; + + workspace.update(cx, |workspace, cx| { + window.update(cx, |_, window, cx| { + workspace.add_item_to_active_pane( + Box::new(cx.new(|cx| { + let mut editor = + Editor::for_buffer(buffer, None, window, cx); + editor.set_read_only(true); + editor + })), + None, + true, + window, + cx, + ); + }) + })??; + + anyhow::Ok(()) + }) + .detach(); + } else { cx.propagate(); return; } - lsp_logs.update(cx, |lsp_logs, cx| { - lsp_logs.open_server_trace( - workspace.clone(), - server_selector.clone(), - window, - cx, - ); - }); } }, - server_info.message.map(|server_message| { + message.map(|server_message| { DocumentationAside::new( DocumentationSide::Right, Rc::new(move |_| Label::new(server_message.clone()).into_any_element()), @@ -340,81 +418,95 @@ impl LanguageServers { #[derive(Debug)] enum ServerData<'a> { - WithHealthCheck( - LanguageServerId, - &'a LanguageServerHealthStatus, - Option<&'a LanguageServerBinaryStatus>, - ), - WithBinaryStatus( - Option<LanguageServerId>, - &'a LanguageServerName, - &'a LanguageServerBinaryStatus, - ), + WithHealthCheck { + server_id: LanguageServerId, + health: &'a LanguageServerHealthStatus, + binary_status: Option<&'a LanguageServerBinaryStatus>, + }, + WithBinaryStatus { + server_id: Option<LanguageServerId>, + server_name: &'a LanguageServerName, + binary_status: &'a LanguageServerBinaryStatus, + }, } #[derive(Debug)] -enum LspItem { - WithHealthCheck( - LanguageServerId, - LanguageServerHealthStatus, - Option<LanguageServerBinaryStatus>, - ), - WithBinaryStatus( - Option<LanguageServerId>, - LanguageServerName, - LanguageServerBinaryStatus, - ), +enum LspMenuItem { + WithHealthCheck { + server_id: LanguageServerId, + health: LanguageServerHealthStatus, + binary_status: Option<LanguageServerBinaryStatus>, + }, + WithBinaryStatus { + server_id: Option<LanguageServerId>, + server_name: LanguageServerName, + binary_status: LanguageServerBinaryStatus, + }, ToggleServersButton { restart: bool, }, + Header { + header: Option<SharedString>, + separator: bool, + }, } -impl LspItem { +impl LspMenuItem { fn server_info(&self) -> Option<ServerInfo> { match self { - LspItem::ToggleServersButton { .. } => None, - LspItem::WithHealthCheck( - language_server_id, - language_server_health_status, - language_server_binary_status, - ) => Some(ServerInfo { - name: language_server_health_status.name.clone(), - id: Some(*language_server_id), - health: language_server_health_status.health(), - binary_status: language_server_binary_status.clone(), - message: language_server_health_status.message(), + Self::Header { .. } => None, + Self::ToggleServersButton { .. } => None, + Self::WithHealthCheck { + server_id, + health, + binary_status, + .. + } => Some(ServerInfo { + name: health.name.clone(), + id: Some(*server_id), + health: health.health(), + binary_status: binary_status.clone(), + message: health.message(), }), - LspItem::WithBinaryStatus( + Self::WithBinaryStatus { server_id, - language_server_name, - language_server_binary_status, - ) => Some(ServerInfo { - name: language_server_name.clone(), + server_name, + binary_status, + .. + } => Some(ServerInfo { + name: server_name.clone(), id: *server_id, health: None, - binary_status: Some(language_server_binary_status.clone()), - message: language_server_binary_status.message.clone(), + binary_status: Some(binary_status.clone()), + message: binary_status.message.clone(), }), } } } impl ServerData<'_> { - fn name(&self) -> &LanguageServerName { - match self { - Self::WithHealthCheck(_, state, _) => &state.name, - Self::WithBinaryStatus(_, name, ..) => name, - } - } - - fn into_lsp_item(self) -> LspItem { + fn into_lsp_item(self) -> LspMenuItem { match self { - Self::WithHealthCheck(id, name, status) => { - LspItem::WithHealthCheck(id, name.clone(), status.cloned()) - } - Self::WithBinaryStatus(server_id, name, status) => { - LspItem::WithBinaryStatus(server_id, name.clone(), status.clone()) - } + Self::WithHealthCheck { + server_id, + health, + binary_status, + .. + } => LspMenuItem::WithHealthCheck { + server_id, + health: health.clone(), + binary_status: binary_status.cloned(), + }, + Self::WithBinaryStatus { + server_id, + server_name, + binary_status, + .. + } => LspMenuItem::WithBinaryStatus { + server_id, + server_name: server_name.clone(), + binary_status: binary_status.clone(), + }, } } } @@ -447,7 +539,6 @@ impl LspTool { let state = cx.new(|_| LanguageServerState { workspace: workspace.weak_handle(), items: Vec::new(), - other_servers_start_index: None, lsp_store: lsp_store.downgrade(), active_editor: None, language_servers: LanguageServers::default(), @@ -537,13 +628,28 @@ impl LspTool { message: proto::update_language_server::Variant::RegisteredForBuffer(update), .. } => { - self.server_state.update(cx, |state, _| { - state + self.server_state.update(cx, |state, cx| { + let Ok(worktree) = state.workspace.update(cx, |workspace, cx| { + workspace + .project() + .read(cx) + .find_worktree(Path::new(&update.buffer_abs_path), cx) + .map(|(worktree, _)| worktree.downgrade()) + }) else { + return; + }; + let entry = state .language_servers .servers_per_buffer_abs_path .entry(PathBuf::from(&update.buffer_abs_path)) - .or_default() - .insert(*language_server_id, name.clone()); + .or_insert_with(|| ServersForPath { + servers: HashMap::default(), + worktree: worktree.clone(), + }); + entry.servers.insert(*language_server_id, name.clone()); + if worktree.is_some() { + entry.worktree = worktree; + } }); updated = true; } @@ -557,94 +663,95 @@ impl LspTool { fn regenerate_items(&mut self, cx: &mut App) { self.server_state.update(cx, |state, cx| { - let editor_buffers = state + let active_worktrees = state .active_editor .as_ref() - .map(|active_editor| active_editor.editor_buffers.clone()) - .unwrap_or_default(); - let editor_buffer_paths = editor_buffers - .iter() - .filter_map(|buffer_id| { - let buffer_path = state - .lsp_store - .update(cx, |lsp_store, cx| { - Some( - project::File::from_dyn( - lsp_store - .buffer_store() - .read(cx) - .get(*buffer_id)? - .read(cx) - .file(), - )? - .abs_path(cx), - ) + .into_iter() + .flat_map(|active_editor| { + active_editor + .editor + .upgrade() + .into_iter() + .flat_map(|active_editor| { + active_editor + .read(cx) + .buffer() + .read(cx) + .all_buffers() + .into_iter() + .filter_map(|buffer| { + project::File::from_dyn(buffer.read(cx).file()) + }) + .map(|buffer_file| buffer_file.worktree.clone()) }) - .ok()??; - Some(buffer_path) }) - .collect::<Vec<_>>(); + .collect::<HashSet<_>>(); - let mut servers_with_health_checks = HashSet::default(); - let mut server_ids_with_health_checks = HashSet::default(); - let mut buffer_servers = - Vec::with_capacity(state.language_servers.health_statuses.len()); - let mut other_servers = - Vec::with_capacity(state.language_servers.health_statuses.len()); - let buffer_server_ids = editor_buffer_paths - .iter() - .filter_map(|buffer_path| { - state - .language_servers - .servers_per_buffer_abs_path - .get(buffer_path) - }) - .flatten() - .fold(HashMap::default(), |mut acc, (server_id, name)| { - match acc.entry(*server_id) { - hash_map::Entry::Occupied(mut o) => { - let old_name: &mut Option<&LanguageServerName> = o.get_mut(); - if old_name.is_none() { - *old_name = name.as_ref(); - } - } - hash_map::Entry::Vacant(v) => { - v.insert(name.as_ref()); + let mut server_ids_to_worktrees = + HashMap::<LanguageServerId, Entity<Worktree>>::default(); + let mut server_names_to_worktrees = HashMap::< + LanguageServerName, + HashSet<(Entity<Worktree>, LanguageServerId)>, + >::default(); + for servers_for_path in state.language_servers.servers_per_buffer_abs_path.values() { + if let Some(worktree) = servers_for_path + .worktree + .as_ref() + .and_then(|worktree| worktree.upgrade()) + { + for (server_id, server_name) in &servers_for_path.servers { + server_ids_to_worktrees.insert(*server_id, worktree.clone()); + if let Some(server_name) = server_name { + server_names_to_worktrees + .entry(server_name.clone()) + .or_default() + .insert((worktree.clone(), *server_id)); } } - acc + } + } + + let mut servers_per_worktree = BTreeMap::<SharedString, Vec<ServerData>>::new(); + let mut servers_without_worktree = Vec::<ServerData>::new(); + let mut servers_with_health_checks = HashSet::default(); + + for (server_id, health) in &state.language_servers.health_statuses { + let worktree = server_ids_to_worktrees.get(server_id).or_else(|| { + let worktrees = server_names_to_worktrees.get(&health.name)?; + worktrees + .iter() + .find(|(worktree, _)| active_worktrees.contains(worktree)) + .or_else(|| worktrees.iter().next()) + .map(|(worktree, _)| worktree) }); - for (server_id, server_state) in &state.language_servers.health_statuses { - let binary_status = state - .language_servers - .binary_statuses - .get(&server_state.name); - servers_with_health_checks.insert(&server_state.name); - server_ids_with_health_checks.insert(*server_id); - if buffer_server_ids.contains_key(server_id) { - buffer_servers.push(ServerData::WithHealthCheck( - *server_id, - server_state, - binary_status, - )); - } else { - other_servers.push(ServerData::WithHealthCheck( - *server_id, - server_state, - binary_status, - )); + servers_with_health_checks.insert(&health.name); + let worktree_name = + worktree.map(|worktree| SharedString::new(worktree.read(cx).root_name())); + + let binary_status = state.language_servers.binary_statuses.get(&health.name); + let server_data = ServerData::WithHealthCheck { + server_id: *server_id, + health, + binary_status, + }; + match worktree_name { + Some(worktree_name) => servers_per_worktree + .entry(worktree_name.clone()) + .or_default() + .push(server_data), + None => servers_without_worktree.push(server_data), } } let mut can_stop_all = !state.language_servers.health_statuses.is_empty(); let mut can_restart_all = state.language_servers.health_statuses.is_empty(); - for (server_name, status) in state + for (server_name, binary_status) in state .language_servers .binary_statuses .iter() .filter(|(name, _)| !servers_with_health_checks.contains(name)) { - match status.status { + match binary_status.status { BinaryStatus::None => { can_restart_all = false; can_stop_all |= true; @@ -669,51 +776,73 @@ impl LspTool { BinaryStatus::Failed { .. } => {} } - let matching_server_id = state - .language_servers - .servers_per_buffer_abs_path - .iter() - .filter(|(path, _)| editor_buffer_paths.contains(path)) - .flat_map(|(_, server_associations)| server_associations.iter()) - .find_map(|(id, name)| { - if name.as_ref() == Some(server_name) { - Some(*id) - } else { - None + match server_names_to_worktrees.get(server_name) { + Some(worktrees_for_name) => { + match worktrees_for_name + .iter() + .find(|(worktree, _)| active_worktrees.contains(worktree)) + .or_else(|| worktrees_for_name.iter().next()) + { + Some((worktree, server_id)) => { + let worktree_name = + SharedString::new(worktree.read(cx).root_name()); + servers_per_worktree + .entry(worktree_name.clone()) + .or_default() + .push(ServerData::WithBinaryStatus { + server_name, + binary_status, + server_id: Some(*server_id), + }); + } + None => servers_without_worktree.push(ServerData::WithBinaryStatus { + server_name, + binary_status, + server_id: None, + }), } - }); - if let Some(server_id) = matching_server_id { - buffer_servers.push(ServerData::WithBinaryStatus( - Some(server_id), + } + None => servers_without_worktree.push(ServerData::WithBinaryStatus { server_name, - status, - )); - } else { - other_servers.push(ServerData::WithBinaryStatus(None, server_name, status)); + binary_status, + server_id: None, + }), } } - buffer_servers.sort_by_key(|data| data.name().clone()); - other_servers.sort_by_key(|data| data.name().clone()); - - let mut other_servers_start_index = None; let mut new_lsp_items = - Vec::with_capacity(buffer_servers.len() + other_servers.len() + 1); - new_lsp_items.extend(buffer_servers.into_iter().map(ServerData::into_lsp_item)); - if !new_lsp_items.is_empty() { - other_servers_start_index = Some(new_lsp_items.len()); + Vec::with_capacity(servers_per_worktree.len() + servers_without_worktree.len() + 2); + for (worktree_name, worktree_servers) in servers_per_worktree { + if worktree_servers.is_empty() { + continue; + } + new_lsp_items.push(LspMenuItem::Header { + header: Some(worktree_name), + separator: false, + }); + new_lsp_items.extend(worktree_servers.into_iter().map(ServerData::into_lsp_item)); + } + if !servers_without_worktree.is_empty() { + new_lsp_items.push(LspMenuItem::Header { + header: Some(SharedString::from("Unknown worktree")), + separator: false, + }); + new_lsp_items.extend( + servers_without_worktree + .into_iter() + .map(ServerData::into_lsp_item), + ); } - new_lsp_items.extend(other_servers.into_iter().map(ServerData::into_lsp_item)); if !new_lsp_items.is_empty() { if can_stop_all { - new_lsp_items.push(LspItem::ToggleServersButton { restart: false }); + new_lsp_items.push(LspMenuItem::ToggleServersButton { restart: true }); + new_lsp_items.push(LspMenuItem::ToggleServersButton { restart: false }); } else if can_restart_all { - new_lsp_items.push(LspItem::ToggleServersButton { restart: true }); + new_lsp_items.push(LspMenuItem::ToggleServersButton { restart: true }); } } state.items = new_lsp_items; - state.other_servers_start_index = other_servers_start_index; }); } @@ -835,10 +964,7 @@ impl StatusItemView for LspTool { impl Render for LspTool { fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl ui::IntoElement { - if !cx.is_staff() - || self.server_state.read(cx).language_servers.is_empty() - || self.lsp_menu.is_none() - { + if self.server_state.read(cx).language_servers.is_empty() || self.lsp_menu.is_none() { return div(); } @@ -846,12 +972,12 @@ impl Render for LspTool { let mut has_warnings = false; let mut has_other_notifications = false; let state = self.server_state.read(cx); - for server in state.language_servers.health_statuses.values() { - if let Some(binary_status) = &state.language_servers.binary_statuses.get(&server.name) { - has_errors |= matches!(binary_status.status, BinaryStatus::Failed { .. }); - has_other_notifications |= binary_status.message.is_some(); - } + for binary_status in state.language_servers.binary_statuses.values() { + has_errors |= matches!(binary_status.status, BinaryStatus::Failed { .. }); + has_other_notifications |= binary_status.message.is_some(); + } + for server in state.language_servers.health_statuses.values() { if let Some((message, health)) = &server.health { has_other_notifications |= message.is_some(); match health { @@ -889,7 +1015,7 @@ impl Render for LspTool { .anchor(Corner::BottomLeft) .with_handle(self.popover_menu_handle.clone()) .trigger_with_tooltip( - IconButton::new("zed-lsp-tool-button", IconName::BoltFilledAlt) + IconButton::new("zed-lsp-tool-button", IconName::BoltOutlined) .when_some(indicator, IconButton::indicator) .icon_size(IconSize::Small) .indicator_border_color(Some(cx.theme().colors().status_bar_background)), diff --git a/crates/languages/Cargo.toml b/crates/languages/Cargo.toml index 2e8f007cff9b3dcdbc1be9c8405c90369ed12413..8e258180702626bb3dd32b28bfb0e82722a1f12f 100644 --- a/crates/languages/Cargo.toml +++ b/crates/languages/Cargo.toml @@ -36,11 +36,13 @@ load-grammars = [ [dependencies] anyhow.workspace = true async-compression.workspace = true +async-fs.workspace = true async-tar.workspace = true async-trait.workspace = true chrono.workspace = true collections.workspace = true dap.workspace = true +feature_flags.workspace = true futures.workspace = true gpui.workspace = true http_client.workspace = true @@ -61,6 +63,7 @@ regex.workspace = true rope.workspace = true rust-embed.workspace = true schemars.workspace = true +sha2.workspace = true serde.workspace = true serde_json.workspace = true serde_json_lenient.workspace = true @@ -68,6 +71,7 @@ settings.workspace = true smol.workspace = true snippet_provider.workspace = true task.workspace = true +tempfile.workspace = true toml.workspace = true tree-sitter = { workspace = true, optional = true } tree-sitter-bash = { workspace = true, optional = true } diff --git a/crates/languages/src/bash/config.toml b/crates/languages/src/bash/config.toml index db9a2749e796e13f49806b59eaca648ff731b3b8..8ff4802aee5124201d013e0b2f0b01c7046e55a0 100644 --- a/crates/languages/src/bash/config.toml +++ b/crates/languages/src/bash/config.toml @@ -18,17 +18,20 @@ brackets = [ { start = "in", end = "esac", close = false, newline = true, not_in = ["comment", "string"] }, ] -### WARN: the following is not working when you insert an `elif` just before an else -### example: (^ is cursor after hitting enter) -### ``` -### if true; then -### foo -### elif -### ^ -### else -### bar -### fi -### ``` -increase_indent_pattern = "(^|\\s+|;)(do|then|in|else|elif)\\b.*$" -decrease_indent_pattern = "(^|\\s+|;)(fi|done|esac|else|elif)\\b.*$" -# make sure to test each line mode & block mode +auto_indent_using_last_non_empty_line = false +increase_indent_pattern = "^\\s*(\\b(else|elif)\\b|([^#]+\\b(do|then|in)\\b)|([\\w\\*]+\\)))\\s*$" +decrease_indent_patterns = [ + { pattern = "^\\s*elif\\b.*", valid_after = ["if", "elif"] }, + { pattern = "^\\s*else\\b.*", valid_after = ["if", "elif", "for", "while"] }, + { pattern = "^\\s*fi\\b.*", valid_after = ["if", "elif", "else"] }, + { pattern = "^\\s*done\\b.*", valid_after = ["for", "while"] }, + { pattern = "^\\s*esac\\b.*", valid_after = ["case"] }, + { pattern = "^\\s*[\\w\\*]+\\)\\s*$", valid_after = ["case_item"] }, +] + +# We can't use decrease_indent_patterns simply for elif, because +# there is bug in tree sitter which throws ERROR on if match. +# +# This is workaround. That means, elif will outdents with despite +# of wrong context. Like using elif after else. +decrease_indent_pattern = "(^|\\s+|;)(elif)\\b.*$" diff --git a/crates/languages/src/bash/indents.scm b/crates/languages/src/bash/indents.scm index acdcddabfe20d4c1efdeacc23c8d097e3ca0b094..468fc595e56e2616547dc3e752318cd89df4a363 100644 --- a/crates/languages/src/bash/indents.scm +++ b/crates/languages/src/bash/indents.scm @@ -1,12 +1,12 @@ -(function_definition - "function"? - body: ( - _ - "{" @start - "}" @end - )) @indent +(_ "[" "]" @end) @indent +(_ "{" "}" @end) @indent +(_ "(" ")" @end) @indent -(array - "(" @start - ")" @end - ) @indent +(function_definition) @start.function +(if_statement) @start.if +(elif_clause) @start.elif +(else_clause) @start.else +(for_statement) @start.for +(while_statement) @start.while +(case_statement) @start.case +(case_item) @start.case_item diff --git a/crates/languages/src/c.rs b/crates/languages/src/c.rs index c06c35ee69e74ad0d7d802a3f40aeb2edc41f119..a55d8ff998a8d92625f5e4a319e01dd3a5735ec4 100644 --- a/crates/languages/src/c.rs +++ b/crates/languages/src/c.rs @@ -2,14 +2,16 @@ use anyhow::{Context as _, Result, bail}; use async_trait::async_trait; use futures::StreamExt; use gpui::{App, AsyncApp}; -use http_client::github::{GitHubLspBinaryVersion, latest_github_release}; +use http_client::github::{AssetKind, GitHubLspBinaryVersion, latest_github_release}; pub use language::*; use lsp::{InitializeParams, LanguageServerBinary, LanguageServerName}; use project::lsp_store::clangd_ext; use serde_json::json; use smol::fs; use std::{any::Any, env::consts, path::PathBuf, sync::Arc}; -use util::{ResultExt, archive::extract_zip, fs::remove_matching, maybe, merge_json_value_into}; +use util::{ResultExt, fs::remove_matching, maybe, merge_json_value_into}; + +use crate::github_download::{GithubBinaryMetadata, download_server_binary}; pub struct CLspAdapter; @@ -58,6 +60,7 @@ impl super::LspAdapter for CLspAdapter { let version = GitHubLspBinaryVersion { name: release.tag_name, url: asset.browser_download_url.clone(), + digest: asset.digest.clone(), }; Ok(Box::new(version) as Box<_>) } @@ -68,32 +71,67 @@ impl super::LspAdapter for CLspAdapter { container_dir: PathBuf, delegate: &dyn LspAdapterDelegate, ) -> Result<LanguageServerBinary> { - let version = version.downcast::<GitHubLspBinaryVersion>().unwrap(); - let version_dir = container_dir.join(format!("clangd_{}", version.name)); + let GitHubLspBinaryVersion { name, url, digest } = + &*version.downcast::<GitHubLspBinaryVersion>().unwrap(); + let version_dir = container_dir.join(format!("clangd_{name}")); let binary_path = version_dir.join("bin/clangd"); - if fs::metadata(&binary_path).await.is_err() { - let mut response = delegate - .http_client() - .get(&version.url, Default::default(), true) - .await - .context("error downloading release")?; - anyhow::ensure!( - response.status().is_success(), - "download failed with status {}", - response.status().to_string() - ); - extract_zip(&container_dir, response.body_mut()) - .await - .with_context(|| format!("unzipping clangd archive to {container_dir:?}"))?; - remove_matching(&container_dir, |entry| entry != version_dir).await; + let binary = LanguageServerBinary { + path: binary_path.clone(), + env: None, + arguments: Default::default(), + }; + + let metadata_path = version_dir.join("metadata"); + let metadata = GithubBinaryMetadata::read_from_file(&metadata_path) + .await + .ok(); + if let Some(metadata) = metadata { + let validity_check = async || { + delegate + .try_exec(LanguageServerBinary { + path: binary_path.clone(), + arguments: vec!["--version".into()], + env: None, + }) + .await + .inspect_err(|err| { + log::warn!("Unable to run {binary_path:?} asset, redownloading: {err}",) + }) + }; + if let (Some(actual_digest), Some(expected_digest)) = (&metadata.digest, digest) { + if actual_digest == expected_digest { + if validity_check().await.is_ok() { + return Ok(binary); + } + } else { + log::info!( + "SHA-256 mismatch for {binary_path:?} asset, downloading new asset. Expected: {expected_digest}, Got: {actual_digest}" + ); + } + } else if validity_check().await.is_ok() { + return Ok(binary); + } } + download_server_binary( + delegate, + url, + digest.as_deref(), + &container_dir, + AssetKind::Zip, + ) + .await?; + remove_matching(&container_dir, |entry| entry != version_dir).await; + GithubBinaryMetadata::write_to_file( + &GithubBinaryMetadata { + metadata_version: 1, + digest: digest.clone(), + }, + &metadata_path, + ) + .await?; - Ok(LanguageServerBinary { - path: binary_path, - env: None, - arguments: Vec::new(), - }) + Ok(binary) } async fn cached_server_binary( diff --git a/crates/languages/src/c/config.toml b/crates/languages/src/c/config.toml index 08cd100f4d4dcb7c00eee33a2491864283986a82..74290fd9e2b31db93bb62187ab707110c818fc44 100644 --- a/crates/languages/src/c/config.toml +++ b/crates/languages/src/c/config.toml @@ -2,6 +2,10 @@ name = "C" grammar = "c" path_suffixes = ["c"] line_comments = ["// "] +decrease_indent_patterns = [ + { pattern = "^\\s*\\{.*\\}?\\s*$", valid_after = ["if", "for", "while", "do", "switch", "else"] }, + { pattern = "^\\s*else\\s*$", valid_after = ["if"] } +] autoclose_before = ";:.,=}])>" brackets = [ { start = "{", end = "}", close = true, newline = true }, @@ -12,4 +16,4 @@ brackets = [ { start = "/*", end = " */", close = true, newline = false, not_in = ["string", "comment"] }, ] debuggers = ["CodeLLDB", "GDB"] -documentation = { start = "/*", end = "*/", prefix = "* ", tab_size = 1 } +documentation_comment = { start = "/*", prefix = "* ", end = "*/", tab_size = 1 } diff --git a/crates/languages/src/c/indents.scm b/crates/languages/src/c/indents.scm index fa40ce215e358a067f48997ab4d870174f1ea479..3b6d5135abe593656d4134b309bf5d43f54a8f59 100644 --- a/crates/languages/src/c/indents.scm +++ b/crates/languages/src/c/indents.scm @@ -3,7 +3,17 @@ (assignment_expression) (if_statement) (for_statement) + (while_statement) + (do_statement) + (else_clause) ] @indent (_ "{" "}" @end) @indent (_ "(" ")" @end) @indent + +(if_statement) @start.if +(for_statement) @start.for +(while_statement) @start.while +(do_statement) @start.do +(switch_statement) @start.switch +(else_clause) @start.else diff --git a/crates/languages/src/cpp/config.toml b/crates/languages/src/cpp/config.toml index a81cbe09cde970398719eef8af75864635b3e43b..fab88266d7444875e29d57a82a770c843d9b2faf 100644 --- a/crates/languages/src/cpp/config.toml +++ b/crates/languages/src/cpp/config.toml @@ -2,6 +2,10 @@ name = "C++" grammar = "cpp" path_suffixes = ["cc", "hh", "cpp", "h", "hpp", "cxx", "hxx", "c++", "ipp", "inl", "ixx", "cu", "cuh", "C", "H"] line_comments = ["// ", "/// ", "//! "] +decrease_indent_patterns = [ + { pattern = "^\\s*\\{.*\\}?\\s*$", valid_after = ["if", "for", "while", "do", "switch", "else"] }, + { pattern = "^\\s*else\\s*$", valid_after = ["if"] } +] autoclose_before = ";:.,=}])>" brackets = [ { start = "{", end = "}", close = true, newline = true }, @@ -12,4 +16,4 @@ brackets = [ { start = "/*", end = " */", close = true, newline = false, not_in = ["string", "comment"] }, ] debuggers = ["CodeLLDB", "GDB"] -documentation = { start = "/*", end = "*/", prefix = "* ", tab_size = 1 } +documentation_comment = { start = "/*", prefix = "* ", end = "*/", tab_size = 1 } diff --git a/crates/languages/src/cpp/indents.scm b/crates/languages/src/cpp/indents.scm index a17f4c4821e1ff096e9977c28801eeba581d8b17..d95dfe178cbada6836cb14bca997619fe2a319b3 100644 --- a/crates/languages/src/cpp/indents.scm +++ b/crates/languages/src/cpp/indents.scm @@ -1,7 +1,19 @@ [ (field_expression) (assignment_expression) + (if_statement) + (for_statement) + (while_statement) + (do_statement) + (else_clause) ] @indent (_ "{" "}" @end) @indent (_ "(" ")" @end) @indent + +(if_statement) @start.if +(for_statement) @start.for +(while_statement) @start.while +(do_statement) @start.do +(switch_statement) @start.switch +(else_clause) @start.else diff --git a/crates/languages/src/css.rs b/crates/languages/src/css.rs index f2a94809a0ea6425de4479fa7a18b33eb4e1c647..7725e079be31cacc3e3bc5e30ec48b6ab8d2d4d4 100644 --- a/crates/languages/src/css.rs +++ b/crates/languages/src/css.rs @@ -5,7 +5,7 @@ use gpui::AsyncApp; use language::{LanguageToolchainStore, LspAdapter, LspAdapterDelegate}; use lsp::{LanguageServerBinary, LanguageServerName}; use node_runtime::NodeRuntime; -use project::Fs; +use project::{Fs, lsp_store::language_server_settings}; use serde_json::json; use smol::fs; use std::{ @@ -14,7 +14,7 @@ use std::{ path::{Path, PathBuf}, sync::Arc, }; -use util::{ResultExt, maybe}; +use util::{ResultExt, maybe, merge_json_value_into}; const SERVER_PATH: &str = "node_modules/vscode-langservers-extracted/bin/vscode-css-language-server"; @@ -134,6 +134,37 @@ impl LspAdapter for CssLspAdapter { "provideFormatter": true }))) } + + async fn workspace_configuration( + self: Arc<Self>, + _: &dyn Fs, + delegate: &Arc<dyn LspAdapterDelegate>, + _: Arc<dyn LanguageToolchainStore>, + cx: &mut AsyncApp, + ) -> Result<serde_json::Value> { + let mut default_config = json!({ + "css": { + "lint": {} + }, + "less": { + "lint": {} + }, + "scss": { + "lint": {} + } + }); + + let project_options = cx.update(|cx| { + language_server_settings(delegate.as_ref(), &self.name(), cx) + .and_then(|s| s.settings.clone()) + })?; + + if let Some(override_options) = project_options { + merge_json_value_into(override_options, &mut default_config); + } + + Ok(default_config) + } } async fn get_cached_server_binary( diff --git a/crates/languages/src/css/config.toml b/crates/languages/src/css/config.toml index 0e0b7315e0e1449641c428fc4397d5d39f92f131..a2ca96e76d3427c2ff2eb249d9a2f93a68d8f1c0 100644 --- a/crates/languages/src/css/config.toml +++ b/crates/languages/src/css/config.toml @@ -10,5 +10,5 @@ brackets = [ { start = "'", end = "'", close = true, newline = false, not_in = ["string", "comment"] }, ] completion_query_characters = ["-"] -block_comment = ["/* ", " */"] +block_comment = { start = "/*", prefix = "* ", end = "*/", tab_size = 1 } prettier_parser_name = "css" diff --git a/crates/languages/src/github_download.rs b/crates/languages/src/github_download.rs new file mode 100644 index 0000000000000000000000000000000000000000..a3cd0a964b31c17e73a72317efc0616474dc846d --- /dev/null +++ b/crates/languages/src/github_download.rs @@ -0,0 +1,190 @@ +use std::{path::Path, pin::Pin, task::Poll}; + +use anyhow::{Context, Result}; +use async_compression::futures::bufread::GzipDecoder; +use futures::{AsyncRead, AsyncSeek, AsyncSeekExt, AsyncWrite, io::BufReader}; +use http_client::github::AssetKind; +use language::LspAdapterDelegate; +use sha2::{Digest, Sha256}; + +#[derive(serde::Deserialize, serde::Serialize, Debug)] +pub(crate) struct GithubBinaryMetadata { + pub(crate) metadata_version: u64, + pub(crate) digest: Option<String>, +} + +impl GithubBinaryMetadata { + pub(crate) async fn read_from_file(metadata_path: &Path) -> Result<GithubBinaryMetadata> { + let metadata_content = async_fs::read_to_string(metadata_path) + .await + .with_context(|| format!("reading metadata file at {metadata_path:?}"))?; + let metadata: GithubBinaryMetadata = serde_json::from_str(&metadata_content) + .with_context(|| format!("parsing metadata file at {metadata_path:?}"))?; + Ok(metadata) + } + + pub(crate) async fn write_to_file(&self, metadata_path: &Path) -> Result<()> { + let metadata_content = serde_json::to_string(self) + .with_context(|| format!("serializing metadata for {metadata_path:?}"))?; + async_fs::write(metadata_path, metadata_content.as_bytes()) + .await + .with_context(|| format!("writing metadata file at {metadata_path:?}"))?; + Ok(()) + } +} + +pub(crate) async fn download_server_binary( + delegate: &dyn LspAdapterDelegate, + url: &str, + digest: Option<&str>, + destination_path: &Path, + asset_kind: AssetKind, +) -> Result<(), anyhow::Error> { + log::info!("downloading github artifact from {url}"); + let mut response = delegate + .http_client() + .get(url, Default::default(), true) + .await + .with_context(|| format!("downloading release from {url}"))?; + let body = response.body_mut(); + match digest { + Some(expected_sha_256) => { + let temp_asset_file = tempfile::NamedTempFile::new() + .with_context(|| format!("creating a temporary file for {url}"))?; + let (temp_asset_file, _temp_guard) = temp_asset_file.into_parts(); + let mut writer = HashingWriter { + writer: async_fs::File::from(temp_asset_file), + hasher: Sha256::new(), + }; + futures::io::copy(&mut BufReader::new(body), &mut writer) + .await + .with_context(|| { + format!("saving archive contents into the temporary file for {url}",) + })?; + let asset_sha_256 = format!("{:x}", writer.hasher.finalize()); + anyhow::ensure!( + asset_sha_256 == expected_sha_256, + "{url} asset got SHA-256 mismatch. Expected: {expected_sha_256}, Got: {asset_sha_256}", + ); + writer + .writer + .seek(std::io::SeekFrom::Start(0)) + .await + .with_context(|| format!("seeking temporary file {destination_path:?}",))?; + stream_file_archive(&mut writer.writer, url, destination_path, asset_kind) + .await + .with_context(|| { + format!("extracting downloaded asset for {url} into {destination_path:?}",) + })?; + } + None => stream_response_archive(body, url, destination_path, asset_kind) + .await + .with_context(|| { + format!("extracting response for asset {url} into {destination_path:?}",) + })?, + } + Ok(()) +} + +async fn stream_response_archive( + response: impl AsyncRead + Unpin, + url: &str, + destination_path: &Path, + asset_kind: AssetKind, +) -> Result<()> { + match asset_kind { + AssetKind::TarGz => extract_tar_gz(destination_path, url, response).await?, + AssetKind::Gz => extract_gz(destination_path, url, response).await?, + AssetKind::Zip => { + util::archive::extract_zip(&destination_path, response).await?; + } + }; + Ok(()) +} + +async fn stream_file_archive( + file_archive: impl AsyncRead + AsyncSeek + Unpin, + url: &str, + destination_path: &Path, + asset_kind: AssetKind, +) -> Result<()> { + match asset_kind { + AssetKind::TarGz => extract_tar_gz(destination_path, url, file_archive).await?, + AssetKind::Gz => extract_gz(destination_path, url, file_archive).await?, + #[cfg(not(windows))] + AssetKind::Zip => { + util::archive::extract_seekable_zip(&destination_path, file_archive).await?; + } + #[cfg(windows)] + AssetKind::Zip => { + util::archive::extract_zip(&destination_path, file_archive).await?; + } + }; + Ok(()) +} + +async fn extract_tar_gz( + destination_path: &Path, + url: &str, + from: impl AsyncRead + Unpin, +) -> Result<(), anyhow::Error> { + let decompressed_bytes = GzipDecoder::new(BufReader::new(from)); + let archive = async_tar::Archive::new(decompressed_bytes); + archive + .unpack(&destination_path) + .await + .with_context(|| format!("extracting {url} to {destination_path:?}"))?; + Ok(()) +} + +async fn extract_gz( + destination_path: &Path, + url: &str, + from: impl AsyncRead + Unpin, +) -> Result<(), anyhow::Error> { + let mut decompressed_bytes = GzipDecoder::new(BufReader::new(from)); + let mut file = smol::fs::File::create(&destination_path) + .await + .with_context(|| { + format!("creating a file {destination_path:?} for a download from {url}") + })?; + futures::io::copy(&mut decompressed_bytes, &mut file) + .await + .with_context(|| format!("extracting {url} to {destination_path:?}"))?; + Ok(()) +} + +struct HashingWriter<W: AsyncWrite + Unpin> { + writer: W, + hasher: Sha256, +} + +impl<W: AsyncWrite + Unpin> AsyncWrite for HashingWriter<W> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll<std::result::Result<usize, std::io::Error>> { + match Pin::new(&mut self.writer).poll_write(cx, buf) { + Poll::Ready(Ok(n)) => { + self.hasher.update(&buf[..n]); + Poll::Ready(Ok(n)) + } + other => other, + } + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<Result<(), std::io::Error>> { + Pin::new(&mut self.writer).poll_flush(cx) + } + + fn poll_close( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<std::result::Result<(), std::io::Error>> { + Pin::new(&mut self.writer).poll_close(cx) + } +} diff --git a/crates/languages/src/go.rs b/crates/languages/src/go.rs index 25aa5a67b909ab9898cc449f6132a0b2f077d707..16c1b67203e673ddb8c20c110d46ea7bf062ea43 100644 --- a/crates/languages/src/go.rs +++ b/crates/languages/src/go.rs @@ -41,7 +41,7 @@ static VERSION_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\d+\.\d+\.\d+").expect("Failed to create VERSION_REGEX")); static GO_ESCAPE_SUBTEST_NAME_REGEX: LazyLock<Regex> = LazyLock::new(|| { - Regex::new(r#"[.*+?^${}()|\[\]\\]"#).expect("Failed to create GO_ESCAPE_SUBTEST_NAME_REGEX") + Regex::new(r#"[.*+?^${}()|\[\]\\"']"#).expect("Failed to create GO_ESCAPE_SUBTEST_NAME_REGEX") }); const BINARY: &str = if cfg!(target_os = "windows") { @@ -685,11 +685,20 @@ impl ContextProvider for GoContextProvider { } fn extract_subtest_name(input: &str) -> Option<String> { - let replaced_spaces = input.trim_matches('"').replace(' ', "_"); + let content = if input.starts_with('`') && input.ends_with('`') { + input.trim_matches('`') + } else { + input.trim_matches('"') + }; + + let processed = content + .chars() + .map(|c| if c.is_whitespace() { '_' } else { c }) + .collect::<String>(); Some( GO_ESCAPE_SUBTEST_NAME_REGEX - .replace_all(&replaced_spaces, |caps: ®ex::Captures| { + .replace_all(&processed, |caps: ®ex::Captures| { format!("\\{}", &caps[0]) }) .to_string(), @@ -700,7 +709,7 @@ fn extract_subtest_name(input: &str) -> Option<String> { mod tests { use super::*; use crate::language; - use gpui::Hsla; + use gpui::{AppContext, Hsla, TestAppContext}; use theme::SyntaxTheme; #[gpui::test] @@ -790,4 +799,108 @@ mod tests { }) ); } + + #[gpui::test] + fn test_go_runnable_detection(cx: &mut TestAppContext) { + let language = language("go", tree_sitter_go::LANGUAGE.into()); + + let interpreted_string_subtest = r#" + package main + + import "testing" + + func TestExample(t *testing.T) { + t.Run("subtest with double quotes", func(t *testing.T) { + // test code + }) + } + "#; + + let raw_string_subtest = r#" + package main + + import "testing" + + func TestExample(t *testing.T) { + t.Run(`subtest with + multiline + backticks`, func(t *testing.T) { + // test code + }) + } + "#; + + let buffer = cx.new(|cx| { + crate::Buffer::local(interpreted_string_subtest, cx).with_language(language.clone(), cx) + }); + cx.executor().run_until_parked(); + + let runnables: Vec<_> = buffer.update(cx, |buffer, _| { + let snapshot = buffer.snapshot(); + snapshot + .runnable_ranges(0..interpreted_string_subtest.len()) + .collect() + }); + + assert!( + runnables.len() == 2, + "Should find test function and subtest with double quotes, found: {}", + runnables.len() + ); + + let buffer = cx.new(|cx| { + crate::Buffer::local(raw_string_subtest, cx).with_language(language.clone(), cx) + }); + cx.executor().run_until_parked(); + + let runnables: Vec<_> = buffer.update(cx, |buffer, _| { + let snapshot = buffer.snapshot(); + snapshot + .runnable_ranges(0..raw_string_subtest.len()) + .collect() + }); + + assert!( + runnables.len() == 2, + "Should find test function and subtest with backticks, found: {}", + runnables.len() + ); + } + + #[test] + fn test_extract_subtest_name() { + // Interpreted string literal + let input_double_quoted = r#""subtest with double quotes""#; + let result = extract_subtest_name(input_double_quoted); + assert_eq!(result, Some(r#"subtest_with_double_quotes"#.to_string())); + + let input_double_quoted_with_backticks = r#""test with `backticks` inside""#; + let result = extract_subtest_name(input_double_quoted_with_backticks); + assert_eq!(result, Some(r#"test_with_`backticks`_inside"#.to_string())); + + // Raw string literal + let input_with_backticks = r#"`subtest with backticks`"#; + let result = extract_subtest_name(input_with_backticks); + assert_eq!(result, Some(r#"subtest_with_backticks"#.to_string())); + + let input_raw_with_quotes = r#"`test with "quotes" and other chars`"#; + let result = extract_subtest_name(input_raw_with_quotes); + assert_eq!( + result, + Some(r#"test_with_\"quotes\"_and_other_chars"#.to_string()) + ); + + let input_multiline = r#"`subtest with + multiline + backticks`"#; + let result = extract_subtest_name(input_multiline); + assert_eq!( + result, + Some(r#"subtest_with_________multiline_________backticks"#.to_string()) + ); + + let input_with_double_quotes = r#"`test with "double quotes"`"#; + let result = extract_subtest_name(input_with_double_quotes); + assert_eq!(result, Some(r#"test_with_\"double_quotes\""#.to_string())); + } } diff --git a/crates/languages/src/go/config.toml b/crates/languages/src/go/config.toml index 84e35d8f0f7e268c32b9838fd0f6b2907aff909d..0a5122c038e1e38e0c963c3d22581f794656c276 100644 --- a/crates/languages/src/go/config.toml +++ b/crates/languages/src/go/config.toml @@ -15,4 +15,4 @@ brackets = [ tab_size = 4 hard_tabs = true debuggers = ["Delve"] -documentation = { start = "/*", end = "*/", prefix = "* ", tab_size = 1 } +documentation_comment = { start = "/*", prefix = "* ", end = "*/", tab_size = 1 } diff --git a/crates/languages/src/go/runnables.scm b/crates/languages/src/go/runnables.scm index 8d5f4375c137e8e804f4af8647eccddef09bdb42..6418cd04d8d69c2bb97434053c93f591aed55c68 100644 --- a/crates/languages/src/go/runnables.scm +++ b/crates/languages/src/go/runnables.scm @@ -1,9 +1,21 @@ ; Functions names start with `Test` ( - ( + [ (function_declaration name: (_) @run (#match? @run "^Test.*")) - ) @_ + (method_declaration + receiver: (parameter_list + (parameter_declaration + name: (identifier) @_receiver_name + type: [ + (pointer_type (type_identifier) @_receiver_type) + (type_identifier) @_receiver_type + ] + ) + ) + name: (field_identifier) @run @_method_name + (#match? @_method_name "^Test.*")) + ] @_ (#set! tag go-test) ) @@ -26,7 +38,10 @@ arguments: ( argument_list . - (interpreted_string_literal) @_subtest_name + [ + (interpreted_string_literal) + (raw_string_literal) + ] @_subtest_name . (func_literal parameters: ( @@ -54,7 +69,7 @@ ( ( (function_declaration name: (_) @run @_name - (#match? @_name "^Benchmark.+")) + (#match? @_name "^Benchmark.*")) ) @_ (#set! tag go-benchmark) ) diff --git a/crates/languages/src/javascript/config.toml b/crates/languages/src/javascript/config.toml index ac87a9befd7af1abcd8153cda07ce10b577cceb8..0df57d985e82595bdabb97517f56e79591343e7b 100644 --- a/crates/languages/src/javascript/config.toml +++ b/crates/languages/src/javascript/config.toml @@ -4,7 +4,8 @@ path_suffixes = ["js", "jsx", "mjs", "cjs"] # [/ ] is so we match "env node" or "/node" but not "ts-node" first_line_pattern = '^#!.*\b(?:[/ ]node|deno run.*--ext[= ]js)\b' line_comments = ["// "] -block_comment = ["/*", "*/"] +block_comment = { start = "/*", prefix = "* ", end = "*/", tab_size = 1 } +documentation_comment = { start = "/**", prefix = "* ", end = "*/", tab_size = 1 } autoclose_before = ";:.,=}])>" brackets = [ { start = "{", end = "}", close = true, newline = true }, @@ -21,7 +22,6 @@ tab_size = 2 scope_opt_in_language_servers = ["tailwindcss-language-server", "emmet-language-server"] prettier_parser_name = "babel" debuggers = ["JavaScript"] -documentation = { start = "/**", end = "*/", prefix = "* ", tab_size = 1 } [jsx_tag_auto_close] open_tag_node_name = "jsx_opening_element" @@ -31,7 +31,7 @@ tag_name_node_name = "identifier" [overrides.element] line_comments = { remove = true } -block_comment = ["{/* ", " */}"] +block_comment = { start = "{/* ", prefix = "", end = "*/}", tab_size = 0 } opt_into_language_servers = ["emmet-language-server"] [overrides.string] diff --git a/crates/languages/src/javascript/outline.scm b/crates/languages/src/javascript/outline.scm index 99aa4bdfd5ad530505ebb90dc075e5ca405a5451..026c71e1f91d323ff2370828f330e4a4944e74db 100644 --- a/crates/languages/src/javascript/outline.scm +++ b/crates/languages/src/javascript/outline.scm @@ -14,6 +14,15 @@ "(" @context ")" @context)) @item +(generator_function_declaration + "async"? @context + "function" @context + "*" @context + name: (_) @name + parameters: (formal_parameters + "(" @context + ")" @context)) @item + (interface_declaration "interface" @context name: (_) @name) @item diff --git a/crates/languages/src/json.rs b/crates/languages/src/json.rs index 7a3300eb010d9da30111023e660ef56a2070ea9e..ca82bb2431f5408e948005ffc9fb705809a93087 100644 --- a/crates/languages/src/json.rs +++ b/crates/languages/src/json.rs @@ -8,8 +8,8 @@ use futures::StreamExt; use gpui::{App, AsyncApp, Task}; use http_client::github::{GitHubLspBinaryVersion, latest_github_release}; use language::{ - ContextProvider, LanguageRegistry, LanguageToolchainStore, LocalFile as _, LspAdapter, - LspAdapterDelegate, + ContextProvider, LanguageName, LanguageRegistry, LanguageToolchainStore, LocalFile as _, + LspAdapter, LspAdapterDelegate, }; use lsp::{LanguageServerBinary, LanguageServerName}; use node_runtime::NodeRuntime; @@ -231,6 +231,13 @@ impl JsonLspAdapter { )) } + schemas + .as_array_mut() + .unwrap() + .extend(cx.all_action_names().into_iter().map(|&name| { + project::lsp_store::json_language_server_ext::url_schema_for_action(name) + })); + // This can be viewed via `dev: open language server logs` -> `json-language-server` -> // `Server Info` serde_json::json!({ @@ -262,7 +269,15 @@ impl JsonLspAdapter { .await; let config = cx.update(|cx| { - Self::get_workspace_config(self.languages.language_names().clone(), adapter_schemas, cx) + Self::get_workspace_config( + self.languages + .language_names() + .into_iter() + .map(|name| name.to_string()) + .collect(), + adapter_schemas, + cx, + ) })?; writer.replace(config.clone()); return Ok(config); @@ -401,10 +416,10 @@ impl LspAdapter for JsonLspAdapter { Ok(config) } - fn language_ids(&self) -> HashMap<String, String> { + fn language_ids(&self) -> HashMap<LanguageName, String> { [ - ("JSON".into(), "json".into()), - ("JSONC".into(), "jsonc".into()), + (LanguageName::new("JSON"), "json".into()), + (LanguageName::new("JSONC"), "jsonc".into()), ] .into_iter() .collect() @@ -502,6 +517,7 @@ impl LspAdapter for NodeVersionAdapter { Ok(Box::new(GitHubLspBinaryVersion { name: release.tag_name, url: asset.browser_download_url.clone(), + digest: asset.digest.clone(), })) } diff --git a/crates/languages/src/lib.rs b/crates/languages/src/lib.rs index 3db015a24182ff9f210958558c868da9e7168be6..195ba79e1d0e96acea7ac1a53590c1a947334069 100644 --- a/crates/languages/src/lib.rs +++ b/crates/languages/src/lib.rs @@ -1,4 +1,5 @@ use anyhow::Context as _; +use feature_flags::{FeatureFlag, FeatureFlagAppExt as _}; use gpui::{App, UpdateGlobal}; use node_runtime::NodeRuntime; use python::PyprojectTomlManifestProvider; @@ -11,11 +12,12 @@ use util::{ResultExt, asset_str}; pub use language::*; -use crate::json::JsonTaskProvider; +use crate::{json::JsonTaskProvider, python::BasedPyrightLspAdapter}; mod bash; mod c; mod css; +mod github_download; mod go; mod json; mod package_json; @@ -52,6 +54,12 @@ pub static LANGUAGE_GIT_COMMIT: std::sync::LazyLock<Arc<Language>> = )) }); +struct BasedPyrightFeatureFlag; + +impl FeatureFlag for BasedPyrightFeatureFlag { + const NAME: &'static str = "basedpyright"; +} + pub fn init(languages: Arc<LanguageRegistry>, node: NodeRuntime, cx: &mut App) { #[cfg(feature = "load-grammars")] languages.register_native_grammars([ @@ -88,6 +96,7 @@ pub fn init(languages: Arc<LanguageRegistry>, node: NodeRuntime, cx: &mut App) { let py_lsp_adapter = Arc::new(python::PyLspAdapter::new()); let python_context_provider = Arc::new(python::PythonContextProvider); let python_lsp_adapter = Arc::new(python::PythonLspAdapter::new(node.clone())); + let basedpyright_lsp_adapter = Arc::new(BasedPyrightLspAdapter::new()); let python_toolchain_provider = Arc::new(python::PythonToolchainProvider::default()); let rust_context_provider = Arc::new(rust::RustContextProvider); let rust_lsp_adapter = Arc::new(rust::RustLspAdapter); @@ -212,6 +221,10 @@ pub fn init(languages: Arc<LanguageRegistry>, node: NodeRuntime, cx: &mut App) { name: "gitcommit", ..Default::default() }, + LanguageInfo { + name: "zed-keybind-context", + ..Default::default() + }, ]; for registration in built_in_languages { @@ -224,6 +237,20 @@ pub fn init(languages: Arc<LanguageRegistry>, node: NodeRuntime, cx: &mut App) { ); } + let mut basedpyright_lsp_adapter = Some(basedpyright_lsp_adapter); + cx.observe_flag::<BasedPyrightFeatureFlag, _>({ + let languages = languages.clone(); + move |enabled, _| { + if enabled { + if let Some(adapter) = basedpyright_lsp_adapter.take() { + languages + .register_available_lsp_adapter(adapter.name(), move || adapter.clone()); + } + } + } + }) + .detach(); + // Register globally available language servers. // // This will allow users to add support for a built-in language server (e.g., Tailwind) @@ -269,6 +296,7 @@ pub fn init(languages: Arc<LanguageRegistry>, node: NodeRuntime, cx: &mut App) { "Astro", "CSS", "ERB", + "HTML/ERB", "HEEX", "HTML", "JavaScript", diff --git a/crates/languages/src/markdown/config.toml b/crates/languages/src/markdown/config.toml index 059e52de9444b10cb8d6b089a2bdf8ec6d49485d..926dcd70d9f9207c03154690e7d4e9866f9aacea 100644 --- a/crates/languages/src/markdown/config.toml +++ b/crates/languages/src/markdown/config.toml @@ -2,7 +2,7 @@ name = "Markdown" grammar = "markdown" path_suffixes = ["md", "mdx", "mdwn", "markdown", "MD"] completion_query_characters = ["-"] -block_comment = ["<!-- ", " -->"] +block_comment = { start = "<!--", prefix = "", end = "-->", tab_size = 0 } autoclose_before = ";:.,=}])>" brackets = [ { start = "{", end = "}", close = true, newline = true }, diff --git a/crates/languages/src/python.rs b/crates/languages/src/python.rs index dc6996d3999a0d6bd366f8d444b19e046e0150a8..0524c02fd5b95c4d8ccc2fbcbd2286a53a900fa2 100644 --- a/crates/languages/src/python.rs +++ b/crates/languages/src/python.rs @@ -4,13 +4,13 @@ use async_trait::async_trait; use collections::HashMap; use gpui::{App, Task}; use gpui::{AsyncApp, SharedString}; -use language::Toolchain; use language::ToolchainList; use language::ToolchainLister; use language::language_settings::language_settings; use language::{ContextLocation, LanguageToolchainStore}; use language::{ContextProvider, LspAdapter, LspAdapterDelegate}; use language::{LanguageName, ManifestName, ManifestProvider, ManifestQuery}; +use language::{Toolchain, WorkspaceFoldersContent}; use lsp::LanguageServerBinary; use lsp::LanguageServerName; use node_runtime::NodeRuntime; @@ -400,6 +400,9 @@ impl LspAdapter for PythonLspAdapter { fn manifest_name(&self) -> Option<ManifestName> { Some(SharedString::new_static("pyproject.toml").into()) } + fn workspace_folders_content(&self) -> WorkspaceFoldersContent { + WorkspaceFoldersContent::WorktreeRoot + } } async fn get_cached_server_binary( @@ -1282,6 +1285,350 @@ impl LspAdapter for PyLspAdapter { fn manifest_name(&self) -> Option<ManifestName> { Some(SharedString::new_static("pyproject.toml").into()) } + fn workspace_folders_content(&self) -> WorkspaceFoldersContent { + WorkspaceFoldersContent::WorktreeRoot + } +} + +pub(crate) struct BasedPyrightLspAdapter { + python_venv_base: OnceCell<Result<Arc<Path>, String>>, +} + +impl BasedPyrightLspAdapter { + const SERVER_NAME: LanguageServerName = LanguageServerName::new_static("basedpyright"); + const BINARY_NAME: &'static str = "basedpyright-langserver"; + + pub(crate) fn new() -> Self { + Self { + python_venv_base: OnceCell::new(), + } + } + + async fn ensure_venv(delegate: &dyn LspAdapterDelegate) -> Result<Arc<Path>> { + let python_path = Self::find_base_python(delegate) + .await + .context("Could not find Python installation for basedpyright")?; + let work_dir = delegate + .language_server_download_dir(&Self::SERVER_NAME) + .await + .context("Could not get working directory for basedpyright")?; + let mut path = PathBuf::from(work_dir.as_ref()); + path.push("basedpyright-venv"); + if !path.exists() { + util::command::new_smol_command(python_path) + .arg("-m") + .arg("venv") + .arg("basedpyright-venv") + .current_dir(work_dir) + .spawn()? + .output() + .await?; + } + + Ok(path.into()) + } + + // Find "baseline", user python version from which we'll create our own venv. + async fn find_base_python(delegate: &dyn LspAdapterDelegate) -> Option<PathBuf> { + for path in ["python3", "python"] { + if let Some(path) = delegate.which(path.as_ref()).await { + return Some(path); + } + } + None + } + + async fn base_venv(&self, delegate: &dyn LspAdapterDelegate) -> Result<Arc<Path>, String> { + self.python_venv_base + .get_or_init(move || async move { + Self::ensure_venv(delegate) + .await + .map_err(|e| format!("{e}")) + }) + .await + .clone() + } +} + +#[async_trait(?Send)] +impl LspAdapter for BasedPyrightLspAdapter { + fn name(&self) -> LanguageServerName { + Self::SERVER_NAME.clone() + } + + async fn initialization_options( + self: Arc<Self>, + _: &dyn Fs, + _: &Arc<dyn LspAdapterDelegate>, + ) -> Result<Option<Value>> { + // Provide minimal initialization options + // Virtual environment configuration will be handled through workspace configuration + Ok(Some(json!({ + "python": { + "analysis": { + "autoSearchPaths": true, + "useLibraryCodeForTypes": true, + "autoImportCompletions": true + } + } + }))) + } + + async fn check_if_user_installed( + &self, + delegate: &dyn LspAdapterDelegate, + toolchains: Arc<dyn LanguageToolchainStore>, + cx: &AsyncApp, + ) -> Option<LanguageServerBinary> { + if let Some(bin) = delegate.which(Self::BINARY_NAME.as_ref()).await { + let env = delegate.shell_env().await; + Some(LanguageServerBinary { + path: bin, + env: Some(env), + arguments: vec!["--stdio".into()], + }) + } else { + let venv = toolchains + .active_toolchain( + delegate.worktree_id(), + Arc::from("".as_ref()), + LanguageName::new("Python"), + &mut cx.clone(), + ) + .await?; + let path = Path::new(venv.path.as_ref()) + .parent()? + .join(Self::BINARY_NAME); + path.exists().then(|| LanguageServerBinary { + path, + arguments: vec!["--stdio".into()], + env: None, + }) + } + } + + async fn fetch_latest_server_version( + &self, + _: &dyn LspAdapterDelegate, + ) -> Result<Box<dyn 'static + Any + Send>> { + Ok(Box::new(()) as Box<_>) + } + + async fn fetch_server_binary( + &self, + _latest_version: Box<dyn 'static + Send + Any>, + _container_dir: PathBuf, + delegate: &dyn LspAdapterDelegate, + ) -> Result<LanguageServerBinary> { + let venv = self.base_venv(delegate).await.map_err(|e| anyhow!(e))?; + let pip_path = venv.join(BINARY_DIR).join("pip3"); + ensure!( + util::command::new_smol_command(pip_path.as_path()) + .arg("install") + .arg("basedpyright") + .arg("-U") + .output() + .await? + .status + .success(), + "basedpyright installation failed" + ); + let pylsp = venv.join(BINARY_DIR).join(Self::BINARY_NAME); + Ok(LanguageServerBinary { + path: pylsp, + env: None, + arguments: vec!["--stdio".into()], + }) + } + + async fn cached_server_binary( + &self, + _container_dir: PathBuf, + delegate: &dyn LspAdapterDelegate, + ) -> Option<LanguageServerBinary> { + let venv = self.base_venv(delegate).await.ok()?; + let pylsp = venv.join(BINARY_DIR).join(Self::BINARY_NAME); + Some(LanguageServerBinary { + path: pylsp, + env: None, + arguments: vec!["--stdio".into()], + }) + } + + async fn process_completions(&self, items: &mut [lsp::CompletionItem]) { + // Pyright assigns each completion item a `sortText` of the form `XX.YYYY.name`. + // Where `XX` is the sorting category, `YYYY` is based on most recent usage, + // and `name` is the symbol name itself. + // + // Because the symbol name is included, there generally are not ties when + // sorting by the `sortText`, so the symbol's fuzzy match score is not taken + // into account. Here, we remove the symbol name from the sortText in order + // to allow our own fuzzy score to be used to break ties. + // + // see https://github.com/microsoft/pyright/blob/95ef4e103b9b2f129c9320427e51b73ea7cf78bd/packages/pyright-internal/src/languageService/completionProvider.ts#LL2873 + for item in items { + let Some(sort_text) = &mut item.sort_text else { + continue; + }; + let mut parts = sort_text.split('.'); + let Some(first) = parts.next() else { continue }; + let Some(second) = parts.next() else { continue }; + let Some(_) = parts.next() else { continue }; + sort_text.replace_range(first.len() + second.len() + 1.., ""); + } + } + + async fn label_for_completion( + &self, + item: &lsp::CompletionItem, + language: &Arc<language::Language>, + ) -> Option<language::CodeLabel> { + let label = &item.label; + let grammar = language.grammar()?; + let highlight_id = match item.kind? { + lsp::CompletionItemKind::METHOD => grammar.highlight_id_for_name("function.method")?, + lsp::CompletionItemKind::FUNCTION => grammar.highlight_id_for_name("function")?, + lsp::CompletionItemKind::CLASS => grammar.highlight_id_for_name("type")?, + lsp::CompletionItemKind::CONSTANT => grammar.highlight_id_for_name("constant")?, + _ => return None, + }; + let filter_range = item + .filter_text + .as_deref() + .and_then(|filter| label.find(filter).map(|ix| ix..ix + filter.len())) + .unwrap_or(0..label.len()); + Some(language::CodeLabel { + text: label.clone(), + runs: vec![(0..label.len(), highlight_id)], + filter_range, + }) + } + + async fn label_for_symbol( + &self, + name: &str, + kind: lsp::SymbolKind, + language: &Arc<language::Language>, + ) -> Option<language::CodeLabel> { + let (text, filter_range, display_range) = match kind { + lsp::SymbolKind::METHOD | lsp::SymbolKind::FUNCTION => { + let text = format!("def {}():\n", name); + let filter_range = 4..4 + name.len(); + let display_range = 0..filter_range.end; + (text, filter_range, display_range) + } + lsp::SymbolKind::CLASS => { + let text = format!("class {}:", name); + let filter_range = 6..6 + name.len(); + let display_range = 0..filter_range.end; + (text, filter_range, display_range) + } + lsp::SymbolKind::CONSTANT => { + let text = format!("{} = 0", name); + let filter_range = 0..name.len(); + let display_range = 0..filter_range.end; + (text, filter_range, display_range) + } + _ => return None, + }; + + Some(language::CodeLabel { + runs: language.highlight_text(&text.as_str().into(), display_range.clone()), + text: text[display_range].to_string(), + filter_range, + }) + } + + async fn workspace_configuration( + self: Arc<Self>, + _: &dyn Fs, + adapter: &Arc<dyn LspAdapterDelegate>, + toolchains: Arc<dyn LanguageToolchainStore>, + cx: &mut AsyncApp, + ) -> Result<Value> { + let toolchain = toolchains + .active_toolchain( + adapter.worktree_id(), + Arc::from("".as_ref()), + LanguageName::new("Python"), + cx, + ) + .await; + cx.update(move |cx| { + let mut user_settings = + language_server_settings(adapter.as_ref(), &Self::SERVER_NAME, cx) + .and_then(|s| s.settings.clone()) + .unwrap_or_default(); + + // If we have a detected toolchain, configure Pyright to use it + if let Some(toolchain) = toolchain { + if user_settings.is_null() { + user_settings = Value::Object(serde_json::Map::default()); + } + let object = user_settings.as_object_mut().unwrap(); + + let interpreter_path = toolchain.path.to_string(); + + // Detect if this is a virtual environment + if let Some(interpreter_dir) = Path::new(&interpreter_path).parent() { + if let Some(venv_dir) = interpreter_dir.parent() { + // Check if this looks like a virtual environment + if venv_dir.join("pyvenv.cfg").exists() + || venv_dir.join("bin/activate").exists() + || venv_dir.join("Scripts/activate.bat").exists() + { + // Set venvPath and venv at the root level + // This matches the format of a pyrightconfig.json file + if let Some(parent) = venv_dir.parent() { + // Use relative path if the venv is inside the workspace + let venv_path = if parent == adapter.worktree_root_path() { + ".".to_string() + } else { + parent.to_string_lossy().into_owned() + }; + object.insert("venvPath".to_string(), Value::String(venv_path)); + } + + if let Some(venv_name) = venv_dir.file_name() { + object.insert( + "venv".to_owned(), + Value::String(venv_name.to_string_lossy().into_owned()), + ); + } + } + } + } + + // Always set the python interpreter path + // Get or create the python section + let python = object + .entry("python") + .or_insert(Value::Object(serde_json::Map::default())) + .as_object_mut() + .unwrap(); + + // Set both pythonPath and defaultInterpreterPath for compatibility + python.insert( + "pythonPath".to_owned(), + Value::String(interpreter_path.clone()), + ); + python.insert( + "defaultInterpreterPath".to_owned(), + Value::String(interpreter_path), + ); + } + + user_settings + }) + } + + fn manifest_name(&self) -> Option<ManifestName> { + Some(SharedString::new_static("pyproject.toml").into()) + } + + fn workspace_folders_content(&self) -> WorkspaceFoldersContent { + WorkspaceFoldersContent::WorktreeRoot + } } #[cfg(test)] diff --git a/crates/languages/src/rust.rs b/crates/languages/src/rust.rs index 3f83c9c000e40436f1215cebae02a03ffab1c0c1..b6567c6e33ff93a8bc00fdbb8ed1674cc31ddaa0 100644 --- a/crates/languages/src/rust.rs +++ b/crates/languages/src/rust.rs @@ -1,8 +1,7 @@ use anyhow::{Context as _, Result}; -use async_compression::futures::bufread::GzipDecoder; use async_trait::async_trait; use collections::HashMap; -use futures::{StreamExt, io::BufReader}; +use futures::StreamExt; use gpui::{App, AppContext, AsyncApp, SharedString, Task}; use http_client::github::AssetKind; use http_client::github::{GitHubLspBinaryVersion, latest_github_release}; @@ -23,14 +22,11 @@ use std::{ sync::{Arc, LazyLock}, }; use task::{TaskTemplate, TaskTemplates, TaskVariables, VariableName}; -use util::archive::extract_zip; +use util::fs::make_file_executable; use util::merge_json_value_into; -use util::{ - ResultExt, - fs::{make_file_executable, remove_matching}, - maybe, -}; +use util::{ResultExt, maybe}; +use crate::github_download::{GithubBinaryMetadata, download_server_binary}; use crate::language_settings::language_settings; pub struct RustLspAdapter; @@ -163,7 +159,6 @@ impl LspAdapter for RustLspAdapter { ) .await?; let asset_name = Self::build_asset_name(); - let asset = release .assets .iter() @@ -172,6 +167,7 @@ impl LspAdapter for RustLspAdapter { Ok(Box::new(GitHubLspBinaryVersion { name: release.tag_name, url: asset.browser_download_url.clone(), + digest: asset.digest.clone(), })) } @@ -181,58 +177,76 @@ impl LspAdapter for RustLspAdapter { container_dir: PathBuf, delegate: &dyn LspAdapterDelegate, ) -> Result<LanguageServerBinary> { - let version = version.downcast::<GitHubLspBinaryVersion>().unwrap(); - let destination_path = container_dir.join(format!("rust-analyzer-{}", version.name)); + let GitHubLspBinaryVersion { name, url, digest } = + &*version.downcast::<GitHubLspBinaryVersion>().unwrap(); + let expected_digest = digest + .as_ref() + .and_then(|digest| digest.strip_prefix("sha256:")); + let destination_path = container_dir.join(format!("rust-analyzer-{name}")); let server_path = match Self::GITHUB_ASSET_KIND { AssetKind::TarGz | AssetKind::Gz => destination_path.clone(), // Tar and gzip extract in place. AssetKind::Zip => destination_path.clone().join("rust-analyzer.exe"), // zip contains a .exe }; - if fs::metadata(&server_path).await.is_err() { - remove_matching(&container_dir, |entry| entry != destination_path).await; - - let mut response = delegate - .http_client() - .get(&version.url, Default::default(), true) - .await - .with_context(|| format!("downloading release from {}", version.url))?; - match Self::GITHUB_ASSET_KIND { - AssetKind::TarGz => { - let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut())); - let archive = async_tar::Archive::new(decompressed_bytes); - archive.unpack(&destination_path).await.with_context(|| { - format!("extracting {} to {:?}", version.url, destination_path) - })?; - } - AssetKind::Gz => { - let mut decompressed_bytes = - GzipDecoder::new(BufReader::new(response.body_mut())); - let mut file = - fs::File::create(&destination_path).await.with_context(|| { - format!( - "creating a file {:?} for a download from {}", - destination_path, version.url, - ) - })?; - futures::io::copy(&mut decompressed_bytes, &mut file) - .await - .with_context(|| { - format!("extracting {} to {:?}", version.url, destination_path) - })?; - } - AssetKind::Zip => { - extract_zip(&destination_path, response.body_mut()) - .await - .with_context(|| { - format!("unzipping {} to {:?}", version.url, destination_path) - })?; - } - }; + let binary = LanguageServerBinary { + path: server_path.clone(), + env: None, + arguments: Default::default(), + }; - // todo("windows") - make_file_executable(&server_path).await?; + let metadata_path = destination_path.with_extension("metadata"); + let metadata = GithubBinaryMetadata::read_from_file(&metadata_path) + .await + .ok(); + if let Some(metadata) = metadata { + let validity_check = async || { + delegate + .try_exec(LanguageServerBinary { + path: server_path.clone(), + arguments: vec!["--version".into()], + env: None, + }) + .await + .inspect_err(|err| { + log::warn!("Unable to run {server_path:?} asset, redownloading: {err}",) + }) + }; + if let (Some(actual_digest), Some(expected_digest)) = + (&metadata.digest, expected_digest) + { + if actual_digest == expected_digest { + if validity_check().await.is_ok() { + return Ok(binary); + } + } else { + log::info!( + "SHA-256 mismatch for {destination_path:?} asset, downloading new asset. Expected: {expected_digest}, Got: {actual_digest}" + ); + } + } else if validity_check().await.is_ok() { + return Ok(binary); + } } + _ = fs::remove_dir_all(&destination_path).await; + download_server_binary( + delegate, + url, + expected_digest, + &destination_path, + Self::GITHUB_ASSET_KIND, + ) + .await?; + make_file_executable(&server_path).await?; + GithubBinaryMetadata::write_to_file( + &GithubBinaryMetadata { + metadata_version: 1, + digest: expected_digest.map(ToString::to_string), + }, + &metadata_path, + ) + .await?; + Ok(LanguageServerBinary { path: server_path, env: None, @@ -291,66 +305,63 @@ impl LspAdapter for RustLspAdapter { completion: &lsp::CompletionItem, language: &Arc<Language>, ) -> Option<CodeLabel> { - let detail = completion + // rust-analyzer calls these detail left and detail right in terms of where it expects things to be rendered + // this usually contains signatures of the thing to be completed + let detail_right = completion .label_details .as_ref() - .and_then(|detail| detail.detail.as_ref()) + .and_then(|detail| detail.description.as_ref()) .or(completion.detail.as_ref()) .map(|detail| detail.trim()); - let function_signature = completion + // this tends to contain alias and import information + let detail_left = completion .label_details .as_ref() - .and_then(|detail| detail.description.as_deref()) - .or(completion.detail.as_deref()); - match (detail, completion.kind) { - (Some(detail), Some(lsp::CompletionItemKind::FIELD)) => { + .and_then(|detail| detail.detail.as_deref()); + let mk_label = |text: String, runs| { + let filter_range = completion + .filter_text + .as_deref() + .and_then(|filter| { + completion + .label + .find(filter) + .map(|ix| ix..ix + filter.len()) + }) + .unwrap_or(0..completion.label.len()); + + CodeLabel { + text, + runs, + filter_range, + } + }; + let mut label = match (detail_right, completion.kind) { + (Some(signature), Some(lsp::CompletionItemKind::FIELD)) => { let name = &completion.label; - let text = format!("{name}: {detail}"); + let text = format!("{name}: {signature}"); let prefix = "struct S { "; - let source = Rope::from(format!("{prefix}{text} }}")); + let source = Rope::from_iter([prefix, &text, " }"]); let runs = language.highlight_text(&source, prefix.len()..prefix.len() + text.len()); - let filter_range = completion - .filter_text - .as_deref() - .and_then(|filter| text.find(filter).map(|ix| ix..ix + filter.len())) - .unwrap_or(0..name.len()); - return Some(CodeLabel { - text, - runs, - filter_range, - }); + mk_label(text, runs) } ( - Some(detail), + Some(signature), Some(lsp::CompletionItemKind::CONSTANT | lsp::CompletionItemKind::VARIABLE), ) if completion.insert_text_format != Some(lsp::InsertTextFormat::SNIPPET) => { let name = &completion.label; - let text = format!( - "{}: {}", - name, - completion.detail.as_deref().unwrap_or(detail) - ); + let text = format!("{name}: {signature}",); let prefix = "let "; - let source = Rope::from(format!("{prefix}{text} = ();")); + let source = Rope::from_iter([prefix, &text, " = ();"]); let runs = language.highlight_text(&source, prefix.len()..prefix.len() + text.len()); - let filter_range = completion - .filter_text - .as_deref() - .and_then(|filter| text.find(filter).map(|ix| ix..ix + filter.len())) - .unwrap_or(0..name.len()); - return Some(CodeLabel { - text, - runs, - filter_range, - }); + mk_label(text, runs) } ( - Some(detail), + function_signature, Some(lsp::CompletionItemKind::FUNCTION | lsp::CompletionItemKind::METHOD), ) => { - static REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new("\\(…?\\)").unwrap()); const FUNCTION_PREFIXES: [&str; 6] = [ "async fn", "async unsafe fn", @@ -359,34 +370,27 @@ impl LspAdapter for RustLspAdapter { "unsafe fn", "fn", ]; - // Is it function `async`? - let fn_keyword = FUNCTION_PREFIXES.iter().find_map(|prefix| { - function_signature.as_ref().and_then(|signature| { - signature - .strip_prefix(*prefix) - .map(|suffix| (*prefix, suffix)) - }) + let fn_prefixed = FUNCTION_PREFIXES.iter().find_map(|&prefix| { + function_signature? + .strip_prefix(prefix) + .map(|suffix| (prefix, suffix)) }); // fn keyword should be followed by opening parenthesis. - if let Some((prefix, suffix)) = fn_keyword { - let mut text = REGEX.replace(&completion.label, suffix).to_string(); - let source = Rope::from(format!("{prefix} {text} {{}}")); + if let Some((prefix, suffix)) = fn_prefixed { + let label = if let Some(label) = completion + .label + .strip_suffix("(…)") + .or_else(|| completion.label.strip_suffix("()")) + { + label + } else { + &completion.label + }; + let text = format!("{label}{suffix}"); + let source = Rope::from_iter([prefix, " ", &text, " {}"]); let run_start = prefix.len() + 1; let runs = language.highlight_text(&source, run_start..run_start + text.len()); - if detail.starts_with("(") { - text.push(' '); - text.push_str(&detail); - } - let filter_range = completion - .filter_text - .as_deref() - .and_then(|filter| text.find(filter).map(|ix| ix..ix + filter.len())) - .unwrap_or(0..completion.label.find('(').unwrap_or(text.len())); - return Some(CodeLabel { - filter_range, - text, - runs, - }); + mk_label(text, runs) } else if completion .detail .as_ref() @@ -396,20 +400,13 @@ impl LspAdapter for RustLspAdapter { let len = text.len(); let source = Rope::from(text.as_str()); let runs = language.highlight_text(&source, 0..len); - let filter_range = completion - .filter_text - .as_deref() - .and_then(|filter| text.find(filter).map(|ix| ix..ix + filter.len())) - .unwrap_or(0..len); - return Some(CodeLabel { - filter_range, - text, - runs, - }); + mk_label(text, runs) + } else { + mk_label(completion.label.clone(), vec![]) } } - (_, Some(kind)) => { - let highlight_name = match kind { + (_, kind) => { + let highlight_name = kind.and_then(|kind| match kind { lsp::CompletionItemKind::STRUCT | lsp::CompletionItemKind::INTERFACE | lsp::CompletionItemKind::ENUM => Some("type"), @@ -419,27 +416,32 @@ impl LspAdapter for RustLspAdapter { Some("constant") } _ => None, - }; + }); - let mut label = completion.label.clone(); - if let Some(detail) = detail.filter(|detail| detail.starts_with("(")) { - label.push(' '); - label.push_str(detail); - } - let mut label = CodeLabel::plain(label, completion.filter_text.as_deref()); + let label = completion.label.clone(); + let mut runs = vec![]; if let Some(highlight_name) = highlight_name { let highlight_id = language.grammar()?.highlight_id_for_name(highlight_name)?; - label.runs.push(( - 0..label.text.rfind('(').unwrap_or(completion.label.len()), + runs.push(( + 0..label.rfind('(').unwrap_or(completion.label.len()), highlight_id, )); } + mk_label(label, runs) + } + }; - return Some(label); + if let Some(detail_left) = detail_left { + label.text.push(' '); + if !detail_left.starts_with('(') { + label.text.push('('); + } + label.text.push_str(detail_left); + if !detail_left.ends_with(')') { + label.text.push(')'); } - _ => {} } - None + Some(label) } async fn label_for_symbol( @@ -448,55 +450,22 @@ impl LspAdapter for RustLspAdapter { kind: lsp::SymbolKind, language: &Arc<Language>, ) -> Option<CodeLabel> { - let (text, filter_range, display_range) = match kind { - lsp::SymbolKind::METHOD | lsp::SymbolKind::FUNCTION => { - let text = format!("fn {} () {{}}", name); - let filter_range = 3..3 + name.len(); - let display_range = 0..filter_range.end; - (text, filter_range, display_range) - } - lsp::SymbolKind::STRUCT => { - let text = format!("struct {} {{}}", name); - let filter_range = 7..7 + name.len(); - let display_range = 0..filter_range.end; - (text, filter_range, display_range) - } - lsp::SymbolKind::ENUM => { - let text = format!("enum {} {{}}", name); - let filter_range = 5..5 + name.len(); - let display_range = 0..filter_range.end; - (text, filter_range, display_range) - } - lsp::SymbolKind::INTERFACE => { - let text = format!("trait {} {{}}", name); - let filter_range = 6..6 + name.len(); - let display_range = 0..filter_range.end; - (text, filter_range, display_range) - } - lsp::SymbolKind::CONSTANT => { - let text = format!("const {}: () = ();", name); - let filter_range = 6..6 + name.len(); - let display_range = 0..filter_range.end; - (text, filter_range, display_range) - } - lsp::SymbolKind::MODULE => { - let text = format!("mod {} {{}}", name); - let filter_range = 4..4 + name.len(); - let display_range = 0..filter_range.end; - (text, filter_range, display_range) - } - lsp::SymbolKind::TYPE_PARAMETER => { - let text = format!("type {} {{}}", name); - let filter_range = 5..5 + name.len(); - let display_range = 0..filter_range.end; - (text, filter_range, display_range) - } + let (prefix, suffix) = match kind { + lsp::SymbolKind::METHOD | lsp::SymbolKind::FUNCTION => ("fn ", " () {}"), + lsp::SymbolKind::STRUCT => ("struct ", " {}"), + lsp::SymbolKind::ENUM => ("enum ", " {}"), + lsp::SymbolKind::INTERFACE => ("trait ", " {}"), + lsp::SymbolKind::CONSTANT => ("const ", ": () = ();"), + lsp::SymbolKind::MODULE => ("mod ", " {}"), + lsp::SymbolKind::TYPE_PARAMETER => ("type ", " {}"), _ => return None, }; + let filter_range = prefix.len()..prefix.len() + name.len(); + let display_range = 0..filter_range.end; Some(CodeLabel { - runs: language.highlight_text(&text.as_str().into(), display_range.clone()), - text: text[display_range].to_string(), + runs: language.highlight_text(&Rope::from_iter([prefix, name, suffix]), display_range), + text: format!("{prefix}{name}"), filter_range, }) } @@ -1025,7 +994,11 @@ async fn get_cached_server_binary(container_dir: PathBuf) -> Option<LanguageServ let mut last = None; let mut entries = fs::read_dir(&container_dir).await?; while let Some(entry) = entries.next().await { - last = Some(entry?.path()); + let path = entry?.path(); + if path.extension().is_some_and(|ext| ext == "metadata") { + continue; + } + last = Some(path); } anyhow::Ok(LanguageServerBinary { @@ -1151,7 +1124,7 @@ mod tests { .await, Some(CodeLabel { text: "hello(&mut Option<T>) -> Vec<T> (use crate::foo)".to_string(), - filter_range: 0..5, + filter_range: 0..10, runs: vec![ (0..5, highlight_function), (7..10, highlight_keyword), @@ -1169,7 +1142,7 @@ mod tests { kind: Some(lsp::CompletionItemKind::FUNCTION), label: "hello(…)".to_string(), label_details: Some(CompletionItemLabelDetails { - detail: Some(" (use crate::foo)".into()), + detail: Some("(use crate::foo)".into()), description: Some("async fn(&mut Option<T>) -> Vec<T>".to_string()), }), ..Default::default() @@ -1179,7 +1152,7 @@ mod tests { .await, Some(CodeLabel { text: "hello(&mut Option<T>) -> Vec<T> (use crate::foo)".to_string(), - filter_range: 0..5, + filter_range: 0..10, runs: vec![ (0..5, highlight_function), (7..10, highlight_keyword), @@ -1216,7 +1189,7 @@ mod tests { kind: Some(lsp::CompletionItemKind::FUNCTION), label: "hello(…)".to_string(), label_details: Some(CompletionItemLabelDetails { - detail: Some(" (use crate::foo)".to_string()), + detail: Some("(use crate::foo)".to_string()), description: Some("fn(&mut Option<T>) -> Vec<T>".to_string()), }), @@ -1225,6 +1198,35 @@ mod tests { &language ) .await, + Some(CodeLabel { + text: "hello(&mut Option<T>) -> Vec<T> (use crate::foo)".to_string(), + filter_range: 0..10, + runs: vec![ + (0..5, highlight_function), + (7..10, highlight_keyword), + (11..17, highlight_type), + (18..19, highlight_type), + (25..28, highlight_type), + (29..30, highlight_type), + ], + }) + ); + + assert_eq!( + adapter + .label_for_completion( + &lsp::CompletionItem { + kind: Some(lsp::CompletionItemKind::FUNCTION), + label: "hello".to_string(), + label_details: Some(CompletionItemLabelDetails { + detail: Some("(use crate::foo)".to_string()), + description: Some("fn(&mut Option<T>) -> Vec<T>".to_string()), + }), + ..Default::default() + }, + &language + ) + .await, Some(CodeLabel { text: "hello(&mut Option<T>) -> Vec<T> (use crate::foo)".to_string(), filter_range: 0..5, @@ -1256,9 +1258,14 @@ mod tests { ) .await, Some(CodeLabel { - text: "await.as_deref_mut()".to_string(), + text: "await.as_deref_mut(&mut self) -> IterMut<'_, T>".to_string(), filter_range: 6..18, - runs: vec![], + runs: vec![ + (6..18, HighlightId(2)), + (20..23, HighlightId(1)), + (33..40, HighlightId(0)), + (45..46, HighlightId(0)) + ], }) ); diff --git a/crates/languages/src/rust/config.toml b/crates/languages/src/rust/config.toml index b55b6da4abdf0cd2eb3da8d5388c172169f53ff9..fe8b4ffdcba4f8b7949b6fe9187d16c8504d6688 100644 --- a/crates/languages/src/rust/config.toml +++ b/crates/languages/src/rust/config.toml @@ -16,4 +16,4 @@ brackets = [ ] collapsed_placeholder = " /* ... */ " debuggers = ["CodeLLDB", "GDB"] -documentation = { start = "/*", end = "*/", prefix = "* ", tab_size = 1 } +documentation_comment = { start = "/*", prefix = "* ", end = "*/", tab_size = 1 } diff --git a/crates/languages/src/tailwind.rs b/crates/languages/src/tailwind.rs index 04f30b624615da9432719e5236833b5277ff1ef2..a7edbb148cd807cf404a80aa6552c211252ec25b 100644 --- a/crates/languages/src/tailwind.rs +++ b/crates/languages/src/tailwind.rs @@ -3,7 +3,7 @@ use async_trait::async_trait; use collections::HashMap; use futures::StreamExt; use gpui::AsyncApp; -use language::{LanguageToolchainStore, LspAdapter, LspAdapterDelegate}; +use language::{LanguageName, LanguageToolchainStore, LspAdapter, LspAdapterDelegate}; use lsp::{LanguageServerBinary, LanguageServerName}; use node_runtime::NodeRuntime; use project::{Fs, lsp_store::language_server_settings}; @@ -168,19 +168,20 @@ impl LspAdapter for TailwindLspAdapter { })) } - fn language_ids(&self) -> HashMap<String, String> { + fn language_ids(&self) -> HashMap<LanguageName, String> { HashMap::from_iter([ - ("Astro".to_string(), "astro".to_string()), - ("HTML".to_string(), "html".to_string()), - ("CSS".to_string(), "css".to_string()), - ("JavaScript".to_string(), "javascript".to_string()), - ("TSX".to_string(), "typescriptreact".to_string()), - ("Svelte".to_string(), "svelte".to_string()), - ("Elixir".to_string(), "phoenix-heex".to_string()), - ("HEEX".to_string(), "phoenix-heex".to_string()), - ("ERB".to_string(), "erb".to_string()), - ("PHP".to_string(), "php".to_string()), - ("Vue.js".to_string(), "vue".to_string()), + (LanguageName::new("Astro"), "astro".to_string()), + (LanguageName::new("HTML"), "html".to_string()), + (LanguageName::new("CSS"), "css".to_string()), + (LanguageName::new("JavaScript"), "javascript".to_string()), + (LanguageName::new("TSX"), "typescriptreact".to_string()), + (LanguageName::new("Svelte"), "svelte".to_string()), + (LanguageName::new("Elixir"), "phoenix-heex".to_string()), + (LanguageName::new("HEEX"), "phoenix-heex".to_string()), + (LanguageName::new("ERB"), "erb".to_string()), + (LanguageName::new("HTML/ERB"), "erb".to_string()), + (LanguageName::new("PHP"), "php".to_string()), + (LanguageName::new("Vue.js"), "vue".to_string()), ]) } } diff --git a/crates/languages/src/tsx/config.toml b/crates/languages/src/tsx/config.toml index 4176e622158089b44cc393a83d25727a2e6efd98..5849b9842fd7f3483f89bbedbdb7b74b3fc1572d 100644 --- a/crates/languages/src/tsx/config.toml +++ b/crates/languages/src/tsx/config.toml @@ -2,7 +2,8 @@ name = "TSX" grammar = "tsx" path_suffixes = ["tsx"] line_comments = ["// "] -block_comment = ["/*", "*/"] +block_comment = { start = "/*", prefix = "* ", end = "*/", tab_size = 1 } +documentation_comment = { start = "/**", prefix = "* ", end = "*/", tab_size = 1 } autoclose_before = ";:.,=}])>" brackets = [ { start = "{", end = "}", close = true, newline = true }, @@ -19,7 +20,6 @@ scope_opt_in_language_servers = ["tailwindcss-language-server", "emmet-language- prettier_parser_name = "typescript" tab_size = 2 debuggers = ["JavaScript"] -documentation = { start = "/**", end = "*/", prefix = "* ", tab_size = 1 } [jsx_tag_auto_close] open_tag_node_name = "jsx_opening_element" @@ -30,7 +30,7 @@ tag_name_node_name_alternates = ["member_expression"] [overrides.element] line_comments = { remove = true } -block_comment = ["{/* ", " */}"] +block_comment = { start = "{/*", prefix = "", end = "*/}", tab_size = 0 } opt_into_language_servers = ["emmet-language-server"] [overrides.string] diff --git a/crates/languages/src/tsx/outline.scm b/crates/languages/src/tsx/outline.scm index df6ffa5aec8aa1b23b8179d0c341231feea5c0b5..5dafe791e493d03f6a73fa7c155ebb03072dc4d5 100644 --- a/crates/languages/src/tsx/outline.scm +++ b/crates/languages/src/tsx/outline.scm @@ -18,6 +18,15 @@ "(" @context ")" @context)) @item +(generator_function_declaration + "async"? @context + "function" @context + "*" @context + name: (_) @name + parameters: (formal_parameters + "(" @context + ")" @context)) @item + (interface_declaration "interface" @context name: (_) @name) @item diff --git a/crates/languages/src/typescript.rs b/crates/languages/src/typescript.rs index 32c45dfa886358124e4331420c2bcc3c8a349514..f976b6261480a65106c510369a107c8c078e5a33 100644 --- a/crates/languages/src/typescript.rs +++ b/crates/languages/src/typescript.rs @@ -1,6 +1,4 @@ use anyhow::{Context as _, Result}; -use async_compression::futures::bufread::GzipDecoder; -use async_tar::Archive; use async_trait::async_trait; use chrono::{DateTime, Local}; use collections::HashMap; @@ -8,13 +6,14 @@ use futures::future::join_all; use gpui::{App, AppContext, AsyncApp, Task}; use http_client::github::{AssetKind, GitHubLspBinaryVersion, build_asset_url}; use language::{ - ContextLocation, ContextProvider, File, LanguageToolchainStore, LspAdapter, LspAdapterDelegate, + ContextLocation, ContextProvider, File, LanguageName, LanguageToolchainStore, LspAdapter, + LspAdapterDelegate, }; use lsp::{CodeActionKind, LanguageServerBinary, LanguageServerName}; use node_runtime::NodeRuntime; use project::{Fs, lsp_store::language_server_settings}; use serde_json::{Value, json}; -use smol::{fs, io::BufReader, lock::RwLock, stream::StreamExt}; +use smol::{fs, lock::RwLock, stream::StreamExt}; use std::{ any::Any, borrow::Cow, @@ -23,11 +22,10 @@ use std::{ sync::Arc, }; use task::{TaskTemplate, TaskTemplates, VariableName}; -use util::archive::extract_zip; use util::merge_json_value_into; use util::{ResultExt, fs::remove_matching, maybe}; -use crate::{PackageJson, PackageJsonData}; +use crate::{PackageJson, PackageJsonData, github_download::download_server_binary}; #[derive(Debug)] pub(crate) struct TypeScriptContextProvider { @@ -512,7 +510,7 @@ fn eslint_server_binary_arguments(server_path: &Path) -> Vec<OsString> { fn replace_test_name_parameters(test_name: &str) -> String { let pattern = regex::Regex::new(r"(%|\$)[0-9a-zA-Z]+").unwrap(); - pattern.replace_all(test_name, "(.+?)").to_string() + regex::escape(&pattern.replace_all(test_name, "(.+?)")) } pub struct TypeScriptLspAdapter { @@ -741,11 +739,11 @@ impl LspAdapter for TypeScriptLspAdapter { })) } - fn language_ids(&self) -> HashMap<String, String> { + fn language_ids(&self) -> HashMap<LanguageName, String> { HashMap::from_iter([ - ("TypeScript".into(), "typescript".into()), - ("JavaScript".into(), "javascript".into()), - ("TSX".into(), "typescriptreact".into()), + (LanguageName::new("TypeScript"), "typescript".into()), + (LanguageName::new("JavaScript"), "javascript".into()), + (LanguageName::new("TSX"), "typescriptreact".into()), ]) } } @@ -863,7 +861,7 @@ impl LspAdapter for EsLintLspAdapter { }, "experimental": { "useFlatConfig": use_flat_config, - }, + } }); let override_options = cx.update(|cx| { @@ -896,6 +894,7 @@ impl LspAdapter for EsLintLspAdapter { Ok(Box::new(GitHubLspBinaryVersion { name: Self::CURRENT_VERSION.into(), + digest: None, url, })) } @@ -913,43 +912,14 @@ impl LspAdapter for EsLintLspAdapter { if fs::metadata(&server_path).await.is_err() { remove_matching(&container_dir, |entry| entry != destination_path).await; - let mut response = delegate - .http_client() - .get(&version.url, Default::default(), true) - .await - .context("downloading release")?; - match Self::GITHUB_ASSET_KIND { - AssetKind::TarGz => { - let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut())); - let archive = Archive::new(decompressed_bytes); - archive.unpack(&destination_path).await.with_context(|| { - format!("extracting {} to {:?}", version.url, destination_path) - })?; - } - AssetKind::Gz => { - let mut decompressed_bytes = - GzipDecoder::new(BufReader::new(response.body_mut())); - let mut file = - fs::File::create(&destination_path).await.with_context(|| { - format!( - "creating a file {:?} for a download from {}", - destination_path, version.url, - ) - })?; - futures::io::copy(&mut decompressed_bytes, &mut file) - .await - .with_context(|| { - format!("extracting {} to {:?}", version.url, destination_path) - })?; - } - AssetKind::Zip => { - extract_zip(&destination_path, response.body_mut()) - .await - .with_context(|| { - format!("unzipping {} to {:?}", version.url, destination_path) - })?; - } - } + download_server_binary( + delegate, + &version.url, + None, + &destination_path, + Self::GITHUB_ASSET_KIND, + ) + .await?; let mut dir = fs::read_dir(&destination_path).await?; let first = dir.next().await.context("missing first file")??; @@ -1075,6 +1045,62 @@ mod tests { ); } + #[gpui::test] + async fn test_generator_function_outline(cx: &mut TestAppContext) { + let language = crate::language("javascript", tree_sitter_typescript::LANGUAGE_TSX.into()); + + let text = r#" + function normalFunction() { + console.log("normal"); + } + + function* simpleGenerator() { + yield 1; + yield 2; + } + + async function* asyncGenerator() { + yield await Promise.resolve(1); + } + + function* generatorWithParams(start, end) { + for (let i = start; i <= end; i++) { + yield i; + } + } + + class TestClass { + *methodGenerator() { + yield "method"; + } + + async *asyncMethodGenerator() { + yield "async method"; + } + } + "# + .unindent(); + + let buffer = cx.new(|cx| language::Buffer::local(text, cx).with_language(language, cx)); + let outline = buffer.read_with(cx, |buffer, _| buffer.snapshot().outline(None).unwrap()); + assert_eq!( + outline + .items + .iter() + .map(|item| (item.text.as_str(), item.depth)) + .collect::<Vec<_>>(), + &[ + ("function normalFunction()", 0), + ("function* simpleGenerator()", 0), + ("async function* asyncGenerator()", 0), + ("function* generatorWithParams( )", 0), + ("class TestClass", 0), + ("*methodGenerator()", 1), + ("async *asyncMethodGenerator()", 1), + ] + ); + } + #[gpui::test] async fn test_package_json_discovery(executor: BackgroundExecutor, cx: &mut TestAppContext) { cx.update(|cx| { diff --git a/crates/languages/src/typescript/config.toml b/crates/languages/src/typescript/config.toml index db0f32aa0d767ef2735189df0e520dc566e2c5c6..d7e3e4bd3d1569f96636b7f7572deea306b46df7 100644 --- a/crates/languages/src/typescript/config.toml +++ b/crates/languages/src/typescript/config.toml @@ -3,7 +3,8 @@ grammar = "typescript" path_suffixes = ["ts", "cts", "mts"] first_line_pattern = '^#!.*\b(?:deno run|ts-node|bun|tsx|[/ ]node)\b' line_comments = ["// "] -block_comment = ["/*", "*/"] +block_comment = { start = "/*", prefix = "* ", end = "*/", tab_size = 1 } +documentation_comment = { start = "/**", prefix = "* ", end = "*/", tab_size = 1 } autoclose_before = ";:.,=}])>" brackets = [ { start = "{", end = "}", close = true, newline = true }, @@ -19,7 +20,6 @@ word_characters = ["#", "$"] prettier_parser_name = "typescript" tab_size = 2 debuggers = ["JavaScript"] -documentation = { start = "/**", end = "*/", prefix = "* ", tab_size = 1 } [overrides.string] completion_query_characters = ["."] diff --git a/crates/languages/src/typescript/outline.scm b/crates/languages/src/typescript/outline.scm index df6ffa5aec8aa1b23b8179d0c341231feea5c0b5..5dafe791e493d03f6a73fa7c155ebb03072dc4d5 100644 --- a/crates/languages/src/typescript/outline.scm +++ b/crates/languages/src/typescript/outline.scm @@ -18,6 +18,15 @@ "(" @context ")" @context)) @item +(generator_function_declaration + "async"? @context + "function" @context + "*" @context + name: (_) @name + parameters: (formal_parameters + "(" @context + ")" @context)) @item + (interface_declaration "interface" @context name: (_) @name) @item diff --git a/crates/languages/src/typescript/runnables.scm b/crates/languages/src/typescript/runnables.scm index 85702cf99d9968b29f9375bfd8215ecba53f2eb5..6bfc53632910ce8212f739d310e3d560d05cffc1 100644 --- a/crates/languages/src/typescript/runnables.scm +++ b/crates/languages/src/typescript/runnables.scm @@ -1,4 +1,4 @@ -; Add support for (node:test, bun:test and Jest) runnable +; Add support for (node:test, bun:test, Jest and Deno.test) runnable ; Function expression that has `it`, `test` or `describe` as the function name ( (call_expression @@ -44,3 +44,42 @@ (#set! tag js-test) ) + +; Add support for Deno.test with string names +( + (call_expression + function: (member_expression + object: (identifier) @_namespace + property: (property_identifier) @_method + ) + (#eq? @_namespace "Deno") + (#eq? @_method "test") + arguments: ( + arguments . [ + (string (string_fragment) @run @DENO_TEST_NAME) + (identifier) @run @DENO_TEST_NAME + ] + ) + ) @_js-test + + (#set! tag js-test) +) + +; Add support for Deno.test with named function expressions +( + (call_expression + function: (member_expression + object: (identifier) @_namespace + property: (property_identifier) @_method + ) + (#eq? @_namespace "Deno") + (#eq? @_method "test") + arguments: ( + arguments . (function_expression + name: (identifier) @run @DENO_TEST_NAME + ) + ) + ) @_js-test + + (#set! tag js-test) +) diff --git a/crates/languages/src/vtsls.rs b/crates/languages/src/vtsls.rs index ca07673d5f460c38163fb81757179796ebff3a7a..33751f733e5b3f81e7fb145de8af6633673a4e0f 100644 --- a/crates/languages/src/vtsls.rs +++ b/crates/languages/src/vtsls.rs @@ -2,7 +2,7 @@ use anyhow::Result; use async_trait::async_trait; use collections::HashMap; use gpui::AsyncApp; -use language::{LanguageToolchainStore, LspAdapter, LspAdapterDelegate}; +use language::{LanguageName, LanguageToolchainStore, LspAdapter, LspAdapterDelegate}; use lsp::{CodeActionKind, LanguageServerBinary, LanguageServerName}; use node_runtime::NodeRuntime; use project::{Fs, lsp_store::language_server_settings}; @@ -273,11 +273,11 @@ impl LspAdapter for VtslsLspAdapter { Ok(default_workspace_configuration) } - fn language_ids(&self) -> HashMap<String, String> { + fn language_ids(&self) -> HashMap<LanguageName, String> { HashMap::from_iter([ - ("TypeScript".into(), "typescript".into()), - ("JavaScript".into(), "javascript".into()), - ("TSX".into(), "typescriptreact".into()), + (LanguageName::new("TypeScript"), "typescript".into()), + (LanguageName::new("JavaScript"), "javascript".into()), + (LanguageName::new("TSX"), "typescriptreact".into()), ]) } } diff --git a/crates/languages/src/yaml/config.toml b/crates/languages/src/yaml/config.toml index 4dfb890c5481c3814722b9d143c17d7d8399b478..e54bceda1ae01eff4a5c917a5b8fc74282165bea 100644 --- a/crates/languages/src/yaml/config.toml +++ b/crates/languages/src/yaml/config.toml @@ -1,6 +1,6 @@ name = "YAML" grammar = "yaml" -path_suffixes = ["yml", "yaml"] +path_suffixes = ["yml", "yaml", "pixi.lock"] line_comments = ["# "] autoclose_before = ",]}" brackets = [ diff --git a/crates/languages/src/yaml/outline.scm b/crates/languages/src/yaml/outline.scm index 7ab007835f3ee181cb792ca4f2d2f8e6a92f5223..c5a7f8e5d40388c020ec9dab83d6cee02746b581 100644 --- a/crates/languages/src/yaml/outline.scm +++ b/crates/languages/src/yaml/outline.scm @@ -1 +1,9 @@ -(block_mapping_pair key: (flow_node (plain_scalar (string_scalar) @name))) @item +(block_mapping_pair + key: + (flow_node + (plain_scalar + (string_scalar) @name)) + value: + (flow_node + (plain_scalar + (string_scalar) @context))?) @item diff --git a/crates/languages/src/zed-keybind-context/brackets.scm b/crates/languages/src/zed-keybind-context/brackets.scm new file mode 100644 index 0000000000000000000000000000000000000000..d086b2e98df0837208a13f6c6f79db84c204fb99 --- /dev/null +++ b/crates/languages/src/zed-keybind-context/brackets.scm @@ -0,0 +1 @@ +("(" @open ")" @close) diff --git a/crates/languages/src/zed-keybind-context/config.toml b/crates/languages/src/zed-keybind-context/config.toml new file mode 100644 index 0000000000000000000000000000000000000000..a999c70f6679843d07521c75a6a14bef26af67bb --- /dev/null +++ b/crates/languages/src/zed-keybind-context/config.toml @@ -0,0 +1,6 @@ +name = "Zed Keybind Context" +grammar = "rust" +autoclose_before = ")" +brackets = [ + { start = "(", end = ")", close = true, newline = false }, +] diff --git a/crates/languages/src/zed-keybind-context/highlights.scm b/crates/languages/src/zed-keybind-context/highlights.scm new file mode 100644 index 0000000000000000000000000000000000000000..9c5ec58eaeb7084bf79f31b280197b57bfe64b54 --- /dev/null +++ b/crates/languages/src/zed-keybind-context/highlights.scm @@ -0,0 +1,23 @@ +(identifier) @variable + +[ + "(" + ")" +] @punctuation.bracket + +[ + (integer_literal) + (float_literal) +] @number + +(boolean_literal) @boolean + +[ + "!=" + "==" + "=>" + ">" + "&&" + "||" + "!" +] @operator diff --git a/crates/livekit_client/Cargo.toml b/crates/livekit_client/Cargo.toml index a0c11d46e6a9317c88fde9bda081e64cf09bff27..821fd5d39006b517d264687d7fb9a25fb570d0c2 100644 --- a/crates/livekit_client/Cargo.toml +++ b/crates/livekit_client/Cargo.toml @@ -40,8 +40,8 @@ util.workspace = true workspace-hack.workspace = true [target.'cfg(not(any(all(target_os = "windows", target_env = "gnu"), target_os = "freebsd")))'.dependencies] -libwebrtc = { rev = "d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4", git = "https://github.com/zed-industries/livekit-rust-sdks" } -livekit = { rev = "d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4", git = "https://github.com/zed-industries/livekit-rust-sdks", features = [ +libwebrtc = { rev = "5f04705ac3f356350ae31534ffbc476abc9ea83d", git = "https://github.com/zed-industries/livekit-rust-sdks" } +livekit = { rev = "5f04705ac3f356350ae31534ffbc476abc9ea83d", git = "https://github.com/zed-industries/livekit-rust-sdks", features = [ "__rustls-tls" ] } diff --git a/crates/livekit_client/src/lib.rs b/crates/livekit_client/src/lib.rs index f94181b8f8b143c68d2269260d8e01dfbaaaf946..149859fdc8ecd8533332c9462a090adb5496f100 100644 --- a/crates/livekit_client/src/lib.rs +++ b/crates/livekit_client/src/lib.rs @@ -3,16 +3,41 @@ use collections::HashMap; mod remote_video_track_view; pub use remote_video_track_view::{RemoteVideoTrackView, RemoteVideoTrackViewEvent}; -#[cfg(not(any(test, feature = "test-support", target_os = "freebsd")))] +#[cfg(not(any( + test, + feature = "test-support", + all(target_os = "windows", target_env = "gnu"), + target_os = "freebsd" +)))] mod livekit_client; -#[cfg(not(any(test, feature = "test-support", target_os = "freebsd")))] +#[cfg(not(any( + test, + feature = "test-support", + all(target_os = "windows", target_env = "gnu"), + target_os = "freebsd" +)))] pub use livekit_client::*; -#[cfg(any(test, feature = "test-support", target_os = "freebsd"))] +#[cfg(any( + test, + feature = "test-support", + all(target_os = "windows", target_env = "gnu"), + target_os = "freebsd" +))] mod mock_client; -#[cfg(any(test, feature = "test-support", target_os = "freebsd"))] +#[cfg(any( + test, + feature = "test-support", + all(target_os = "windows", target_env = "gnu"), + target_os = "freebsd" +))] pub mod test; -#[cfg(any(test, feature = "test-support", target_os = "freebsd"))] +#[cfg(any( + test, + feature = "test-support", + all(target_os = "windows", target_env = "gnu"), + target_os = "freebsd" +))] pub use mock_client::*; #[derive(Debug, Clone)] diff --git a/crates/livekit_client/src/livekit_client/playback.rs b/crates/livekit_client/src/livekit_client/playback.rs index 7e36314c12f24fcc696cb5d66f57717ed052a81b..f14e156125f6da815fe24aabd798e53c6c3e82b8 100644 --- a/crates/livekit_client/src/livekit_client/playback.rs +++ b/crates/livekit_client/src/livekit_client/playback.rs @@ -1,6 +1,7 @@ use anyhow::{Context as _, Result}; use cpal::traits::{DeviceTrait, HostTrait, StreamTrait as _}; +use cpal::{Data, FromSample, I24, SampleFormat, SizedSample}; use futures::channel::mpsc::UnboundedSender; use futures::{Stream, StreamExt as _}; use gpui::{ @@ -258,9 +259,15 @@ impl AudioStack { let stream = device .build_input_stream_raw( &config.config(), - cpal::SampleFormat::I16, + config.sample_format(), move |data, _: &_| { - let mut data = data.as_slice::<i16>().unwrap(); + let data = + Self::get_sample_data(config.sample_format(), data).log_err(); + let Some(data) = data else { + return; + }; + let mut data = data.as_slice(); + while data.len() > 0 { let remainder = (buf.capacity() - buf.len()).min(data.len()); buf.extend_from_slice(&data[..remainder]); @@ -313,6 +320,33 @@ impl AudioStack { drop(end_on_drop_tx) } } + + fn get_sample_data(sample_format: SampleFormat, data: &Data) -> Result<Vec<i16>> { + match sample_format { + SampleFormat::I8 => Ok(Self::convert_sample_data::<i8, i16>(data)), + SampleFormat::I16 => Ok(data.as_slice::<i16>().unwrap().to_vec()), + SampleFormat::I24 => Ok(Self::convert_sample_data::<I24, i16>(data)), + SampleFormat::I32 => Ok(Self::convert_sample_data::<i32, i16>(data)), + SampleFormat::I64 => Ok(Self::convert_sample_data::<i64, i16>(data)), + SampleFormat::U8 => Ok(Self::convert_sample_data::<u8, i16>(data)), + SampleFormat::U16 => Ok(Self::convert_sample_data::<u16, i16>(data)), + SampleFormat::U32 => Ok(Self::convert_sample_data::<u32, i16>(data)), + SampleFormat::U64 => Ok(Self::convert_sample_data::<u64, i16>(data)), + SampleFormat::F32 => Ok(Self::convert_sample_data::<f32, i16>(data)), + SampleFormat::F64 => Ok(Self::convert_sample_data::<f64, i16>(data)), + _ => anyhow::bail!("Unsupported sample format"), + } + } + + fn convert_sample_data<TSource: SizedSample, TDest: SizedSample + FromSample<TSource>>( + data: &Data, + ) -> Vec<TDest> { + data.as_slice::<TSource>() + .unwrap() + .iter() + .map(|e| e.to_sample::<TDest>()) + .collect() + } } use super::LocalVideoTrack; @@ -326,11 +360,11 @@ pub(crate) async fn capture_local_video_track( capture_source: &dyn ScreenCaptureSource, cx: &mut gpui::AsyncApp, ) -> Result<(crate::LocalVideoTrack, Box<dyn ScreenCaptureStream>)> { - let resolution = capture_source.resolution()?; + let metadata = capture_source.metadata()?; let track_source = gpui_tokio::Tokio::spawn(cx, async move { NativeVideoSource::new(VideoResolution { - width: resolution.width.0 as u32, - height: resolution.height.0 as u32, + width: metadata.resolution.width.0 as u32, + height: metadata.resolution.height.0 as u32, }) })? .await?; diff --git a/crates/livekit_client/src/mock_client/participant.rs b/crates/livekit_client/src/mock_client/participant.rs index 1f4168b8e04058f00af3b3117ba17dfa90947736..033808cbb54189fa2a7841264097751da4deb027 100644 --- a/crates/livekit_client/src/mock_client/participant.rs +++ b/crates/livekit_client/src/mock_client/participant.rs @@ -5,7 +5,9 @@ use crate::{ }; use anyhow::Result; use collections::HashMap; -use gpui::{AsyncApp, ScreenCaptureSource, ScreenCaptureStream}; +use gpui::{ + AsyncApp, DevicePixels, ScreenCaptureSource, ScreenCaptureStream, SourceMetadata, size, +}; #[derive(Clone, Debug)] pub struct LocalParticipant { @@ -122,4 +124,13 @@ impl RemoteParticipant { struct TestScreenCaptureStream; -impl gpui::ScreenCaptureStream for TestScreenCaptureStream {} +impl ScreenCaptureStream for TestScreenCaptureStream { + fn metadata(&self) -> Result<SourceMetadata> { + Ok(SourceMetadata { + id: 0, + is_main: None, + label: None, + resolution: size(DevicePixels(1), DevicePixels(1)), + }) + } +} diff --git a/crates/lmstudio/src/lmstudio.rs b/crates/lmstudio/src/lmstudio.rs index a5477994ff844c31be536c6910b66c41fca54b25..43c78115cdd4f517a51052991121620a0a93c363 100644 --- a/crates/lmstudio/src/lmstudio.rs +++ b/crates/lmstudio/src/lmstudio.rs @@ -1,4 +1,4 @@ -use anyhow::{Context as _, Result}; +use anyhow::{Context as _, Result, anyhow}; use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http}; use serde::{Deserialize, Serialize}; @@ -275,11 +275,16 @@ impl Capabilities { } } +#[derive(Serialize, Deserialize, Debug)] +pub struct LmStudioError { + pub message: String, +} + #[derive(Serialize, Deserialize, Debug)] #[serde(untagged)] pub enum ResponseStreamResult { Ok(ResponseStreamEvent), - Err { error: String }, + Err { error: LmStudioError }, } #[derive(Serialize, Deserialize, Debug)] @@ -392,7 +397,6 @@ pub async fn stream_chat_completion( let mut response = client.send(request).await?; if response.status().is_success() { let reader = BufReader::new(response.into_body()); - Ok(reader .lines() .filter_map(|line| async move { @@ -402,18 +406,16 @@ pub async fn stream_chat_completion( if line == "[DONE]" { None } else { - let result = serde_json::from_str(&line) - .context("Unable to parse chat completions response"); - if let Err(ref e) = result { - eprintln!("Error parsing line: {e}\nLine content: '{line}'"); + match serde_json::from_str(line) { + Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)), + Ok(ResponseStreamResult::Err { error, .. }) => { + Some(Err(anyhow!(error.message))) + } + Err(error) => Some(Err(anyhow!(error))), } - Some(result) } } - Err(e) => { - eprintln!("Error reading line: {e}"); - Some(Err(e.into())) - } + Err(error) => Some(Err(anyhow!(error))), } }) .boxed()) diff --git a/crates/lsp/src/input_handler.rs b/crates/lsp/src/input_handler.rs index db3f1190fc60d8b2b7e1c9d16c9a35ed31b02870..001ebf1fc988ebb30301887d3dadbed76326857c 100644 --- a/crates/lsp/src/input_handler.rs +++ b/crates/lsp/src/input_handler.rs @@ -13,14 +13,15 @@ use parking_lot::Mutex; use smol::io::BufReader; use crate::{ - AnyNotification, AnyResponse, CONTENT_LEN_HEADER, IoHandler, IoKind, RequestId, ResponseHandler, + AnyResponse, CONTENT_LEN_HEADER, IoHandler, IoKind, NotificationOrRequest, RequestId, + ResponseHandler, }; const HEADER_DELIMITER: &[u8; 4] = b"\r\n\r\n"; /// Handler for stdout of language server. pub struct LspStdoutHandler { pub(super) loop_handle: Task<Result<()>>, - pub(super) notifications_channel: UnboundedReceiver<AnyNotification>, + pub(super) incoming_messages: UnboundedReceiver<NotificationOrRequest>, } async fn read_headers<Stdout>(reader: &mut BufReader<Stdout>, buffer: &mut Vec<u8>) -> Result<()> @@ -54,13 +55,13 @@ impl LspStdoutHandler { let loop_handle = cx.spawn(Self::handler(stdout, tx, response_handlers, io_handlers)); Self { loop_handle, - notifications_channel, + incoming_messages: notifications_channel, } } async fn handler<Input>( stdout: Input, - notifications_sender: UnboundedSender<AnyNotification>, + notifications_sender: UnboundedSender<NotificationOrRequest>, response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>, io_handlers: Arc<Mutex<HashMap<i32, IoHandler>>>, ) -> anyhow::Result<()> @@ -96,7 +97,7 @@ impl LspStdoutHandler { } } - if let Ok(msg) = serde_json::from_slice::<AnyNotification>(&buffer) { + if let Ok(msg) = serde_json::from_slice::<NotificationOrRequest>(&buffer) { notifications_sender.unbounded_send(msg)?; } else if let Ok(AnyResponse { id, error, result, .. diff --git a/crates/lsp/src/lsp.rs b/crates/lsp/src/lsp.rs index 53dc24a21a93fecee9a320a44a9b9c46655f31be..a92787cd3e74d3a8b1e18f1d4cd41d4cf300d484 100644 --- a/crates/lsp/src/lsp.rs +++ b/crates/lsp/src/lsp.rs @@ -4,7 +4,7 @@ pub use lsp_types::request::*; pub use lsp_types::*; use anyhow::{Context as _, Result, anyhow}; -use collections::HashMap; +use collections::{BTreeMap, HashMap}; use futures::{ AsyncRead, AsyncWrite, Future, FutureExt, channel::oneshot::{self, Canceled}, @@ -29,7 +29,7 @@ use std::{ ffi::{OsStr, OsString}, fmt, io::Write, - ops::{Deref, DerefMut}, + ops::DerefMut, path::PathBuf, pin::Pin, sync::{ @@ -40,7 +40,7 @@ use std::{ time::{Duration, Instant}, }; use std::{path::Path, process::Stdio}; -use util::{ConnectionResult, ResultExt, TryFutureExt}; +use util::{ConnectionResult, ResultExt, TryFutureExt, redact}; const JSON_RPC_VERSION: &str = "2.0"; const CONTENT_LEN_HEADER: &str = "Content-Length: "; @@ -62,7 +62,7 @@ pub enum IoKind { /// Represents a launchable language server. This can either be a standalone binary or the path /// to a runtime with arguments to instruct it to launch the actual language server file. -#[derive(Debug, Clone, Deserialize)] +#[derive(Clone, Deserialize)] pub struct LanguageServerBinary { pub path: PathBuf, pub arguments: Vec<OsString>, @@ -100,7 +100,7 @@ pub struct LanguageServer { io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>, output_done_rx: Mutex<Option<barrier::Receiver>>, server: Arc<Mutex<Option<Child>>>, - workspace_folders: Arc<Mutex<BTreeSet<Url>>>, + workspace_folders: Option<Arc<Mutex<BTreeSet<Url>>>>, root_uri: Url, } @@ -242,7 +242,7 @@ struct Notification<'a, T> { /// Language server RPC notification message before it is deserialized into a concrete type. #[derive(Debug, Clone, Deserialize)] -struct AnyNotification { +struct NotificationOrRequest { #[serde(default)] id: Option<RequestId>, method: String, @@ -252,7 +252,10 @@ struct AnyNotification { #[derive(Debug, Serialize, Deserialize)] struct Error { + code: i64, message: String, + #[serde(default)] + data: Option<serde_json::Value>, } pub trait LspRequestFuture<O>: Future<Output = ConnectionResult<O>> { @@ -307,7 +310,7 @@ impl LanguageServer { binary: LanguageServerBinary, root_path: &Path, code_action_kinds: Option<Vec<CodeActionKind>>, - workspace_folders: Arc<Mutex<BTreeSet<Url>>>, + workspace_folders: Option<Arc<Mutex<BTreeSet<Url>>>>, cx: &mut AsyncApp, ) -> Result<Self> { let working_dir = if root_path.is_dir() { @@ -364,6 +367,7 @@ impl LanguageServer { notification.method, serde_json::to_string_pretty(¬ification.params).unwrap(), ); + false }, ); @@ -381,7 +385,7 @@ impl LanguageServer { code_action_kinds: Option<Vec<CodeActionKind>>, binary: LanguageServerBinary, root_uri: Url, - workspace_folders: Arc<Mutex<BTreeSet<Url>>>, + workspace_folders: Option<Arc<Mutex<BTreeSet<Url>>>>, cx: &mut AsyncApp, on_unhandled_notification: F, ) -> Self @@ -389,7 +393,7 @@ impl LanguageServer { Stdin: AsyncWrite + Unpin + Send + 'static, Stdout: AsyncRead + Unpin + Send + 'static, Stderr: AsyncRead + Unpin + Send + 'static, - F: FnMut(AnyNotification) + 'static + Send + Sync + Clone, + F: Fn(&NotificationOrRequest) -> bool + 'static + Send + Sync + Clone, { let (outbound_tx, outbound_rx) = channel::unbounded::<String>(); let (output_done_tx, output_done_rx) = barrier::channel(); @@ -400,14 +404,34 @@ impl LanguageServer { let io_handlers = Arc::new(Mutex::new(HashMap::default())); let stdout_input_task = cx.spawn({ - let on_unhandled_notification = on_unhandled_notification.clone(); + let unhandled_notification_wrapper = { + let response_channel = outbound_tx.clone(); + async move |msg: NotificationOrRequest| { + let did_handle = on_unhandled_notification(&msg); + if !did_handle && let Some(message_id) = msg.id { + let response = AnyResponse { + jsonrpc: JSON_RPC_VERSION, + id: message_id, + error: Some(Error { + code: -32601, + message: format!("Unrecognized method `{}`", msg.method), + data: None, + }), + result: None, + }; + if let Ok(response) = serde_json::to_string(&response) { + response_channel.send(response).await.ok(); + } + } + } + }; let notification_handlers = notification_handlers.clone(); let response_handlers = response_handlers.clone(); let io_handlers = io_handlers.clone(); async move |cx| { - Self::handle_input( + Self::handle_incoming_messages( stdout, - on_unhandled_notification, + unhandled_notification_wrapper, notification_handlers, response_handlers, io_handlers, @@ -421,19 +445,19 @@ impl LanguageServer { .map(|stderr| { let io_handlers = io_handlers.clone(); let stderr_captures = stderr_capture.clone(); - cx.spawn(async move |_| { + cx.background_spawn(async move { Self::handle_stderr(stderr, io_handlers, stderr_captures) .log_err() .await }) }) .unwrap_or_else(|| Task::ready(None)); - let input_task = cx.spawn(async move |_| { + let input_task = cx.background_spawn(async move { let (stdout, stderr) = futures::join!(stdout_input_task, stderr_input_task); stdout.or(stderr) }); let output_task = cx.background_spawn({ - Self::handle_output( + Self::handle_outgoing_messages( stdin, outbound_rx, output_done_tx, @@ -479,9 +503,9 @@ impl LanguageServer { self.code_action_kinds.clone() } - async fn handle_input<Stdout, F>( + async fn handle_incoming_messages<Stdout>( stdout: Stdout, - mut on_unhandled_notification: F, + on_unhandled_notification: impl AsyncFn(NotificationOrRequest) + 'static + Send, notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>, response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>, io_handlers: Arc<Mutex<HashMap<i32, IoHandler>>>, @@ -489,7 +513,6 @@ impl LanguageServer { ) -> anyhow::Result<()> where Stdout: AsyncRead + Unpin + Send + 'static, - F: FnMut(AnyNotification) + 'static + Send, { use smol::stream::StreamExt; let stdout = BufReader::new(stdout); @@ -506,15 +529,19 @@ impl LanguageServer { cx.background_executor().clone(), ); - while let Some(msg) = input_handler.notifications_channel.next().await { - { + while let Some(msg) = input_handler.incoming_messages.next().await { + let unhandled_message = { let mut notification_handlers = notification_handlers.lock(); if let Some(handler) = notification_handlers.get_mut(msg.method.as_str()) { handler(msg.id, msg.params.unwrap_or(Value::Null), cx); + None } else { - drop(notification_handlers); - on_unhandled_notification(msg); + Some(msg) } + }; + + if let Some(msg) = unhandled_message { + on_unhandled_notification(msg).await; } // Don't starve the main thread when receiving lots of notifications at once. @@ -558,7 +585,7 @@ impl LanguageServer { } } - async fn handle_output<Stdin>( + async fn handle_outgoing_messages<Stdin>( stdin: Stdin, outbound_rx: channel::Receiver<String>, output_done_tx: barrier::Sender, @@ -595,16 +622,26 @@ impl LanguageServer { } pub fn default_initialize_params(&self, pull_diagnostics: bool, cx: &App) -> InitializeParams { - let workspace_folders = self - .workspace_folders - .lock() - .iter() - .cloned() - .map(|uri| WorkspaceFolder { - name: Default::default(), - uri, - }) - .collect::<Vec<_>>(); + let workspace_folders = self.workspace_folders.as_ref().map_or_else( + || { + vec![WorkspaceFolder { + name: Default::default(), + uri: self.root_uri.clone(), + }] + }, + |folders| { + folders + .lock() + .iter() + .cloned() + .map(|uri| WorkspaceFolder { + name: Default::default(), + uri, + }) + .collect() + }, + ); + #[allow(deprecated)] InitializeParams { process_id: None, @@ -633,7 +670,7 @@ impl LanguageServer { inlay_hint: Some(InlayHintWorkspaceClientCapabilities { refresh_support: Some(true), }), - diagnostic: Some(DiagnosticWorkspaceClientCapabilities { + diagnostics: Some(DiagnosticWorkspaceClientCapabilities { refresh_support: Some(true), }) .filter(|_| pull_diagnostics), @@ -710,6 +747,10 @@ impl LanguageServer { InsertTextMode::ADJUST_INDENTATION, ], }), + documentation_format: Some(vec![ + MarkupKind::Markdown, + MarkupKind::PlainText, + ]), ..Default::default() }), insert_text_mode: Some(InsertTextMode::ADJUST_INDENTATION), @@ -836,7 +877,7 @@ impl LanguageServer { configuration: Arc<DidChangeConfigurationParams>, cx: &App, ) -> Task<Result<Arc<Self>>> { - cx.spawn(async move |_| { + cx.background_spawn(async move { let response = self .request::<request::Initialize>(params) .await @@ -874,43 +915,44 @@ impl LanguageServer { &executor, (), ); - let exit = Self::notify_internal::<notification::Exit>(&outbound_tx, &()); - outbound_tx.close(); let server = self.server.clone(); let name = self.name.clone(); + let server_id = self.server_id; let mut timer = self.executor.timer(SERVER_SHUTDOWN_TIMEOUT).fuse(); - Some( - async move { - log::debug!("language server shutdown started"); - - select! { - request_result = shutdown_request.fuse() => { - match request_result { - ConnectionResult::Timeout => { - log::warn!("timeout waiting for language server {name} to shutdown"); - }, - ConnectionResult::ConnectionReset => {}, - ConnectionResult::Result(r) => r?, - } + Some(async move { + log::debug!("language server shutdown started"); + + select! { + request_result = shutdown_request.fuse() => { + match request_result { + ConnectionResult::Timeout => { + log::warn!("timeout waiting for language server {name} (id {server_id}) to shutdown"); + }, + ConnectionResult::ConnectionReset => { + log::warn!("language server {name} (id {server_id}) closed the shutdown request connection"); + }, + ConnectionResult::Result(Err(e)) => { + log::error!("Shutdown request failure, server {name} (id {server_id}): {e:#}"); + }, + ConnectionResult::Result(Ok(())) => {} } - - _ = timer => { - log::info!("timeout waiting for language server {name} to shutdown"); - }, } - response_handlers.lock().take(); - exit?; - output_done.recv().await; - server.lock().take().map(|mut child| child.kill()); - log::debug!("language server shutdown finished"); - - drop(tasks); - anyhow::Ok(()) + _ = timer => { + log::info!("timeout waiting for language server {name} (id {server_id}) to shutdown"); + }, } - .log_err(), - ) + + response_handlers.lock().take(); + Self::notify_internal::<notification::Exit>(&outbound_tx, &()).ok(); + outbound_tx.close(); + output_done.recv().await; + server.lock().take().map(|mut child| child.kill()); + drop(tasks); + log::debug!("language server shutdown finished"); + Some(()) + }) } else { None } @@ -1025,7 +1067,9 @@ impl LanguageServer { jsonrpc: JSON_RPC_VERSION, id, value: LspResult::Error(Some(Error { + code: lsp_types::error_codes::REQUEST_FAILED, message: error.to_string(), + data: None, })), }, }; @@ -1046,7 +1090,9 @@ impl LanguageServer { id, result: None, error: Some(Error { + code: -32700, // Parse error message: error.to_string(), + data: None, }), }; if let Some(response) = serde_json::to_string(&response).log_err() { @@ -1107,6 +1153,7 @@ impl LanguageServer { pub fn binary(&self) -> &LanguageServerBinary { &self.binary } + /// Sends a RPC request to the language server. /// /// [LSP Specification](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#requestMessage) @@ -1126,16 +1173,40 @@ impl LanguageServer { ) } - fn request_internal<T>( + /// Sends a RPC request to the language server, with a custom timer, a future which when becoming + /// ready causes the request to be timed out with the future's output message. + /// + /// [LSP Specification](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#requestMessage) + pub fn request_with_timer<T: request::Request, U: Future<Output = String>>( + &self, + params: T::Params, + timer: U, + ) -> impl LspRequestFuture<T::Result> + use<T, U> + where + T::Result: 'static + Send, + { + Self::request_internal_with_timer::<T, U>( + &self.next_id, + &self.response_handlers, + &self.outbound_tx, + &self.executor, + timer, + params, + ) + } + + fn request_internal_with_timer<T, U>( next_id: &AtomicI32, response_handlers: &Mutex<Option<HashMap<RequestId, ResponseHandler>>>, outbound_tx: &channel::Sender<String>, executor: &BackgroundExecutor, + timer: U, params: T::Params, - ) -> impl LspRequestFuture<T::Result> + use<T> + ) -> impl LspRequestFuture<T::Result> + use<T, U> where T::Result: 'static + Send, T: request::Request, + U: Future<Output = String>, { let id = next_id.fetch_add(1, SeqCst); let message = serde_json::to_string(&Request { @@ -1180,7 +1251,6 @@ impl LanguageServer { .context("failed to write to language server's stdin"); let outbound_tx = outbound_tx.downgrade(); - let mut timeout = executor.timer(LSP_REQUEST_TIMEOUT).fuse(); let started = Instant::now(); LspRequest::new(id, async move { if let Err(e) = handle_response { @@ -1217,14 +1287,41 @@ impl LanguageServer { } } - _ = timeout => { - log::error!("Cancelled LSP request task for {method:?} id {id} which took over {LSP_REQUEST_TIMEOUT:?}"); + message = timer.fuse() => { + log::error!("Cancelled LSP request task for {method:?} id {id} {message}"); ConnectionResult::Timeout } } }) } + fn request_internal<T>( + next_id: &AtomicI32, + response_handlers: &Mutex<Option<HashMap<RequestId, ResponseHandler>>>, + outbound_tx: &channel::Sender<String>, + executor: &BackgroundExecutor, + params: T::Params, + ) -> impl LspRequestFuture<T::Result> + use<T> + where + T::Result: 'static + Send, + T: request::Request, + { + Self::request_internal_with_timer::<T, _>( + next_id, + response_handlers, + outbound_tx, + executor, + Self::default_request_timer(executor.clone()), + params, + ) + } + + pub fn default_request_timer(executor: BackgroundExecutor) -> impl Future<Output = String> { + executor + .timer(LSP_REQUEST_TIMEOUT) + .map(|_| format!("which took over {LSP_REQUEST_TIMEOUT:?}")) + } + /// Sends a RPC notification to the language server. /// /// [LSP Specification](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#notificationMessage) @@ -1263,7 +1360,10 @@ impl LanguageServer { return; } - let is_new_folder = self.workspace_folders.lock().insert(uri.clone()); + let Some(workspace_folders) = self.workspace_folders.as_ref() else { + return; + }; + let is_new_folder = workspace_folders.lock().insert(uri.clone()); if is_new_folder { let params = DidChangeWorkspaceFoldersParams { event: WorkspaceFoldersChangeEvent { @@ -1293,7 +1393,10 @@ impl LanguageServer { { return; } - let was_removed = self.workspace_folders.lock().remove(&uri); + let Some(workspace_folders) = self.workspace_folders.as_ref() else { + return; + }; + let was_removed = workspace_folders.lock().remove(&uri); if was_removed { let params = DidChangeWorkspaceFoldersParams { event: WorkspaceFoldersChangeEvent { @@ -1308,7 +1411,10 @@ impl LanguageServer { } } pub fn set_workspace_folders(&self, folders: BTreeSet<Url>) { - let mut workspace_folders = self.workspace_folders.lock(); + let Some(workspace_folders) = self.workspace_folders.as_ref() else { + return; + }; + let mut workspace_folders = workspace_folders.lock(); let old_workspace_folders = std::mem::take(&mut *workspace_folders); let added: Vec<_> = folders @@ -1337,8 +1443,11 @@ impl LanguageServer { } } - pub fn workspace_folders(&self) -> impl Deref<Target = BTreeSet<Url>> + '_ { - self.workspace_folders.lock() + pub fn workspace_folders(&self) -> BTreeSet<Url> { + self.workspace_folders.as_ref().map_or_else( + || BTreeSet::from_iter([self.root_uri.clone()]), + |folders| folders.lock().clone(), + ) } pub fn register_buffer( @@ -1398,6 +1507,33 @@ impl fmt::Debug for LanguageServer { } } +impl fmt::Debug for LanguageServerBinary { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut debug = f.debug_struct("LanguageServerBinary"); + debug.field("path", &self.path); + debug.field("arguments", &self.arguments); + + if let Some(env) = &self.env { + let redacted_env: BTreeMap<String, String> = env + .iter() + .map(|(key, value)| { + let redacted_value = if redact::should_redact(key) { + "REDACTED".to_string() + } else { + value.clone() + }; + (key.clone(), redacted_value) + }) + .collect(); + debug.field("env", &Some(redacted_env)); + } else { + debug.field("env", &self.env); + } + + debug.finish() + } +} + impl Drop for Subscription { fn drop(&mut self) { match self { @@ -1456,9 +1592,9 @@ impl FakeLanguageServer { None, binary.clone(), root, - workspace_folders.clone(), + Some(workspace_folders.clone()), cx, - |_| {}, + |_| false, ); server.process_name = process_name; let fake = FakeLanguageServer { @@ -1475,15 +1611,16 @@ impl FakeLanguageServer { None, binary, Self::root_path(), - workspace_folders, + Some(workspace_folders), cx, move |msg| { notifications_tx .try_send(( msg.method.to_string(), - msg.params.unwrap_or(Value::Null).to_string(), + msg.params.as_ref().unwrap_or(&Value::Null).to_string(), )) .ok(); + true }, ); server.process_name = name.as_str().into(); @@ -1508,6 +1645,8 @@ impl FakeLanguageServer { } }); + fake.set_request_handler::<request::Shutdown, _, _>(|_, _| async move { Ok(()) }); + (server, fake) } #[cfg(target_os = "windows")] @@ -1759,7 +1898,7 @@ mod tests { #[gpui::test] fn test_deserialize_string_digit_id() { let json = r#"{"jsonrpc":"2.0","id":"2","method":"workspace/configuration","params":{"items":[{"scopeUri":"file:///Users/mph/Devel/personal/hello-scala/","section":"metals"}]}}"#; - let notification = serde_json::from_str::<AnyNotification>(json) + let notification = serde_json::from_str::<NotificationOrRequest>(json) .expect("message with string id should be parsed"); let expected_id = RequestId::Str("2".to_string()); assert_eq!(notification.id, Some(expected_id)); @@ -1768,7 +1907,7 @@ mod tests { #[gpui::test] fn test_deserialize_string_id() { let json = r#"{"jsonrpc":"2.0","id":"anythingAtAll","method":"workspace/configuration","params":{"items":[{"scopeUri":"file:///Users/mph/Devel/personal/hello-scala/","section":"metals"}]}}"#; - let notification = serde_json::from_str::<AnyNotification>(json) + let notification = serde_json::from_str::<NotificationOrRequest>(json) .expect("message with string id should be parsed"); let expected_id = RequestId::Str("anythingAtAll".to_string()); assert_eq!(notification.id, Some(expected_id)); @@ -1777,7 +1916,7 @@ mod tests { #[gpui::test] fn test_deserialize_int_id() { let json = r#"{"jsonrpc":"2.0","id":2,"method":"workspace/configuration","params":{"items":[{"scopeUri":"file:///Users/mph/Devel/personal/hello-scala/","section":"metals"}]}}"#; - let notification = serde_json::from_str::<AnyNotification>(json) + let notification = serde_json::from_str::<NotificationOrRequest>(json) .expect("message with string id should be parsed"); let expected_id = RequestId::Int(2); assert_eq!(notification.id, Some(expected_id)); diff --git a/crates/markdown_preview/src/markdown_preview_view.rs b/crates/markdown_preview/src/markdown_preview_view.rs index 03cfd7ee8211118fc70d19c948cf4231b6cb23ca..a0c8819991d68336a306af85a4dd709353222fa1 100644 --- a/crates/markdown_preview/src/markdown_preview_view.rs +++ b/crates/markdown_preview/src/markdown_preview_view.rs @@ -18,6 +18,7 @@ use workspace::item::{Item, ItemHandle}; use workspace::{Pane, Workspace}; use crate::markdown_elements::ParsedMarkdownElement; +use crate::markdown_renderer::CheckboxClickedEvent; use crate::{ MovePageDown, MovePageUp, OpenFollowingPreview, OpenPreview, OpenPreviewToTheSide, markdown_elements::ParsedMarkdown, @@ -203,114 +204,7 @@ impl MarkdownPreviewView { cx: &mut Context<Workspace>, ) -> Entity<Self> { cx.new(|cx| { - let view = cx.entity().downgrade(); - - let list_state = ListState::new( - 0, - gpui::ListAlignment::Top, - px(1000.), - move |ix, window, cx| { - if let Some(view) = view.upgrade() { - view.update(cx, |this: &mut Self, cx| { - let Some(contents) = &this.contents else { - return div().into_any(); - }; - - let mut render_cx = - RenderContext::new(Some(this.workspace.clone()), window, cx) - .with_checkbox_clicked_callback({ - let view = view.clone(); - move |checked, source_range, window, cx| { - view.update(cx, |view, cx| { - if let Some(editor) = view - .active_editor - .as_ref() - .map(|s| s.editor.clone()) - { - editor.update(cx, |editor, cx| { - let task_marker = - if checked { "[x]" } else { "[ ]" }; - - editor.edit( - vec![(source_range, task_marker)], - cx, - ); - }); - view.parse_markdown_from_active_editor( - false, window, cx, - ); - cx.notify(); - } - }) - } - }); - - let block = contents.children.get(ix).unwrap(); - let rendered_block = render_markdown_block(block, &mut render_cx); - - let should_apply_padding = Self::should_apply_padding_between( - block, - contents.children.get(ix + 1), - ); - - div() - .id(ix) - .when(should_apply_padding, |this| { - this.pb(render_cx.scaled_rems(0.75)) - }) - .group("markdown-block") - .on_click(cx.listener( - move |this, event: &ClickEvent, window, cx| { - if event.down.click_count == 2 { - if let Some(source_range) = this - .contents - .as_ref() - .and_then(|c| c.children.get(ix)) - .and_then(|block| block.source_range()) - { - this.move_cursor_to_block( - window, - cx, - source_range.start..source_range.start, - ); - } - } - }, - )) - .map(move |container| { - let indicator = div() - .h_full() - .w(px(4.0)) - .when(ix == this.selected_block, |this| { - this.bg(cx.theme().colors().border) - }) - .group_hover("markdown-block", |s| { - if ix == this.selected_block { - s - } else { - s.bg(cx.theme().colors().border_variant) - } - }) - .rounded_xs(); - - container.child( - div() - .relative() - .child( - div() - .pl(render_cx.scaled_rems(1.0)) - .child(rendered_block), - ) - .child(indicator.absolute().left_0().top_0()), - ) - }) - .into_any() - }) - } else { - div().into_any() - } - }, - ); + let list_state = ListState::new(0, gpui::ListAlignment::Top, px(1000.)); let mut this = Self { selected_block: 0, @@ -607,10 +501,107 @@ impl Render for MarkdownPreviewView { .p_4() .text_size(buffer_size) .line_height(relative(buffer_line_height.value())) - .child( - div() - .flex_grow() - .map(|this| this.child(list(self.list_state.clone()).size_full())), - ) + .child(div().flex_grow().map(|this| { + this.child( + list( + self.list_state.clone(), + cx.processor(|this, ix, window, cx| { + let Some(contents) = &this.contents else { + return div().into_any(); + }; + + let mut render_cx = + RenderContext::new(Some(this.workspace.clone()), window, cx) + .with_checkbox_clicked_callback(cx.listener( + move |this, e: &CheckboxClickedEvent, window, cx| { + if let Some(editor) = this + .active_editor + .as_ref() + .map(|s| s.editor.clone()) + { + editor.update(cx, |editor, cx| { + let task_marker = + if e.checked() { "[x]" } else { "[ ]" }; + + editor.edit( + vec![(e.source_range(), task_marker)], + cx, + ); + }); + this.parse_markdown_from_active_editor( + false, window, cx, + ); + cx.notify(); + } + }, + )); + + let block = contents.children.get(ix).unwrap(); + let rendered_block = render_markdown_block(block, &mut render_cx); + + let should_apply_padding = Self::should_apply_padding_between( + block, + contents.children.get(ix + 1), + ); + + div() + .id(ix) + .when(should_apply_padding, |this| { + this.pb(render_cx.scaled_rems(0.75)) + }) + .group("markdown-block") + .on_click(cx.listener( + move |this, event: &ClickEvent, window, cx| { + if event.click_count() == 2 { + if let Some(source_range) = this + .contents + .as_ref() + .and_then(|c| c.children.get(ix)) + .and_then(|block: &ParsedMarkdownElement| { + block.source_range() + }) + { + this.move_cursor_to_block( + window, + cx, + source_range.start..source_range.start, + ); + } + } + }, + )) + .map(move |container| { + let indicator = div() + .h_full() + .w(px(4.0)) + .when(ix == this.selected_block, |this| { + this.bg(cx.theme().colors().border) + }) + .group_hover("markdown-block", |s| { + if ix == this.selected_block { + s + } else { + s.bg(cx.theme().colors().border_variant) + } + }) + .rounded_xs(); + + container.child( + div() + .relative() + .child( + div() + .pl(render_cx.scaled_rems(1.0)) + .child(rendered_block), + ) + .child(indicator.absolute().left_0().top_0()), + ) + }) + .into_any() + }), + ) + .size_full(), + ) + })) } } diff --git a/crates/markdown_preview/src/markdown_renderer.rs b/crates/markdown_preview/src/markdown_renderer.rs index 80bed8a6e80ec92e62ea6c0d06b6447fd87b366f..37d2ca21105566f1e2e3271f49c75a3ce1d7846b 100644 --- a/crates/markdown_preview/src/markdown_renderer.rs +++ b/crates/markdown_preview/src/markdown_renderer.rs @@ -26,7 +26,22 @@ use ui::{ }; use workspace::{OpenOptions, OpenVisible, Workspace}; -type CheckboxClickedCallback = Arc<Box<dyn Fn(bool, Range<usize>, &mut Window, &mut App)>>; +pub struct CheckboxClickedEvent { + pub checked: bool, + pub source_range: Range<usize>, +} + +impl CheckboxClickedEvent { + pub fn source_range(&self) -> Range<usize> { + self.source_range.clone() + } + + pub fn checked(&self) -> bool { + self.checked + } +} + +type CheckboxClickedCallback = Arc<Box<dyn Fn(&CheckboxClickedEvent, &mut Window, &mut App)>>; #[derive(Clone)] pub struct RenderContext { @@ -80,7 +95,7 @@ impl RenderContext { pub fn with_checkbox_clicked_callback( mut self, - callback: impl Fn(bool, Range<usize>, &mut Window, &mut App) + 'static, + callback: impl Fn(&CheckboxClickedEvent, &mut Window, &mut App) + 'static, ) -> Self { self.checkbox_clicked_callback = Some(Arc::new(Box::new(callback))); self @@ -229,7 +244,14 @@ fn render_markdown_list_item( }; if window.modifiers().secondary() { - callback(checked, range.clone(), window, cx); + callback( + &CheckboxClickedEvent { + checked, + source_range: range.clone(), + }, + window, + cx, + ); } } }) diff --git a/crates/migrator/src/migrations/m_2025_01_29/keymap.rs b/crates/migrator/src/migrations/m_2025_01_29/keymap.rs index c32da88229b429ad206168eeee30f401863b39bd..646af8f63dc90b6ebe3faef9432eecc54140b438 100644 --- a/crates/migrator/src/migrations/m_2025_01_29/keymap.rs +++ b/crates/migrator/src/migrations/m_2025_01_29/keymap.rs @@ -242,22 +242,22 @@ static STRING_REPLACE: LazyLock<HashMap<&str, &str>> = LazyLock::new(|| { "inline_completion::ToggleMenu", "edit_prediction::ToggleMenu", ), - ("editor::NextInlineCompletion", "editor::NextEditPrediction"), + ("editor::NextEditPrediction", "editor::NextEditPrediction"), ( - "editor::PreviousInlineCompletion", + "editor::PreviousEditPrediction", "editor::PreviousEditPrediction", ), ( - "editor::AcceptPartialInlineCompletion", + "editor::AcceptPartialEditPrediction", "editor::AcceptPartialEditPrediction", ), - ("editor::ShowInlineCompletion", "editor::ShowEditPrediction"), + ("editor::ShowEditPrediction", "editor::ShowEditPrediction"), ( - "editor::AcceptInlineCompletion", + "editor::AcceptEditPrediction", "editor::AcceptEditPrediction", ), ( - "editor::ToggleInlineCompletions", + "editor::ToggleEditPredictions", "editor::ToggleEditPrediction", ), ]) diff --git a/crates/migrator/src/migrations/m_2025_04_15/keymap.rs b/crates/migrator/src/migrations/m_2025_04_15/keymap.rs index d1443a922afc52a37912d8aa78b5c9f0d4b4e017..efbdc6b1c64443c4733c73568a13a70cc3fa1f97 100644 --- a/crates/migrator/src/migrations/m_2025_04_15/keymap.rs +++ b/crates/migrator/src/migrations/m_2025_04_15/keymap.rs @@ -25,7 +25,7 @@ fn replace_string_action( None } -/// "ctrl-k ctrl-1": "inline_completion::ToggleMenu" -> "edit_prediction::ToggleMenu" +/// "space": "outline_panel::Open" -> "outline_panel::OpenSelectedEntry" static STRING_REPLACE: LazyLock<HashMap<&str, &str>> = LazyLock::new(|| { HashMap::from_iter([("outline_panel::Open", "outline_panel::OpenSelectedEntry")]) }); diff --git a/crates/mistral/src/mistral.rs b/crates/mistral/src/mistral.rs index a3a017be83544ce60dc1b8bac09cde5e5058b4f4..c466a598a0ca509654ec501614169c24c2094409 100644 --- a/crates/mistral/src/mistral.rs +++ b/crates/mistral/src/mistral.rs @@ -48,18 +48,29 @@ pub enum Model { #[serde(rename = "codestral-latest", alias = "codestral-latest")] #[default] CodestralLatest, + #[serde(rename = "mistral-large-latest", alias = "mistral-large-latest")] MistralLargeLatest, #[serde(rename = "mistral-medium-latest", alias = "mistral-medium-latest")] MistralMediumLatest, #[serde(rename = "mistral-small-latest", alias = "mistral-small-latest")] MistralSmallLatest, + + #[serde(rename = "magistral-medium-latest", alias = "magistral-medium-latest")] + MagistralMediumLatest, + #[serde(rename = "magistral-small-latest", alias = "magistral-small-latest")] + MagistralSmallLatest, + #[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo")] OpenMistralNemo, #[serde(rename = "open-codestral-mamba", alias = "open-codestral-mamba")] OpenCodestralMamba, + + #[serde(rename = "devstral-medium-latest", alias = "devstral-medium-latest")] + DevstralMediumLatest, #[serde(rename = "devstral-small-latest", alias = "devstral-small-latest")] DevstralSmallLatest, + #[serde(rename = "pixtral-12b-latest", alias = "pixtral-12b-latest")] Pixtral12BLatest, #[serde(rename = "pixtral-large-latest", alias = "pixtral-large-latest")] @@ -89,8 +100,11 @@ impl Model { "mistral-large-latest" => Ok(Self::MistralLargeLatest), "mistral-medium-latest" => Ok(Self::MistralMediumLatest), "mistral-small-latest" => Ok(Self::MistralSmallLatest), + "magistral-medium-latest" => Ok(Self::MagistralMediumLatest), + "magistral-small-latest" => Ok(Self::MagistralSmallLatest), "open-mistral-nemo" => Ok(Self::OpenMistralNemo), "open-codestral-mamba" => Ok(Self::OpenCodestralMamba), + "devstral-medium-latest" => Ok(Self::DevstralMediumLatest), "devstral-small-latest" => Ok(Self::DevstralSmallLatest), "pixtral-12b-latest" => Ok(Self::Pixtral12BLatest), "pixtral-large-latest" => Ok(Self::PixtralLargeLatest), @@ -104,8 +118,11 @@ impl Model { Self::MistralLargeLatest => "mistral-large-latest", Self::MistralMediumLatest => "mistral-medium-latest", Self::MistralSmallLatest => "mistral-small-latest", + Self::MagistralMediumLatest => "magistral-medium-latest", + Self::MagistralSmallLatest => "magistral-small-latest", Self::OpenMistralNemo => "open-mistral-nemo", Self::OpenCodestralMamba => "open-codestral-mamba", + Self::DevstralMediumLatest => "devstral-medium-latest", Self::DevstralSmallLatest => "devstral-small-latest", Self::Pixtral12BLatest => "pixtral-12b-latest", Self::PixtralLargeLatest => "pixtral-large-latest", @@ -119,8 +136,11 @@ impl Model { Self::MistralLargeLatest => "mistral-large-latest", Self::MistralMediumLatest => "mistral-medium-latest", Self::MistralSmallLatest => "mistral-small-latest", + Self::MagistralMediumLatest => "magistral-medium-latest", + Self::MagistralSmallLatest => "magistral-small-latest", Self::OpenMistralNemo => "open-mistral-nemo", Self::OpenCodestralMamba => "open-codestral-mamba", + Self::DevstralMediumLatest => "devstral-medium-latest", Self::DevstralSmallLatest => "devstral-small-latest", Self::Pixtral12BLatest => "pixtral-12b-latest", Self::PixtralLargeLatest => "pixtral-large-latest", @@ -136,8 +156,11 @@ impl Model { Self::MistralLargeLatest => 131000, Self::MistralMediumLatest => 128000, Self::MistralSmallLatest => 32000, + Self::MagistralMediumLatest => 40000, + Self::MagistralSmallLatest => 40000, Self::OpenMistralNemo => 131000, Self::OpenCodestralMamba => 256000, + Self::DevstralMediumLatest => 128000, Self::DevstralSmallLatest => 262144, Self::Pixtral12BLatest => 128000, Self::PixtralLargeLatest => 128000, @@ -160,8 +183,11 @@ impl Model { | Self::MistralLargeLatest | Self::MistralMediumLatest | Self::MistralSmallLatest + | Self::MagistralMediumLatest + | Self::MagistralSmallLatest | Self::OpenMistralNemo | Self::OpenCodestralMamba + | Self::DevstralMediumLatest | Self::DevstralSmallLatest | Self::Pixtral12BLatest | Self::PixtralLargeLatest => true, @@ -177,8 +203,11 @@ impl Model { | Self::MistralSmallLatest => true, Self::CodestralLatest | Self::MistralLargeLatest + | Self::MagistralMediumLatest + | Self::MagistralSmallLatest | Self::OpenMistralNemo | Self::OpenCodestralMamba + | Self::DevstralMediumLatest | Self::DevstralSmallLatest => false, Self::Custom { supports_images, .. diff --git a/crates/multi_buffer/src/anchor.rs b/crates/multi_buffer/src/anchor.rs index 9e28295c5612a843add726d2c248882577d4ee04..1305328d384023517dbb80d25e210b44e632eed8 100644 --- a/crates/multi_buffer/src/anchor.rs +++ b/crates/multi_buffer/src/anchor.rs @@ -167,10 +167,10 @@ impl Anchor { if *self == Anchor::min() || *self == Anchor::max() { true } else if let Some(excerpt) = snapshot.excerpt(self.excerpt_id) { - excerpt.contains(self) - && (self.text_anchor == excerpt.range.context.start - || self.text_anchor == excerpt.range.context.end - || self.text_anchor.is_valid(&excerpt.buffer)) + (self.text_anchor == excerpt.range.context.start + || self.text_anchor == excerpt.range.context.end + || self.text_anchor.is_valid(&excerpt.buffer)) + && excerpt.contains(self) } else { false } diff --git a/crates/multi_buffer/src/multi_buffer.rs b/crates/multi_buffer/src/multi_buffer.rs index e22fdb1ed5a978211d4dc6fd071107600ccf789f..eb12e6929cbc4bf74f44a2cb6eb9970c825d0fe3 100644 --- a/crates/multi_buffer/src/multi_buffer.rs +++ b/crates/multi_buffer/src/multi_buffer.rs @@ -43,7 +43,7 @@ use std::{ sync::Arc, time::{Duration, Instant}, }; -use sum_tree::{Bias, Cursor, Dimension, SumTree, Summary, TreeMap}; +use sum_tree::{Bias, Cursor, Dimension, Dimensions, SumTree, Summary, TreeMap}; use text::{ BufferId, Edit, LineIndent, TextSummary, locator::Locator, @@ -474,7 +474,7 @@ pub struct MultiBufferRows<'a> { pub struct MultiBufferChunks<'a> { excerpts: Cursor<'a, Excerpt, ExcerptOffset>, - diff_transforms: Cursor<'a, DiffTransform, (usize, ExcerptOffset)>, + diff_transforms: Cursor<'a, DiffTransform, Dimensions<usize, ExcerptOffset>>, diffs: &'a TreeMap<BufferId, BufferDiffSnapshot>, diff_base_chunks: Option<(BufferId, BufferChunks<'a>)>, buffer_chunk: Option<Chunk<'a>>, @@ -1211,7 +1211,7 @@ impl MultiBuffer { let buffer = buffer_state.buffer.read(cx); for range in buffer.edited_ranges_for_transaction_id::<D>(*buffer_transaction) { for excerpt_id in &buffer_state.excerpts { - cursor.seek(excerpt_id, Bias::Left, &()); + cursor.seek(excerpt_id, Bias::Left); if let Some(excerpt) = cursor.item() { if excerpt.locator == *excerpt_id { let excerpt_buffer_start = @@ -1322,7 +1322,7 @@ impl MultiBuffer { let start_locator = snapshot.excerpt_locator_for_id(selection.start.excerpt_id); let end_locator = snapshot.excerpt_locator_for_id(selection.end.excerpt_id); - cursor.seek(&Some(start_locator), Bias::Left, &()); + cursor.seek(&Some(start_locator), Bias::Left); while let Some(excerpt) = cursor.item() { if excerpt.locator > *end_locator { break; @@ -1347,7 +1347,7 @@ impl MultiBuffer { goal: selection.goal, }); - cursor.next(&()); + cursor.next(); } } @@ -1769,13 +1769,13 @@ impl MultiBuffer { let mut next_excerpt_id = move || ExcerptId(post_inc(&mut next_excerpt_id)); let mut excerpts_cursor = snapshot.excerpts.cursor::<Option<&Locator>>(&()); - excerpts_cursor.next(&()); + excerpts_cursor.next(); loop { let new = new_iter.peek(); let existing = if let Some(existing_id) = existing_iter.peek() { let locator = snapshot.excerpt_locator_for_id(*existing_id); - excerpts_cursor.seek_forward(&Some(locator), Bias::Left, &()); + excerpts_cursor.seek_forward(&Some(locator), Bias::Left); if let Some(excerpt) = excerpts_cursor.item() { if excerpt.buffer_id != buffer_snapshot.remote_id() { to_remove.push(*existing_id); @@ -1970,7 +1970,7 @@ impl MultiBuffer { let mut prev_locator = snapshot.excerpt_locator_for_id(prev_excerpt_id).clone(); let mut new_excerpt_ids = mem::take(&mut snapshot.excerpt_ids); let mut cursor = snapshot.excerpts.cursor::<Option<&Locator>>(&()); - let mut new_excerpts = cursor.slice(&prev_locator, Bias::Right, &()); + let mut new_excerpts = cursor.slice(&prev_locator, Bias::Right); prev_locator = cursor.start().unwrap_or(Locator::min_ref()).clone(); let edit_start = ExcerptOffset::new(new_excerpts.summary().text.len); @@ -2019,7 +2019,7 @@ impl MultiBuffer { let edit_end = ExcerptOffset::new(new_excerpts.summary().text.len); - let suffix = cursor.suffix(&()); + let suffix = cursor.suffix(); let changed_trailing_excerpt = suffix.is_empty(); new_excerpts.append(suffix, &()); drop(cursor); @@ -2104,7 +2104,7 @@ impl MultiBuffer { .into_iter() .flatten() { - cursor.seek_forward(&Some(locator), Bias::Left, &()); + cursor.seek_forward(&Some(locator), Bias::Left); if let Some(excerpt) = cursor.item() { if excerpt.locator == *locator { excerpts.push((excerpt.id, excerpt.range.clone())); @@ -2120,29 +2120,29 @@ impl MultiBuffer { let buffers = self.buffers.borrow(); let mut excerpts = snapshot .excerpts - .cursor::<(Option<&Locator>, ExcerptDimension<Point>)>(&()); + .cursor::<Dimensions<Option<&Locator>, ExcerptDimension<Point>>>(&()); let mut diff_transforms = snapshot .diff_transforms - .cursor::<(ExcerptDimension<Point>, OutputDimension<Point>)>(&()); - diff_transforms.next(&()); + .cursor::<Dimensions<ExcerptDimension<Point>, OutputDimension<Point>>>(&()); + diff_transforms.next(); let locators = buffers .get(&buffer_id) .into_iter() .flat_map(|state| &state.excerpts); let mut result = Vec::new(); for locator in locators { - excerpts.seek_forward(&Some(locator), Bias::Left, &()); + excerpts.seek_forward(&Some(locator), Bias::Left); if let Some(excerpt) = excerpts.item() { if excerpt.locator == *locator { let excerpt_start = excerpts.start().1.clone(); let excerpt_end = ExcerptDimension(excerpt_start.0 + excerpt.text_summary.lines); - diff_transforms.seek_forward(&excerpt_start, Bias::Left, &()); + diff_transforms.seek_forward(&excerpt_start, Bias::Left); let overshoot = excerpt_start.0 - diff_transforms.start().0.0; let start = diff_transforms.start().1.0 + overshoot; - diff_transforms.seek_forward(&excerpt_end, Bias::Right, &()); + diff_transforms.seek_forward(&excerpt_end, Bias::Right); let overshoot = excerpt_end.0 - diff_transforms.start().0.0; let end = diff_transforms.start().1.0 + overshoot; @@ -2281,7 +2281,7 @@ impl MultiBuffer { let mut new_excerpts = SumTree::default(); let mut cursor = snapshot .excerpts - .cursor::<(Option<&Locator>, ExcerptOffset)>(&()); + .cursor::<Dimensions<Option<&Locator>, ExcerptOffset>>(&()); let mut edits = Vec::new(); let mut excerpt_ids = ids.iter().copied().peekable(); let mut removed_buffer_ids = Vec::new(); @@ -2290,7 +2290,7 @@ impl MultiBuffer { self.paths_by_excerpt.remove(&excerpt_id); // Seek to the next excerpt to remove, preserving any preceding excerpts. let locator = snapshot.excerpt_locator_for_id(excerpt_id); - new_excerpts.append(cursor.slice(&Some(locator), Bias::Left, &()), &()); + new_excerpts.append(cursor.slice(&Some(locator), Bias::Left), &()); if let Some(mut excerpt) = cursor.item() { if excerpt.id != excerpt_id { @@ -2311,7 +2311,7 @@ impl MultiBuffer { removed_buffer_ids.push(excerpt.buffer_id); } } - cursor.next(&()); + cursor.next(); // Skip over any subsequent excerpts that are also removed. if let Some(&next_excerpt_id) = excerpt_ids.peek() { @@ -2344,7 +2344,7 @@ impl MultiBuffer { }); } } - let suffix = cursor.suffix(&()); + let suffix = cursor.suffix(); let changed_trailing_excerpt = suffix.is_empty(); new_excerpts.append(suffix, &()); drop(cursor); @@ -2492,8 +2492,8 @@ impl MultiBuffer { for locator in &buffer_state.excerpts { let mut cursor = snapshot .excerpts - .cursor::<(Option<&Locator>, ExcerptOffset)>(&()); - cursor.seek_forward(&Some(locator), Bias::Left, &()); + .cursor::<Dimensions<Option<&Locator>, ExcerptOffset>>(&()); + cursor.seek_forward(&Some(locator), Bias::Left); if let Some(excerpt) = cursor.item() { if excerpt.locator == *locator { let excerpt_buffer_range = excerpt.range.context.to_offset(&excerpt.buffer); @@ -2724,7 +2724,7 @@ impl MultiBuffer { let snapshot = self.read(cx); let mut cursor = snapshot.diff_transforms.cursor::<usize>(&()); let offset_range = range.to_offset(&snapshot); - cursor.seek(&offset_range.start, Bias::Left, &()); + cursor.seek(&offset_range.start, Bias::Left); while let Some(item) = cursor.item() { if *cursor.start() >= offset_range.end && *cursor.start() > offset_range.start { break; @@ -2732,7 +2732,7 @@ impl MultiBuffer { if item.hunk_info().is_some() { return true; } - cursor.next(&()); + cursor.next(); } false } @@ -2746,7 +2746,7 @@ impl MultiBuffer { let end = snapshot.point_to_offset(Point::new(range.end.row + 1, 0)); let start = start.saturating_sub(1); let end = snapshot.len().min(end + 1); - cursor.seek(&start, Bias::Right, &()); + cursor.seek(&start, Bias::Right); while let Some(item) = cursor.item() { if *cursor.start() >= end { break; @@ -2754,7 +2754,7 @@ impl MultiBuffer { if item.hunk_info().is_some() { return true; } - cursor.next(&()); + cursor.next(); } } false @@ -2845,10 +2845,10 @@ impl MultiBuffer { let mut new_excerpts = SumTree::default(); let mut cursor = snapshot .excerpts - .cursor::<(Option<&Locator>, ExcerptOffset)>(&()); + .cursor::<Dimensions<Option<&Locator>, ExcerptOffset>>(&()); let mut edits = Vec::<Edit<ExcerptOffset>>::new(); - let prefix = cursor.slice(&Some(locator), Bias::Left, &()); + let prefix = cursor.slice(&Some(locator), Bias::Left); new_excerpts.append(prefix, &()); let mut excerpt = cursor.item().unwrap().clone(); @@ -2883,9 +2883,9 @@ impl MultiBuffer { new_excerpts.push(excerpt, &()); - cursor.next(&()); + cursor.next(); - new_excerpts.append(cursor.suffix(&()), &()); + new_excerpts.append(cursor.suffix(), &()); drop(cursor); snapshot.excerpts = new_excerpts; @@ -2921,11 +2921,11 @@ impl MultiBuffer { let mut new_excerpts = SumTree::default(); let mut cursor = snapshot .excerpts - .cursor::<(Option<&Locator>, ExcerptOffset)>(&()); + .cursor::<Dimensions<Option<&Locator>, ExcerptOffset>>(&()); let mut edits = Vec::<Edit<ExcerptOffset>>::new(); for locator in &locators { - let prefix = cursor.slice(&Some(locator), Bias::Left, &()); + let prefix = cursor.slice(&Some(locator), Bias::Left); new_excerpts.append(prefix, &()); let mut excerpt = cursor.item().unwrap().clone(); @@ -2987,10 +2987,10 @@ impl MultiBuffer { new_excerpts.push(excerpt, &()); - cursor.next(&()); + cursor.next(); } - new_excerpts.append(cursor.suffix(&()), &()); + new_excerpts.append(cursor.suffix(), &()); drop(cursor); snapshot.excerpts = new_excerpts; @@ -3067,10 +3067,10 @@ impl MultiBuffer { let mut new_excerpts = SumTree::default(); let mut cursor = snapshot .excerpts - .cursor::<(Option<&Locator>, ExcerptOffset)>(&()); + .cursor::<Dimensions<Option<&Locator>, ExcerptOffset>>(&()); for (locator, buffer, buffer_edited) in excerpts_to_edit { - new_excerpts.append(cursor.slice(&Some(locator), Bias::Left, &()), &()); + new_excerpts.append(cursor.slice(&Some(locator), Bias::Left), &()); let old_excerpt = cursor.item().unwrap(); let buffer = buffer.read(cx); let buffer_id = buffer.remote_id(); @@ -3112,9 +3112,9 @@ impl MultiBuffer { } new_excerpts.push(new_excerpt, &()); - cursor.next(&()); + cursor.next(); } - new_excerpts.append(cursor.suffix(&()), &()); + new_excerpts.append(cursor.suffix(), &()); drop(cursor); snapshot.excerpts = new_excerpts; @@ -3135,7 +3135,7 @@ impl MultiBuffer { let mut excerpts = snapshot.excerpts.cursor::<ExcerptOffset>(&()); let mut old_diff_transforms = snapshot .diff_transforms - .cursor::<(ExcerptOffset, usize)>(&()); + .cursor::<Dimensions<ExcerptOffset, usize>>(&()); let mut new_diff_transforms = SumTree::default(); let mut old_expanded_hunks = HashSet::default(); let mut output_edits = Vec::new(); @@ -3145,23 +3145,22 @@ impl MultiBuffer { let mut excerpt_edits = excerpt_edits.into_iter().peekable(); while let Some(edit) = excerpt_edits.next() { - excerpts.seek_forward(&edit.new.start, Bias::Right, &()); + excerpts.seek_forward(&edit.new.start, Bias::Right); if excerpts.item().is_none() && *excerpts.start() == edit.new.start { - excerpts.prev(&()); + excerpts.prev(); } // Keep any transforms that are before the edit. if at_transform_boundary { at_transform_boundary = false; - let transforms_before_edit = - old_diff_transforms.slice(&edit.old.start, Bias::Left, &()); + let transforms_before_edit = old_diff_transforms.slice(&edit.old.start, Bias::Left); self.append_diff_transforms(&mut new_diff_transforms, transforms_before_edit); if let Some(transform) = old_diff_transforms.item() { - if old_diff_transforms.end(&()).0 == edit.old.start + if old_diff_transforms.end().0 == edit.old.start && old_diff_transforms.start().0 < edit.old.start { self.push_diff_transform(&mut new_diff_transforms, transform.clone()); - old_diff_transforms.next(&()); + old_diff_transforms.next(); } } } @@ -3203,7 +3202,7 @@ impl MultiBuffer { // then recreate the content up to the end of this transform, to prepare // for reusing additional slices of the old transforms. if excerpt_edits.peek().map_or(true, |next_edit| { - next_edit.old.start >= old_diff_transforms.end(&()).0 + next_edit.old.start >= old_diff_transforms.end().0 }) { let keep_next_old_transform = (old_diff_transforms.start().0 >= edit.old.end) && match old_diff_transforms.item() { @@ -3218,8 +3217,8 @@ impl MultiBuffer { let mut excerpt_offset = edit.new.end; if !keep_next_old_transform { - excerpt_offset += old_diff_transforms.end(&()).0 - edit.old.end; - old_diff_transforms.next(&()); + excerpt_offset += old_diff_transforms.end().0 - edit.old.end; + old_diff_transforms.next(); } old_expanded_hunks.clear(); @@ -3234,7 +3233,7 @@ impl MultiBuffer { } // Keep any transforms that are after the last edit. - self.append_diff_transforms(&mut new_diff_transforms, old_diff_transforms.suffix(&())); + self.append_diff_transforms(&mut new_diff_transforms, old_diff_transforms.suffix()); // Ensure there's always at least one buffer content transform. if new_diff_transforms.is_empty() { @@ -3261,7 +3260,7 @@ impl MultiBuffer { &self, edit: &Edit<TypedOffset<Excerpt>>, excerpts: &mut Cursor<Excerpt, TypedOffset<Excerpt>>, - old_diff_transforms: &mut Cursor<DiffTransform, (TypedOffset<Excerpt>, usize)>, + old_diff_transforms: &mut Cursor<DiffTransform, Dimensions<TypedOffset<Excerpt>, usize>>, new_diff_transforms: &mut SumTree<DiffTransform>, end_of_current_insert: &mut Option<(TypedOffset<Excerpt>, DiffTransformHunkInfo)>, old_expanded_hunks: &mut HashSet<DiffTransformHunkInfo>, @@ -3283,10 +3282,10 @@ impl MultiBuffer { ); old_expanded_hunks.insert(hunk_info); } - if old_diff_transforms.end(&()).0 > edit.old.end { + if old_diff_transforms.end().0 > edit.old.end { break; } - old_diff_transforms.next(&()); + old_diff_transforms.next(); } // Avoid querying diff hunks if there's no possibility of hunks being expanded. @@ -3413,8 +3412,8 @@ impl MultiBuffer { } } - if excerpts.end(&()) <= edit.new.end { - excerpts.next(&()); + if excerpts.end() <= edit.new.end { + excerpts.next(); } else { break; } @@ -3439,9 +3438,9 @@ impl MultiBuffer { *summary, ) { let mut cursor = subtree.cursor::<()>(&()); - cursor.next(&()); - cursor.next(&()); - new_transforms.append(cursor.suffix(&()), &()); + cursor.next(); + cursor.next(); + new_transforms.append(cursor.suffix(), &()); return; } } @@ -4714,15 +4713,17 @@ impl MultiBufferSnapshot { O: ToOffset, { let range = range.start.to_offset(self)..range.end.to_offset(self); - let mut cursor = self.diff_transforms.cursor::<(usize, ExcerptOffset)>(&()); - cursor.seek(&range.start, Bias::Right, &()); + let mut cursor = self + .diff_transforms + .cursor::<Dimensions<usize, ExcerptOffset>>(&()); + cursor.seek(&range.start, Bias::Right); let Some(first_transform) = cursor.item() else { return D::from_text_summary(&TextSummary::default()); }; let diff_transform_start = cursor.start().0; - let diff_transform_end = cursor.end(&()).0; + let diff_transform_end = cursor.end().0; let diff_start = range.start; let start_overshoot = diff_start - diff_transform_start; let end_overshoot = std::cmp::min(range.end, diff_transform_end) - diff_transform_start; @@ -4765,12 +4766,10 @@ impl MultiBufferSnapshot { return result; } - cursor.next(&()); - result.add_assign(&D::from_text_summary(&cursor.summary( - &range.end, - Bias::Right, - &(), - ))); + cursor.next(); + result.add_assign(&D::from_text_summary( + &cursor.summary(&range.end, Bias::Right), + )); let Some(last_transform) = cursor.item() else { return result; @@ -4813,9 +4812,9 @@ impl MultiBufferSnapshot { // let mut range = range.start..range.end; let mut summary = D::zero(&()); let mut cursor = self.excerpts.cursor::<ExcerptOffset>(&()); - cursor.seek(&range.start, Bias::Right, &()); + cursor.seek(&range.start, Bias::Right); if let Some(excerpt) = cursor.item() { - let mut end_before_newline = cursor.end(&()); + let mut end_before_newline = cursor.end(); if excerpt.has_trailing_newline { end_before_newline -= ExcerptOffset::new(1); } @@ -4834,13 +4833,13 @@ impl MultiBufferSnapshot { summary.add_assign(&D::from_text_summary(&TextSummary::from("\n"))); } - cursor.next(&()); + cursor.next(); } if range.end > *cursor.start() { summary.add_assign( &cursor - .summary::<_, ExcerptDimension<D>>(&range.end, Bias::Right, &()) + .summary::<_, ExcerptDimension<D>>(&range.end, Bias::Right) .0, ); if let Some(excerpt) = cursor.item() { @@ -4870,17 +4869,20 @@ impl MultiBufferSnapshot { &self, anchor: &Anchor, excerpt_position: D, - diff_transforms: &mut Cursor<DiffTransform, (ExcerptDimension<D>, OutputDimension<D>)>, + diff_transforms: &mut Cursor< + DiffTransform, + Dimensions<ExcerptDimension<D>, OutputDimension<D>>, + >, ) -> D where D: TextDimension + Ord + Sub<D, Output = D>, { loop { - let transform_end_position = diff_transforms.end(&()).0.0; + let transform_end_position = diff_transforms.end().0.0; let at_transform_end = excerpt_position == transform_end_position && diff_transforms.item().is_some(); if at_transform_end && anchor.text_anchor.bias == Bias::Right { - diff_transforms.next(&()); + diff_transforms.next(); continue; } @@ -4906,7 +4908,7 @@ impl MultiBufferSnapshot { ); position.add_assign(&position_in_hunk); } else if at_transform_end { - diff_transforms.next(&()); + diff_transforms.next(); continue; } } @@ -4915,7 +4917,7 @@ impl MultiBufferSnapshot { } _ => { if at_transform_end && anchor.diff_base_anchor.is_some() { - diff_transforms.next(&()); + diff_transforms.next(); continue; } let overshoot = excerpt_position - diff_transforms.start().0.0; @@ -4930,12 +4932,12 @@ impl MultiBufferSnapshot { fn excerpt_offset_for_anchor(&self, anchor: &Anchor) -> ExcerptOffset { let mut cursor = self .excerpts - .cursor::<(Option<&Locator>, ExcerptOffset)>(&()); + .cursor::<Dimensions<Option<&Locator>, ExcerptOffset>>(&()); let locator = self.excerpt_locator_for_id(anchor.excerpt_id); - cursor.seek(&Some(locator), Bias::Left, &()); + cursor.seek(&Some(locator), Bias::Left); if cursor.item().is_none() { - cursor.next(&()); + cursor.next(); } let mut position = cursor.start().1; @@ -4974,8 +4976,8 @@ impl MultiBufferSnapshot { let mut cursor = self.excerpts.cursor::<ExcerptSummary>(&()); let mut diff_transforms_cursor = self .diff_transforms - .cursor::<(ExcerptDimension<D>, OutputDimension<D>)>(&()); - diff_transforms_cursor.next(&()); + .cursor::<Dimensions<ExcerptDimension<D>, OutputDimension<D>>>(&()); + diff_transforms_cursor.next(); let mut summaries = Vec::new(); while let Some(anchor) = anchors.peek() { @@ -4990,9 +4992,9 @@ impl MultiBufferSnapshot { }); let locator = self.excerpt_locator_for_id(excerpt_id); - cursor.seek_forward(locator, Bias::Left, &()); + cursor.seek_forward(locator, Bias::Left); if cursor.item().is_none() { - cursor.next(&()); + cursor.next(); } let excerpt_start_position = D::from_text_summary(&cursor.start().text); @@ -5022,11 +5024,8 @@ impl MultiBufferSnapshot { } if position > diff_transforms_cursor.start().0.0 { - diff_transforms_cursor.seek_forward( - &ExcerptDimension(position), - Bias::Left, - &(), - ); + diff_transforms_cursor + .seek_forward(&ExcerptDimension(position), Bias::Left); } summaries.push(self.resolve_summary_for_anchor( @@ -5036,11 +5035,8 @@ impl MultiBufferSnapshot { )); } } else { - diff_transforms_cursor.seek_forward( - &ExcerptDimension(excerpt_start_position), - Bias::Left, - &(), - ); + diff_transforms_cursor + .seek_forward(&ExcerptDimension(excerpt_start_position), Bias::Left); let position = self.resolve_summary_for_anchor( &Anchor::max(), excerpt_start_position, @@ -5099,7 +5095,7 @@ impl MultiBufferSnapshot { { let mut anchors = anchors.into_iter().enumerate().peekable(); let mut cursor = self.excerpts.cursor::<Option<&Locator>>(&()); - cursor.next(&()); + cursor.next(); let mut result = Vec::new(); @@ -5108,10 +5104,10 @@ impl MultiBufferSnapshot { // Find the location where this anchor's excerpt should be. let old_locator = self.excerpt_locator_for_id(old_excerpt_id); - cursor.seek_forward(&Some(old_locator), Bias::Left, &()); + cursor.seek_forward(&Some(old_locator), Bias::Left); if cursor.item().is_none() { - cursor.next(&()); + cursor.next(); } let next_excerpt = cursor.item(); @@ -5210,14 +5206,16 @@ impl MultiBufferSnapshot { // Find the given position in the diff transforms. Determine the corresponding // offset in the excerpts, and whether the position is within a deleted hunk. - let mut diff_transforms = self.diff_transforms.cursor::<(usize, ExcerptOffset)>(&()); - diff_transforms.seek(&offset, Bias::Right, &()); + let mut diff_transforms = self + .diff_transforms + .cursor::<Dimensions<usize, ExcerptOffset>>(&()); + diff_transforms.seek(&offset, Bias::Right); if offset == diff_transforms.start().0 && bias == Bias::Left { if let Some(prev_item) = diff_transforms.prev_item() { match prev_item { DiffTransform::DeletedHunk { .. } => { - diff_transforms.prev(&()); + diff_transforms.prev(); } _ => {} } @@ -5259,14 +5257,14 @@ impl MultiBufferSnapshot { let mut excerpts = self .excerpts - .cursor::<(ExcerptOffset, Option<ExcerptId>)>(&()); - excerpts.seek(&excerpt_offset, Bias::Right, &()); + .cursor::<Dimensions<ExcerptOffset, Option<ExcerptId>>>(&()); + excerpts.seek(&excerpt_offset, Bias::Right); if excerpts.item().is_none() && excerpt_offset == excerpts.start().0 && bias == Bias::Left { - excerpts.prev(&()); + excerpts.prev(); } if let Some(excerpt) = excerpts.item() { let mut overshoot = excerpt_offset.saturating_sub(excerpts.start().0).value; - if excerpt.has_trailing_newline && excerpt_offset == excerpts.end(&()).0 { + if excerpt.has_trailing_newline && excerpt_offset == excerpts.end().0 { overshoot -= 1; bias = Bias::Right; } @@ -5297,7 +5295,7 @@ impl MultiBufferSnapshot { let excerpt_id = self.latest_excerpt_id(excerpt_id); let locator = self.excerpt_locator_for_id(excerpt_id); let mut cursor = self.excerpts.cursor::<Option<&Locator>>(&()); - cursor.seek(locator, Bias::Left, &()); + cursor.seek(locator, Bias::Left); if let Some(excerpt) = cursor.item() { if excerpt.id == excerpt_id { let text_anchor = excerpt.clip_anchor(text_anchor); @@ -5350,14 +5348,14 @@ impl MultiBufferSnapshot { let start_locator = self.excerpt_locator_for_id(id); let mut excerpts = self .excerpts - .cursor::<(Option<&Locator>, ExcerptDimension<usize>)>(&()); - excerpts.seek(&Some(start_locator), Bias::Left, &()); - excerpts.prev(&()); + .cursor::<Dimensions<Option<&Locator>, ExcerptDimension<usize>>>(&()); + excerpts.seek(&Some(start_locator), Bias::Left); + excerpts.prev(); let mut diff_transforms = self.diff_transforms.cursor::<DiffTransforms<usize>>(&()); - diff_transforms.seek(&excerpts.start().1, Bias::Left, &()); - if diff_transforms.end(&()).excerpt_dimension < excerpts.start().1 { - diff_transforms.next(&()); + diff_transforms.seek(&excerpts.start().1, Bias::Left); + if diff_transforms.end().excerpt_dimension < excerpts.start().1 { + diff_transforms.next(); } let excerpt = excerpts.item()?; @@ -5905,7 +5903,6 @@ impl MultiBufferSnapshot { let depth = if found_indent { line_indent.len(tab_size) / tab_size - + ((line_indent.len(tab_size) % tab_size) > 0) as u32 } else { 0 }; @@ -6194,7 +6191,7 @@ impl MultiBufferSnapshot { Locator::max_ref() } else { let mut cursor = self.excerpt_ids.cursor::<ExcerptId>(&()); - cursor.seek(&id, Bias::Left, &()); + cursor.seek(&id, Bias::Left); if let Some(entry) = cursor.item() { if entry.id == id { return &entry.locator; @@ -6230,7 +6227,7 @@ impl MultiBufferSnapshot { let mut cursor = self.excerpt_ids.cursor::<ExcerptId>(&()); for id in sorted_ids { - if cursor.seek_forward(&id, Bias::Left, &()) { + if cursor.seek_forward(&id, Bias::Left) { locators.push(cursor.item().unwrap().locator.clone()); } else { panic!("invalid excerpt id {:?}", id); @@ -6252,18 +6249,18 @@ impl MultiBufferSnapshot { pub fn range_for_excerpt(&self, excerpt_id: ExcerptId) -> Option<Range<Point>> { let mut cursor = self .excerpts - .cursor::<(Option<&Locator>, ExcerptDimension<Point>)>(&()); + .cursor::<Dimensions<Option<&Locator>, ExcerptDimension<Point>>>(&()); let locator = self.excerpt_locator_for_id(excerpt_id); - if cursor.seek(&Some(locator), Bias::Left, &()) { + if cursor.seek(&Some(locator), Bias::Left) { let start = cursor.start().1.clone(); - let end = cursor.end(&()).1; + let end = cursor.end().1; let mut diff_transforms = self .diff_transforms - .cursor::<(ExcerptDimension<Point>, OutputDimension<Point>)>(&()); - diff_transforms.seek(&start, Bias::Left, &()); + .cursor::<Dimensions<ExcerptDimension<Point>, OutputDimension<Point>>>(&()); + diff_transforms.seek(&start, Bias::Left); let overshoot = start.0 - diff_transforms.start().0.0; let start = diff_transforms.start().1.0 + overshoot; - diff_transforms.seek(&end, Bias::Right, &()); + diff_transforms.seek(&end, Bias::Right); let overshoot = end.0 - diff_transforms.start().0.0; let end = diff_transforms.start().1.0 + overshoot; Some(start..end) @@ -6275,7 +6272,7 @@ impl MultiBufferSnapshot { pub fn buffer_range_for_excerpt(&self, excerpt_id: ExcerptId) -> Option<Range<text::Anchor>> { let mut cursor = self.excerpts.cursor::<Option<&Locator>>(&()); let locator = self.excerpt_locator_for_id(excerpt_id); - if cursor.seek(&Some(locator), Bias::Left, &()) { + if cursor.seek(&Some(locator), Bias::Left) { if let Some(excerpt) = cursor.item() { return Some(excerpt.range.context.clone()); } @@ -6286,7 +6283,7 @@ impl MultiBufferSnapshot { fn excerpt(&self, excerpt_id: ExcerptId) -> Option<&Excerpt> { let mut cursor = self.excerpts.cursor::<Option<&Locator>>(&()); let locator = self.excerpt_locator_for_id(excerpt_id); - cursor.seek(&Some(locator), Bias::Left, &()); + cursor.seek(&Some(locator), Bias::Left); if let Some(excerpt) = cursor.item() { if excerpt.id == excerpt_id { return Some(excerpt); @@ -6334,7 +6331,7 @@ impl MultiBufferSnapshot { let mut cursor = self.excerpts.cursor::<ExcerptSummary>(&()); let start_locator = self.excerpt_locator_for_id(range.start.excerpt_id); let end_locator = self.excerpt_locator_for_id(range.end.excerpt_id); - cursor.seek(start_locator, Bias::Left, &()); + cursor.seek(start_locator, Bias::Left); cursor .take_while(move |excerpt| excerpt.locator <= *end_locator) .flat_map(move |excerpt| { @@ -6473,11 +6470,11 @@ where fn seek(&mut self, position: &D) { self.cached_region.take(); self.diff_transforms - .seek(&OutputDimension(*position), Bias::Right, &()); + .seek(&OutputDimension(*position), Bias::Right); if self.diff_transforms.item().is_none() && *position == self.diff_transforms.start().output_dimension.0 { - self.diff_transforms.prev(&()); + self.diff_transforms.prev(); } let mut excerpt_position = self.diff_transforms.start().excerpt_dimension.0; @@ -6487,20 +6484,20 @@ where } self.excerpts - .seek(&ExcerptDimension(excerpt_position), Bias::Right, &()); + .seek(&ExcerptDimension(excerpt_position), Bias::Right); if self.excerpts.item().is_none() && excerpt_position == self.excerpts.start().0 { - self.excerpts.prev(&()); + self.excerpts.prev(); } } fn seek_forward(&mut self, position: &D) { self.cached_region.take(); self.diff_transforms - .seek_forward(&OutputDimension(*position), Bias::Right, &()); + .seek_forward(&OutputDimension(*position), Bias::Right); if self.diff_transforms.item().is_none() && *position == self.diff_transforms.start().output_dimension.0 { - self.diff_transforms.prev(&()); + self.diff_transforms.prev(); } let overshoot = *position - self.diff_transforms.start().output_dimension.0; @@ -6510,31 +6507,30 @@ where } self.excerpts - .seek_forward(&ExcerptDimension(excerpt_position), Bias::Right, &()); + .seek_forward(&ExcerptDimension(excerpt_position), Bias::Right); if self.excerpts.item().is_none() && excerpt_position == self.excerpts.start().0 { - self.excerpts.prev(&()); + self.excerpts.prev(); } } fn next_excerpt(&mut self) { - self.excerpts.next(&()); + self.excerpts.next(); self.seek_to_start_of_current_excerpt(); } fn prev_excerpt(&mut self) { - self.excerpts.prev(&()); + self.excerpts.prev(); self.seek_to_start_of_current_excerpt(); } fn seek_to_start_of_current_excerpt(&mut self) { self.cached_region.take(); - self.diff_transforms - .seek(self.excerpts.start(), Bias::Left, &()); - if self.diff_transforms.end(&()).excerpt_dimension == *self.excerpts.start() + self.diff_transforms.seek(self.excerpts.start(), Bias::Left); + if self.diff_transforms.end().excerpt_dimension == *self.excerpts.start() && self.diff_transforms.start().excerpt_dimension < *self.excerpts.start() && self.diff_transforms.next_item().is_some() { - self.diff_transforms.next(&()); + self.diff_transforms.next(); } } @@ -6542,18 +6538,18 @@ where self.cached_region.take(); match self .diff_transforms - .end(&()) + .end() .excerpt_dimension - .cmp(&self.excerpts.end(&())) + .cmp(&self.excerpts.end()) { - cmp::Ordering::Less => self.diff_transforms.next(&()), - cmp::Ordering::Greater => self.excerpts.next(&()), + cmp::Ordering::Less => self.diff_transforms.next(), + cmp::Ordering::Greater => self.excerpts.next(), cmp::Ordering::Equal => { - self.diff_transforms.next(&()); - if self.diff_transforms.end(&()).excerpt_dimension > self.excerpts.end(&()) + self.diff_transforms.next(); + if self.diff_transforms.end().excerpt_dimension > self.excerpts.end() || self.diff_transforms.item().is_none() { - self.excerpts.next(&()); + self.excerpts.next(); } else if let Some(DiffTransform::DeletedHunk { hunk_info, .. }) = self.diff_transforms.item() { @@ -6562,7 +6558,7 @@ where .item() .map_or(false, |excerpt| excerpt.id != hunk_info.excerpt_id) { - self.excerpts.next(&()); + self.excerpts.next(); } } } @@ -6577,14 +6573,14 @@ where .excerpt_dimension .cmp(self.excerpts.start()) { - cmp::Ordering::Less => self.excerpts.prev(&()), - cmp::Ordering::Greater => self.diff_transforms.prev(&()), + cmp::Ordering::Less => self.excerpts.prev(), + cmp::Ordering::Greater => self.diff_transforms.prev(), cmp::Ordering::Equal => { - self.diff_transforms.prev(&()); + self.diff_transforms.prev(); if self.diff_transforms.start().excerpt_dimension < *self.excerpts.start() || self.diff_transforms.item().is_none() { - self.excerpts.prev(&()); + self.excerpts.prev(); } } } @@ -6604,9 +6600,9 @@ where return true; } - self.diff_transforms.prev(&()); + self.diff_transforms.prev(); let prev_transform = self.diff_transforms.item(); - self.diff_transforms.next(&()); + self.diff_transforms.next(); prev_transform.map_or(true, |next_transform| { matches!(next_transform, DiffTransform::BufferContent { .. }) @@ -6614,9 +6610,9 @@ where } fn is_at_end_of_excerpt(&mut self) -> bool { - if self.diff_transforms.end(&()).excerpt_dimension < self.excerpts.end(&()) { + if self.diff_transforms.end().excerpt_dimension < self.excerpts.end() { return false; - } else if self.diff_transforms.end(&()).excerpt_dimension > self.excerpts.end(&()) + } else if self.diff_transforms.end().excerpt_dimension > self.excerpts.end() || self.diff_transforms.item().is_none() { return true; @@ -6637,7 +6633,7 @@ where let buffer = &excerpt.buffer; let buffer_context_start = excerpt.range.context.start.summary::<D>(buffer); let mut buffer_start = buffer_context_start; - let overshoot = self.diff_transforms.end(&()).excerpt_dimension.0 - self.excerpts.start().0; + let overshoot = self.diff_transforms.end().excerpt_dimension.0 - self.excerpts.start().0; buffer_start.add_assign(&overshoot); Some(buffer_start) } @@ -6660,7 +6656,7 @@ where let mut buffer_end = buffer_start; buffer_end.add_assign(&buffer_range_len); let start = self.diff_transforms.start().output_dimension.0; - let end = self.diff_transforms.end(&()).output_dimension.0; + let end = self.diff_transforms.end().output_dimension.0; return Some(MultiBufferRegion { buffer, excerpt, @@ -6694,16 +6690,16 @@ where let mut end; let mut buffer_end; let has_trailing_newline; - if self.diff_transforms.end(&()).excerpt_dimension.0 < self.excerpts.end(&()).0 { + if self.diff_transforms.end().excerpt_dimension.0 < self.excerpts.end().0 { let overshoot = - self.diff_transforms.end(&()).excerpt_dimension.0 - self.excerpts.start().0; - end = self.diff_transforms.end(&()).output_dimension.0; + self.diff_transforms.end().excerpt_dimension.0 - self.excerpts.start().0; + end = self.diff_transforms.end().output_dimension.0; buffer_end = buffer_context_start; buffer_end.add_assign(&overshoot); has_trailing_newline = false; } else { let overshoot = - self.excerpts.end(&()).0 - self.diff_transforms.start().excerpt_dimension.0; + self.excerpts.end().0 - self.diff_transforms.start().excerpt_dimension.0; end = self.diff_transforms.start().output_dimension.0; end.add_assign(&overshoot); buffer_end = excerpt.range.context.end.summary::<D>(buffer); @@ -7087,11 +7083,11 @@ impl<'a> MultiBufferExcerpt<'a> { /// Maps a range within the [`MultiBuffer`] to a range within the [`Buffer`] pub fn map_range_to_buffer(&mut self, range: Range<usize>) -> Range<usize> { self.diff_transforms - .seek(&OutputDimension(range.start), Bias::Right, &()); + .seek(&OutputDimension(range.start), Bias::Right); let start = self.map_offset_to_buffer_internal(range.start); let end = if range.end > range.start { self.diff_transforms - .seek_forward(&OutputDimension(range.end), Bias::Right, &()); + .seek_forward(&OutputDimension(range.end), Bias::Right); self.map_offset_to_buffer_internal(range.end) } else { start @@ -7124,7 +7120,7 @@ impl<'a> MultiBufferExcerpt<'a> { } let overshoot = buffer_range.start - self.buffer_offset; let excerpt_offset = ExcerptDimension(self.excerpt_offset.0 + overshoot); - self.diff_transforms.seek(&excerpt_offset, Bias::Right, &()); + self.diff_transforms.seek(&excerpt_offset, Bias::Right); if excerpt_offset.0 < self.diff_transforms.start().excerpt_dimension.0 { log::warn!( "Attempting to map a range from a buffer offset that starts before the current buffer offset" @@ -7138,7 +7134,7 @@ impl<'a> MultiBufferExcerpt<'a> { let overshoot = buffer_range.end - self.buffer_offset; let excerpt_offset = ExcerptDimension(self.excerpt_offset.0 + overshoot); self.diff_transforms - .seek_forward(&excerpt_offset, Bias::Right, &()); + .seek_forward(&excerpt_offset, Bias::Right); let overshoot = excerpt_offset.0 - self.diff_transforms.start().excerpt_dimension.0; self.diff_transforms.start().output_dimension.0 + overshoot } else { @@ -7510,7 +7506,7 @@ impl Iterator for MultiBufferRows<'_> { if let Some(next_region) = self.cursor.region() { region = next_region; } else { - if self.point == self.cursor.diff_transforms.end(&()).output_dimension.0 { + if self.point == self.cursor.diff_transforms.end().output_dimension.0 { let multibuffer_row = MultiBufferRow(self.point.row); let last_excerpt = self .cursor @@ -7616,14 +7612,14 @@ impl<'a> MultiBufferChunks<'a> { } pub fn seek(&mut self, range: Range<usize>) { - self.diff_transforms.seek(&range.end, Bias::Right, &()); + self.diff_transforms.seek(&range.end, Bias::Right); let mut excerpt_end = self.diff_transforms.start().1; if let Some(DiffTransform::BufferContent { .. }) = self.diff_transforms.item() { let overshoot = range.end - self.diff_transforms.start().0; excerpt_end.value += overshoot; } - self.diff_transforms.seek(&range.start, Bias::Right, &()); + self.diff_transforms.seek(&range.start, Bias::Right); let mut excerpt_start = self.diff_transforms.start().1; if let Some(DiffTransform::BufferContent { .. }) = self.diff_transforms.item() { let overshoot = range.start - self.diff_transforms.start().0; @@ -7637,7 +7633,7 @@ impl<'a> MultiBufferChunks<'a> { fn seek_to_excerpt_offset_range(&mut self, new_range: Range<ExcerptOffset>) { self.excerpt_offset_range = new_range.clone(); - self.excerpts.seek(&new_range.start, Bias::Right, &()); + self.excerpts.seek(&new_range.start, Bias::Right); if let Some(excerpt) = self.excerpts.item() { let excerpt_start = *self.excerpts.start(); if let Some(excerpt_chunks) = self @@ -7670,7 +7666,7 @@ impl<'a> MultiBufferChunks<'a> { self.excerpt_offset_range.start.value += chunk.text.len(); return Some(chunk); } else { - self.excerpts.next(&()); + self.excerpts.next(); let excerpt = self.excerpts.item()?; self.excerpt_chunks = Some(excerpt.chunks_in_range( 0..(self.excerpt_offset_range.end - *self.excerpts.start()).value, @@ -7713,12 +7709,12 @@ impl<'a> Iterator for MultiBufferChunks<'a> { if self.range.start >= self.range.end { return None; } - if self.range.start == self.diff_transforms.end(&()).0 { - self.diff_transforms.next(&()); + if self.range.start == self.diff_transforms.end().0 { + self.diff_transforms.next(); } let diff_transform_start = self.diff_transforms.start().0; - let diff_transform_end = self.diff_transforms.end(&()).0; + let diff_transform_end = self.diff_transforms.end().0; debug_assert!(self.range.start < diff_transform_end); let diff_transform = self.diff_transforms.item()?; diff --git a/crates/nc/Cargo.toml b/crates/nc/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..46ef2d3c62e233cc8693b3fdb3082749c05d9ed5 --- /dev/null +++ b/crates/nc/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "nc" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/nc.rs" +doctest = false + +[dependencies] +anyhow.workspace = true +futures.workspace = true +net.workspace = true +smol.workspace = true +workspace-hack.workspace = true diff --git a/crates/nc/LICENSE-GPL b/crates/nc/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/nc/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/nc/src/nc.rs b/crates/nc/src/nc.rs new file mode 100644 index 0000000000000000000000000000000000000000..fccb4d726c49258d75323bf59389bfeb5baafa6a --- /dev/null +++ b/crates/nc/src/nc.rs @@ -0,0 +1,56 @@ +use anyhow::Result; + +#[cfg(windows)] +pub fn main(_socket: &str) -> Result<()> { + // It looks like we can't get an async stdio stream on Windows from smol. + // + // We decided to merge this with a panic on Windows since this is only used + // by the experimental Claude Code Agent Server. + // + // We're tracking this internally, and we will address it before shipping the integration. + panic!("--nc isn't yet supported on Windows"); +} + +/// The main function for when Zed is running in netcat mode +#[cfg(not(windows))] +pub fn main(socket: &str) -> Result<()> { + use futures::{AsyncReadExt as _, AsyncWriteExt as _, FutureExt as _, io::BufReader, select}; + use net::async_net::UnixStream; + use smol::{Async, io::AsyncBufReadExt}; + + smol::block_on(async { + let socket_stream = UnixStream::connect(socket).await?; + let (socket_read, mut socket_write) = socket_stream.split(); + let mut socket_reader = BufReader::new(socket_read); + + let mut stdout = Async::new(std::io::stdout())?; + let stdin = Async::new(std::io::stdin())?; + let mut stdin_reader = BufReader::new(stdin); + + let mut socket_line = Vec::new(); + let mut stdin_line = Vec::new(); + + loop { + select! { + bytes_read = socket_reader.read_until(b'\n', &mut socket_line).fuse() => { + if bytes_read? == 0 { + break + } + stdout.write_all(&socket_line).await?; + stdout.flush().await?; + socket_line.clear(); + } + bytes_read = stdin_reader.read_until(b'\n', &mut stdin_line).fuse() => { + if bytes_read? == 0 { + break + } + socket_write.write_all(&stdin_line).await?; + socket_write.flush().await?; + stdin_line.clear(); + } + } + } + + anyhow::Ok(()) + }) +} diff --git a/crates/notifications/src/notification_store.rs b/crates/notifications/src/notification_store.rs index c2f18e57001bdc748738850c2f87948ea6196cef..29653748e4873a271f58f932ee71c820aa755b9a 100644 --- a/crates/notifications/src/notification_store.rs +++ b/crates/notifications/src/notification_store.rs @@ -6,7 +6,7 @@ use db::smol::stream::StreamExt; use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, Task}; use rpc::{Notification, TypedEnvelope, proto}; use std::{ops::Range, sync::Arc}; -use sum_tree::{Bias, SumTree}; +use sum_tree::{Bias, Dimensions, SumTree}; use time::OffsetDateTime; use util::ResultExt; @@ -132,12 +132,12 @@ impl NotificationStore { } let ix = count - 1 - ix; let mut cursor = self.notifications.cursor::<Count>(&()); - cursor.seek(&Count(ix), Bias::Right, &()); + cursor.seek(&Count(ix), Bias::Right); cursor.item() } pub fn notification_for_id(&self, id: u64) -> Option<&NotificationEntry> { let mut cursor = self.notifications.cursor::<NotificationId>(&()); - cursor.seek(&NotificationId(id), Bias::Left, &()); + cursor.seek(&NotificationId(id), Bias::Left); if let Some(item) = cursor.item() { if item.id == id { return Some(item); @@ -360,12 +360,14 @@ impl NotificationStore { is_new: bool, cx: &mut Context<NotificationStore>, ) { - let mut cursor = self.notifications.cursor::<(NotificationId, Count)>(&()); + let mut cursor = self + .notifications + .cursor::<Dimensions<NotificationId, Count>>(&()); let mut new_notifications = SumTree::default(); let mut old_range = 0..0; for (i, (id, new_notification)) in notifications.into_iter().enumerate() { - new_notifications.append(cursor.slice(&NotificationId(id), Bias::Left, &()), &()); + new_notifications.append(cursor.slice(&NotificationId(id), Bias::Left), &()); if i == 0 { old_range.start = cursor.start().1.0; @@ -374,7 +376,7 @@ impl NotificationStore { let old_notification = cursor.item(); if let Some(old_notification) = old_notification { if old_notification.id == id { - cursor.next(&()); + cursor.next(); if let Some(new_notification) = &new_notification { if new_notification.is_read { @@ -403,7 +405,7 @@ impl NotificationStore { old_range.end = cursor.start().1.0; let new_count = new_notifications.summary().count - old_range.start; - new_notifications.append(cursor.suffix(&()), &()); + new_notifications.append(cursor.suffix(), &()); drop(cursor); self.notifications = new_notifications; diff --git a/crates/ollama/src/ollama.rs b/crates/ollama/src/ollama.rs index 109fea7353d8e35ccaabb3fce4fd76e9b3529a9b..64cd1cc0cbc06607ee9b3b72ee81cbeb9489c344 100644 --- a/crates/ollama/src/ollama.rs +++ b/crates/ollama/src/ollama.rs @@ -55,9 +55,10 @@ fn get_max_tokens(name: &str) -> u64 { "codellama" | "starcoder2" => 16384, "mistral" | "codestral" | "mixstral" | "llava" | "qwen2" | "qwen2.5-coder" | "dolphin-mixtral" => 32768, + "magistral" => 40000, "llama3.1" | "llama3.2" | "llama3.3" | "phi3" | "phi3.5" | "phi4" | "command-r" | "qwen3" | "gemma3" | "deepseek-coder-v2" | "deepseek-v3" | "deepseek-r1" | "yi-coder" - | "devstral" => 128000, + | "devstral" | "gpt-oss" => 128000, _ => DEFAULT_TOKENS, } .clamp(1, MAXIMUM_TOKENS) diff --git a/crates/onboarding/Cargo.toml b/crates/onboarding/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..436c714cf311bfd46e0eb5b961a8b929a5f09b58 --- /dev/null +++ b/crates/onboarding/Cargo.toml @@ -0,0 +1,47 @@ +[package] +name = "onboarding" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/onboarding.rs" + +[features] +default = [] + +[dependencies] +ai_onboarding.workspace = true +anyhow.workspace = true +client.workspace = true +command_palette_hooks.workspace = true +component.workspace = true +db.workspace = true +documented.workspace = true +editor.workspace = true +feature_flags.workspace = true +fs.workspace = true +fuzzy.workspace = true +gpui.workspace = true +itertools.workspace = true +language.workspace = true +language_model.workspace = true +menu.workspace = true +notifications.workspace = true +picker.workspace = true +project.workspace = true +schemars.workspace = true +serde.workspace = true +settings.workspace = true +theme.workspace = true +ui.workspace = true +util.workspace = true +vim_mode_setting.workspace = true +workspace-hack.workspace = true +workspace.workspace = true +zed_actions.workspace = true +zlog.workspace = true diff --git a/crates/onboarding/LICENSE-GPL b/crates/onboarding/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/onboarding/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/onboarding/src/ai_setup_page.rs b/crates/onboarding/src/ai_setup_page.rs new file mode 100644 index 0000000000000000000000000000000000000000..00f2d5fc8ba8ed904da4ecae5555792e9b54bf49 --- /dev/null +++ b/crates/onboarding/src/ai_setup_page.rs @@ -0,0 +1,432 @@ +use std::sync::Arc; + +use ai_onboarding::AiUpsellCard; +use client::{Client, UserStore}; +use fs::Fs; +use gpui::{ + Action, AnyView, App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, WeakEntity, + Window, prelude::*, +}; +use itertools; +use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry}; +use project::DisableAiSettings; +use settings::{Settings, update_settings_file}; +use ui::{ + Badge, ButtonLike, Divider, KeyBinding, Modal, ModalFooter, ModalHeader, Section, SwitchField, + ToggleState, prelude::*, tooltip_container, +}; +use util::ResultExt; +use workspace::{ModalView, Workspace}; +use zed_actions::agent::OpenSettings; + +const FEATURED_PROVIDERS: [&'static str; 4] = ["anthropic", "google", "openai", "ollama"]; + +fn render_llm_provider_section( + tab_index: &mut isize, + workspace: WeakEntity<Workspace>, + disabled: bool, + window: &mut Window, + cx: &mut App, +) -> impl IntoElement { + v_flex() + .gap_4() + .child( + v_flex() + .child(Label::new("Or use other LLM providers").size(LabelSize::Large)) + .child( + Label::new("Bring your API keys to use the available providers with Zed's UI for free.") + .color(Color::Muted), + ), + ) + .child(render_llm_provider_card(tab_index, workspace, disabled, window, cx)) +} + +fn render_privacy_card(tab_index: &mut isize, disabled: bool, cx: &mut App) -> impl IntoElement { + let privacy_badge = || { + Badge::new("Privacy") + .icon(IconName::ShieldCheck) + .tooltip(move |_, cx| cx.new(|_| AiPrivacyTooltip::new()).into()) + }; + + v_flex() + .relative() + .pt_2() + .pb_2p5() + .pl_3() + .pr_2() + .border_1() + .border_dashed() + .border_color(cx.theme().colors().border.opacity(0.5)) + .bg(cx.theme().colors().surface_background.opacity(0.3)) + .rounded_lg() + .overflow_hidden() + .map(|this| { + if disabled { + this.child( + h_flex() + .gap_2() + .justify_between() + .child( + h_flex() + .gap_1() + .child(Label::new("AI is disabled across Zed")) + .child( + Icon::new(IconName::Check) + .color(Color::Success) + .size(IconSize::XSmall), + ), + ) + .child(privacy_badge()), + ) + .child( + Label::new("Re-enable it any time in Settings.") + .size(LabelSize::Small) + .color(Color::Muted), + ) + } else { + this.child( + h_flex() + .gap_2() + .justify_between() + .child(Label::new("Privacy is the default for Zed")) + .child( + h_flex().gap_1().child(privacy_badge()).child( + Button::new("learn_more", "Learn More") + .style(ButtonStyle::Outlined) + .label_size(LabelSize::Small) + .icon(IconName::ArrowUpRight) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .on_click(|_, _, cx| { + cx.open_url("https://zed.dev/docs/ai/privacy-and-security"); + }) + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }), + ), + ), + ) + .child( + Label::new( + "Any use or storage of your data is with your explicit, single-use, opt-in consent.", + ) + .size(LabelSize::Small) + .color(Color::Muted), + ) + } + }) +} + +fn render_llm_provider_card( + tab_index: &mut isize, + workspace: WeakEntity<Workspace>, + disabled: bool, + _: &mut Window, + cx: &mut App, +) -> impl IntoElement { + let registry = LanguageModelRegistry::read_global(cx); + + v_flex() + .border_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().surface_background.opacity(0.5)) + .rounded_lg() + .overflow_hidden() + .children(itertools::intersperse_with( + FEATURED_PROVIDERS + .into_iter() + .flat_map(|provider_name| { + registry.provider(&LanguageModelProviderId::new(provider_name)) + }) + .enumerate() + .map(|(index, provider)| { + let group_name = SharedString::new(format!("onboarding-hover-group-{}", index)); + let is_authenticated = provider.is_authenticated(cx); + + ButtonLike::new(("onboarding-ai-setup-buttons", index)) + .size(ButtonSize::Large) + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }) + .child( + h_flex() + .group(&group_name) + .px_0p5() + .w_full() + .gap_2() + .justify_between() + .child( + h_flex() + .gap_1() + .child( + Icon::new(provider.icon()) + .color(Color::Muted) + .size(IconSize::XSmall), + ) + .child(Label::new(provider.name().0)), + ) + .child( + h_flex() + .gap_1() + .when(!is_authenticated, |el| { + el.visible_on_hover(group_name.clone()) + .child( + Icon::new(IconName::Settings) + .color(Color::Muted) + .size(IconSize::XSmall), + ) + .child( + Label::new("Configure") + .color(Color::Muted) + .size(LabelSize::Small), + ) + }) + .when(is_authenticated && !disabled, |el| { + el.child( + Icon::new(IconName::Check) + .color(Color::Success) + .size(IconSize::XSmall), + ) + .child( + Label::new("Configured") + .color(Color::Muted) + .size(LabelSize::Small), + ) + }), + ), + ) + .on_click({ + let workspace = workspace.clone(); + move |_, window, cx| { + workspace + .update(cx, |workspace, cx| { + workspace.toggle_modal(window, cx, |window, cx| { + let modal = AiConfigurationModal::new( + provider.clone(), + window, + cx, + ); + window.focus(&modal.focus_handle(cx)); + modal + }); + }) + .log_err(); + } + }) + .into_any_element() + }), + || Divider::horizontal().into_any_element(), + )) + .child(Divider::horizontal()) + .child( + Button::new("agent_settings", "Add Many Others") + .size(ButtonSize::Large) + .icon(IconName::Plus) + .icon_position(IconPosition::Start) + .icon_color(Color::Muted) + .icon_size(IconSize::XSmall) + .on_click(|_event, window, cx| { + window.dispatch_action(OpenSettings.boxed_clone(), cx) + }) + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }), + ) +} + +pub(crate) fn render_ai_setup_page( + workspace: WeakEntity<Workspace>, + user_store: Entity<UserStore>, + client: Arc<Client>, + window: &mut Window, + cx: &mut App, +) -> impl IntoElement { + let mut tab_index = 0; + let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai; + + v_flex() + .gap_2() + .child( + SwitchField::new( + "enable_ai", + "Enable AI features", + None, + if is_ai_disabled { + ToggleState::Unselected + } else { + ToggleState::Selected + }, + |&toggle_state, _, cx| { + let fs = <dyn Fs>::global(cx); + update_settings_file::<DisableAiSettings>( + fs, + cx, + move |ai_settings: &mut Option<bool>, _| { + *ai_settings = match toggle_state { + ToggleState::Indeterminate => None, + ToggleState::Unselected => Some(true), + ToggleState::Selected => Some(false), + }; + }, + ); + }, + ) + .tab_index({ + tab_index += 1; + tab_index - 1 + }), + ) + .child(render_privacy_card(&mut tab_index, is_ai_disabled, cx)) + .child( + v_flex() + .mt_2() + .gap_6() + .child({ + let mut ai_upsell_card = + AiUpsellCard::new(client, &user_store, user_store.read(cx).plan(), cx); + + ai_upsell_card.tab_index = Some({ + tab_index += 1; + tab_index - 1 + }); + + ai_upsell_card + }) + .child(render_llm_provider_section( + &mut tab_index, + workspace, + is_ai_disabled, + window, + cx, + )) + .when(is_ai_disabled, |this| { + this.child( + div() + .id("backdrop") + .size_full() + .absolute() + .inset_0() + .bg(cx.theme().colors().editor_background) + .opacity(0.8) + .block_mouse_except_scroll(), + ) + }), + ) +} + +struct AiConfigurationModal { + focus_handle: FocusHandle, + selected_provider: Arc<dyn LanguageModelProvider>, + configuration_view: AnyView, +} + +impl AiConfigurationModal { + fn new( + selected_provider: Arc<dyn LanguageModelProvider>, + window: &mut Window, + cx: &mut Context<Self>, + ) -> Self { + let focus_handle = cx.focus_handle(); + let configuration_view = selected_provider.configuration_view(window, cx); + + Self { + focus_handle, + configuration_view, + selected_provider, + } + } + + fn cancel(&mut self, _: &menu::Cancel, cx: &mut Context<Self>) { + cx.emit(DismissEvent); + } +} + +impl ModalView for AiConfigurationModal {} + +impl EventEmitter<DismissEvent> for AiConfigurationModal {} + +impl Focusable for AiConfigurationModal { + fn focus_handle(&self, _cx: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl Render for AiConfigurationModal { + fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { + v_flex() + .key_context("OnboardingAiConfigurationModal") + .w(rems(34.)) + .elevation_3(cx) + .track_focus(&self.focus_handle) + .on_action( + cx.listener(|this, _: &menu::Cancel, _window, cx| this.cancel(&menu::Cancel, cx)), + ) + .child( + Modal::new("onboarding-ai-setup-modal", None) + .header( + ModalHeader::new() + .icon( + Icon::new(self.selected_provider.icon()) + .color(Color::Muted) + .size(IconSize::Small), + ) + .headline(self.selected_provider.name().0), + ) + .section(Section::new().child(self.configuration_view.clone())) + .footer( + ModalFooter::new().end_slot( + Button::new("ai-onb-modal-Done", "Done") + .key_binding( + KeyBinding::for_action_in( + &menu::Cancel, + &self.focus_handle.clone(), + window, + cx, + ) + .map(|kb| kb.size(rems_from_px(12.))), + ) + .on_click(cx.listener(|this, _event, _window, cx| { + this.cancel(&menu::Cancel, cx) + })), + ), + ), + ) + } +} + +pub struct AiPrivacyTooltip {} + +impl AiPrivacyTooltip { + pub fn new() -> Self { + Self {} + } +} + +impl Render for AiPrivacyTooltip { + fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { + const DESCRIPTION: &'static str = "We believe in opt-in data sharing as the default for building AI products, rather than opt-out. We'll only use or store your data if you affirmatively send it to us. "; + + tooltip_container(window, cx, move |this, _, _| { + this.child( + h_flex() + .gap_1() + .child( + Icon::new(IconName::ShieldCheck) + .size(IconSize::Small) + .color(Color::Muted), + ) + .child(Label::new("Privacy First")), + ) + .child( + div().max_w_64().child( + Label::new(DESCRIPTION) + .size(LabelSize::Small) + .color(Color::Muted), + ), + ) + }) + } +} diff --git a/crates/onboarding/src/basics_page.rs b/crates/onboarding/src/basics_page.rs new file mode 100644 index 0000000000000000000000000000000000000000..a19a21fddf309b71d275a28d5ce8dcabea0fadc6 --- /dev/null +++ b/crates/onboarding/src/basics_page.rs @@ -0,0 +1,361 @@ +use std::sync::Arc; + +use client::TelemetrySettings; +use fs::Fs; +use gpui::{App, IntoElement}; +use settings::{BaseKeymap, Settings, update_settings_file}; +use theme::{ + Appearance, SystemAppearance, ThemeMode, ThemeName, ThemeRegistry, ThemeSelection, + ThemeSettings, +}; +use ui::{ + ParentElement as _, StatefulInteractiveElement, SwitchField, ToggleButtonGroup, + ToggleButtonSimple, ToggleButtonWithIcon, prelude::*, rems_from_px, +}; +use vim_mode_setting::VimModeSetting; + +use crate::theme_preview::{ThemePreviewStyle, ThemePreviewTile}; + +fn render_theme_section(tab_index: &mut isize, cx: &mut App) -> impl IntoElement { + let theme_selection = ThemeSettings::get_global(cx).theme_selection.clone(); + let system_appearance = theme::SystemAppearance::global(cx); + let theme_selection = theme_selection.unwrap_or_else(|| ThemeSelection::Dynamic { + mode: match *system_appearance { + Appearance::Light => ThemeMode::Light, + Appearance::Dark => ThemeMode::Dark, + }, + light: ThemeName("One Light".into()), + dark: ThemeName("One Dark".into()), + }); + + let theme_mode = theme_selection + .mode() + .unwrap_or_else(|| match *system_appearance { + Appearance::Light => ThemeMode::Light, + Appearance::Dark => ThemeMode::Dark, + }); + + return v_flex() + .gap_2() + .child( + h_flex().justify_between().child(Label::new("Theme")).child( + ToggleButtonGroup::single_row( + "theme-selector-onboarding-dark-light", + [ThemeMode::Light, ThemeMode::Dark, ThemeMode::System].map(|mode| { + const MODE_NAMES: [SharedString; 3] = [ + SharedString::new_static("Light"), + SharedString::new_static("Dark"), + SharedString::new_static("System"), + ]; + ToggleButtonSimple::new( + MODE_NAMES[mode as usize].clone(), + move |_, _, cx| { + write_mode_change(mode, cx); + }, + ) + }), + ) + .tab_index(tab_index) + .selected_index(theme_mode as usize) + .style(ui::ToggleButtonGroupStyle::Outlined) + .button_width(rems_from_px(64.)), + ), + ) + .child( + h_flex() + .gap_4() + .justify_between() + .children(render_theme_previews(tab_index, &theme_selection, cx)), + ); + + fn render_theme_previews( + tab_index: &mut isize, + theme_selection: &ThemeSelection, + cx: &mut App, + ) -> [impl IntoElement; 3] { + let system_appearance = SystemAppearance::global(cx); + let theme_registry = ThemeRegistry::global(cx); + + let theme_seed = 0xBEEF as f32; + let theme_mode = theme_selection + .mode() + .unwrap_or_else(|| match *system_appearance { + Appearance::Light => ThemeMode::Light, + Appearance::Dark => ThemeMode::Dark, + }); + let appearance = match theme_mode { + ThemeMode::Light => Appearance::Light, + ThemeMode::Dark => Appearance::Dark, + ThemeMode::System => *system_appearance, + }; + let current_theme_name = theme_selection.theme(appearance); + + const LIGHT_THEMES: [&'static str; 3] = ["One Light", "Ayu Light", "Gruvbox Light"]; + const DARK_THEMES: [&'static str; 3] = ["One Dark", "Ayu Dark", "Gruvbox Dark"]; + const FAMILY_NAMES: [SharedString; 3] = [ + SharedString::new_static("One"), + SharedString::new_static("Ayu"), + SharedString::new_static("Gruvbox"), + ]; + + let theme_names = match appearance { + Appearance::Light => LIGHT_THEMES, + Appearance::Dark => DARK_THEMES, + }; + + let themes = theme_names.map(|theme| theme_registry.get(theme).unwrap()); + + let theme_previews = [0, 1, 2].map(|index| { + let theme = &themes[index]; + let is_selected = theme.name == current_theme_name; + let name = theme.name.clone(); + let colors = cx.theme().colors(); + + v_flex() + .w_full() + .items_center() + .gap_1() + .child( + h_flex() + .id(name.clone()) + .relative() + .w_full() + .border_2() + .border_color(colors.border_transparent) + .rounded(ThemePreviewTile::ROOT_RADIUS) + .map(|this| { + if is_selected { + this.border_color(colors.border_selected) + } else { + this.opacity(0.8).hover(|s| s.border_color(colors.border)) + } + }) + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }) + .focus(|mut style| { + style.border_color = Some(colors.border_focused); + style + }) + .on_click({ + let theme_name = theme.name.clone(); + move |_, _, cx| { + write_theme_change(theme_name.clone(), theme_mode, cx); + } + }) + .map(|this| { + if theme_mode == ThemeMode::System { + let (light, dark) = ( + theme_registry.get(LIGHT_THEMES[index]).unwrap(), + theme_registry.get(DARK_THEMES[index]).unwrap(), + ); + this.child( + ThemePreviewTile::new(light, theme_seed) + .style(ThemePreviewStyle::SideBySide(dark)), + ) + } else { + this.child( + ThemePreviewTile::new(theme.clone(), theme_seed) + .style(ThemePreviewStyle::Bordered), + ) + } + }), + ) + .child( + Label::new(FAMILY_NAMES[index].clone()) + .color(Color::Muted) + .size(LabelSize::Small), + ) + }); + + theme_previews + } + + fn write_mode_change(mode: ThemeMode, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + update_settings_file::<ThemeSettings>(fs, cx, move |settings, _cx| { + settings.set_mode(mode); + }); + } + + fn write_theme_change(theme: impl Into<Arc<str>>, theme_mode: ThemeMode, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + let theme = theme.into(); + update_settings_file::<ThemeSettings>(fs, cx, move |settings, cx| { + if theme_mode == ThemeMode::System { + settings.theme = Some(ThemeSelection::Dynamic { + mode: ThemeMode::System, + light: ThemeName(theme.clone()), + dark: ThemeName(theme.clone()), + }); + } else { + let appearance = *SystemAppearance::global(cx); + settings.set_theme(theme.clone(), appearance); + } + }); + } +} + +fn render_telemetry_section(tab_index: &mut isize, cx: &App) -> impl IntoElement { + let fs = <dyn Fs>::global(cx); + + v_flex() + .pt_6() + .gap_4() + .border_t_1() + .border_color(cx.theme().colors().border_variant.opacity(0.5)) + .child(Label::new("Telemetry").size(LabelSize::Large)) + .child(SwitchField::new( + "onboarding-telemetry-metrics", + "Help Improve Zed", + Some("Anonymous usage data helps us build the right features and improve your experience.".into()), + if TelemetrySettings::get_global(cx).metrics { + ui::ToggleState::Selected + } else { + ui::ToggleState::Unselected + }, + { + let fs = fs.clone(); + move |selection, _, cx| { + let enabled = match selection { + ToggleState::Selected => true, + ToggleState::Unselected => false, + ToggleState::Indeterminate => { return; }, + }; + + update_settings_file::<TelemetrySettings>( + fs.clone(), + cx, + move |setting, _| setting.metrics = Some(enabled), + ); + }}, + ).tab_index({ + *tab_index += 1; + *tab_index + })) + .child(SwitchField::new( + "onboarding-telemetry-crash-reports", + "Help Fix Zed", + Some("Send crash reports so we can fix critical issues fast.".into()), + if TelemetrySettings::get_global(cx).diagnostics { + ui::ToggleState::Selected + } else { + ui::ToggleState::Unselected + }, + { + let fs = fs.clone(); + move |selection, _, cx| { + let enabled = match selection { + ToggleState::Selected => true, + ToggleState::Unselected => false, + ToggleState::Indeterminate => { return; }, + }; + + update_settings_file::<TelemetrySettings>( + fs.clone(), + cx, + move |setting, _| setting.diagnostics = Some(enabled), + ); + } + } + ).tab_index({ + *tab_index += 1; + *tab_index + })) +} + +fn render_base_keymap_section(tab_index: &mut isize, cx: &mut App) -> impl IntoElement { + let base_keymap = match BaseKeymap::get_global(cx) { + BaseKeymap::VSCode => Some(0), + BaseKeymap::JetBrains => Some(1), + BaseKeymap::SublimeText => Some(2), + BaseKeymap::Atom => Some(3), + BaseKeymap::Emacs => Some(4), + BaseKeymap::Cursor => Some(5), + BaseKeymap::TextMate | BaseKeymap::None => None, + }; + + return v_flex().gap_2().child(Label::new("Base Keymap")).child( + ToggleButtonGroup::two_rows( + "base_keymap_selection", + [ + ToggleButtonWithIcon::new("VS Code", IconName::EditorVsCode, |_, _, cx| { + write_keymap_base(BaseKeymap::VSCode, cx); + }), + ToggleButtonWithIcon::new("Jetbrains", IconName::EditorJetBrains, |_, _, cx| { + write_keymap_base(BaseKeymap::JetBrains, cx); + }), + ToggleButtonWithIcon::new("Sublime Text", IconName::EditorSublime, |_, _, cx| { + write_keymap_base(BaseKeymap::SublimeText, cx); + }), + ], + [ + ToggleButtonWithIcon::new("Atom", IconName::EditorAtom, |_, _, cx| { + write_keymap_base(BaseKeymap::Atom, cx); + }), + ToggleButtonWithIcon::new("Emacs", IconName::EditorEmacs, |_, _, cx| { + write_keymap_base(BaseKeymap::Emacs, cx); + }), + ToggleButtonWithIcon::new("Cursor", IconName::EditorCursor, |_, _, cx| { + write_keymap_base(BaseKeymap::Cursor, cx); + }), + ], + ) + .when_some(base_keymap, |this, base_keymap| { + this.selected_index(base_keymap) + }) + .tab_index(tab_index) + .button_width(rems_from_px(216.)) + .size(ui::ToggleButtonGroupSize::Medium) + .style(ui::ToggleButtonGroupStyle::Outlined), + ); + + fn write_keymap_base(keymap_base: BaseKeymap, cx: &App) { + let fs = <dyn Fs>::global(cx); + + update_settings_file::<BaseKeymap>(fs, cx, move |setting, _| { + *setting = Some(keymap_base); + }); + } +} + +fn render_vim_mode_switch(tab_index: &mut isize, cx: &mut App) -> impl IntoElement { + let toggle_state = if VimModeSetting::get_global(cx).0 { + ui::ToggleState::Selected + } else { + ui::ToggleState::Unselected + }; + SwitchField::new( + "onboarding-vim-mode", + "Vim Mode", + Some("Coming from Neovim? Use our first-class implementation of Vim Mode.".into()), + toggle_state, + { + let fs = <dyn Fs>::global(cx); + move |&selection, _, cx| { + update_settings_file::<VimModeSetting>(fs.clone(), cx, move |setting, _| { + *setting = match selection { + ToggleState::Selected => Some(true), + ToggleState::Unselected => Some(false), + ToggleState::Indeterminate => None, + } + }); + } + }, + ) + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }) +} + +pub(crate) fn render_basics_page(cx: &mut App) -> impl IntoElement { + let mut tab_index = 0; + v_flex() + .gap_6() + .child(render_theme_section(&mut tab_index, cx)) + .child(render_base_keymap_section(&mut tab_index, cx)) + .child(render_vim_mode_switch(&mut tab_index, cx)) + .child(render_telemetry_section(&mut tab_index, cx)) +} diff --git a/crates/onboarding/src/editing_page.rs b/crates/onboarding/src/editing_page.rs new file mode 100644 index 0000000000000000000000000000000000000000..8b4293db0dba387c4e38d324f3b57edbc0173d4c --- /dev/null +++ b/crates/onboarding/src/editing_page.rs @@ -0,0 +1,720 @@ +use std::sync::Arc; + +use editor::{EditorSettings, ShowMinimap}; +use fs::Fs; +use fuzzy::{StringMatch, StringMatchCandidate}; +use gpui::{ + Action, AnyElement, App, Context, FontFeatures, IntoElement, Pixels, SharedString, Task, Window, +}; +use language::language_settings::{AllLanguageSettings, FormatOnSave}; +use picker::{Picker, PickerDelegate}; +use project::project_settings::ProjectSettings; +use settings::{Settings as _, update_settings_file}; +use theme::{FontFamilyCache, FontFamilyName, ThemeSettings}; +use ui::{ + ButtonLike, ListItem, ListItemSpacing, NumericStepper, PopoverMenu, SwitchField, + ToggleButtonGroup, ToggleButtonGroupStyle, ToggleButtonSimple, ToggleState, Tooltip, + prelude::*, +}; + +use crate::{ImportCursorSettings, ImportVsCodeSettings, SettingsImportState}; + +fn read_show_mini_map(cx: &App) -> ShowMinimap { + editor::EditorSettings::get_global(cx).minimap.show +} + +fn write_show_mini_map(show: ShowMinimap, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + + // This is used to speed up the UI + // the UI reads the current values to get what toggle state to show on buttons + // there's a slight delay if we just call update_settings_file so we manually set + // the value here then call update_settings file to get around the delay + let mut curr_settings = EditorSettings::get_global(cx).clone(); + curr_settings.minimap.show = show; + EditorSettings::override_global(curr_settings, cx); + + update_settings_file::<EditorSettings>(fs, cx, move |editor_settings, _| { + editor_settings.minimap.get_or_insert_default().show = Some(show); + }); +} + +fn read_inlay_hints(cx: &App) -> bool { + AllLanguageSettings::get_global(cx) + .defaults + .inlay_hints + .enabled +} + +fn write_inlay_hints(enabled: bool, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + + let mut curr_settings = AllLanguageSettings::get_global(cx).clone(); + curr_settings.defaults.inlay_hints.enabled = enabled; + AllLanguageSettings::override_global(curr_settings, cx); + + update_settings_file::<AllLanguageSettings>(fs, cx, move |all_language_settings, cx| { + all_language_settings + .defaults + .inlay_hints + .get_or_insert_with(|| { + AllLanguageSettings::get_global(cx) + .clone() + .defaults + .inlay_hints + }) + .enabled = enabled; + }); +} + +fn read_git_blame(cx: &App) -> bool { + ProjectSettings::get_global(cx).git.inline_blame_enabled() +} + +fn set_git_blame(enabled: bool, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + + let mut curr_settings = ProjectSettings::get_global(cx).clone(); + curr_settings + .git + .inline_blame + .get_or_insert_default() + .enabled = enabled; + ProjectSettings::override_global(curr_settings, cx); + + update_settings_file::<ProjectSettings>(fs, cx, move |project_settings, _| { + project_settings + .git + .inline_blame + .get_or_insert_default() + .enabled = enabled; + }); +} + +fn write_ui_font_family(font: SharedString, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + + update_settings_file::<ThemeSettings>(fs, cx, move |theme_settings, _| { + theme_settings.ui_font_family = Some(FontFamilyName(font.into())); + }); +} + +fn write_ui_font_size(size: Pixels, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + + update_settings_file::<ThemeSettings>(fs, cx, move |theme_settings, _| { + theme_settings.ui_font_size = Some(size.into()); + }); +} + +fn write_buffer_font_size(size: Pixels, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + + update_settings_file::<ThemeSettings>(fs, cx, move |theme_settings, _| { + theme_settings.buffer_font_size = Some(size.into()); + }); +} + +fn write_buffer_font_family(font_family: SharedString, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + + update_settings_file::<ThemeSettings>(fs, cx, move |theme_settings, _| { + theme_settings.buffer_font_family = Some(FontFamilyName(font_family.into())); + }); +} + +fn read_font_ligatures(cx: &App) -> bool { + ThemeSettings::get_global(cx) + .buffer_font + .features + .is_calt_enabled() + .unwrap_or(true) +} + +fn write_font_ligatures(enabled: bool, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + let bit = if enabled { 1 } else { 0 }; + + update_settings_file::<ThemeSettings>(fs, cx, move |theme_settings, _| { + let mut features = theme_settings + .buffer_font_features + .as_mut() + .map(|features| features.tag_value_list().to_vec()) + .unwrap_or_default(); + + if let Some(calt_index) = features.iter().position(|(tag, _)| tag == "calt") { + features[calt_index].1 = bit; + } else { + features.push(("calt".into(), bit)); + } + + theme_settings.buffer_font_features = Some(FontFeatures(Arc::new(features))); + }); +} + +fn read_format_on_save(cx: &App) -> bool { + match AllLanguageSettings::get_global(cx).defaults.format_on_save { + FormatOnSave::On | FormatOnSave::List(_) => true, + FormatOnSave::Off => false, + } +} + +fn write_format_on_save(format_on_save: bool, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + + update_settings_file::<AllLanguageSettings>(fs, cx, move |language_settings, _| { + language_settings.defaults.format_on_save = Some(match format_on_save { + true => FormatOnSave::On, + false => FormatOnSave::Off, + }); + }); +} + +fn render_setting_import_button( + tab_index: isize, + label: SharedString, + icon_name: IconName, + action: &dyn Action, + imported: bool, +) -> impl IntoElement { + let action = action.boxed_clone(); + h_flex().w_full().child( + ButtonLike::new(label.clone()) + .full_width() + .style(ButtonStyle::Outlined) + .size(ButtonSize::Large) + .tab_index(tab_index) + .child( + h_flex() + .w_full() + .justify_between() + .child( + h_flex() + .gap_1p5() + .px_1() + .child( + Icon::new(icon_name) + .color(Color::Muted) + .size(IconSize::XSmall), + ) + .child(Label::new(label)), + ) + .when(imported, |this| { + this.child( + h_flex() + .gap_1p5() + .child( + Icon::new(IconName::Check) + .color(Color::Success) + .size(IconSize::XSmall), + ) + .child(Label::new("Imported").size(LabelSize::Small)), + ) + }), + ) + .on_click(move |_, window, cx| window.dispatch_action(action.boxed_clone(), cx)), + ) +} + +fn render_import_settings_section(tab_index: &mut isize, cx: &App) -> impl IntoElement { + let import_state = SettingsImportState::global(cx); + let imports: [(SharedString, IconName, &dyn Action, bool); 2] = [ + ( + "VS Code".into(), + IconName::EditorVsCode, + &ImportVsCodeSettings { skip_prompt: false }, + import_state.vscode, + ), + ( + "Cursor".into(), + IconName::EditorCursor, + &ImportCursorSettings { skip_prompt: false }, + import_state.cursor, + ), + ]; + + let [vscode, cursor] = imports.map(|(label, icon_name, action, imported)| { + *tab_index += 1; + render_setting_import_button(*tab_index - 1, label, icon_name, action, imported) + }); + + v_flex() + .gap_4() + .child( + v_flex() + .child(Label::new("Import Settings").size(LabelSize::Large)) + .child( + Label::new("Automatically pull your settings from other editors.") + .color(Color::Muted), + ), + ) + .child(h_flex().w_full().gap_4().child(vscode).child(cursor)) +} + +fn render_font_customization_section( + tab_index: &mut isize, + window: &mut Window, + cx: &mut App, +) -> impl IntoElement { + let theme_settings = ThemeSettings::get_global(cx); + let ui_font_size = theme_settings.ui_font_size(cx); + let ui_font_family = theme_settings.ui_font.family.clone(); + let buffer_font_family = theme_settings.buffer_font.family.clone(); + let buffer_font_size = theme_settings.buffer_font_size(cx); + + let ui_font_picker = + cx.new(|cx| font_picker(ui_font_family.clone(), write_ui_font_family, window, cx)); + + let buffer_font_picker = cx.new(|cx| { + font_picker( + buffer_font_family.clone(), + write_buffer_font_family, + window, + cx, + ) + }); + + let ui_font_handle = ui::PopoverMenuHandle::default(); + let buffer_font_handle = ui::PopoverMenuHandle::default(); + + h_flex() + .w_full() + .gap_4() + .child( + v_flex() + .w_full() + .gap_1() + .child(Label::new("UI Font")) + .child( + h_flex() + .w_full() + .justify_between() + .gap_2() + .child( + PopoverMenu::new("ui-font-picker") + .menu({ + let ui_font_picker = ui_font_picker.clone(); + move |_window, _cx| Some(ui_font_picker.clone()) + }) + .trigger( + ButtonLike::new("ui-font-family-button") + .style(ButtonStyle::Outlined) + .size(ButtonSize::Medium) + .full_width() + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }) + .child( + h_flex() + .w_full() + .justify_between() + .child(Label::new(ui_font_family)) + .child( + Icon::new(IconName::ChevronUpDown) + .color(Color::Muted) + .size(IconSize::XSmall), + ), + ), + ) + .full_width(true) + .anchor(gpui::Corner::TopLeft) + .offset(gpui::Point { + x: px(0.0), + y: px(4.0), + }) + .with_handle(ui_font_handle), + ) + .child( + NumericStepper::new( + "ui-font-size", + ui_font_size.to_string(), + move |_, _, cx| { + write_ui_font_size(ui_font_size - px(1.), cx); + }, + move |_, _, cx| { + write_ui_font_size(ui_font_size + px(1.), cx); + }, + ) + .style(ui::NumericStepperStyle::Outlined) + .tab_index({ + *tab_index += 2; + *tab_index - 2 + }), + ), + ), + ) + .child( + v_flex() + .w_full() + .gap_1() + .child(Label::new("Editor Font")) + .child( + h_flex() + .w_full() + .justify_between() + .gap_2() + .child( + PopoverMenu::new("buffer-font-picker") + .menu({ + let buffer_font_picker = buffer_font_picker.clone(); + move |_window, _cx| Some(buffer_font_picker.clone()) + }) + .trigger( + ButtonLike::new("buffer-font-family-button") + .style(ButtonStyle::Outlined) + .size(ButtonSize::Medium) + .full_width() + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }) + .child( + h_flex() + .w_full() + .justify_between() + .child(Label::new(buffer_font_family)) + .child( + Icon::new(IconName::ChevronUpDown) + .color(Color::Muted) + .size(IconSize::XSmall), + ), + ), + ) + .full_width(true) + .anchor(gpui::Corner::TopLeft) + .offset(gpui::Point { + x: px(0.0), + y: px(4.0), + }) + .with_handle(buffer_font_handle), + ) + .child( + NumericStepper::new( + "buffer-font-size", + buffer_font_size.to_string(), + move |_, _, cx| { + write_buffer_font_size(buffer_font_size - px(1.), cx); + }, + move |_, _, cx| { + write_buffer_font_size(buffer_font_size + px(1.), cx); + }, + ) + .style(ui::NumericStepperStyle::Outlined) + .tab_index({ + *tab_index += 2; + *tab_index - 2 + }), + ), + ), + ) +} + +type FontPicker = Picker<FontPickerDelegate>; + +pub struct FontPickerDelegate { + fonts: Vec<SharedString>, + filtered_fonts: Vec<StringMatch>, + selected_index: usize, + current_font: SharedString, + on_font_changed: Arc<dyn Fn(SharedString, &mut App) + 'static>, +} + +impl FontPickerDelegate { + fn new( + current_font: SharedString, + on_font_changed: impl Fn(SharedString, &mut App) + 'static, + cx: &mut Context<FontPicker>, + ) -> Self { + let font_family_cache = FontFamilyCache::global(cx); + + let fonts: Vec<SharedString> = font_family_cache + .list_font_families(cx) + .into_iter() + .collect(); + + let selected_index = fonts + .iter() + .position(|font| *font == current_font) + .unwrap_or(0); + + Self { + fonts: fonts.clone(), + filtered_fonts: fonts + .iter() + .enumerate() + .map(|(index, font)| StringMatch { + candidate_id: index, + string: font.to_string(), + positions: Vec::new(), + score: 0.0, + }) + .collect(), + selected_index, + current_font, + on_font_changed: Arc::new(on_font_changed), + } + } +} + +impl PickerDelegate for FontPickerDelegate { + type ListItem = AnyElement; + + fn match_count(&self) -> usize { + self.filtered_fonts.len() + } + + fn selected_index(&self) -> usize { + self.selected_index + } + + fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<FontPicker>) { + self.selected_index = ix.min(self.filtered_fonts.len().saturating_sub(1)); + cx.notify(); + } + + fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> { + "Search fonts…".into() + } + + fn update_matches( + &mut self, + query: String, + _window: &mut Window, + cx: &mut Context<FontPicker>, + ) -> Task<()> { + let fonts = self.fonts.clone(); + let current_font = self.current_font.clone(); + + let matches: Vec<StringMatch> = if query.is_empty() { + fonts + .iter() + .enumerate() + .map(|(index, font)| StringMatch { + candidate_id: index, + string: font.to_string(), + positions: Vec::new(), + score: 0.0, + }) + .collect() + } else { + let _candidates: Vec<StringMatchCandidate> = fonts + .iter() + .enumerate() + .map(|(id, font)| StringMatchCandidate::new(id, font.as_ref())) + .collect(); + + fonts + .iter() + .enumerate() + .filter(|(_, font)| font.to_lowercase().contains(&query.to_lowercase())) + .map(|(index, font)| StringMatch { + candidate_id: index, + string: font.to_string(), + positions: Vec::new(), + score: 0.0, + }) + .collect() + }; + + let selected_index = if query.is_empty() { + fonts + .iter() + .position(|font| *font == current_font) + .unwrap_or(0) + } else { + matches + .iter() + .position(|m| fonts[m.candidate_id] == current_font) + .unwrap_or(0) + }; + + self.filtered_fonts = matches; + self.selected_index = selected_index; + cx.notify(); + + Task::ready(()) + } + + fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context<FontPicker>) { + if let Some(font_match) = self.filtered_fonts.get(self.selected_index) { + let font = font_match.string.clone(); + (self.on_font_changed)(font.into(), cx); + } + } + + fn dismissed(&mut self, _window: &mut Window, _cx: &mut Context<FontPicker>) {} + + fn render_match( + &self, + ix: usize, + selected: bool, + _window: &mut Window, + _cx: &mut Context<FontPicker>, + ) -> Option<Self::ListItem> { + let font_match = self.filtered_fonts.get(ix)?; + + Some( + ListItem::new(ix) + .inset(true) + .spacing(ListItemSpacing::Sparse) + .toggle_state(selected) + .child(Label::new(font_match.string.clone())) + .into_any_element(), + ) + } +} + +fn font_picker( + current_font: SharedString, + on_font_changed: impl Fn(SharedString, &mut App) + 'static, + window: &mut Window, + cx: &mut Context<FontPicker>, +) -> FontPicker { + let delegate = FontPickerDelegate::new(current_font, on_font_changed, cx); + + Picker::list(delegate, window, cx) + .show_scrollbar(true) + .width(rems_from_px(210.)) + .max_height(Some(rems(20.).into())) +} + +fn render_popular_settings_section( + tab_index: &mut isize, + window: &mut Window, + cx: &mut App, +) -> impl IntoElement { + const LIGATURE_TOOLTIP: &'static str = + "Font ligatures combine two characters into one. For example, turning =/= into ≠."; + + v_flex() + .pt_6() + .gap_4() + .border_t_1() + .border_color(cx.theme().colors().border_variant.opacity(0.5)) + .child(Label::new("Popular Settings").size(LabelSize::Large)) + .child(render_font_customization_section(tab_index, window, cx)) + .child( + SwitchField::new( + "onboarding-font-ligatures", + "Font Ligatures", + Some("Combine text characters into their associated symbols.".into()), + if read_font_ligatures(cx) { + ui::ToggleState::Selected + } else { + ui::ToggleState::Unselected + }, + |toggle_state, _, cx| { + write_font_ligatures(toggle_state == &ToggleState::Selected, cx); + }, + ) + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }) + .tooltip(Tooltip::text(LIGATURE_TOOLTIP)), + ) + .child( + SwitchField::new( + "onboarding-format-on-save", + "Format on Save", + Some("Format code automatically when saving.".into()), + if read_format_on_save(cx) { + ui::ToggleState::Selected + } else { + ui::ToggleState::Unselected + }, + |toggle_state, _, cx| { + write_format_on_save(toggle_state == &ToggleState::Selected, cx); + }, + ) + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }), + ) + .child( + SwitchField::new( + "onboarding-enable-inlay-hints", + "Inlay Hints", + Some("See parameter names for function and method calls inline.".into()), + if read_inlay_hints(cx) { + ui::ToggleState::Selected + } else { + ui::ToggleState::Unselected + }, + |toggle_state, _, cx| { + write_inlay_hints(toggle_state == &ToggleState::Selected, cx); + }, + ) + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }), + ) + .child( + SwitchField::new( + "onboarding-git-blame-switch", + "Git Blame", + Some("See who committed each line on a given file.".into()), + if read_git_blame(cx) { + ui::ToggleState::Selected + } else { + ui::ToggleState::Unselected + }, + |toggle_state, _, cx| { + set_git_blame(toggle_state == &ToggleState::Selected, cx); + }, + ) + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }), + ) + .child( + h_flex() + .items_start() + .justify_between() + .child( + v_flex().child(Label::new("Mini Map")).child( + Label::new("See a high-level overview of your source code.") + .color(Color::Muted), + ), + ) + .child( + ToggleButtonGroup::single_row( + "onboarding-show-mini-map", + [ + ToggleButtonSimple::new("Auto", |_, _, cx| { + write_show_mini_map(ShowMinimap::Auto, cx); + }) + .tooltip(Tooltip::text( + "Show the minimap if the editor's scrollbar is visible.", + )), + ToggleButtonSimple::new("Always", |_, _, cx| { + write_show_mini_map(ShowMinimap::Always, cx); + }), + ToggleButtonSimple::new("Never", |_, _, cx| { + write_show_mini_map(ShowMinimap::Never, cx); + }), + ], + ) + .selected_index(match read_show_mini_map(cx) { + ShowMinimap::Auto => 0, + ShowMinimap::Always => 1, + ShowMinimap::Never => 2, + }) + .tab_index(tab_index) + .style(ToggleButtonGroupStyle::Outlined) + .button_width(ui::rems_from_px(64.)), + ), + ) +} + +pub(crate) fn render_editing_page(window: &mut Window, cx: &mut App) -> impl IntoElement { + let mut tab_index = 0; + v_flex() + .gap_6() + .child(render_import_settings_section(&mut tab_index, cx)) + .child(render_popular_settings_section(&mut tab_index, window, cx)) +} diff --git a/crates/onboarding/src/onboarding.rs b/crates/onboarding/src/onboarding.rs new file mode 100644 index 0000000000000000000000000000000000000000..98f61df97b8e9476b974dd88ad31ff256c61cfa0 --- /dev/null +++ b/crates/onboarding/src/onboarding.rs @@ -0,0 +1,843 @@ +use crate::welcome::{ShowWelcome, WelcomePage}; +use client::{Client, UserStore}; +use command_palette_hooks::CommandPaletteFilter; +use db::kvp::KEY_VALUE_STORE; +use feature_flags::{FeatureFlag, FeatureFlagViewExt as _}; +use fs::Fs; +use gpui::{ + Action, AnyElement, App, AppContext, AsyncWindowContext, Context, Entity, EventEmitter, + FocusHandle, Focusable, Global, IntoElement, KeyContext, Render, SharedString, Subscription, + Task, WeakEntity, Window, actions, +}; +use notifications::status_toast::{StatusToast, ToastIcon}; +use schemars::JsonSchema; +use serde::Deserialize; +use settings::{SettingsStore, VsCodeSettingsSource}; +use std::sync::Arc; +use ui::{ + Avatar, ButtonLike, FluentBuilder, Headline, KeyBinding, ParentElement as _, + StatefulInteractiveElement, Vector, VectorName, prelude::*, rems_from_px, +}; +use workspace::{ + AppState, Workspace, WorkspaceId, + dock::DockPosition, + item::{Item, ItemEvent}, + notifications::NotifyResultExt as _, + open_new, register_serializable_item, with_active_or_new_workspace, +}; + +mod ai_setup_page; +mod basics_page; +mod editing_page; +mod theme_preview; +mod welcome; + +pub struct OnBoardingFeatureFlag {} + +impl FeatureFlag for OnBoardingFeatureFlag { + const NAME: &'static str = "onboarding"; +} + +/// Imports settings from Visual Studio Code. +#[derive(Copy, Clone, Debug, Default, PartialEq, Deserialize, JsonSchema, Action)] +#[action(namespace = zed)] +#[serde(deny_unknown_fields)] +pub struct ImportVsCodeSettings { + #[serde(default)] + pub skip_prompt: bool, +} + +/// Imports settings from Cursor editor. +#[derive(Copy, Clone, Debug, Default, PartialEq, Deserialize, JsonSchema, Action)] +#[action(namespace = zed)] +#[serde(deny_unknown_fields)] +pub struct ImportCursorSettings { + #[serde(default)] + pub skip_prompt: bool, +} + +pub const FIRST_OPEN: &str = "first_open"; + +actions!( + zed, + [ + /// Opens the onboarding view. + OpenOnboarding + ] +); + +actions!( + onboarding, + [ + /// Activates the Basics page. + ActivateBasicsPage, + /// Activates the Editing page. + ActivateEditingPage, + /// Activates the AI Setup page. + ActivateAISetupPage, + /// Finish the onboarding process. + Finish, + /// Sign in while in the onboarding flow. + SignIn + ] +); + +pub fn init(cx: &mut App) { + cx.on_action(|_: &OpenOnboarding, cx| { + with_active_or_new_workspace(cx, |workspace, window, cx| { + workspace + .with_local_workspace(window, cx, |workspace, window, cx| { + let existing = workspace + .active_pane() + .read(cx) + .items() + .find_map(|item| item.downcast::<Onboarding>()); + + if let Some(existing) = existing { + workspace.activate_item(&existing, true, true, window, cx); + } else { + let settings_page = Onboarding::new(workspace, cx); + workspace.add_item_to_active_pane( + Box::new(settings_page), + None, + true, + window, + cx, + ) + } + }) + .detach(); + }); + }); + + cx.on_action(|_: &ShowWelcome, cx| { + with_active_or_new_workspace(cx, |workspace, window, cx| { + workspace + .with_local_workspace(window, cx, |workspace, window, cx| { + let existing = workspace + .active_pane() + .read(cx) + .items() + .find_map(|item| item.downcast::<WelcomePage>()); + + if let Some(existing) = existing { + workspace.activate_item(&existing, true, true, window, cx); + } else { + let settings_page = WelcomePage::new(window, cx); + workspace.add_item_to_active_pane( + Box::new(settings_page), + None, + true, + window, + cx, + ) + } + }) + .detach(); + }); + }); + + cx.observe_new(|workspace: &mut Workspace, _window, _cx| { + workspace.register_action(|_workspace, action: &ImportVsCodeSettings, window, cx| { + let fs = <dyn Fs>::global(cx); + let action = *action; + + let workspace = cx.weak_entity(); + + window + .spawn(cx, async move |cx: &mut AsyncWindowContext| { + handle_import_vscode_settings( + workspace, + VsCodeSettingsSource::VsCode, + action.skip_prompt, + fs, + cx, + ) + .await + }) + .detach(); + }); + + workspace.register_action(|_workspace, action: &ImportCursorSettings, window, cx| { + let fs = <dyn Fs>::global(cx); + let action = *action; + + let workspace = cx.weak_entity(); + + window + .spawn(cx, async move |cx: &mut AsyncWindowContext| { + handle_import_vscode_settings( + workspace, + VsCodeSettingsSource::Cursor, + action.skip_prompt, + fs, + cx, + ) + .await + }) + .detach(); + }); + }) + .detach(); + + cx.observe_new::<Workspace>(|_, window, cx| { + let Some(window) = window else { + return; + }; + + let onboarding_actions = [ + std::any::TypeId::of::<OpenOnboarding>(), + std::any::TypeId::of::<ShowWelcome>(), + ]; + + CommandPaletteFilter::update_global(cx, |filter, _cx| { + filter.hide_action_types(&onboarding_actions); + }); + + cx.observe_flag::<OnBoardingFeatureFlag, _>(window, move |is_enabled, _, _, cx| { + if is_enabled { + CommandPaletteFilter::update_global(cx, |filter, _cx| { + filter.show_action_types(onboarding_actions.iter()); + }); + } else { + CommandPaletteFilter::update_global(cx, |filter, _cx| { + filter.hide_action_types(&onboarding_actions); + }); + } + }) + .detach(); + }) + .detach(); + register_serializable_item::<Onboarding>(cx); +} + +pub fn show_onboarding_view(app_state: Arc<AppState>, cx: &mut App) -> Task<anyhow::Result<()>> { + open_new( + Default::default(), + app_state, + cx, + |workspace, window, cx| { + { + workspace.toggle_dock(DockPosition::Left, window, cx); + let onboarding_page = Onboarding::new(workspace, cx); + workspace.add_item_to_center(Box::new(onboarding_page.clone()), window, cx); + + window.focus(&onboarding_page.focus_handle(cx)); + + cx.notify(); + }; + db::write_and_log(cx, || { + KEY_VALUE_STORE.write_kvp(FIRST_OPEN.to_string(), "false".to_string()) + }); + }, + ) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum SelectedPage { + Basics, + Editing, + AiSetup, +} + +struct Onboarding { + workspace: WeakEntity<Workspace>, + focus_handle: FocusHandle, + selected_page: SelectedPage, + user_store: Entity<UserStore>, + _settings_subscription: Subscription, +} + +impl Onboarding { + fn new(workspace: &Workspace, cx: &mut App) -> Entity<Self> { + cx.new(|cx| Self { + workspace: workspace.weak_handle(), + focus_handle: cx.focus_handle(), + selected_page: SelectedPage::Basics, + user_store: workspace.user_store().clone(), + _settings_subscription: cx.observe_global::<SettingsStore>(move |_, cx| cx.notify()), + }) + } + + fn set_page(&mut self, page: SelectedPage, cx: &mut Context<Self>) { + self.selected_page = page; + cx.notify(); + cx.emit(ItemEvent::UpdateTab); + } + + fn render_nav_buttons( + &mut self, + window: &mut Window, + cx: &mut Context<Self>, + ) -> [impl IntoElement; 3] { + let pages = [ + SelectedPage::Basics, + SelectedPage::Editing, + SelectedPage::AiSetup, + ]; + + let text = ["Basics", "Editing", "AI Setup"]; + + let actions: [&dyn Action; 3] = [ + &ActivateBasicsPage, + &ActivateEditingPage, + &ActivateAISetupPage, + ]; + + let mut binding = actions.map(|action| { + KeyBinding::for_action_in(action, &self.focus_handle, window, cx) + .map(|kb| kb.size(rems_from_px(12.))) + }); + + pages.map(|page| { + let i = page as usize; + let selected = self.selected_page == page; + h_flex() + .id(text[i]) + .relative() + .w_full() + .gap_2() + .px_2() + .py_0p5() + .justify_between() + .rounded_sm() + .when(selected, |this| { + this.child( + div() + .h_4() + .w_px() + .bg(cx.theme().colors().text_accent) + .absolute() + .left_0(), + ) + }) + .hover(|style| style.bg(cx.theme().colors().element_hover)) + .child(Label::new(text[i]).map(|this| { + if selected { + this.color(Color::Default) + } else { + this.color(Color::Muted) + } + })) + .child(binding[i].take().map_or( + gpui::Empty.into_any_element(), + IntoElement::into_any_element, + )) + .on_click(cx.listener(move |this, _, _, cx| { + this.set_page(page, cx); + })) + }) + } + + fn render_nav(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { + let ai_setup_page = matches!(self.selected_page, SelectedPage::AiSetup); + + v_flex() + .h_full() + .w(rems_from_px(220.)) + .flex_shrink_0() + .gap_4() + .justify_between() + .child( + v_flex() + .gap_6() + .child( + h_flex() + .px_2() + .gap_4() + .child(Vector::square(VectorName::ZedLogo, rems(2.5))) + .child( + v_flex() + .child( + Headline::new("Welcome to Zed").size(HeadlineSize::Small), + ) + .child( + Label::new("The editor for what's next") + .color(Color::Muted) + .size(LabelSize::Small) + .italic(), + ), + ), + ) + .child( + v_flex() + .gap_4() + .child( + v_flex() + .py_4() + .border_y_1() + .border_color(cx.theme().colors().border_variant.opacity(0.5)) + .gap_1() + .children(self.render_nav_buttons(window, cx)), + ) + .map(|this| { + let keybinding = KeyBinding::for_action_in( + &Finish, + &self.focus_handle, + window, + cx, + ) + .map(|kb| kb.size(rems_from_px(12.))); + + if ai_setup_page { + this.child( + ButtonLike::new("start_building") + .style(ButtonStyle::Outlined) + .size(ButtonSize::Medium) + .child( + h_flex() + .ml_1() + .w_full() + .justify_between() + .child(Label::new("Start Building")) + .children(keybinding), + ) + .on_click(|_, window, cx| { + window.dispatch_action(Finish.boxed_clone(), cx); + }), + ) + } else { + this.child( + ButtonLike::new("skip_all") + .size(ButtonSize::Medium) + .child( + h_flex() + .ml_1() + .w_full() + .justify_between() + .child( + Label::new("Skip All").color(Color::Muted), + ) + .children(keybinding), + ) + .on_click(|_, window, cx| { + window.dispatch_action(Finish.boxed_clone(), cx); + }), + ) + } + }), + ), + ) + .child( + if let Some(user) = self.user_store.read(cx).current_user() { + h_flex() + .pl_1p5() + .gap_2() + .child(Avatar::new(user.avatar_uri.clone())) + .child(Label::new(user.github_login.clone())) + .into_any_element() + } else { + Button::new("sign_in", "Sign In") + .full_width() + .style(ButtonStyle::Outlined) + .size(ButtonSize::Medium) + .key_binding( + KeyBinding::for_action_in(&SignIn, &self.focus_handle, window, cx) + .map(|kb| kb.size(rems_from_px(12.))), + ) + .on_click(|_, window, cx| { + window.dispatch_action(SignIn.boxed_clone(), cx); + }) + .into_any_element() + }, + ) + } + + fn on_finish(_: &Finish, _: &mut Window, cx: &mut App) { + go_to_welcome_page(cx); + } + + fn handle_sign_in(_: &SignIn, window: &mut Window, cx: &mut App) { + let client = Client::global(cx); + + window + .spawn(cx, async move |cx| { + client + .sign_in_with_optional_connect(true, &cx) + .await + .notify_async_err(cx); + }) + .detach(); + } + + fn render_page(&mut self, window: &mut Window, cx: &mut Context<Self>) -> AnyElement { + let client = Client::global(cx); + + match self.selected_page { + SelectedPage::Basics => crate::basics_page::render_basics_page(cx).into_any_element(), + SelectedPage::Editing => { + crate::editing_page::render_editing_page(window, cx).into_any_element() + } + SelectedPage::AiSetup => crate::ai_setup_page::render_ai_setup_page( + self.workspace.clone(), + self.user_store.clone(), + client, + window, + cx, + ) + .into_any_element(), + } + } +} + +impl Render for Onboarding { + fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { + h_flex() + .image_cache(gpui::retain_all("onboarding-page")) + .key_context({ + let mut ctx = KeyContext::new_with_defaults(); + ctx.add("Onboarding"); + ctx.add("menu"); + ctx + }) + .track_focus(&self.focus_handle) + .size_full() + .bg(cx.theme().colors().editor_background) + .on_action(Self::on_finish) + .on_action(Self::handle_sign_in) + .on_action(cx.listener(|this, _: &ActivateBasicsPage, _, cx| { + this.set_page(SelectedPage::Basics, cx); + })) + .on_action(cx.listener(|this, _: &ActivateEditingPage, _, cx| { + this.set_page(SelectedPage::Editing, cx); + })) + .on_action(cx.listener(|this, _: &ActivateAISetupPage, _, cx| { + this.set_page(SelectedPage::AiSetup, cx); + })) + .on_action(cx.listener(|_, _: &menu::SelectNext, window, cx| { + window.focus_next(); + cx.notify(); + })) + .on_action(cx.listener(|_, _: &menu::SelectPrevious, window, cx| { + window.focus_prev(); + cx.notify(); + })) + .child( + h_flex() + .max_w(rems_from_px(1100.)) + .size_full() + .m_auto() + .py_20() + .px_12() + .items_start() + .gap_12() + .child(self.render_nav(window, cx)) + .child( + v_flex() + .max_w_full() + .min_w_0() + .pl_12() + .border_l_1() + .border_color(cx.theme().colors().border_variant.opacity(0.5)) + .size_full() + .child(self.render_page(window, cx)), + ), + ) + } +} + +impl EventEmitter<ItemEvent> for Onboarding {} + +impl Focusable for Onboarding { + fn focus_handle(&self, _: &App) -> gpui::FocusHandle { + self.focus_handle.clone() + } +} + +impl Item for Onboarding { + type Event = ItemEvent; + + fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString { + "Onboarding".into() + } + + fn telemetry_event_text(&self) -> Option<&'static str> { + Some("Onboarding Page Opened") + } + + fn show_toolbar(&self) -> bool { + false + } + + fn clone_on_split( + &self, + _workspace_id: Option<WorkspaceId>, + _: &mut Window, + cx: &mut Context<Self>, + ) -> Option<Entity<Self>> { + self.workspace + .update(cx, |workspace, cx| Onboarding::new(workspace, cx)) + .ok() + } + + fn to_item_events(event: &Self::Event, mut f: impl FnMut(workspace::item::ItemEvent)) { + f(*event) + } +} + +fn go_to_welcome_page(cx: &mut App) { + with_active_or_new_workspace(cx, |workspace, window, cx| { + let Some((onboarding_id, onboarding_idx)) = workspace + .active_pane() + .read(cx) + .items() + .enumerate() + .find_map(|(idx, item)| { + let _ = item.downcast::<Onboarding>()?; + Some((item.item_id(), idx)) + }) + else { + return; + }; + + workspace.active_pane().update(cx, |pane, cx| { + // Get the index here to get around the borrow checker + let idx = pane.items().enumerate().find_map(|(idx, item)| { + let _ = item.downcast::<WelcomePage>()?; + Some(idx) + }); + + if let Some(idx) = idx { + pane.activate_item(idx, true, true, window, cx); + } else { + let item = Box::new(WelcomePage::new(window, cx)); + pane.add_item(item, true, true, Some(onboarding_idx), window, cx); + } + + pane.remove_item(onboarding_id, false, false, window, cx); + }); + }); +} + +pub async fn handle_import_vscode_settings( + workspace: WeakEntity<Workspace>, + source: VsCodeSettingsSource, + skip_prompt: bool, + fs: Arc<dyn Fs>, + cx: &mut AsyncWindowContext, +) { + use util::truncate_and_remove_front; + + let vscode_settings = + match settings::VsCodeSettings::load_user_settings(source, fs.clone()).await { + Ok(vscode_settings) => vscode_settings, + Err(err) => { + zlog::error!("{err}"); + let _ = cx.prompt( + gpui::PromptLevel::Info, + &format!("Could not find or load a {source} settings file"), + None, + &["Ok"], + ); + return; + } + }; + + if !skip_prompt { + let prompt = cx.prompt( + gpui::PromptLevel::Warning, + &format!( + "Importing {} settings may overwrite your existing settings. \ + Will import settings from {}", + vscode_settings.source, + truncate_and_remove_front(&vscode_settings.path.to_string_lossy(), 128), + ), + None, + &["Ok", "Cancel"], + ); + let result = cx.spawn(async move |_| prompt.await.ok()).await; + if result != Some(0) { + return; + } + }; + + let Ok(result_channel) = cx.update(|_, cx| { + let source = vscode_settings.source; + let path = vscode_settings.path.clone(); + let result_channel = cx + .global::<SettingsStore>() + .import_vscode_settings(fs, vscode_settings); + zlog::info!("Imported {source} settings from {}", path.display()); + result_channel + }) else { + return; + }; + + let result = result_channel.await; + workspace + .update_in(cx, |workspace, _, cx| match result { + Ok(_) => { + let confirmation_toast = StatusToast::new( + format!("Your {} settings were successfully imported.", source), + cx, + |this, _| { + this.icon(ToastIcon::new(IconName::Check).color(Color::Success)) + .dismiss_button(true) + }, + ); + SettingsImportState::update(cx, |state, _| match source { + VsCodeSettingsSource::VsCode => { + state.vscode = true; + } + VsCodeSettingsSource::Cursor => { + state.cursor = true; + } + }); + workspace.toggle_status_toast(confirmation_toast, cx); + } + Err(_) => { + let error_toast = StatusToast::new( + "Failed to import settings. See log for details", + cx, + |this, _| { + this.icon(ToastIcon::new(IconName::X).color(Color::Error)) + .action("Open Log", |window, cx| { + window.dispatch_action(workspace::OpenLog.boxed_clone(), cx) + }) + .dismiss_button(true) + }, + ); + workspace.toggle_status_toast(error_toast, cx); + } + }) + .ok(); +} + +#[derive(Default, Copy, Clone)] +pub struct SettingsImportState { + pub cursor: bool, + pub vscode: bool, +} + +impl Global for SettingsImportState {} + +impl SettingsImportState { + pub fn global(cx: &App) -> Self { + cx.try_global().cloned().unwrap_or_default() + } + pub fn update<R>(cx: &mut App, f: impl FnOnce(&mut Self, &mut App) -> R) -> R { + cx.update_default_global(f) + } +} + +impl workspace::SerializableItem for Onboarding { + fn serialized_item_kind() -> &'static str { + "OnboardingPage" + } + + fn cleanup( + workspace_id: workspace::WorkspaceId, + alive_items: Vec<workspace::ItemId>, + _window: &mut Window, + cx: &mut App, + ) -> gpui::Task<gpui::Result<()>> { + workspace::delete_unloaded_items( + alive_items, + workspace_id, + "onboarding_pages", + &persistence::ONBOARDING_PAGES, + cx, + ) + } + + fn deserialize( + _project: Entity<project::Project>, + workspace: WeakEntity<Workspace>, + workspace_id: workspace::WorkspaceId, + item_id: workspace::ItemId, + window: &mut Window, + cx: &mut App, + ) -> gpui::Task<gpui::Result<Entity<Self>>> { + window.spawn(cx, async move |cx| { + if let Some(page_number) = + persistence::ONBOARDING_PAGES.get_onboarding_page(item_id, workspace_id)? + { + let page = match page_number { + 0 => Some(SelectedPage::Basics), + 1 => Some(SelectedPage::Editing), + 2 => Some(SelectedPage::AiSetup), + _ => None, + }; + workspace.update(cx, |workspace, cx| { + let onboarding_page = Onboarding::new(workspace, cx); + if let Some(page) = page { + zlog::info!("Onboarding page {page:?} loaded"); + onboarding_page.update(cx, |onboarding_page, cx| { + onboarding_page.set_page(page, cx); + }) + } + onboarding_page + }) + } else { + Err(anyhow::anyhow!("No onboarding page to deserialize")) + } + }) + } + + fn serialize( + &mut self, + workspace: &mut Workspace, + item_id: workspace::ItemId, + _closing: bool, + _window: &mut Window, + cx: &mut ui::Context<Self>, + ) -> Option<gpui::Task<gpui::Result<()>>> { + let workspace_id = workspace.database_id()?; + let page_number = self.selected_page as u16; + Some(cx.background_spawn(async move { + persistence::ONBOARDING_PAGES + .save_onboarding_page(item_id, workspace_id, page_number) + .await + })) + } + + fn should_serialize(&self, event: &Self::Event) -> bool { + event == &ItemEvent::UpdateTab + } +} + +mod persistence { + use db::{define_connection, query, sqlez_macros::sql}; + use workspace::WorkspaceDb; + + define_connection! { + pub static ref ONBOARDING_PAGES: OnboardingPagesDb<WorkspaceDb> = + &[ + sql!( + CREATE TABLE onboarding_pages ( + workspace_id INTEGER, + item_id INTEGER UNIQUE, + page_number INTEGER, + + PRIMARY KEY(workspace_id, item_id), + FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id) + ON DELETE CASCADE + ) STRICT; + ), + ]; + } + + impl OnboardingPagesDb { + query! { + pub async fn save_onboarding_page( + item_id: workspace::ItemId, + workspace_id: workspace::WorkspaceId, + page_number: u16 + ) -> Result<()> { + INSERT OR REPLACE INTO onboarding_pages(item_id, workspace_id, page_number) + VALUES (?, ?, ?) + } + } + + query! { + pub fn get_onboarding_page( + item_id: workspace::ItemId, + workspace_id: workspace::WorkspaceId + ) -> Result<Option<u16>> { + SELECT page_number + FROM onboarding_pages + WHERE item_id = ? AND workspace_id = ? + } + } + } +} diff --git a/crates/onboarding/src/theme_preview.rs b/crates/onboarding/src/theme_preview.rs new file mode 100644 index 0000000000000000000000000000000000000000..81eb14ec4b0cb502618c8bd06d4cd92a63186c71 --- /dev/null +++ b/crates/onboarding/src/theme_preview.rs @@ -0,0 +1,378 @@ +#![allow(unused, dead_code)] +use gpui::{Hsla, Length}; +use std::sync::Arc; +use theme::{Theme, ThemeColors, ThemeRegistry}; +use ui::{ + IntoElement, RenderOnce, component_prelude::Documented, prelude::*, utils::inner_corner_radius, +}; + +#[derive(Clone, PartialEq)] +pub enum ThemePreviewStyle { + Bordered, + Borderless, + SideBySide(Arc<Theme>), +} + +/// Shows a preview of a theme as an abstract illustration +/// of a thumbnail-sized editor. +#[derive(IntoElement, RegisterComponent, Documented)] +pub struct ThemePreviewTile { + theme: Arc<Theme>, + seed: f32, + style: ThemePreviewStyle, +} + +impl ThemePreviewTile { + pub const SKELETON_HEIGHT_DEFAULT: Pixels = px(2.); + pub const SIDEBAR_SKELETON_ITEM_COUNT: usize = 8; + pub const SIDEBAR_WIDTH_DEFAULT: DefiniteLength = relative(0.25); + pub const ROOT_RADIUS: Pixels = px(8.0); + pub const ROOT_BORDER: Pixels = px(2.0); + pub const ROOT_PADDING: Pixels = px(2.0); + pub const CHILD_BORDER: Pixels = px(1.0); + pub const CHILD_RADIUS: std::cell::LazyCell<Pixels> = std::cell::LazyCell::new(|| { + inner_corner_radius( + Self::ROOT_RADIUS, + Self::ROOT_BORDER, + Self::ROOT_PADDING, + Self::CHILD_BORDER, + ) + }); + + pub fn new(theme: Arc<Theme>, seed: f32) -> Self { + Self { + theme, + seed, + style: ThemePreviewStyle::Bordered, + } + } + + pub fn style(mut self, style: ThemePreviewStyle) -> Self { + self.style = style; + self + } + + pub fn item_skeleton(w: Length, h: Length, bg: Hsla) -> impl IntoElement { + div().w(w).h(h).rounded_full().bg(bg) + } + + pub fn render_sidebar_skeleton_items( + seed: f32, + colors: &ThemeColors, + skeleton_height: impl Into<Length> + Clone, + ) -> [impl IntoElement; Self::SIDEBAR_SKELETON_ITEM_COUNT] { + let skeleton_height = skeleton_height.into(); + std::array::from_fn(|index| { + let width = { + let value = (seed * 1000.0 + index as f32 * 10.0).sin() * 0.5 + 0.5; + 0.5 + value * 0.45 + }; + Self::item_skeleton( + relative(width).into(), + skeleton_height, + colors.text.alpha(0.45), + ) + }) + } + + pub fn render_pseudo_code_skeleton( + seed: f32, + theme: Arc<Theme>, + skeleton_height: impl Into<Length>, + ) -> impl IntoElement { + let colors = theme.colors(); + let syntax = theme.syntax(); + + let keyword_color = syntax.get("keyword").color; + let function_color = syntax.get("function").color; + let string_color = syntax.get("string").color; + let comment_color = syntax.get("comment").color; + let variable_color = syntax.get("variable").color; + let type_color = syntax.get("type").color; + let punctuation_color = syntax.get("punctuation").color; + + let syntax_colors = [ + keyword_color, + function_color, + string_color, + variable_color, + type_color, + punctuation_color, + comment_color, + ]; + + let skeleton_height = skeleton_height.into(); + + let line_width = |line_idx: usize, block_idx: usize| -> f32 { + let val = + (seed * 100.0 + line_idx as f32 * 20.0 + block_idx as f32 * 5.0).sin() * 0.5 + 0.5; + 0.05 + val * 0.2 + }; + + let indentation = |line_idx: usize| -> f32 { + let step = line_idx % 6; + if step < 3 { + step as f32 * 0.1 + } else { + (5 - step) as f32 * 0.1 + } + }; + + let pick_color = |line_idx: usize, block_idx: usize| -> Hsla { + let idx = ((seed * 10.0 + line_idx as f32 * 7.0 + block_idx as f32 * 3.0).sin() * 3.5) + .abs() as usize + % syntax_colors.len(); + syntax_colors[idx].unwrap_or(colors.text) + }; + + let line_count = 13; + + let lines = (0..line_count) + .map(|line_idx| { + let block_count = (((seed * 30.0 + line_idx as f32 * 12.0).sin() * 0.5 + 0.5) * 3.0) + .round() as usize + + 2; + + let indent = indentation(line_idx); + + let blocks = (0..block_count) + .map(|block_idx| { + let width = line_width(line_idx, block_idx); + let color = pick_color(line_idx, block_idx); + Self::item_skeleton(relative(width).into(), skeleton_height, color) + }) + .collect::<Vec<_>>(); + + h_flex().gap(px(2.)).ml(relative(indent)).children(blocks) + }) + .collect::<Vec<_>>(); + + v_flex().size_full().p_1().gap_1p5().children(lines) + } + + pub fn render_sidebar( + seed: f32, + colors: &ThemeColors, + width: impl Into<Length> + Clone, + skeleton_height: impl Into<Length>, + ) -> impl IntoElement { + div() + .h_full() + .w(width) + .border_r(px(1.)) + .border_color(colors.border_transparent) + .bg(colors.panel_background) + .child(v_flex().p_2().size_full().gap_1().children( + Self::render_sidebar_skeleton_items(seed, colors, skeleton_height.into()), + )) + } + + pub fn render_pane( + seed: f32, + theme: Arc<Theme>, + skeleton_height: impl Into<Length>, + ) -> impl IntoElement { + v_flex().h_full().flex_grow().child( + div() + .size_full() + .overflow_hidden() + .bg(theme.colors().editor_background) + .p_2() + .child(Self::render_pseudo_code_skeleton( + seed, + theme, + skeleton_height.into(), + )), + ) + } + + pub fn render_editor( + seed: f32, + theme: Arc<Theme>, + sidebar_width: impl Into<Length> + Clone, + skeleton_height: impl Into<Length> + Clone, + ) -> impl IntoElement { + div() + .size_full() + .flex() + .bg(theme.colors().background.alpha(1.00)) + .child(Self::render_sidebar( + seed, + theme.colors(), + sidebar_width, + skeleton_height.clone(), + )) + .child(Self::render_pane(seed, theme, skeleton_height.clone())) + } + + fn render_borderless(seed: f32, theme: Arc<Theme>) -> impl IntoElement { + return Self::render_editor( + seed, + theme, + Self::SIDEBAR_WIDTH_DEFAULT, + Self::SKELETON_HEIGHT_DEFAULT, + ); + } + + fn render_border(seed: f32, theme: Arc<Theme>) -> impl IntoElement { + div() + .size_full() + .p(Self::ROOT_PADDING) + .rounded(Self::ROOT_RADIUS) + .child( + div() + .size_full() + .rounded(*Self::CHILD_RADIUS) + .border(Self::CHILD_BORDER) + .border_color(theme.colors().border) + .child(Self::render_editor( + seed, + theme.clone(), + Self::SIDEBAR_WIDTH_DEFAULT, + Self::SKELETON_HEIGHT_DEFAULT, + )), + ) + } + + fn render_side_by_side( + seed: f32, + theme: Arc<Theme>, + other_theme: Arc<Theme>, + border_color: Hsla, + ) -> impl IntoElement { + let sidebar_width = relative(0.20); + + return div() + .size_full() + .p(Self::ROOT_PADDING) + .rounded(Self::ROOT_RADIUS) + .child( + h_flex() + .size_full() + .relative() + .rounded(*Self::CHILD_RADIUS) + .border(Self::CHILD_BORDER) + .border_color(border_color) + .overflow_hidden() + .child(div().size_full().child(Self::render_editor( + seed, + theme.clone(), + sidebar_width, + Self::SKELETON_HEIGHT_DEFAULT, + ))) + .child( + div() + .size_full() + .absolute() + .left_1_2() + .bg(other_theme.colors().editor_background) + .child(Self::render_editor( + seed, + other_theme, + sidebar_width, + Self::SKELETON_HEIGHT_DEFAULT, + )), + ), + ) + .into_any_element(); + } +} + +impl RenderOnce for ThemePreviewTile { + fn render(self, _window: &mut ui::Window, _cx: &mut ui::App) -> impl IntoElement { + match self.style { + ThemePreviewStyle::Bordered => { + Self::render_border(self.seed, self.theme).into_any_element() + } + ThemePreviewStyle::Borderless => { + Self::render_borderless(self.seed, self.theme).into_any_element() + } + ThemePreviewStyle::SideBySide(other_theme) => Self::render_side_by_side( + self.seed, + self.theme, + other_theme, + _cx.theme().colors().border, + ) + .into_any_element(), + } + } +} + +impl Component for ThemePreviewTile { + fn scope() -> ComponentScope { + ComponentScope::Onboarding + } + + fn name() -> &'static str { + "Theme Preview Tile" + } + + fn sort_name() -> &'static str { + "Theme Preview Tile" + } + + fn description() -> Option<&'static str> { + Some(Self::DOCS) + } + + fn preview(_window: &mut Window, cx: &mut App) -> Option<AnyElement> { + let theme_registry = ThemeRegistry::global(cx); + + let one_dark = theme_registry.get("One Dark"); + let one_light = theme_registry.get("One Light"); + let gruvbox_dark = theme_registry.get("Gruvbox Dark"); + let gruvbox_light = theme_registry.get("Gruvbox Light"); + + let themes_to_preview = vec![ + one_dark.clone().ok(), + one_light.clone().ok(), + gruvbox_dark.clone().ok(), + gruvbox_light.clone().ok(), + ] + .into_iter() + .flatten() + .collect::<Vec<_>>(); + + Some( + v_flex() + .gap_6() + .p_4() + .children({ + if let Some(one_dark) = one_dark.ok() { + vec![example_group(vec![single_example( + "Default", + div() + .w(px(240.)) + .h(px(180.)) + .child(ThemePreviewTile::new(one_dark.clone(), 0.42)) + .into_any_element(), + )])] + } else { + vec![] + } + }) + .child( + example_group(vec![single_example( + "Default Themes", + h_flex() + .gap_4() + .children( + themes_to_preview + .iter() + .enumerate() + .map(|(_, theme)| { + div() + .w(px(200.)) + .h(px(140.)) + .child(ThemePreviewTile::new(theme.clone(), 0.42)) + }) + .collect::<Vec<_>>(), + ) + .into_any_element(), + )]) + .grow(), + ) + .into_any_element(), + ) + } +} diff --git a/crates/onboarding/src/welcome.rs b/crates/onboarding/src/welcome.rs new file mode 100644 index 0000000000000000000000000000000000000000..d4d6c3f701dea4f381160b8d76b43c1378685656 --- /dev/null +++ b/crates/onboarding/src/welcome.rs @@ -0,0 +1,355 @@ +use gpui::{ + Action, App, Context, Entity, EventEmitter, FocusHandle, Focusable, InteractiveElement, + NoAction, ParentElement, Render, Styled, Window, actions, +}; +use menu::{SelectNext, SelectPrevious}; +use ui::{ButtonLike, Divider, DividerColor, KeyBinding, Vector, VectorName, prelude::*}; +use workspace::{ + NewFile, Open, WorkspaceId, + item::{Item, ItemEvent}, + with_active_or_new_workspace, +}; +use zed_actions::{Extensions, OpenSettings, agent, command_palette}; + +use crate::{Onboarding, OpenOnboarding}; + +actions!( + zed, + [ + /// Show the Zed welcome screen + ShowWelcome + ] +); + +const CONTENT: (Section<4>, Section<3>) = ( + Section { + title: "Get Started", + entries: [ + SectionEntry { + icon: IconName::Plus, + title: "New File", + action: &NewFile, + }, + SectionEntry { + icon: IconName::FolderOpen, + title: "Open Project", + action: &Open, + }, + SectionEntry { + icon: IconName::CloudDownload, + title: "Clone a Repo", + // TODO: use proper action + action: &NoAction, + }, + SectionEntry { + icon: IconName::ListCollapse, + title: "Open Command Palette", + action: &command_palette::Toggle, + }, + ], + }, + Section { + title: "Configure", + entries: [ + SectionEntry { + icon: IconName::Settings, + title: "Open Settings", + action: &OpenSettings, + }, + SectionEntry { + icon: IconName::ZedAssistant, + title: "View AI Settings", + action: &agent::OpenSettings, + }, + SectionEntry { + icon: IconName::Blocks, + title: "Explore Extensions", + action: &Extensions { + category_filter: None, + id: None, + }, + }, + ], + }, +); + +struct Section<const COLS: usize> { + title: &'static str, + entries: [SectionEntry; COLS], +} + +impl<const COLS: usize> Section<COLS> { + fn render( + self, + index_offset: usize, + focus: &FocusHandle, + window: &mut Window, + cx: &mut App, + ) -> impl IntoElement { + v_flex() + .min_w_full() + .child( + h_flex() + .px_1() + .mb_2() + .gap_2() + .child( + Label::new(self.title.to_ascii_uppercase()) + .buffer_font(cx) + .color(Color::Muted) + .size(LabelSize::XSmall), + ) + .child(Divider::horizontal().color(DividerColor::BorderVariant)), + ) + .children( + self.entries + .iter() + .enumerate() + .map(|(index, entry)| entry.render(index_offset + index, &focus, window, cx)), + ) + } +} + +struct SectionEntry { + icon: IconName, + title: &'static str, + action: &'static dyn Action, +} + +impl SectionEntry { + fn render( + &self, + button_index: usize, + focus: &FocusHandle, + window: &Window, + cx: &App, + ) -> impl IntoElement { + ButtonLike::new(("onboarding-button-id", button_index)) + .tab_index(button_index as isize) + .full_width() + .size(ButtonSize::Medium) + .child( + h_flex() + .w_full() + .justify_between() + .child( + h_flex() + .gap_2() + .child( + Icon::new(self.icon) + .color(Color::Muted) + .size(IconSize::XSmall), + ) + .child(Label::new(self.title)), + ) + .children( + KeyBinding::for_action_in(self.action, focus, window, cx) + .map(|s| s.size(rems_from_px(12.))), + ), + ) + .on_click(|_, window, cx| window.dispatch_action(self.action.boxed_clone(), cx)) + } +} + +pub struct WelcomePage { + focus_handle: FocusHandle, +} + +impl WelcomePage { + fn select_next(&mut self, _: &SelectNext, window: &mut Window, cx: &mut Context<Self>) { + window.focus_next(); + cx.notify(); + } + + fn select_previous(&mut self, _: &SelectPrevious, window: &mut Window, cx: &mut Context<Self>) { + window.focus_prev(); + cx.notify(); + } +} + +impl Render for WelcomePage { + fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { + let (first_section, second_section) = CONTENT; + let first_section_entries = first_section.entries.len(); + let last_index = first_section_entries + second_section.entries.len(); + + h_flex() + .size_full() + .justify_center() + .overflow_hidden() + .bg(cx.theme().colors().editor_background) + .key_context("Welcome") + .track_focus(&self.focus_handle(cx)) + .on_action(cx.listener(Self::select_previous)) + .on_action(cx.listener(Self::select_next)) + .child( + h_flex() + .px_12() + .py_40() + .size_full() + .relative() + .max_w(px(1100.)) + .child( + div() + .size_full() + .max_w_128() + .mx_auto() + .child( + h_flex() + .w_full() + .justify_center() + .gap_4() + .child(Vector::square(VectorName::ZedLogo, rems(2.))) + .child( + div().child(Headline::new("Welcome to Zed")).child( + Label::new("The editor for what's next") + .size(LabelSize::Small) + .color(Color::Muted) + .italic(), + ), + ), + ) + .child( + v_flex() + .mt_10() + .gap_6() + .child(first_section.render( + Default::default(), + &self.focus_handle, + window, + cx, + )) + .child(second_section.render( + first_section_entries, + &self.focus_handle, + window, + cx, + )) + .child( + h_flex() + .w_full() + .pt_4() + .justify_center() + // We call this a hack + .rounded_b_xs() + .border_t_1() + .border_color(cx.theme().colors().border.opacity(0.6)) + .border_dashed() + .child( + Button::new("welcome-exit", "Return to Setup") + .tab_index(last_index as isize) + .full_width() + .label_size(LabelSize::XSmall) + .on_click(|_, window, cx| { + window.dispatch_action( + OpenOnboarding.boxed_clone(), + cx, + ); + + with_active_or_new_workspace(cx, |workspace, window, cx| { + let Some((welcome_id, welcome_idx)) = workspace + .active_pane() + .read(cx) + .items() + .enumerate() + .find_map(|(idx, item)| { + let _ = item.downcast::<WelcomePage>()?; + Some((item.item_id(), idx)) + }) + else { + return; + }; + + workspace.active_pane().update(cx, |pane, cx| { + // Get the index here to get around the borrow checker + let idx = pane.items().enumerate().find_map( + |(idx, item)| { + let _ = + item.downcast::<Onboarding>()?; + Some(idx) + }, + ); + + if let Some(idx) = idx { + pane.activate_item( + idx, true, true, window, cx, + ); + } else { + let item = + Box::new(Onboarding::new(workspace, cx)); + pane.add_item( + item, + true, + true, + Some(welcome_idx), + window, + cx, + ); + } + + pane.remove_item( + welcome_id, + false, + false, + window, + cx, + ); + }); + }); + }), + ), + ), + ), + ), + ) + } +} + +impl WelcomePage { + pub fn new(window: &mut Window, cx: &mut App) -> Entity<Self> { + cx.new(|cx| { + let focus_handle = cx.focus_handle(); + cx.on_focus(&focus_handle, window, |_, _, cx| cx.notify()) + .detach(); + + WelcomePage { focus_handle } + }) + } +} + +impl EventEmitter<ItemEvent> for WelcomePage {} + +impl Focusable for WelcomePage { + fn focus_handle(&self, _: &App) -> gpui::FocusHandle { + self.focus_handle.clone() + } +} + +impl Item for WelcomePage { + type Event = ItemEvent; + + fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString { + "Welcome".into() + } + + fn telemetry_event_text(&self) -> Option<&'static str> { + Some("New Welcome Page Opened") + } + + fn show_toolbar(&self) -> bool { + false + } + + fn clone_on_split( + &self, + _workspace_id: Option<WorkspaceId>, + _: &mut Window, + _: &mut Context<Self>, + ) -> Option<Entity<Self>> { + None + } + + fn to_item_events(event: &Self::Event, mut f: impl FnMut(workspace::item::ItemEvent)) { + f(*event) + } +} diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index 12a5cf52d2efe7bf1d94bfc45ed629e38bc94382..4697d71ed337a94ad3c2a80cda3e50923042aa23 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -74,6 +74,12 @@ pub enum Model { O3, #[serde(rename = "o4-mini")] O4Mini, + #[serde(rename = "gpt-5")] + Five, + #[serde(rename = "gpt-5-mini")] + FiveMini, + #[serde(rename = "gpt-5-nano")] + FiveNano, #[serde(rename = "custom")] Custom { @@ -105,6 +111,9 @@ impl Model { "o3-mini" => Ok(Self::O3Mini), "o3" => Ok(Self::O3), "o4-mini" => Ok(Self::O4Mini), + "gpt-5" => Ok(Self::Five), + "gpt-5-mini" => Ok(Self::FiveMini), + "gpt-5-nano" => Ok(Self::FiveNano), invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"), } } @@ -123,6 +132,9 @@ impl Model { Self::O3Mini => "o3-mini", Self::O3 => "o3", Self::O4Mini => "o4-mini", + Self::Five => "gpt-5", + Self::FiveMini => "gpt-5-mini", + Self::FiveNano => "gpt-5-nano", Self::Custom { name, .. } => name, } } @@ -141,6 +153,9 @@ impl Model { Self::O3Mini => "o3-mini", Self::O3 => "o3", Self::O4Mini => "o4-mini", + Self::Five => "gpt-5", + Self::FiveMini => "gpt-5-mini", + Self::FiveNano => "gpt-5-nano", Self::Custom { name, display_name, .. } => display_name.as_ref().unwrap_or(name), @@ -161,6 +176,9 @@ impl Model { Self::O3Mini => 200_000, Self::O3 => 200_000, Self::O4Mini => 200_000, + Self::Five => 272_000, + Self::FiveMini => 272_000, + Self::FiveNano => 272_000, Self::Custom { max_tokens, .. } => *max_tokens, } } @@ -182,6 +200,9 @@ impl Model { Self::O3Mini => Some(100_000), Self::O3 => Some(100_000), Self::O4Mini => Some(100_000), + Self::Five => Some(128_000), + Self::FiveMini => Some(128_000), + Self::FiveNano => Some(128_000), } } @@ -197,7 +218,10 @@ impl Model { | Self::FourOmniMini | Self::FourPointOne | Self::FourPointOneMini - | Self::FourPointOneNano => true, + | Self::FourPointOneNano + | Self::Five + | Self::FiveMini + | Self::FiveNano => true, Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false, } } diff --git a/crates/open_router/src/open_router.rs b/crates/open_router/src/open_router.rs index 4128426a7fb429337028b03c12e90c6395651f0e..3e6e406d9842d5996f2e866d534094ded23fd61c 100644 --- a/crates/open_router/src/open_router.rs +++ b/crates/open_router/src/open_router.rs @@ -153,11 +153,12 @@ pub struct RequestUsage { } #[derive(Debug, Serialize, Deserialize)] -#[serde(untagged)] +#[serde(rename_all = "lowercase")] pub enum ToolChoice { Auto, Required, None, + #[serde(untagged)] Other(ToolDefinition), } diff --git a/crates/outline_panel/src/outline_panel.rs b/crates/outline_panel/src/outline_panel.rs index 12dcab9e8702a98dbcecd8549ce40fe86fa45e0f..1cda3897ec356c76b8abf4751bad6c35873c1300 100644 --- a/crates/outline_panel/src/outline_panel.rs +++ b/crates/outline_panel/src/outline_panel.rs @@ -1,19 +1,5 @@ mod outline_panel_settings; -use std::{ - cmp, - collections::BTreeMap, - hash::Hash, - ops::Range, - path::{MAIN_SEPARATOR_STR, Path, PathBuf}, - sync::{ - Arc, OnceLock, - atomic::{self, AtomicBool}, - }, - time::Duration, - u32, -}; - use anyhow::Context as _; use collections::{BTreeSet, HashMap, HashSet, hash_map}; use db::kvp::KEY_VALUE_STORE; @@ -36,8 +22,21 @@ use gpui::{ uniform_list, }; use itertools::Itertools; -use language::{BufferId, BufferSnapshot, OffsetRangeExt, OutlineItem}; +use language::{Anchor, BufferId, BufferSnapshot, OffsetRangeExt, OutlineItem}; use menu::{Cancel, SelectFirst, SelectLast, SelectNext, SelectPrevious}; +use std::{ + cmp, + collections::BTreeMap, + hash::Hash, + ops::Range, + path::{MAIN_SEPARATOR_STR, Path, PathBuf}, + sync::{ + Arc, OnceLock, + atomic::{self, AtomicBool}, + }, + time::Duration, + u32, +}; use outline_panel_settings::{OutlinePanelDockPosition, OutlinePanelSettings, ShowIndentGuides}; use project::{File, Fs, GitEntry, GitTraversal, Project, ProjectItem}; @@ -132,6 +131,8 @@ pub struct OutlinePanel { hide_scrollbar_task: Option<Task<()>>, max_width_item_index: Option<usize>, preserve_selection_on_buffer_fold_toggles: HashSet<BufferId>, + pending_default_expansion_depth: Option<usize>, + outline_children_cache: HashMap<BufferId, HashMap<(Range<Anchor>, usize), bool>>, } #[derive(Debug)] @@ -318,12 +319,13 @@ struct CachedEntry { entry: PanelEntry, } -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] enum CollapsedEntry { Dir(WorktreeId, ProjectEntryId), File(WorktreeId, BufferId), ExternalFile(BufferId), Excerpt(BufferId, ExcerptId), + Outline(BufferId, ExcerptId, Range<Anchor>), } #[derive(Debug)] @@ -803,8 +805,56 @@ impl OutlinePanel { outline_panel.update_cached_entries(Some(UPDATE_DEBOUNCE), window, cx); } } else if &outline_panel_settings != new_settings { + let old_expansion_depth = outline_panel_settings.expand_outlines_with_depth; outline_panel_settings = *new_settings; - cx.notify(); + + if old_expansion_depth != new_settings.expand_outlines_with_depth { + let old_collapsed_entries = outline_panel.collapsed_entries.clone(); + outline_panel + .collapsed_entries + .retain(|entry| !matches!(entry, CollapsedEntry::Outline(..))); + + let new_depth = new_settings.expand_outlines_with_depth; + + for (buffer_id, excerpts) in &outline_panel.excerpts { + for (excerpt_id, excerpt) in excerpts { + if let ExcerptOutlines::Outlines(outlines) = &excerpt.outlines { + for outline in outlines { + if outline_panel + .outline_children_cache + .get(buffer_id) + .and_then(|children_map| { + let key = + (outline.range.clone(), outline.depth); + children_map.get(&key) + }) + .copied() + .unwrap_or(false) + && (new_depth == 0 || outline.depth >= new_depth) + { + outline_panel.collapsed_entries.insert( + CollapsedEntry::Outline( + *buffer_id, + *excerpt_id, + outline.range.clone(), + ), + ); + } + } + } + } + } + + if old_collapsed_entries != outline_panel.collapsed_entries { + outline_panel.update_cached_entries( + Some(UPDATE_DEBOUNCE), + window, + cx, + ); + } + } else { + cx.notify(); + } } }); @@ -841,6 +891,7 @@ impl OutlinePanel { updating_cached_entries: false, new_entries_for_fs_update: HashSet::default(), preserve_selection_on_buffer_fold_toggles: HashSet::default(), + pending_default_expansion_depth: None, fs_entries_update_task: Task::ready(()), cached_entries_update_task: Task::ready(()), reveal_selection_task: Task::ready(Ok(())), @@ -855,6 +906,7 @@ impl OutlinePanel { workspace_subscription, filter_update_subscription, ], + outline_children_cache: HashMap::default(), }; if let Some((item, editor)) = workspace_active_editor(workspace, cx) { outline_panel.replace_active_editor(item, editor, window, cx); @@ -989,7 +1041,7 @@ impl OutlinePanel { fn open_excerpts( &mut self, - action: &editor::OpenExcerpts, + action: &editor::actions::OpenExcerpts, window: &mut Window, cx: &mut Context<Self>, ) { @@ -1005,7 +1057,7 @@ impl OutlinePanel { fn open_excerpts_split( &mut self, - action: &editor::OpenExcerptsSplit, + action: &editor::actions::OpenExcerptsSplit, window: &mut Window, cx: &mut Context<Self>, ) { @@ -1462,7 +1514,12 @@ impl OutlinePanel { PanelEntry::Outline(OutlineEntry::Excerpt(excerpt)) => { Some(CollapsedEntry::Excerpt(excerpt.buffer_id, excerpt.id)) } - PanelEntry::Search(_) | PanelEntry::Outline(..) => return, + PanelEntry::Outline(OutlineEntry::Outline(outline)) => Some(CollapsedEntry::Outline( + outline.buffer_id, + outline.excerpt_id, + outline.outline.range.clone(), + )), + PanelEntry::Search(_) => return, }; let Some(collapsed_entry) = entry_to_expand else { return; @@ -1565,7 +1622,14 @@ impl OutlinePanel { PanelEntry::Outline(OutlineEntry::Excerpt(excerpt)) => self .collapsed_entries .insert(CollapsedEntry::Excerpt(excerpt.buffer_id, excerpt.id)), - PanelEntry::Search(_) | PanelEntry::Outline(..) => false, + PanelEntry::Outline(OutlineEntry::Outline(outline)) => { + self.collapsed_entries.insert(CollapsedEntry::Outline( + outline.buffer_id, + outline.excerpt_id, + outline.outline.range.clone(), + )) + } + PanelEntry::Search(_) => false, }; if collapsed { @@ -1780,7 +1844,17 @@ impl OutlinePanel { self.collapsed_entries.insert(collapsed_entry); } } - PanelEntry::Search(_) | PanelEntry::Outline(..) => return, + PanelEntry::Outline(OutlineEntry::Outline(outline)) => { + let collapsed_entry = CollapsedEntry::Outline( + outline.buffer_id, + outline.excerpt_id, + outline.outline.range.clone(), + ); + if !self.collapsed_entries.remove(&collapsed_entry) { + self.collapsed_entries.insert(collapsed_entry); + } + } + _ => {} } active_editor.update(cx, |editor, cx| { @@ -2108,7 +2182,7 @@ impl OutlinePanel { PanelEntry::Outline(OutlineEntry::Excerpt(excerpt.clone())), item_id, depth, - Some(icon), + icon, is_active, label_element, window, @@ -2160,10 +2234,31 @@ impl OutlinePanel { _ => false, }; - let icon = if self.is_singleton_active(cx) { - None + let has_children = self + .outline_children_cache + .get(&outline.buffer_id) + .and_then(|children_map| { + let key = (outline.outline.range.clone(), outline.outline.depth); + children_map.get(&key) + }) + .copied() + .unwrap_or(false); + let is_expanded = !self.collapsed_entries.contains(&CollapsedEntry::Outline( + outline.buffer_id, + outline.excerpt_id, + outline.outline.range.clone(), + )); + + let icon = if has_children { + FileIcons::get_chevron_icon(is_expanded, cx) + .map(|icon_path| { + Icon::from_path(icon_path) + .color(entry_label_color(is_active)) + .into_any_element() + }) + .unwrap_or_else(empty_icon) } else { - Some(empty_icon()) + empty_icon() }; self.entry_element( @@ -2287,7 +2382,7 @@ impl OutlinePanel { PanelEntry::Fs(rendered_entry.clone()), item_id, depth, - Some(icon), + icon, is_active, label_element, window, @@ -2358,7 +2453,7 @@ impl OutlinePanel { PanelEntry::FoldedDirs(folded_dir.clone()), item_id, depth, - Some(icon), + icon, is_active, label_element, window, @@ -2449,7 +2544,7 @@ impl OutlinePanel { }), ElementId::from(SharedString::from(format!("search-{match_range:?}"))), depth, - None, + empty_icon(), is_active, entire_label, window, @@ -2462,7 +2557,7 @@ impl OutlinePanel { rendered_entry: PanelEntry, item_id: ElementId, depth: usize, - icon_element: Option<AnyElement>, + icon_element: AnyElement, is_active: bool, label_element: gpui::AnyElement, window: &mut Window, @@ -2475,11 +2570,13 @@ impl OutlinePanel { .on_click({ let clicked_entry = rendered_entry.clone(); cx.listener(move |outline_panel, event: &gpui::ClickEvent, window, cx| { - if event.down.button == MouseButton::Right || event.down.first_mouse { + if event.is_right_click() || event.first_focus() { return; } - let change_focus = event.down.click_count > 1; + + let change_focus = event.click_count() > 1; outline_panel.toggle_expanded(&clicked_entry, window, cx); + outline_panel.scroll_editor_to_entry( &clicked_entry, true, @@ -2495,10 +2592,11 @@ impl OutlinePanel { .indent_level(depth) .indent_step_size(px(settings.indent_size)) .toggle_state(is_active) - .when_some(icon_element, |list_item, icon_element| { - list_item.child(h_flex().child(icon_element)) - }) - .child(h_flex().h_6().child(label_element).ml_1()) + .child( + h_flex() + .child(h_flex().w(px(16.)).justify_center().child(icon_element)) + .child(h_flex().h_6().child(label_element).ml_1()), + ) .on_secondary_mouse_down(cx.listener( move |outline_panel, event: &MouseDownEvent, window, cx| { // Stop propagation to prevent the catch-all context menu for the project @@ -2940,7 +3038,12 @@ impl OutlinePanel { outline_panel.fs_entries_depth = new_depth_map; outline_panel.fs_children_count = new_children_count; outline_panel.update_non_fs_items(window, cx); - outline_panel.update_cached_entries(debounce, window, cx); + + // Only update cached entries if we don't have outlines to fetch + // If we do have outlines to fetch, let fetch_outdated_outlines handle the update + if outline_panel.excerpt_fetch_ranges(cx).is_empty() { + outline_panel.update_cached_entries(debounce, window, cx); + } cx.notify(); }) @@ -2956,6 +3059,12 @@ impl OutlinePanel { cx: &mut Context<Self>, ) { self.clear_previous(window, cx); + + let default_expansion_depth = + OutlinePanelSettings::get_global(cx).expand_outlines_with_depth; + // We'll apply the expansion depth after outlines are loaded + self.pending_default_expansion_depth = Some(default_expansion_depth); + let buffer_search_subscription = cx.subscribe_in( &new_active_editor, window, @@ -3004,6 +3113,7 @@ impl OutlinePanel { self.selected_entry = SelectedEntry::None; self.pinned = false; self.mode = ItemsDisplayMode::Outline; + self.pending_default_expansion_depth = None; } fn location_for_editor_selection( @@ -3259,25 +3369,74 @@ impl OutlinePanel { || buffer_language.as_ref() == buffer_snapshot.language_at(outline.range.start) }); - outlines + + let outlines_with_children = outlines + .windows(2) + .filter_map(|window| { + let current = &window[0]; + let next = &window[1]; + if next.depth > current.depth { + Some((current.range.clone(), current.depth)) + } else { + None + } + }) + .collect::<HashSet<_>>(); + + (outlines, outlines_with_children) }) .await; + + let (fetched_outlines, outlines_with_children) = fetched_outlines; + outline_panel .update_in(cx, |outline_panel, window, cx| { + let pending_default_depth = + outline_panel.pending_default_expansion_depth.take(); + + let debounce = + if first_update.fetch_and(false, atomic::Ordering::AcqRel) { + None + } else { + Some(UPDATE_DEBOUNCE) + }; + if let Some(excerpt) = outline_panel .excerpts .entry(buffer_id) .or_default() .get_mut(&excerpt_id) { - let debounce = if first_update - .fetch_and(false, atomic::Ordering::AcqRel) - { - None - } else { - Some(UPDATE_DEBOUNCE) - }; excerpt.outlines = ExcerptOutlines::Outlines(fetched_outlines); + + if let Some(default_depth) = pending_default_depth { + if let ExcerptOutlines::Outlines(outlines) = + &excerpt.outlines + { + outlines + .iter() + .filter(|outline| { + (default_depth == 0 + || outline.depth >= default_depth) + && outlines_with_children.contains(&( + outline.range.clone(), + outline.depth, + )) + }) + .for_each(|outline| { + outline_panel.collapsed_entries.insert( + CollapsedEntry::Outline( + buffer_id, + excerpt_id, + outline.range.clone(), + ), + ); + }); + } + } + + // Even if no outlines to check, we still need to update cached entries + // to show the outline entries that were just fetched outline_panel.update_cached_entries(debounce, window, cx); } }) @@ -4083,7 +4242,7 @@ impl OutlinePanel { } fn add_excerpt_entries( - &self, + &mut self, state: &mut GenerationState, buffer_id: BufferId, entries_to_add: &[ExcerptId], @@ -4094,6 +4253,8 @@ impl OutlinePanel { cx: &mut Context<Self>, ) { if let Some(excerpts) = self.excerpts.get(&buffer_id) { + let buffer_snapshot = self.buffer_snapshot_for_id(buffer_id, cx); + for &excerpt_id in entries_to_add { let Some(excerpt) = excerpts.get(&excerpt_id) else { continue; @@ -4123,15 +4284,84 @@ impl OutlinePanel { continue; } - for outline in excerpt.iter_outlines() { + let mut last_depth_at_level: Vec<Option<Range<Anchor>>> = vec![None; 10]; + + let all_outlines: Vec<_> = excerpt.iter_outlines().collect(); + + let mut outline_has_children = HashMap::default(); + let mut visible_outlines = Vec::new(); + let mut collapsed_state: Option<(usize, Range<Anchor>)> = None; + + for (i, &outline) in all_outlines.iter().enumerate() { + let has_children = all_outlines + .get(i + 1) + .map(|next| next.depth > outline.depth) + .unwrap_or(false); + + outline_has_children + .insert((outline.range.clone(), outline.depth), has_children); + + let mut should_include = true; + + if let Some((collapsed_depth, collapsed_range)) = &collapsed_state { + if outline.depth <= *collapsed_depth { + collapsed_state = None; + } else if let Some(buffer_snapshot) = buffer_snapshot.as_ref() { + let outline_start = outline.range.start; + if outline_start + .cmp(&collapsed_range.start, buffer_snapshot) + .is_ge() + && outline_start + .cmp(&collapsed_range.end, buffer_snapshot) + .is_lt() + { + should_include = false; // Skip - inside collapsed range + } else { + collapsed_state = None; + } + } + } + + // Check if this outline itself is collapsed + if should_include + && self.collapsed_entries.contains(&CollapsedEntry::Outline( + buffer_id, + excerpt_id, + outline.range.clone(), + )) + { + collapsed_state = Some((outline.depth, outline.range.clone())); + } + + if should_include { + visible_outlines.push(outline); + } + } + + self.outline_children_cache + .entry(buffer_id) + .or_default() + .extend(outline_has_children); + + for outline in visible_outlines { + let outline_entry = OutlineEntryOutline { + buffer_id, + excerpt_id, + outline: outline.clone(), + }; + + if outline.depth < last_depth_at_level.len() { + last_depth_at_level[outline.depth] = Some(outline.range.clone()); + // Clear deeper levels when we go back to a shallower depth + for d in (outline.depth + 1)..last_depth_at_level.len() { + last_depth_at_level[d] = None; + } + } + self.push_entry( state, track_matches, - PanelEntry::Outline(OutlineEntry::Outline(OutlineEntryOutline { - buffer_id, - excerpt_id, - outline: outline.clone(), - })), + PanelEntry::Outline(OutlineEntry::Outline(outline_entry)), outline_base_depth + outline.depth, cx, ); @@ -5728,7 +5958,7 @@ mod tests { }); outline_panel.update_in(cx, |outline_panel, window, cx| { - outline_panel.open_excerpts(&editor::OpenExcerpts, window, cx); + outline_panel.open_excerpts(&editor::actions::OpenExcerpts, window, cx); }); cx.executor() .advance_clock(UPDATE_DEBOUNCE + Duration::from_millis(100)); @@ -6908,4 +7138,540 @@ outline: struct OutlineEntryExcerpt multi_buffer_snapshot.text_for_range(line_start..line_end).collect::<String>().trim().to_owned() }) } + + #[gpui::test] + async fn test_outline_keyboard_expand_collapse(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree( + "/test", + json!({ + "src": { + "lib.rs": indoc!(" + mod outer { + pub struct OuterStruct { + field: String, + } + impl OuterStruct { + pub fn new() -> Self { + Self { field: String::new() } + } + pub fn method(&self) { + println!(\"{}\", self.field); + } + } + mod inner { + pub fn inner_function() { + let x = 42; + println!(\"{}\", x); + } + pub struct InnerStruct { + value: i32, + } + } + } + fn main() { + let s = outer::OuterStruct::new(); + s.method(); + } + "), + } + }), + ) + .await; + + let project = Project::test(fs.clone(), ["/test".as_ref()], cx).await; + project.read_with(cx, |project, _| { + project.languages().add(Arc::new( + rust_lang() + .with_outline_query( + r#" + (struct_item + (visibility_modifier)? @context + "struct" @context + name: (_) @name) @item + (impl_item + "impl" @context + trait: (_)? @context + "for"? @context + type: (_) @context + body: (_)) @item + (function_item + (visibility_modifier)? @context + "fn" @context + name: (_) @name + parameters: (_) @context) @item + (mod_item + (visibility_modifier)? @context + "mod" @context + name: (_) @name) @item + (enum_item + (visibility_modifier)? @context + "enum" @context + name: (_) @name) @item + (field_declaration + (visibility_modifier)? @context + name: (_) @name + ":" @context + type: (_) @context) @item + "#, + ) + .unwrap(), + )) + }); + let workspace = add_outline_panel(&project, cx).await; + let cx = &mut VisualTestContext::from_window(*workspace, cx); + let outline_panel = outline_panel(&workspace, cx); + + outline_panel.update_in(cx, |outline_panel, window, cx| { + outline_panel.set_active(true, window, cx) + }); + + workspace + .update(cx, |workspace, window, cx| { + workspace.open_abs_path( + PathBuf::from("/test/src/lib.rs"), + OpenOptions { + visible: Some(OpenVisible::All), + ..Default::default() + }, + window, + cx, + ) + }) + .unwrap() + .await + .unwrap(); + + cx.executor() + .advance_clock(UPDATE_DEBOUNCE + Duration::from_millis(500)); + cx.run_until_parked(); + + // Force another update cycle to ensure outlines are fetched + outline_panel.update_in(cx, |panel, window, cx| { + panel.update_non_fs_items(window, cx); + panel.update_cached_entries(Some(UPDATE_DEBOUNCE), window, cx); + }); + cx.executor() + .advance_clock(UPDATE_DEBOUNCE + Duration::from_millis(500)); + cx.run_until_parked(); + + outline_panel.update(cx, |outline_panel, cx| { + assert_eq!( + display_entries( + &project, + &snapshot(&outline_panel, cx), + &outline_panel.cached_entries, + outline_panel.selected_entry(), + cx, + ), + indoc!( + " +outline: mod outer <==== selected + outline: pub struct OuterStruct + outline: field: String + outline: impl OuterStruct + outline: pub fn new() + outline: pub fn method(&self) + outline: mod inner + outline: pub fn inner_function() + outline: pub struct InnerStruct + outline: value: i32 +outline: fn main()" + ) + ); + }); + + let parent_outline = outline_panel + .read_with(cx, |panel, _cx| { + panel + .cached_entries + .iter() + .find_map(|entry| match &entry.entry { + PanelEntry::Outline(OutlineEntry::Outline(outline)) + if panel + .outline_children_cache + .get(&outline.buffer_id) + .and_then(|children_map| { + let key = + (outline.outline.range.clone(), outline.outline.depth); + children_map.get(&key) + }) + .copied() + .unwrap_or(false) => + { + Some(entry.entry.clone()) + } + _ => None, + }) + }) + .expect("Should find an outline with children"); + + outline_panel.update_in(cx, |panel, window, cx| { + panel.select_entry(parent_outline.clone(), true, window, cx); + panel.collapse_selected_entry(&CollapseSelectedEntry, window, cx); + }); + cx.executor() + .advance_clock(UPDATE_DEBOUNCE + Duration::from_millis(100)); + cx.run_until_parked(); + + outline_panel.update(cx, |outline_panel, cx| { + assert_eq!( + display_entries( + &project, + &snapshot(&outline_panel, cx), + &outline_panel.cached_entries, + outline_panel.selected_entry(), + cx, + ), + indoc!( + " +outline: mod outer <==== selected +outline: fn main()" + ) + ); + }); + + outline_panel.update_in(cx, |panel, window, cx| { + panel.expand_selected_entry(&ExpandSelectedEntry, window, cx); + }); + cx.executor() + .advance_clock(UPDATE_DEBOUNCE + Duration::from_millis(100)); + cx.run_until_parked(); + + outline_panel.update(cx, |outline_panel, cx| { + assert_eq!( + display_entries( + &project, + &snapshot(&outline_panel, cx), + &outline_panel.cached_entries, + outline_panel.selected_entry(), + cx, + ), + indoc!( + " +outline: mod outer <==== selected + outline: pub struct OuterStruct + outline: field: String + outline: impl OuterStruct + outline: pub fn new() + outline: pub fn method(&self) + outline: mod inner + outline: pub fn inner_function() + outline: pub struct InnerStruct + outline: value: i32 +outline: fn main()" + ) + ); + }); + + outline_panel.update_in(cx, |panel, window, cx| { + panel.collapsed_entries.clear(); + panel.update_cached_entries(None, window, cx); + }); + cx.executor() + .advance_clock(UPDATE_DEBOUNCE + Duration::from_millis(100)); + cx.run_until_parked(); + + outline_panel.update_in(cx, |panel, window, cx| { + let outlines_with_children: Vec<_> = panel + .cached_entries + .iter() + .filter_map(|entry| match &entry.entry { + PanelEntry::Outline(OutlineEntry::Outline(outline)) + if panel + .outline_children_cache + .get(&outline.buffer_id) + .and_then(|children_map| { + let key = (outline.outline.range.clone(), outline.outline.depth); + children_map.get(&key) + }) + .copied() + .unwrap_or(false) => + { + Some(entry.entry.clone()) + } + _ => None, + }) + .collect(); + + for outline in outlines_with_children { + panel.select_entry(outline, false, window, cx); + panel.collapse_selected_entry(&CollapseSelectedEntry, window, cx); + } + }); + cx.executor() + .advance_clock(UPDATE_DEBOUNCE + Duration::from_millis(100)); + cx.run_until_parked(); + + outline_panel.update(cx, |outline_panel, cx| { + assert_eq!( + display_entries( + &project, + &snapshot(&outline_panel, cx), + &outline_panel.cached_entries, + outline_panel.selected_entry(), + cx, + ), + indoc!( + " +outline: mod outer +outline: fn main()" + ) + ); + }); + + let collapsed_entries_count = + outline_panel.read_with(cx, |panel, _| panel.collapsed_entries.len()); + assert!( + collapsed_entries_count > 0, + "Should have collapsed entries tracked" + ); + } + + #[gpui::test] + async fn test_outline_click_toggle_behavior(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree( + "/test", + json!({ + "src": { + "main.rs": indoc!(" + struct Config { + name: String, + value: i32, + } + impl Config { + fn new(name: String) -> Self { + Self { name, value: 0 } + } + fn get_value(&self) -> i32 { + self.value + } + } + enum Status { + Active, + Inactive, + } + fn process_config(config: Config) -> Status { + if config.get_value() > 0 { + Status::Active + } else { + Status::Inactive + } + } + fn main() { + let config = Config::new(\"test\".to_string()); + let status = process_config(config); + } + "), + } + }), + ) + .await; + + let project = Project::test(fs.clone(), ["/test".as_ref()], cx).await; + project.read_with(cx, |project, _| { + project.languages().add(Arc::new( + rust_lang() + .with_outline_query( + r#" + (struct_item + (visibility_modifier)? @context + "struct" @context + name: (_) @name) @item + (impl_item + "impl" @context + trait: (_)? @context + "for"? @context + type: (_) @context + body: (_)) @item + (function_item + (visibility_modifier)? @context + "fn" @context + name: (_) @name + parameters: (_) @context) @item + (mod_item + (visibility_modifier)? @context + "mod" @context + name: (_) @name) @item + (enum_item + (visibility_modifier)? @context + "enum" @context + name: (_) @name) @item + (field_declaration + (visibility_modifier)? @context + name: (_) @name + ":" @context + type: (_) @context) @item + "#, + ) + .unwrap(), + )) + }); + + let workspace = add_outline_panel(&project, cx).await; + let cx = &mut VisualTestContext::from_window(*workspace, cx); + let outline_panel = outline_panel(&workspace, cx); + + outline_panel.update_in(cx, |outline_panel, window, cx| { + outline_panel.set_active(true, window, cx) + }); + + let _editor = workspace + .update(cx, |workspace, window, cx| { + workspace.open_abs_path( + PathBuf::from("/test/src/main.rs"), + OpenOptions { + visible: Some(OpenVisible::All), + ..Default::default() + }, + window, + cx, + ) + }) + .unwrap() + .await + .unwrap(); + + cx.executor() + .advance_clock(UPDATE_DEBOUNCE + Duration::from_millis(100)); + cx.run_until_parked(); + + outline_panel.update(cx, |outline_panel, _cx| { + outline_panel.selected_entry = SelectedEntry::None; + }); + + // Check initial state - all entries should be expanded by default + outline_panel.update(cx, |outline_panel, cx| { + assert_eq!( + display_entries( + &project, + &snapshot(&outline_panel, cx), + &outline_panel.cached_entries, + outline_panel.selected_entry(), + cx, + ), + indoc!( + " +outline: struct Config + outline: name: String + outline: value: i32 +outline: impl Config + outline: fn new(name: String) + outline: fn get_value(&self) +outline: enum Status +outline: fn process_config(config: Config) +outline: fn main()" + ) + ); + }); + + outline_panel.update(cx, |outline_panel, _cx| { + outline_panel.selected_entry = SelectedEntry::None; + }); + + cx.update(|window, cx| { + outline_panel.update(cx, |outline_panel, cx| { + outline_panel.select_first(&SelectFirst, window, cx); + }); + }); + + cx.executor() + .advance_clock(UPDATE_DEBOUNCE + Duration::from_millis(100)); + cx.run_until_parked(); + + outline_panel.update(cx, |outline_panel, cx| { + assert_eq!( + display_entries( + &project, + &snapshot(&outline_panel, cx), + &outline_panel.cached_entries, + outline_panel.selected_entry(), + cx, + ), + indoc!( + " +outline: struct Config <==== selected + outline: name: String + outline: value: i32 +outline: impl Config + outline: fn new(name: String) + outline: fn get_value(&self) +outline: enum Status +outline: fn process_config(config: Config) +outline: fn main()" + ) + ); + }); + + cx.update(|window, cx| { + outline_panel.update(cx, |outline_panel, cx| { + outline_panel.open_selected_entry(&OpenSelectedEntry, window, cx); + }); + }); + + cx.executor() + .advance_clock(UPDATE_DEBOUNCE + Duration::from_millis(100)); + cx.run_until_parked(); + + outline_panel.update(cx, |outline_panel, cx| { + assert_eq!( + display_entries( + &project, + &snapshot(&outline_panel, cx), + &outline_panel.cached_entries, + outline_panel.selected_entry(), + cx, + ), + indoc!( + " +outline: struct Config <==== selected +outline: impl Config + outline: fn new(name: String) + outline: fn get_value(&self) +outline: enum Status +outline: fn process_config(config: Config) +outline: fn main()" + ) + ); + }); + + cx.update(|window, cx| { + outline_panel.update(cx, |outline_panel, cx| { + outline_panel.open_selected_entry(&OpenSelectedEntry, window, cx); + }); + }); + + cx.executor() + .advance_clock(UPDATE_DEBOUNCE + Duration::from_millis(100)); + cx.run_until_parked(); + + outline_panel.update(cx, |outline_panel, cx| { + assert_eq!( + display_entries( + &project, + &snapshot(&outline_panel, cx), + &outline_panel.cached_entries, + outline_panel.selected_entry(), + cx, + ), + indoc!( + " +outline: struct Config <==== selected + outline: name: String + outline: value: i32 +outline: impl Config + outline: fn new(name: String) + outline: fn get_value(&self) +outline: enum Status +outline: fn process_config(config: Config) +outline: fn main()" + ) + ); + }); + } } diff --git a/crates/outline_panel/src/outline_panel_settings.rs b/crates/outline_panel/src/outline_panel_settings.rs index 6b70cb54fbc23e03fdbc13c90b912648daf9515b..133d28b748d2978e07a540b3c8c7517b03dc4767 100644 --- a/crates/outline_panel/src/outline_panel_settings.rs +++ b/crates/outline_panel/src/outline_panel_settings.rs @@ -31,6 +31,7 @@ pub struct OutlinePanelSettings { pub auto_reveal_entries: bool, pub auto_fold_dirs: bool, pub scrollbar: ScrollbarSettings, + pub expand_outlines_with_depth: usize, } #[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] @@ -105,6 +106,13 @@ pub struct OutlinePanelSettingsContent { pub indent_guides: Option<IndentGuidesSettingsContent>, /// Scrollbar-related settings pub scrollbar: Option<ScrollbarSettingsContent>, + /// Default depth to expand outline items in the current file. + /// The default depth to which outline entries are expanded on reveal. + /// - Set to 0 to collapse all items that have children + /// - Set to 1 or higher to collapse items at that depth or deeper + /// + /// Default: 100 + pub expand_outlines_with_depth: Option<usize>, } impl Settings for OutlinePanelSettings { diff --git a/crates/paths/src/paths.rs b/crates/paths/src/paths.rs index 2f3b18898077bcc455ca8e616a9d550019cd3cbb..47a0f12c0634dbde48d015e4f577519babc67b34 100644 --- a/crates/paths/src/paths.rs +++ b/crates/paths/src/paths.rs @@ -35,6 +35,7 @@ pub fn remote_server_dir_relative() -> &'static Path { /// Sets a custom directory for all user data, overriding the default data directory. /// This function must be called before any other path operations that depend on the data directory. +/// The directory's path will be canonicalized to an absolute path by a blocking FS operation. /// The directory will be created if it doesn't exist. /// /// # Arguments @@ -50,13 +51,20 @@ pub fn remote_server_dir_relative() -> &'static Path { /// /// Panics if: /// * Called after the data directory has been initialized (e.g., via `data_dir` or `config_dir`) +/// * The directory's path cannot be canonicalized to an absolute path /// * The directory cannot be created pub fn set_custom_data_dir(dir: &str) -> &'static PathBuf { if CURRENT_DATA_DIR.get().is_some() || CONFIG_DIR.get().is_some() { panic!("set_custom_data_dir called after data_dir or config_dir was initialized"); } CUSTOM_DATA_DIR.get_or_init(|| { - let path = PathBuf::from(dir); + let mut path = PathBuf::from(dir); + if path.is_relative() { + let abs_path = path + .canonicalize() + .expect("failed to canonicalize custom data directory's path to an absolute path"); + path = PathBuf::from(util::paths::SanitizedPath::from(abs_path)) + } std::fs::create_dir_all(&path).expect("failed to create custom data directory"); path }) diff --git a/crates/picker/src/picker.rs b/crates/picker/src/picker.rs index 692bdd5bd7a49a3d293603358c1c4d8a2061c42a..34af5fed02e66fe242c398ebcf910bc89d81a256 100644 --- a/crates/picker/src/picker.rs +++ b/crates/picker/src/picker.rs @@ -292,7 +292,7 @@ impl<D: PickerDelegate> Picker<D> { window: &mut Window, cx: &mut Context<Self>, ) -> Self { - let element_container = Self::create_element_container(container, cx); + let element_container = Self::create_element_container(container); let scrollbar_state = match &element_container { ElementContainer::UniformList(scroll_handle) => { ScrollbarState::new(scroll_handle.clone()) @@ -323,31 +323,13 @@ impl<D: PickerDelegate> Picker<D> { this } - fn create_element_container( - container: ContainerKind, - cx: &mut Context<Self>, - ) -> ElementContainer { + fn create_element_container(container: ContainerKind) -> ElementContainer { match container { ContainerKind::UniformList => { ElementContainer::UniformList(UniformListScrollHandle::new()) } ContainerKind::List => { - let entity = cx.entity().downgrade(); - ElementContainer::List(ListState::new( - 0, - gpui::ListAlignment::Top, - px(1000.), - move |ix, window, cx| { - entity - .upgrade() - .map(|entity| { - entity.update(cx, |this, cx| { - this.render_element(window, cx, ix).into_any_element() - }) - }) - .unwrap_or_else(|| div().into_any_element()) - }, - )) + ElementContainer::List(ListState::new(0, gpui::ListAlignment::Top, px(1000.))) } } } @@ -786,11 +768,16 @@ impl<D: PickerDelegate> Picker<D> { .py_1() .track_scroll(scroll_handle.clone()) .into_any_element(), - ElementContainer::List(state) => list(state.clone()) - .with_sizing_behavior(sizing_behavior) - .flex_grow() - .py_2() - .into_any_element(), + ElementContainer::List(state) => list( + state.clone(), + cx.processor(|this, ix, window, cx| { + this.render_element(window, cx, ix).into_any_element() + }), + ) + .with_sizing_behavior(sizing_behavior) + .flex_grow() + .py_2() + .into_any_element(), } } diff --git a/crates/picker/src/popover_menu.rs b/crates/picker/src/popover_menu.rs index dd1d9c2865586dc771d10765df410b65777c8caa..d05308ee71e87a472ffcb33e9727ef74fae70602 100644 --- a/crates/picker/src/popover_menu.rs +++ b/crates/picker/src/popover_menu.rs @@ -80,6 +80,7 @@ where { fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement { let picker = self.picker.clone(); + PopoverMenu::new("popover-menu") .menu(move |_window, _cx| Some(picker.clone())) .trigger_with_tooltip(self.trigger, self.tooltip) diff --git a/crates/prettier/src/prettier_server.js b/crates/prettier/src/prettier_server.js index 6799b4acebc8be68d99c95610a617f5b7f1ea1c4..b3d8a660a40d6f629ba63847f5e00d91046b7cd7 100644 --- a/crates/prettier/src/prettier_server.js +++ b/crates/prettier/src/prettier_server.js @@ -152,6 +152,10 @@ async function handleMessage(message, prettier) { throw new Error(`Message method is undefined: ${JSON.stringify(message)}`); } else if (method == "initialized") { return; + } else if (method === "shutdown") { + sendResponse({ result: {} }); + } else if (method == "exit") { + process.exit(0); } if (id === undefined) { diff --git a/crates/project/Cargo.toml b/crates/project/Cargo.toml index 729d61aab53a60a2d6f59c7b066b00ecc49ab913..57d6d6ca283af0fd51ed10622f55edc9fb086e7e 100644 --- a/crates/project/Cargo.toml +++ b/crates/project/Cargo.toml @@ -31,6 +31,7 @@ aho-corasick.workspace = true anyhow.workspace = true askpass.workspace = true async-trait.workspace = true +base64.workspace = true buffer_diff.workspace = true circular-buffer.workspace = true client.workspace = true @@ -72,6 +73,7 @@ settings.workspace = true sha2.workspace = true shellexpand.workspace = true shlex.workspace = true +smallvec.workspace = true smol.workspace = true snippet.workspace = true snippet_provider.workspace = true diff --git a/crates/project/src/context_server_store.rs b/crates/project/src/context_server_store.rs index fd31e638d4bf7774af83d430dca232d1ade74f01..c96ab4e8f3ba87133d9b64e9701130f5d32adfb9 100644 --- a/crates/project/src/context_server_store.rs +++ b/crates/project/src/context_server_store.rs @@ -13,6 +13,7 @@ use settings::{Settings as _, SettingsStore}; use util::ResultExt as _; use crate::{ + Project, project_settings::{ContextServerSettings, ProjectSettings}, worktree_store::WorktreeStore, }; @@ -144,6 +145,7 @@ pub struct ContextServerStore { context_server_settings: HashMap<Arc<str>, ContextServerSettings>, servers: HashMap<ContextServerId, ContextServerState>, worktree_store: Entity<WorktreeStore>, + project: WeakEntity<Project>, registry: Entity<ContextServerDescriptorRegistry>, update_servers_task: Option<Task<Result<()>>>, context_server_factory: Option<ContextServerFactory>, @@ -161,12 +163,17 @@ pub enum Event { impl EventEmitter<Event> for ContextServerStore {} impl ContextServerStore { - pub fn new(worktree_store: Entity<WorktreeStore>, cx: &mut Context<Self>) -> Self { + pub fn new( + worktree_store: Entity<WorktreeStore>, + weak_project: WeakEntity<Project>, + cx: &mut Context<Self>, + ) -> Self { Self::new_internal( true, None, ContextServerDescriptorRegistry::default_global(cx), worktree_store, + weak_project, cx, ) } @@ -184,9 +191,10 @@ impl ContextServerStore { pub fn test( registry: Entity<ContextServerDescriptorRegistry>, worktree_store: Entity<WorktreeStore>, + weak_project: WeakEntity<Project>, cx: &mut Context<Self>, ) -> Self { - Self::new_internal(false, None, registry, worktree_store, cx) + Self::new_internal(false, None, registry, worktree_store, weak_project, cx) } #[cfg(any(test, feature = "test-support"))] @@ -194,6 +202,7 @@ impl ContextServerStore { context_server_factory: ContextServerFactory, registry: Entity<ContextServerDescriptorRegistry>, worktree_store: Entity<WorktreeStore>, + weak_project: WeakEntity<Project>, cx: &mut Context<Self>, ) -> Self { Self::new_internal( @@ -201,6 +210,7 @@ impl ContextServerStore { Some(context_server_factory), registry, worktree_store, + weak_project, cx, ) } @@ -210,6 +220,7 @@ impl ContextServerStore { context_server_factory: Option<ContextServerFactory>, registry: Entity<ContextServerDescriptorRegistry>, worktree_store: Entity<WorktreeStore>, + weak_project: WeakEntity<Project>, cx: &mut Context<Self>, ) -> Self { let subscriptions = if maintain_server_loop { @@ -235,6 +246,7 @@ impl ContextServerStore { context_server_settings: Self::resolve_context_server_settings(&worktree_store, cx) .clone(), worktree_store, + project: weak_project, registry, needs_server_update: false, servers: HashMap::default(), @@ -360,7 +372,7 @@ impl ContextServerStore { let configuration = state.configuration(); self.stop_server(&state.server().id(), cx)?; - let new_server = self.create_context_server(id.clone(), configuration.clone())?; + let new_server = self.create_context_server(id.clone(), configuration.clone(), cx); self.run_server(new_server, configuration, cx); } Ok(()) @@ -449,14 +461,33 @@ impl ContextServerStore { &self, id: ContextServerId, configuration: Arc<ContextServerConfiguration>, - ) -> Result<Arc<ContextServer>> { + cx: &mut Context<Self>, + ) -> Arc<ContextServer> { + let root_path = self + .project + .read_with(cx, |project, cx| project.active_project_directory(cx)) + .ok() + .flatten() + .or_else(|| { + self.worktree_store.read_with(cx, |store, cx| { + store.visible_worktrees(cx).fold(None, |acc, item| { + if acc.is_none() { + item.read(cx).root_dir() + } else { + acc + } + }) + }) + }); + if let Some(factory) = self.context_server_factory.as_ref() { - Ok(factory(id, configuration)) + factory(id, configuration) } else { - Ok(Arc::new(ContextServer::stdio( + Arc::new(ContextServer::stdio( id, configuration.command().clone(), - ))) + root_path, + )) } } @@ -553,7 +584,7 @@ impl ContextServerStore { let mut servers_to_remove = HashSet::default(); let mut servers_to_stop = HashSet::default(); - this.update(cx, |this, _cx| { + this.update(cx, |this, cx| { for server_id in this.servers.keys() { // All servers that are not in desired_servers should be removed from the store. // This can happen if the user removed a server from the context server settings. @@ -572,14 +603,10 @@ impl ContextServerStore { let existing_config = state.as_ref().map(|state| state.configuration()); if existing_config.as_deref() != Some(&config) || is_stopped { let config = Arc::new(config); - if let Some(server) = this - .create_context_server(id.clone(), config.clone()) - .log_err() - { - servers_to_start.push((server, config)); - if this.servers.contains_key(&id) { - servers_to_stop.insert(id); - } + let server = this.create_context_server(id.clone(), config.clone(), cx); + servers_to_start.push((server, config)); + if this.servers.contains_key(&id) { + servers_to_stop.insert(id); } } } @@ -610,7 +637,7 @@ mod tests { use context_server::test::create_fake_transport; use gpui::{AppContext, TestAppContext, UpdateGlobal as _}; use serde_json::json; - use std::{cell::RefCell, rc::Rc}; + use std::{cell::RefCell, path::PathBuf, rc::Rc}; use util::path; #[gpui::test] @@ -630,7 +657,12 @@ mod tests { let registry = cx.new(|_| ContextServerDescriptorRegistry::new()); let store = cx.new(|cx| { - ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx) + ContextServerStore::test( + registry.clone(), + project.read(cx).worktree_store(), + project.downgrade(), + cx, + ) }); let server_1_id = ContextServerId(SERVER_1_ID.into()); @@ -705,7 +737,12 @@ mod tests { let registry = cx.new(|_| ContextServerDescriptorRegistry::new()); let store = cx.new(|cx| { - ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx) + ContextServerStore::test( + registry.clone(), + project.read(cx).worktree_store(), + project.downgrade(), + cx, + ) }); let server_1_id = ContextServerId(SERVER_1_ID.into()); @@ -758,7 +795,12 @@ mod tests { let registry = cx.new(|_| ContextServerDescriptorRegistry::new()); let store = cx.new(|cx| { - ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx) + ContextServerStore::test( + registry.clone(), + project.read(cx).worktree_store(), + project.downgrade(), + cx, + ) }); let server_id = ContextServerId(SERVER_1_ID.into()); @@ -842,6 +884,7 @@ mod tests { }), registry.clone(), project.read(cx).worktree_store(), + project.downgrade(), cx, ) }); @@ -931,7 +974,7 @@ mod tests { ContextServerSettings::Custom { enabled: true, command: ContextServerCommand { - path: "somebinary".to_string(), + path: "somebinary".into(), args: vec!["arg".to_string()], env: None, }, @@ -971,7 +1014,7 @@ mod tests { ContextServerSettings::Custom { enabled: true, command: ContextServerCommand { - path: "somebinary".to_string(), + path: "somebinary".into(), args: vec!["anotherArg".to_string()], env: None, }, @@ -1053,7 +1096,7 @@ mod tests { ContextServerSettings::Custom { enabled: true, command: ContextServerCommand { - path: "somebinary".to_string(), + path: "somebinary".into(), args: vec!["arg".to_string()], env: None, }, @@ -1074,6 +1117,7 @@ mod tests { }), registry.clone(), project.read(cx).worktree_store(), + project.downgrade(), cx, ) }); @@ -1104,7 +1148,7 @@ mod tests { ContextServerSettings::Custom { enabled: false, command: ContextServerCommand { - path: "somebinary".to_string(), + path: "somebinary".into(), args: vec!["arg".to_string()], env: None, }, @@ -1132,7 +1176,7 @@ mod tests { ContextServerSettings::Custom { enabled: true, command: ContextServerCommand { - path: "somebinary".to_string(), + path: "somebinary".into(), args: vec!["arg".to_string()], env: None, }, @@ -1184,7 +1228,7 @@ mod tests { ContextServerSettings::Custom { enabled: true, command: ContextServerCommand { - path: "somebinary".to_string(), + path: "somebinary".into(), args: vec!["arg".to_string()], env: None, }, @@ -1256,11 +1300,11 @@ mod tests { } struct FakeContextServerDescriptor { - path: String, + path: PathBuf, } impl FakeContextServerDescriptor { - fn new(path: impl Into<String>) -> Self { + fn new(path: impl Into<PathBuf>) -> Self { Self { path: path.into() } } } diff --git a/crates/project/src/context_server_store/extension.rs b/crates/project/src/context_server_store/extension.rs index 1eaecd987dd51158fc2f505c1ae9b0c8fcc076a3..1eb0fe7da129ba9dbd3ee640cb6e02474a3990b6 100644 --- a/crates/project/src/context_server_store/extension.rs +++ b/crates/project/src/context_server_store/extension.rs @@ -61,10 +61,7 @@ impl registry::ContextServerDescriptor for ContextServerDescriptor { let mut command = extension .context_server_command(id.clone(), extension_project.clone()) .await?; - command.command = extension - .path_from_extension(command.command.as_ref()) - .to_string_lossy() - .to_string(); + command.command = extension.path_from_extension(&command.command); log::info!("loaded command for context server {id}: {command:?}"); diff --git a/crates/project/src/debugger.rs b/crates/project/src/debugger.rs index d078988a51bb4f7e6ab824913e2d6584e1bd0d2e..6c22468040097768688d93cde0720320a9e45be9 100644 --- a/crates/project/src/debugger.rs +++ b/crates/project/src/debugger.rs @@ -15,7 +15,9 @@ pub mod breakpoint_store; pub mod dap_command; pub mod dap_store; pub mod locators; +mod memory; pub mod session; #[cfg(any(feature = "test-support", test))] pub mod test; +pub use memory::MemoryCell; diff --git a/crates/project/src/debugger/dap_command.rs b/crates/project/src/debugger/dap_command.rs index 411bacd3ba1557b7392eb0e981df1a4297772b31..3be3192369452b58fd2382471ca2f41f4aeac75f 100644 --- a/crates/project/src/debugger/dap_command.rs +++ b/crates/project/src/debugger/dap_command.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use anyhow::{Context as _, Ok, Result}; +use base64::Engine; use dap::{ Capabilities, ContinueArguments, ExceptionFilterOptions, InitializeRequestArguments, InitializeRequestArgumentsPathFormat, NextArguments, SetVariableResponse, SourceBreakpoint, @@ -10,6 +11,7 @@ use dap::{ proto_conversions::ProtoConversion, requests::{Continue, Next}, }; + use rpc::proto; use serde_json::Value; use util::ResultExt; @@ -105,7 +107,7 @@ impl<T: DapCommand> DapCommand for Arc<T> { #[derive(Debug, Hash, PartialEq, Eq)] pub struct StepCommand { - pub thread_id: u64, + pub thread_id: i64, pub granularity: Option<SteppingGranularity>, pub single_thread: Option<bool>, } @@ -481,7 +483,7 @@ impl DapCommand for ContinueCommand { #[derive(Debug, Hash, PartialEq, Eq)] pub(crate) struct PauseCommand { - pub thread_id: u64, + pub thread_id: i64, } impl LocalDapCommand for PauseCommand { @@ -610,7 +612,7 @@ impl DapCommand for DisconnectCommand { #[derive(Debug, Hash, PartialEq, Eq)] pub(crate) struct TerminateThreadsCommand { - pub thread_ids: Option<Vec<u64>>, + pub thread_ids: Option<Vec<i64>>, } impl LocalDapCommand for TerminateThreadsCommand { @@ -812,7 +814,7 @@ impl DapCommand for RestartCommand { } } -#[derive(Debug, Hash, PartialEq, Eq)] +#[derive(Clone, Debug, Hash, PartialEq, Eq)] pub struct VariablesCommand { pub variables_reference: u64, pub filter: Option<VariablesArgumentsFilter>, @@ -1180,7 +1182,7 @@ impl DapCommand for LoadedSourcesCommand { #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub(crate) struct StackTraceCommand { - pub thread_id: u64, + pub thread_id: i64, pub start_frame: Option<u64>, pub levels: Option<u64>, } @@ -1666,6 +1668,130 @@ impl LocalDapCommand for SetBreakpoints { Ok(message.breakpoints) } } + +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub enum DataBreakpointContext { + Variable { + variables_reference: u64, + name: String, + bytes: Option<u64>, + }, + Expression { + expression: String, + frame_id: Option<u64>, + }, + Address { + address: String, + bytes: Option<u64>, + }, +} + +impl DataBreakpointContext { + pub fn human_readable_label(&self) -> String { + match self { + DataBreakpointContext::Variable { name, .. } => format!("Variable: {}", name), + DataBreakpointContext::Expression { expression, .. } => { + format!("Expression: {}", expression) + } + DataBreakpointContext::Address { address, bytes } => { + let mut label = format!("Address: {}", address); + if let Some(bytes) = bytes { + label.push_str(&format!( + " ({} byte{})", + bytes, + if *bytes == 1 { "" } else { "s" } + )); + } + label + } + } + } +} + +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub(crate) struct DataBreakpointInfoCommand { + pub context: Arc<DataBreakpointContext>, + pub mode: Option<String>, +} + +impl LocalDapCommand for DataBreakpointInfoCommand { + type Response = dap::DataBreakpointInfoResponse; + type DapRequest = dap::requests::DataBreakpointInfo; + const CACHEABLE: bool = true; + + // todo(debugger): We should expand this trait in the future to take a &self + // Depending on this command is_supported could be differentb + fn is_supported(capabilities: &Capabilities) -> bool { + capabilities.supports_data_breakpoints.unwrap_or(false) + } + + fn to_dap(&self) -> <Self::DapRequest as dap::requests::Request>::Arguments { + let (variables_reference, name, frame_id, as_address, bytes) = match &*self.context { + DataBreakpointContext::Variable { + variables_reference, + name, + bytes, + } => ( + Some(*variables_reference), + name.clone(), + None, + Some(false), + *bytes, + ), + DataBreakpointContext::Expression { + expression, + frame_id, + } => (None, expression.clone(), *frame_id, Some(false), None), + DataBreakpointContext::Address { address, bytes } => { + (None, address.clone(), None, Some(true), *bytes) + } + }; + + dap::DataBreakpointInfoArguments { + variables_reference, + name, + frame_id, + bytes, + as_address, + mode: self.mode.clone(), + } + } + + fn response_from_dap( + &self, + message: <Self::DapRequest as dap::requests::Request>::Response, + ) -> Result<Self::Response> { + Ok(message) + } +} + +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub(crate) struct SetDataBreakpointsCommand { + pub breakpoints: Vec<dap::DataBreakpoint>, +} + +impl LocalDapCommand for SetDataBreakpointsCommand { + type Response = Vec<dap::Breakpoint>; + type DapRequest = dap::requests::SetDataBreakpoints; + + fn is_supported(capabilities: &Capabilities) -> bool { + capabilities.supports_data_breakpoints.unwrap_or(false) + } + + fn to_dap(&self) -> <Self::DapRequest as dap::requests::Request>::Arguments { + dap::SetDataBreakpointsArguments { + breakpoints: self.breakpoints.clone(), + } + } + + fn response_from_dap( + &self, + message: <Self::DapRequest as dap::requests::Request>::Response, + ) -> Result<Self::Response> { + Ok(message.breakpoints) + } +} + #[derive(Clone, Debug, Hash, PartialEq)] pub(super) enum SetExceptionBreakpoints { Plain { @@ -1774,3 +1900,76 @@ impl DapCommand for LocationsCommand { }) } } + +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub(crate) struct ReadMemory { + pub(crate) memory_reference: String, + pub(crate) offset: Option<u64>, + pub(crate) count: u64, +} + +#[derive(Clone, Debug, PartialEq)] +pub(crate) struct ReadMemoryResponse { + pub(super) address: Arc<str>, + pub(super) unreadable_bytes: Option<u64>, + pub(super) content: Arc<[u8]>, +} + +impl LocalDapCommand for ReadMemory { + type Response = ReadMemoryResponse; + type DapRequest = dap::requests::ReadMemory; + const CACHEABLE: bool = true; + + fn is_supported(capabilities: &Capabilities) -> bool { + capabilities + .supports_read_memory_request + .unwrap_or_default() + } + fn to_dap(&self) -> <Self::DapRequest as dap::requests::Request>::Arguments { + dap::ReadMemoryArguments { + memory_reference: self.memory_reference.clone(), + offset: self.offset, + count: self.count, + } + } + + fn response_from_dap( + &self, + message: <Self::DapRequest as dap::requests::Request>::Response, + ) -> Result<Self::Response> { + let data = if let Some(data) = message.data { + base64::engine::general_purpose::STANDARD + .decode(data) + .log_err() + .context("parsing base64 data from DAP's ReadMemory response")? + } else { + vec![] + }; + + Ok(ReadMemoryResponse { + address: message.address.into(), + content: data.into(), + unreadable_bytes: message.unreadable_bytes, + }) + } +} + +impl LocalDapCommand for dap::WriteMemoryArguments { + type Response = dap::WriteMemoryResponse; + type DapRequest = dap::requests::WriteMemory; + fn is_supported(capabilities: &Capabilities) -> bool { + capabilities + .supports_write_memory_request + .unwrap_or_default() + } + fn to_dap(&self) -> <Self::DapRequest as dap::requests::Request>::Arguments { + self.clone() + } + + fn response_from_dap( + &self, + message: <Self::DapRequest as dap::requests::Request>::Response, + ) -> Result<Self::Response> { + Ok(message) + } +} diff --git a/crates/project/src/debugger/dap_store.rs b/crates/project/src/debugger/dap_store.rs index 19e64adb2d6e412e25f49170109af1596273ea34..6f834b5dc0cfd3fc6357d92403bdb7cbfefdd4b0 100644 --- a/crates/project/src/debugger/dap_store.rs +++ b/crates/project/src/debugger/dap_store.rs @@ -6,6 +6,7 @@ use super::{ }; use crate::{ InlayHint, InlayHintLabel, ProjectEnvironment, ResolveState, + debugger::session::SessionQuirks, project_settings::ProjectSettings, terminals::{SshCommand, wrap_for_ssh}, worktree_store::WorktreeStore, @@ -385,10 +386,11 @@ impl DapStore { pub fn new_session( &mut self, - label: SharedString, + label: Option<SharedString>, adapter: DebugAdapterName, task_context: TaskContext, parent_session: Option<Entity<Session>>, + quirks: SessionQuirks, cx: &mut Context<Self>, ) -> Entity<Session> { let session_id = SessionId(util::post_inc(&mut self.next_session_id)); @@ -406,6 +408,7 @@ impl DapStore { label, adapter, task_context, + quirks, cx, ); @@ -560,6 +563,11 @@ impl DapStore { fn format_value(mut value: String) -> String { const LIMIT: usize = 100; + if let Some(index) = value.find("\n") { + value.truncate(index); + value.push_str("…"); + } + if value.len() > LIMIT { let mut index = LIMIT; // If index isn't a char boundary truncate will cause a panic @@ -567,7 +575,7 @@ impl DapStore { index -= 1; } value.truncate(index); - value.push_str("..."); + value.push_str("…"); } format!(": {}", value) @@ -912,12 +920,22 @@ impl dap::adapters::DapDelegate for DapAdapterDelegate { self.console.unbounded_send(msg).ok(); } + #[cfg(not(target_os = "windows"))] async fn which(&self, command: &OsStr) -> Option<PathBuf> { let worktree_abs_path = self.worktree.abs_path(); let shell_path = self.shell_env().await.get("PATH").cloned(); which::which_in(command, shell_path.as_ref(), worktree_abs_path).ok() } + #[cfg(target_os = "windows")] + async fn which(&self, command: &OsStr) -> Option<PathBuf> { + // On Windows, `PATH` is handled differently from Unix. Windows generally expects users to modify the `PATH` themselves, + // and every program loads it directly from the system at startup. + // There's also no concept of a default shell on Windows, and you can't really retrieve one, so trying to get shell environment variables + // from a specific directory doesn’t make sense on Windows. + which::which(command).ok() + } + async fn shell_env(&self) -> HashMap<String, String> { let task = self.load_shell_env_task.clone(); task.await.unwrap_or_default() diff --git a/crates/project/src/debugger/locators/cargo.rs b/crates/project/src/debugger/locators/cargo.rs index 7d70371380192c99e1ace9676b02088f86ed9e5f..fa265dae586148f9c8efe14187ee26c805c65e42 100644 --- a/crates/project/src/debugger/locators/cargo.rs +++ b/crates/project/src/debugger/locators/cargo.rs @@ -128,7 +128,7 @@ impl DapLocator for CargoLocator { .chain(Some("--message-format=json".to_owned())) .collect(), ); - let mut child = Command::new(program) + let mut child = util::command::new_smol_command(program) .args(args) .envs(build_config.env.iter().map(|(k, v)| (k.clone(), v.clone()))) .current_dir(cwd) diff --git a/crates/project/src/debugger/memory.rs b/crates/project/src/debugger/memory.rs new file mode 100644 index 0000000000000000000000000000000000000000..fec3c344c5a433eebb3a1f314a8fd911bd603022 --- /dev/null +++ b/crates/project/src/debugger/memory.rs @@ -0,0 +1,384 @@ +//! This module defines the format in which memory of debuggee is represented. +//! +//! Each byte in memory can either be mapped or unmapped. We try to mimic that twofold: +//! - We assume that the memory is divided into pages of a fixed size. +//! - We assume that each page can be either mapped or unmapped. +//! These two assumptions drive the shape of the memory representation. +//! In particular, we want the unmapped pages to be represented without allocating any memory, as *most* +//! of the memory in a program space is usually unmapped. +//! Note that per DAP we don't know what the address space layout is, so we can't optimize off of it. +//! Note that while we optimize for a paged layout, we also want to be able to represent memory that is not paged. +//! This use case is relevant to embedded folks. Furthermore, we cater to default 4k page size. +//! It is picked arbitrarily as a ubiquous default - other than that, the underlying format of Zed's memory storage should not be relevant +//! to the users of this module. + +use std::{collections::BTreeMap, ops::RangeInclusive, sync::Arc}; + +use gpui::BackgroundExecutor; +use smallvec::SmallVec; + +const PAGE_SIZE: u64 = 4096; + +/// Represents the contents of a single page. We special-case unmapped pages to be allocation-free, +/// since they're going to make up the majority of the memory in a program space (even though the user might not even get to see them - ever). +#[derive(Clone, Debug)] +pub(super) enum PageContents { + /// Whole page is unreadable. + Unmapped, + Mapped(Arc<MappedPageContents>), +} + +impl PageContents { + #[cfg(test)] + fn mapped(contents: Vec<u8>) -> Self { + PageContents::Mapped(Arc::new(MappedPageContents( + vec![PageChunk::Mapped(contents.into())].into(), + ))) + } +} + +#[derive(Clone, Debug)] +enum PageChunk { + Mapped(Arc<[u8]>), + Unmapped(u64), +} + +impl PageChunk { + fn len(&self) -> u64 { + match self { + PageChunk::Mapped(contents) => contents.len() as u64, + PageChunk::Unmapped(size) => *size, + } + } +} + +impl MappedPageContents { + fn len(&self) -> u64 { + self.0.iter().map(|chunk| chunk.len()).sum() + } +} +/// We hope for the whole page to be mapped in a single chunk, but we do leave the possibility open +/// of having interleaved read permissions in a single page; debuggee's execution environment might either +/// have a different page size OR it might not have paged memory layout altogether +/// (which might be relevant to embedded systems). +/// +/// As stated previously, the concept of a page in this module has to do more +/// with optimizing fetching of the memory and not with the underlying bits and pieces +/// of the memory of a debuggee. + +#[derive(Default, Debug)] +pub(super) struct MappedPageContents( + /// Most of the time there should be only one chunk (either mapped or unmapped), + /// but we do leave the possibility open of having multiple regions of memory in a single page. + SmallVec<[PageChunk; 1]>, +); + +type MemoryAddress = u64; +#[derive(Clone, Copy, Debug, PartialEq, PartialOrd, Ord, Eq)] +#[repr(transparent)] +pub(super) struct PageAddress(u64); + +impl PageAddress { + pub(super) fn iter_range( + range: RangeInclusive<PageAddress>, + ) -> impl Iterator<Item = PageAddress> { + let mut current = range.start().0; + let end = range.end().0; + + std::iter::from_fn(move || { + if current > end { + None + } else { + let addr = PageAddress(current); + current += PAGE_SIZE; + Some(addr) + } + }) + } +} + +pub(super) struct Memory { + pages: BTreeMap<PageAddress, PageContents>, +} + +/// Represents a single memory cell (or None if a given cell is unmapped/unknown). +#[derive(Copy, Clone, Debug, PartialEq, PartialOrd, Ord, Eq)] +#[repr(transparent)] +pub struct MemoryCell(pub Option<u8>); + +impl Memory { + pub(super) fn new() -> Self { + Self { + pages: Default::default(), + } + } + + pub(super) fn memory_range_to_page_range( + range: RangeInclusive<MemoryAddress>, + ) -> RangeInclusive<PageAddress> { + let start_page = (range.start() / PAGE_SIZE) * PAGE_SIZE; + let end_page = (range.end() / PAGE_SIZE) * PAGE_SIZE; + PageAddress(start_page)..=PageAddress(end_page) + } + + pub(super) fn build_page(&self, page_address: PageAddress) -> Option<MemoryPageBuilder> { + if self.pages.contains_key(&page_address) { + // We already know the state of this page. + None + } else { + Some(MemoryPageBuilder::new(page_address)) + } + } + + pub(super) fn insert_page(&mut self, address: PageAddress, page: PageContents) { + self.pages.insert(address, page); + } + + pub(super) fn memory_range(&self, range: RangeInclusive<MemoryAddress>) -> MemoryIterator { + let pages = Self::memory_range_to_page_range(range.clone()); + let pages = self + .pages + .range(pages) + .map(|(address, page)| (*address, page.clone())) + .collect::<Vec<_>>(); + MemoryIterator::new(range, pages.into_iter()) + } + + pub(crate) fn clear(&mut self, background_executor: &BackgroundExecutor) { + let memory = std::mem::take(&mut self.pages); + background_executor + .spawn(async move { + drop(memory); + }) + .detach(); + } +} + +/// Builder for memory pages. +/// +/// Memory reads in DAP are sequential (or at least we make them so). +/// ReadMemory response includes `unreadableBytes` property indicating the number of bytes +/// that could not be read after the last successfully read byte. +/// +/// We use it as follows: +/// - We start off with a "large" 1-page ReadMemory request. +/// - If it succeeds/fails wholesale, cool; we have no unknown memory regions in this page. +/// - If it succeeds partially, we know # of mapped bytes. +/// We might also know the # of unmapped bytes. +/// However, we're still unsure about what's *after* the unreadable region. +/// +/// This is where this builder comes in. It lets us track the state of figuring out contents of a single page. +pub(super) struct MemoryPageBuilder { + chunks: MappedPageContents, + base_address: PageAddress, + left_to_read: u64, +} + +/// Represents a chunk of memory of which we don't know if it's mapped or unmapped; thus we need +/// to issue a request to figure out it's state. +pub(super) struct UnknownMemory { + pub(super) address: MemoryAddress, + pub(super) size: u64, +} + +impl MemoryPageBuilder { + fn new(base_address: PageAddress) -> Self { + Self { + chunks: Default::default(), + base_address, + left_to_read: PAGE_SIZE, + } + } + + pub(super) fn build(self) -> (PageAddress, PageContents) { + debug_assert_eq!(self.left_to_read, 0); + debug_assert_eq!( + self.chunks.len(), + PAGE_SIZE, + "Expected `build` to be called on a fully-fetched page" + ); + let contents = if let Some(first) = self.chunks.0.first() + && self.chunks.len() == 1 + && matches!(first, PageChunk::Unmapped(PAGE_SIZE)) + { + PageContents::Unmapped + } else { + PageContents::Mapped(Arc::new(MappedPageContents(self.chunks.0))) + }; + (self.base_address, contents) + } + /// Drives the fetching of memory, in an iterator-esque style. + pub(super) fn next_request(&self) -> Option<UnknownMemory> { + if self.left_to_read == 0 { + None + } else { + let offset_in_current_page = PAGE_SIZE - self.left_to_read; + Some(UnknownMemory { + address: self.base_address.0 + offset_in_current_page, + size: self.left_to_read, + }) + } + } + pub(super) fn unknown(&mut self, bytes: u64) { + if bytes == 0 { + return; + } + self.left_to_read -= bytes; + self.chunks.0.push(PageChunk::Unmapped(bytes)); + } + pub(super) fn known(&mut self, data: Arc<[u8]>) { + if data.is_empty() { + return; + } + self.left_to_read -= data.len() as u64; + self.chunks.0.push(PageChunk::Mapped(data)); + } +} + +fn page_contents_into_iter(data: Arc<MappedPageContents>) -> Box<dyn Iterator<Item = MemoryCell>> { + let mut data_range = 0..data.0.len(); + let iter = std::iter::from_fn(move || { + let data = &data; + let data_ref = data.clone(); + data_range.next().map(move |index| { + let contents = &data_ref.0[index]; + match contents { + PageChunk::Mapped(items) => { + let chunk_range = 0..items.len(); + let items = items.clone(); + Box::new( + chunk_range + .into_iter() + .map(move |ix| MemoryCell(Some(items[ix]))), + ) as Box<dyn Iterator<Item = MemoryCell>> + } + PageChunk::Unmapped(len) => { + Box::new(std::iter::repeat_n(MemoryCell(None), *len as usize)) + } + } + }) + }) + .flatten(); + + Box::new(iter) +} +/// Defines an iteration over a range of memory. Some of this memory might be unmapped or straight up missing. +/// Thus, this iterator alternates between synthesizing values and yielding known memory. +pub struct MemoryIterator { + start: MemoryAddress, + end: MemoryAddress, + current_known_page: Option<(PageAddress, Box<dyn Iterator<Item = MemoryCell>>)>, + pages: std::vec::IntoIter<(PageAddress, PageContents)>, +} + +impl MemoryIterator { + fn new( + range: RangeInclusive<MemoryAddress>, + pages: std::vec::IntoIter<(PageAddress, PageContents)>, + ) -> Self { + Self { + start: *range.start(), + end: *range.end(), + current_known_page: None, + pages, + } + } + fn fetch_next_page(&mut self) -> bool { + if let Some((mut address, chunk)) = self.pages.next() { + let mut contents = match chunk { + PageContents::Unmapped => None, + PageContents::Mapped(mapped_page_contents) => { + Some(page_contents_into_iter(mapped_page_contents)) + } + }; + + if address.0 < self.start { + // Skip ahead till our iterator is at the start of the range + + //address: 20, start: 25 + // + let to_skip = self.start - address.0; + address.0 += to_skip; + if let Some(contents) = &mut contents { + contents.nth(to_skip as usize - 1); + } + } + self.current_known_page = contents.map(|contents| (address, contents)); + true + } else { + false + } + } +} +impl Iterator for MemoryIterator { + type Item = MemoryCell; + + fn next(&mut self) -> Option<Self::Item> { + if self.start > self.end { + return None; + } + if let Some((current_page_address, current_memory_chunk)) = self.current_known_page.as_mut() + { + if current_page_address.0 <= self.start { + if let Some(next_cell) = current_memory_chunk.next() { + self.start += 1; + return Some(next_cell); + } else { + self.current_known_page.take(); + } + } + } + if !self.fetch_next_page() { + self.start += 1; + return Some(MemoryCell(None)); + } else { + self.next() + } + } +} + +#[cfg(test)] +mod tests { + use crate::debugger::{ + MemoryCell, + memory::{MemoryIterator, PageAddress, PageContents}, + }; + + #[test] + fn iterate_over_unmapped_memory() { + let empty_iterator = MemoryIterator::new(0..=127, Default::default()); + let actual = empty_iterator.collect::<Vec<_>>(); + let expected = vec![MemoryCell(None); 128]; + assert_eq!(actual.len(), expected.len()); + assert_eq!(actual, expected); + } + + #[test] + fn iterate_over_partially_mapped_memory() { + let it = MemoryIterator::new( + 0..=127, + vec![(PageAddress(5), PageContents::mapped(vec![1]))].into_iter(), + ); + let actual = it.collect::<Vec<_>>(); + let expected = std::iter::repeat_n(MemoryCell(None), 5) + .chain(std::iter::once(MemoryCell(Some(1)))) + .chain(std::iter::repeat_n(MemoryCell(None), 122)) + .collect::<Vec<_>>(); + assert_eq!(actual.len(), expected.len()); + assert_eq!(actual, expected); + } + + #[test] + fn reads_from_the_middle_of_a_page() { + let partial_iter = MemoryIterator::new( + 20..=30, + vec![(PageAddress(0), PageContents::mapped((0..255).collect()))].into_iter(), + ); + let actual = partial_iter.collect::<Vec<_>>(); + let expected = (20..=30) + .map(|val| MemoryCell(Some(val))) + .collect::<Vec<_>>(); + assert_eq!(actual.len(), expected.len()); + assert_eq!(actual, expected); + } +} diff --git a/crates/project/src/debugger/session.rs b/crates/project/src/debugger/session.rs index 59c35da4cac4328dc109b8463ef02868b4885d63..d9c28df497b3baa4543e6271106ddb1cd11b4419 100644 --- a/crates/project/src/debugger/session.rs +++ b/crates/project/src/debugger/session.rs @@ -1,18 +1,21 @@ use crate::debugger::breakpoint_store::BreakpointSessionState; +use crate::debugger::dap_command::{DataBreakpointContext, ReadMemory}; +use crate::debugger::memory::{self, Memory, MemoryIterator, MemoryPageBuilder, PageAddress}; use super::breakpoint_store::{ BreakpointStore, BreakpointStoreEvent, BreakpointUpdatedReason, SourceBreakpoint, }; use super::dap_command::{ - self, Attach, ConfigurationDone, ContinueCommand, DisconnectCommand, EvaluateCommand, - Initialize, Launch, LoadedSourcesCommand, LocalDapCommand, LocationsCommand, ModulesCommand, - NextCommand, PauseCommand, RestartCommand, RestartStackFrameCommand, ScopesCommand, - SetExceptionBreakpoints, SetVariableValueCommand, StackTraceCommand, StepBackCommand, - StepCommand, StepInCommand, StepOutCommand, TerminateCommand, TerminateThreadsCommand, - ThreadsCommand, VariablesCommand, + self, Attach, ConfigurationDone, ContinueCommand, DataBreakpointInfoCommand, DisconnectCommand, + EvaluateCommand, Initialize, Launch, LoadedSourcesCommand, LocalDapCommand, LocationsCommand, + ModulesCommand, NextCommand, PauseCommand, RestartCommand, RestartStackFrameCommand, + ScopesCommand, SetDataBreakpointsCommand, SetExceptionBreakpoints, SetVariableValueCommand, + StackTraceCommand, StepBackCommand, StepCommand, StepInCommand, StepOutCommand, + TerminateCommand, TerminateThreadsCommand, ThreadsCommand, VariablesCommand, }; use super::dap_store::DapStore; use anyhow::{Context as _, Result, anyhow}; +use base64::Engine; use collections::{HashMap, HashSet, IndexMap}; use dap::adapters::{DebugAdapterBinary, DebugAdapterName}; use dap::messages::Response; @@ -26,7 +29,7 @@ use dap::{ use dap::{ ExceptionBreakpointsFilter, ExceptionFilterOptions, OutputEvent, OutputEventCategory, RunInTerminalRequestArguments, StackFramePresentationHint, StartDebuggingRequestArguments, - StartDebuggingRequestArgumentsRequest, VariablePresentationHint, + StartDebuggingRequestArgumentsRequest, VariablePresentationHint, WriteMemoryArguments, }; use futures::SinkExt; use futures::channel::mpsc::UnboundedSender; @@ -42,6 +45,7 @@ use serde_json::Value; use smol::stream::StreamExt; use std::any::TypeId; use std::collections::BTreeMap; +use std::ops::RangeInclusive; use std::u64; use std::{ any::Any, @@ -52,20 +56,15 @@ use std::{ }; use task::TaskContext; use text::{PointUtf16, ToPointUtf16}; -use util::ResultExt; +use util::{ResultExt, debug_panic, maybe}; use worktree::Worktree; #[derive(Debug, Copy, Clone, Hash, PartialEq, PartialOrd, Ord, Eq)] #[repr(transparent)] -pub struct ThreadId(pub u64); +pub struct ThreadId(pub i64); -impl ThreadId { - pub const MIN: ThreadId = ThreadId(u64::MIN); - pub const MAX: ThreadId = ThreadId(u64::MAX); -} - -impl From<u64> for ThreadId { - fn from(id: u64) -> Self { +impl From<i64> for ThreadId { + fn from(id: i64) -> Self { Self(id) } } @@ -134,8 +133,18 @@ pub struct Watcher { pub presentation_hint: Option<VariablePresentationHint>, } -pub enum Mode { - Building, +#[derive(Debug, Clone, PartialEq)] +pub struct DataBreakpointState { + pub dap: dap::DataBreakpoint, + pub is_enabled: bool, + pub context: Arc<DataBreakpointContext>, +} + +pub enum SessionState { + /// Represents a session that is building/initializing + /// even if a session doesn't have a pre build task this state + /// is used to run all the async tasks that are required to start the session + Booting(Option<Task<Result<()>>>), Running(RunningMode), } @@ -151,6 +160,12 @@ pub struct RunningMode { messages_tx: UnboundedSender<Message>, } +#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)] +pub struct SessionQuirks { + pub compact: bool, + pub prefer_thread_name: bool, +} + fn client_source(abs_path: &Path) -> dap::Source { dap::Source { name: abs_path @@ -554,15 +569,15 @@ impl RunningMode { } } -impl Mode { +impl SessionState { pub(super) fn request_dap<R: LocalDapCommand>(&self, request: R) -> Task<Result<R::Response>> where <R::DapRequest as dap::requests::Request>::Response: 'static, <R::DapRequest as dap::requests::Request>::Arguments: 'static + Send, { match self { - Mode::Running(debug_adapter_client) => debug_adapter_client.request(request), - Mode::Building => Task::ready(Err(anyhow!( + SessionState::Running(debug_adapter_client) => debug_adapter_client.request(request), + SessionState::Booting(_) => Task::ready(Err(anyhow!( "no adapter running to send request: {request:?}" ))), } @@ -571,13 +586,13 @@ impl Mode { /// Did this debug session stop at least once? pub(crate) fn has_ever_stopped(&self) -> bool { match self { - Mode::Building => false, - Mode::Running(running_mode) => running_mode.has_ever_stopped, + SessionState::Booting(_) => false, + SessionState::Running(running_mode) => running_mode.has_ever_stopped, } } fn stopped(&mut self) { - if let Mode::Running(running) = self { + if let SessionState::Running(running) = self { running.has_ever_stopped = true; } } @@ -654,9 +669,9 @@ type IsEnabled = bool; pub struct OutputToken(pub usize); /// Represents a current state of a single debug adapter and provides ways to mutate it. pub struct Session { - pub mode: Mode, + pub mode: SessionState, id: SessionId, - label: SharedString, + label: Option<SharedString>, adapter: DebugAdapterName, pub(super) capabilities: Capabilities, child_session_ids: HashSet<SessionId>, @@ -676,9 +691,12 @@ pub struct Session { pub(crate) breakpoint_store: Entity<BreakpointStore>, ignore_breakpoints: bool, exception_breakpoints: BTreeMap<String, (ExceptionBreakpointsFilter, IsEnabled)>, + data_breakpoints: BTreeMap<String, DataBreakpointState>, background_tasks: Vec<Task<()>>, restart_task: Option<Task<()>>, task_context: TaskContext, + memory: memory::Memory, + quirks: SessionQuirks, } trait CacheableCommand: Any + Send + Sync { @@ -768,6 +786,7 @@ pub enum SessionEvent { request: RunInTerminalRequestArguments, sender: mpsc::Sender<Result<u32>>, }, + DataBreakpointInfo, ConsoleOutput, } @@ -792,9 +811,10 @@ impl Session { breakpoint_store: Entity<BreakpointStore>, session_id: SessionId, parent_session: Option<Entity<Session>>, - label: SharedString, + label: Option<SharedString>, adapter: DebugAdapterName, task_context: TaskContext, + quirks: SessionQuirks, cx: &mut App, ) -> Entity<Self> { cx.new::<Self>(|cx| { @@ -820,10 +840,9 @@ impl Session { BreakpointStoreEvent::SetDebugLine | BreakpointStoreEvent::ClearDebugLines => {} }) .detach(); - // cx.on_app_quit(Self::on_app_quit).detach(); let this = Self { - mode: Mode::Building, + mode: SessionState::Booting(None), id: session_id, child_session_ids: HashSet::default(), parent_session, @@ -844,10 +863,13 @@ impl Session { is_session_terminated: false, ignore_breakpoints: false, breakpoint_store, + data_breakpoints: Default::default(), exception_breakpoints: Default::default(), label, adapter, task_context, + memory: memory::Memory::new(), + quirks, }; this @@ -860,8 +882,8 @@ impl Session { pub fn worktree(&self) -> Option<Entity<Worktree>> { match &self.mode { - Mode::Building => None, - Mode::Running(local_mode) => local_mode.worktree.upgrade(), + SessionState::Booting(_) => None, + SessionState::Running(local_mode) => local_mode.worktree.upgrade(), } } @@ -920,7 +942,16 @@ impl Session { ) .await?; this.update(cx, |this, cx| { - this.mode = Mode::Running(mode); + match &mut this.mode { + SessionState::Booting(task) if task.is_some() => { + task.take().unwrap().detach_and_log_err(cx); + } + SessionState::Booting(_) => {} + SessionState::Running(_) => { + debug_panic!("Attempting to boot a session that is already running"); + } + }; + this.mode = SessionState::Running(mode); cx.emit(SessionStateEvent::Running); })?; @@ -1013,8 +1044,8 @@ impl Session { pub fn binary(&self) -> Option<&DebugAdapterBinary> { match &self.mode { - Mode::Building => None, - Mode::Running(running_mode) => Some(&running_mode.binary), + SessionState::Booting(_) => None, + SessionState::Running(running_mode) => Some(&running_mode.binary), } } @@ -1022,7 +1053,7 @@ impl Session { self.adapter.clone() } - pub fn label(&self) -> SharedString { + pub fn label(&self) -> Option<SharedString> { self.label.clone() } @@ -1059,26 +1090,26 @@ impl Session { pub fn is_started(&self) -> bool { match &self.mode { - Mode::Building => false, - Mode::Running(running) => running.is_started, + SessionState::Booting(_) => false, + SessionState::Running(running) => running.is_started, } } pub fn is_building(&self) -> bool { - matches!(self.mode, Mode::Building) + matches!(self.mode, SessionState::Booting(_)) } pub fn as_running_mut(&mut self) -> Option<&mut RunningMode> { match &mut self.mode { - Mode::Running(local_mode) => Some(local_mode), - Mode::Building => None, + SessionState::Running(local_mode) => Some(local_mode), + SessionState::Booting(_) => None, } } pub fn as_running(&self) -> Option<&RunningMode> { match &self.mode { - Mode::Running(local_mode) => Some(local_mode), - Mode::Building => None, + SessionState::Running(local_mode) => Some(local_mode), + SessionState::Booting(_) => None, } } @@ -1220,7 +1251,7 @@ impl Session { let adapter_id = self.adapter().to_string(); let request = Initialize { adapter_id }; - let Mode::Running(running) = &self.mode else { + let SessionState::Running(running) = &self.mode else { return Task::ready(Err(anyhow!( "Cannot send initialize request, task still building" ))); @@ -1269,10 +1300,12 @@ impl Session { cx: &mut Context<Self>, ) -> Task<Result<()>> { match &self.mode { - Mode::Running(local_mode) => { + SessionState::Running(local_mode) => { local_mode.initialize_sequence(&self.capabilities, initialize_rx, dap_store, cx) } - Mode::Building => Task::ready(Err(anyhow!("cannot initialize, still building"))), + SessionState::Booting(_) => { + Task::ready(Err(anyhow!("cannot initialize, still building"))) + } } } @@ -1283,7 +1316,7 @@ impl Session { cx: &mut Context<Self>, ) { match &mut self.mode { - Mode::Running(local_mode) => { + SessionState::Running(local_mode) => { if !matches!( self.thread_states.thread_state(active_thread_id), Some(ThreadStatus::Stopped) @@ -1307,7 +1340,7 @@ impl Session { }) .detach(); } - Mode::Building => {} + SessionState::Booting(_) => {} } } @@ -1587,7 +1620,7 @@ impl Session { fn request_inner<T: LocalDapCommand + PartialEq + Eq + Hash>( capabilities: &Capabilities, - mode: &Mode, + mode: &SessionState, request: T, process_result: impl FnOnce( &mut Self, @@ -1643,6 +1676,12 @@ impl Session { self.invalidate_command_type::<ModulesCommand>(); self.invalidate_command_type::<LoadedSourcesCommand>(); self.invalidate_command_type::<ThreadsCommand>(); + self.invalidate_command_type::<DataBreakpointInfoCommand>(); + self.invalidate_command_type::<ReadMemory>(); + let executor = self.as_running().map(|running| running.executor.clone()); + if let Some(executor) = executor { + self.memory.clear(&executor); + } } fn invalidate_state(&mut self, key: &RequestSlot) { @@ -1715,6 +1754,137 @@ impl Session { &self.modules } + // CodeLLDB returns the size of a pointed-to-memory, which we can use to make the experience of go-to-memory better. + pub fn data_access_size( + &mut self, + frame_id: Option<u64>, + evaluate_name: &str, + cx: &mut Context<Self>, + ) -> Task<Option<u64>> { + let request = self.request( + EvaluateCommand { + expression: format!("?${{sizeof({evaluate_name})}}"), + frame_id, + + context: Some(EvaluateArgumentsContext::Repl), + source: None, + }, + |_, response, _| response.ok(), + cx, + ); + cx.background_spawn(async move { + let result = request.await?; + result.result.parse().ok() + }) + } + + pub fn memory_reference_of_expr( + &mut self, + frame_id: Option<u64>, + expression: String, + cx: &mut Context<Self>, + ) -> Task<Option<(String, Option<String>)>> { + let request = self.request( + EvaluateCommand { + expression, + frame_id, + + context: Some(EvaluateArgumentsContext::Repl), + source: None, + }, + |_, response, _| response.ok(), + cx, + ); + cx.background_spawn(async move { + let result = request.await?; + result + .memory_reference + .map(|reference| (reference, result.type_)) + }) + } + + pub fn write_memory(&mut self, address: u64, data: &[u8], cx: &mut Context<Self>) { + let data = base64::engine::general_purpose::STANDARD.encode(data); + self.request( + WriteMemoryArguments { + memory_reference: address.to_string(), + data, + allow_partial: None, + offset: None, + }, + |this, response, cx| { + this.memory.clear(cx.background_executor()); + this.invalidate_command_type::<ReadMemory>(); + this.invalidate_command_type::<VariablesCommand>(); + cx.emit(SessionEvent::Variables); + response.ok() + }, + cx, + ) + .detach(); + } + pub fn read_memory( + &mut self, + range: RangeInclusive<u64>, + cx: &mut Context<Self>, + ) -> MemoryIterator { + // This function is a bit more involved when it comes to fetching data. + // Since we attempt to read memory in pages, we need to account for some parts + // of memory being unreadable. Therefore, we start off by fetching a page per request. + // In case that fails, we try to re-fetch smaller regions until we have the full range. + let page_range = Memory::memory_range_to_page_range(range.clone()); + for page_address in PageAddress::iter_range(page_range) { + self.read_single_page_memory(page_address, cx); + } + self.memory.memory_range(range) + } + + fn read_single_page_memory(&mut self, page_start: PageAddress, cx: &mut Context<Self>) { + _ = maybe!({ + let builder = self.memory.build_page(page_start)?; + + self.memory_read_fetch_page_recursive(builder, cx); + Some(()) + }); + } + fn memory_read_fetch_page_recursive( + &mut self, + mut builder: MemoryPageBuilder, + cx: &mut Context<Self>, + ) { + let Some(next_request) = builder.next_request() else { + // We're done fetching. Let's grab the page and insert it into our memory store. + let (address, contents) = builder.build(); + self.memory.insert_page(address, contents); + + return; + }; + let size = next_request.size; + self.fetch( + ReadMemory { + memory_reference: format!("0x{:X}", next_request.address), + offset: Some(0), + count: next_request.size, + }, + move |this, memory, cx| { + if let Ok(memory) = memory { + builder.known(memory.content); + if let Some(unknown) = memory.unreadable_bytes { + builder.unknown(unknown); + } + // This is the recursive bit: if we're not yet done with + // the whole page, we'll kick off a new request with smaller range. + // Note that this function is recursive only conceptually; + // since it kicks off a new request with callback, we don't need to worry about stack overflow. + this.memory_read_fetch_page_recursive(builder, cx); + } else { + builder.unknown(size); + } + }, + cx, + ); + } + pub fn ignore_breakpoints(&self) -> bool { self.ignore_breakpoints } @@ -1745,6 +1915,10 @@ impl Session { } } + pub fn data_breakpoints(&self) -> impl Iterator<Item = &DataBreakpointState> { + self.data_breakpoints.values() + } + pub fn exception_breakpoints( &self, ) -> impl Iterator<Item = &(ExceptionBreakpointsFilter, IsEnabled)> { @@ -1778,6 +1952,45 @@ impl Session { } } + pub fn toggle_data_breakpoint(&mut self, id: &str, cx: &mut Context<'_, Session>) { + if let Some(state) = self.data_breakpoints.get_mut(id) { + state.is_enabled = !state.is_enabled; + self.send_exception_breakpoints(cx); + } + } + + fn send_data_breakpoints(&mut self, cx: &mut Context<Self>) { + if let Some(mode) = self.as_running() { + let breakpoints = self + .data_breakpoints + .values() + .filter_map(|state| state.is_enabled.then(|| state.dap.clone())) + .collect(); + let command = SetDataBreakpointsCommand { breakpoints }; + mode.request(command).detach_and_log_err(cx); + } + } + + pub fn create_data_breakpoint( + &mut self, + context: Arc<DataBreakpointContext>, + data_id: String, + dap: dap::DataBreakpoint, + cx: &mut Context<Self>, + ) { + if self.data_breakpoints.remove(&data_id).is_none() { + self.data_breakpoints.insert( + data_id, + DataBreakpointState { + dap, + is_enabled: true, + context, + }, + ); + } + self.send_data_breakpoints(cx); + } + pub fn breakpoints_enabled(&self) -> bool { self.ignore_breakpoints } @@ -1907,28 +2120,36 @@ impl Session { self.thread_states.exit_all_threads(); cx.notify(); - let task = if self - .capabilities - .supports_terminate_request - .unwrap_or_default() - { - self.request( - TerminateCommand { - restart: Some(false), - }, - Self::clear_active_debug_line_response, - cx, - ) - } else { - self.request( - DisconnectCommand { - restart: Some(false), - terminate_debuggee: Some(true), - suspend_debuggee: Some(false), - }, - Self::clear_active_debug_line_response, - cx, - ) + let task = match &mut self.mode { + SessionState::Running(_) => { + if self + .capabilities + .supports_terminate_request + .unwrap_or_default() + { + self.request( + TerminateCommand { + restart: Some(false), + }, + Self::clear_active_debug_line_response, + cx, + ) + } else { + self.request( + DisconnectCommand { + restart: Some(false), + terminate_debuggee: Some(true), + suspend_debuggee: Some(false), + }, + Self::clear_active_debug_line_response, + cx, + ) + } + } + SessionState::Booting(build_task) => { + build_task.take(); + Task::ready(Some(())) + } }; cx.emit(SessionStateEvent::Shutdown); @@ -1978,8 +2199,8 @@ impl Session { pub fn adapter_client(&self) -> Option<Arc<DebugAdapterClient>> { match self.mode { - Mode::Running(ref local) => Some(local.client.clone()), - Mode::Building => None, + SessionState::Running(ref local) => Some(local.client.clone()), + SessionState::Booting(_) => None, } } @@ -2331,6 +2552,20 @@ impl Session { .unwrap_or_default() } + pub fn data_breakpoint_info( + &mut self, + context: Arc<DataBreakpointContext>, + mode: Option<String>, + cx: &mut Context<Self>, + ) -> Task<Option<dap::DataBreakpointInfoResponse>> { + let command = DataBreakpointInfoCommand { + context: context.clone(), + mode, + }; + + self.request(command, |_, response, _| response.ok(), cx) + } + pub fn set_variable_value( &mut self, stack_frame_id: u64, @@ -2349,6 +2584,8 @@ impl Session { move |this, response, cx| { let response = response.log_err()?; this.invalidate_command_type::<VariablesCommand>(); + this.invalidate_command_type::<ReadMemory>(); + this.memory.clear(cx.background_executor()); this.refresh_watchers(stack_frame_id, cx); cx.emit(SessionEvent::Variables); Some(response) @@ -2388,6 +2625,8 @@ impl Session { cx.spawn(async move |this, cx| { let response = request.await; this.update(cx, |this, cx| { + this.memory.clear(cx.background_executor()); + this.invalidate_command_type::<ReadMemory>(); match response { Ok(response) => { let event = dap::OutputEvent { @@ -2443,7 +2682,7 @@ impl Session { } pub fn is_attached(&self) -> bool { - let Mode::Running(local_mode) = &self.mode else { + let SessionState::Running(local_mode) = &self.mode else { return false; }; local_mode.binary.request_args.request == StartDebuggingRequestArgumentsRequest::Attach @@ -2481,4 +2720,8 @@ impl Session { pub fn thread_state(&self, thread_id: ThreadId) -> Option<ThreadStatus> { self.thread_states.thread_state(thread_id) } + + pub fn quirks(&self) -> SessionQuirks { + self.quirks + } } diff --git a/crates/project/src/debugger/test.rs b/crates/project/src/debugger/test.rs index 3b9425e3690fb872b7953634cbdb995a40a5b021..53b88323e6326fe7d6d74f79a5e92845514c6b61 100644 --- a/crates/project/src/debugger/test.rs +++ b/crates/project/src/debugger/test.rs @@ -1,7 +1,7 @@ use std::{path::Path, sync::Arc}; use dap::client::DebugAdapterClient; -use gpui::{App, AppContext, Subscription}; +use gpui::{App, Subscription}; use super::session::{Session, SessionStateEvent}; @@ -19,14 +19,6 @@ pub fn intercept_debug_sessions<T: Fn(&Arc<DebugAdapterClient>) + 'static>( let client = session.adapter_client().unwrap(); register_default_handlers(session, &client, cx); configure(&client); - cx.background_spawn(async move { - client - .fake_event(dap::messages::Events::Initialized( - Some(Default::default()), - )) - .await - }) - .detach(); } }) .detach(); diff --git a/crates/project/src/git_store.rs b/crates/project/src/git_store.rs index 9ff3823e0f13a87fdcff944db7ad2d52350a7cce..01fc987816447340bcec77e53ddf77ff72146be9 100644 --- a/crates/project/src/git_store.rs +++ b/crates/project/src/git_store.rs @@ -14,9 +14,10 @@ use collections::HashMap; pub use conflict_set::{ConflictRegion, ConflictSet, ConflictSetSnapshot, ConflictSetUpdate}; use fs::Fs; use futures::{ - FutureExt, StreamExt as _, + FutureExt, StreamExt, channel::{mpsc, oneshot}, future::{self, Shared}, + stream::FuturesOrdered, }; use git::{ BuildPermalinkParams, GitHostingProviderRegistry, WORK_DIRECTORY_REPO_PATH, @@ -63,8 +64,8 @@ use sum_tree::{Edit, SumTree, TreeSet}; use text::{Bias, BufferId}; use util::{ResultExt, debug_panic, post_inc}; use worktree::{ - File, PathKey, PathProgress, PathSummary, PathTarget, UpdatedGitRepositoriesSet, - UpdatedGitRepository, Worktree, + File, PathChange, PathKey, PathProgress, PathSummary, PathTarget, ProjectEntryId, + UpdatedGitRepositoriesSet, UpdatedGitRepository, Worktree, }; pub struct GitStore { @@ -245,6 +246,8 @@ pub struct RepositorySnapshot { pub head_commit: Option<CommitDetails>, pub scan_id: u64, pub merge: MergeDetails, + pub remote_origin_url: Option<String>, + pub remote_upstream_url: Option<String>, } type JobId = u64; @@ -419,6 +422,8 @@ impl GitStore { client.add_entity_request_handler(Self::handle_fetch); client.add_entity_request_handler(Self::handle_stage); client.add_entity_request_handler(Self::handle_unstage); + client.add_entity_request_handler(Self::handle_stash); + client.add_entity_request_handler(Self::handle_stash_pop); client.add_entity_request_handler(Self::handle_commit); client.add_entity_request_handler(Self::handle_reset); client.add_entity_request_handler(Self::handle_show); @@ -1083,27 +1088,26 @@ impl GitStore { match event { WorktreeStoreEvent::WorktreeUpdatedEntries(worktree_id, updated_entries) => { - let mut paths_by_git_repo = HashMap::<_, Vec<_>>::default(); - for (relative_path, _, _) in updated_entries.iter() { - let Some((repo, repo_path)) = self.repository_and_path_for_project_path( - &(*worktree_id, relative_path.clone()).into(), - cx, - ) else { - continue; - }; - paths_by_git_repo.entry(repo).or_default().push(repo_path) - } - - for (repo, paths) in paths_by_git_repo { - repo.update(cx, |repo, cx| { - repo.paths_changed( - paths, - downstream - .as_ref() - .map(|downstream| downstream.updates_tx.clone()), - cx, - ); - }); + if let Some(worktree) = self + .worktree_store + .read(cx) + .worktree_for_id(*worktree_id, cx) + { + let paths_by_git_repo = + self.process_updated_entries(&worktree, updated_entries, cx); + let downstream = downstream + .as_ref() + .map(|downstream| downstream.updates_tx.clone()); + cx.spawn(async move |_, cx| { + let paths_by_git_repo = paths_by_git_repo.await; + for (repo, paths) in paths_by_git_repo { + repo.update(cx, |repo, cx| { + repo.paths_changed(paths, downstream.clone(), cx); + }) + .ok(); + } + }) + .detach(); } } WorktreeStoreEvent::WorktreeUpdatedGitRepositories(worktree_id, changed_repos) => { @@ -1696,6 +1700,48 @@ impl GitStore { Ok(proto::Ack {}) } + async fn handle_stash( + this: Entity<Self>, + envelope: TypedEnvelope<proto::Stash>, + mut cx: AsyncApp, + ) -> Result<proto::Ack> { + let repository_id = RepositoryId::from_proto(envelope.payload.repository_id); + let repository_handle = Self::repository_for_request(&this, repository_id, &mut cx)?; + + let entries = envelope + .payload + .paths + .into_iter() + .map(PathBuf::from) + .map(RepoPath::new) + .collect(); + + repository_handle + .update(&mut cx, |repository_handle, cx| { + repository_handle.stash_entries(entries, cx) + })? + .await?; + + Ok(proto::Ack {}) + } + + async fn handle_stash_pop( + this: Entity<Self>, + envelope: TypedEnvelope<proto::StashPop>, + mut cx: AsyncApp, + ) -> Result<proto::Ack> { + let repository_id = RepositoryId::from_proto(envelope.payload.repository_id); + let repository_handle = Self::repository_for_request(&this, repository_id, &mut cx)?; + + repository_handle + .update(&mut cx, |repository_handle, cx| { + repository_handle.stash_pop(cx) + })? + .await?; + + Ok(proto::Ack {}) + } + async fn handle_set_index_text( this: Entity<Self>, envelope: TypedEnvelope<proto::SetIndexText>, @@ -1738,6 +1784,7 @@ impl GitStore { name.zip(email), CommitOptions { amend: options.amend, + signoff: options.signoff, }, cx, ) @@ -2190,6 +2237,80 @@ impl GitStore { .map(|(id, repo)| (*id, repo.read(cx).snapshot.clone())) .collect() } + + fn process_updated_entries( + &self, + worktree: &Entity<Worktree>, + updated_entries: &[(Arc<Path>, ProjectEntryId, PathChange)], + cx: &mut App, + ) -> Task<HashMap<Entity<Repository>, Vec<RepoPath>>> { + let mut repo_paths = self + .repositories + .values() + .map(|repo| (repo.read(cx).work_directory_abs_path.clone(), repo.clone())) + .collect::<Vec<_>>(); + let mut entries: Vec<_> = updated_entries + .iter() + .map(|(path, _, _)| path.clone()) + .collect(); + entries.sort(); + let worktree = worktree.read(cx); + + let entries = entries + .into_iter() + .filter_map(|path| worktree.absolutize(&path).ok()) + .collect::<Arc<[_]>>(); + + let executor = cx.background_executor().clone(); + cx.background_executor().spawn(async move { + repo_paths.sort_by(|lhs, rhs| lhs.0.cmp(&rhs.0)); + let mut paths_by_git_repo = HashMap::<_, Vec<_>>::default(); + let mut tasks = FuturesOrdered::new(); + for (repo_path, repo) in repo_paths.into_iter().rev() { + let entries = entries.clone(); + let task = executor.spawn(async move { + // Find all repository paths that belong to this repo + let mut ix = entries.partition_point(|path| path < &*repo_path); + if ix == entries.len() { + return None; + }; + + let mut paths = vec![]; + // All paths prefixed by a given repo will constitute a continuous range. + while let Some(path) = entries.get(ix) + && let Some(repo_path) = + RepositorySnapshot::abs_path_to_repo_path_inner(&repo_path, &path) + { + paths.push((repo_path, ix)); + ix += 1; + } + Some((repo, paths)) + }); + tasks.push_back(task); + } + + // Now, let's filter out the "duplicate" entries that were processed by multiple distinct repos. + let mut path_was_used = vec![false; entries.len()]; + let tasks = tasks.collect::<Vec<_>>().await; + // Process tasks from the back: iterating backwards allows us to see more-specific paths first. + // We always want to assign a path to it's innermost repository. + for t in tasks { + let Some((repo, paths)) = t else { + continue; + }; + let entry = paths_by_git_repo.entry(repo).or_default(); + for (repo_path, ix) in paths { + if path_was_used[ix] { + continue; + } + path_was_used[ix] = true; + entry.push(repo_path); + } + } + + paths_by_git_repo + }) + } } impl BufferGitState { @@ -2554,6 +2675,8 @@ impl RepositorySnapshot { head_commit: None, scan_id: 0, merge: Default::default(), + remote_origin_url: None, + remote_upstream_url: None, } } @@ -2659,8 +2782,16 @@ impl RepositorySnapshot { } pub fn abs_path_to_repo_path(&self, abs_path: &Path) -> Option<RepoPath> { + Self::abs_path_to_repo_path_inner(&self.work_directory_abs_path, abs_path) + } + + #[inline] + fn abs_path_to_repo_path_inner( + work_directory_abs_path: &Path, + abs_path: &Path, + ) -> Option<RepoPath> { abs_path - .strip_prefix(&self.work_directory_abs_path) + .strip_prefix(&work_directory_abs_path) .map(RepoPath::from) .ok() } @@ -3457,6 +3588,82 @@ impl Repository { self.unstage_entries(to_unstage, cx) } + pub fn stash_all(&mut self, cx: &mut Context<Self>) -> Task<anyhow::Result<()>> { + let to_stash = self + .cached_status() + .map(|entry| entry.repo_path.clone()) + .collect(); + + self.stash_entries(to_stash, cx) + } + + pub fn stash_entries( + &mut self, + entries: Vec<RepoPath>, + cx: &mut Context<Self>, + ) -> Task<anyhow::Result<()>> { + let id = self.id; + + cx.spawn(async move |this, cx| { + this.update(cx, |this, _| { + this.send_job(None, move |git_repo, _cx| async move { + match git_repo { + RepositoryState::Local { + backend, + environment, + .. + } => backend.stash_paths(entries, environment).await, + RepositoryState::Remote { project_id, client } => { + client + .request(proto::Stash { + project_id: project_id.0, + repository_id: id.to_proto(), + paths: entries + .into_iter() + .map(|repo_path| repo_path.as_ref().to_proto()) + .collect(), + }) + .await + .context("sending stash request")?; + Ok(()) + } + } + }) + })? + .await??; + Ok(()) + }) + } + + pub fn stash_pop(&mut self, cx: &mut Context<Self>) -> Task<anyhow::Result<()>> { + let id = self.id; + cx.spawn(async move |this, cx| { + this.update(cx, |this, _| { + this.send_job(None, move |git_repo, _cx| async move { + match git_repo { + RepositoryState::Local { + backend, + environment, + .. + } => backend.stash_pop(environment).await, + RepositoryState::Remote { project_id, client } => { + client + .request(proto::StashPop { + project_id: project_id.0, + repository_id: id.to_proto(), + }) + .await + .context("sending stash pop request")?; + Ok(()) + } + } + }) + })? + .await??; + Ok(()) + }) + } + pub fn commit( &mut self, message: SharedString, @@ -3488,6 +3695,7 @@ impl Repository { email: email.map(String::from), options: Some(proto::commit::CommitOptions { amend: options.amend, + signoff: options.signoff, }), }) .await @@ -3821,6 +4029,25 @@ impl Repository { }) } + pub fn default_branch(&mut self) -> oneshot::Receiver<Result<Option<SharedString>>> { + let id = self.id; + self.send_job(None, move |repo, _| async move { + match repo { + RepositoryState::Local { backend, .. } => backend.default_branch().await, + RepositoryState::Remote { project_id, client } => { + let response = client + .request(proto::GetDefaultBranch { + project_id: project_id.0, + repository_id: id.to_proto(), + }) + .await?; + + anyhow::Ok(response.branch.map(SharedString::from)) + } + } + }) + } + pub fn diff(&mut self, diff_type: DiffType, _cx: &App) -> oneshot::Receiver<Result<String>> { let id = self.id; self.send_job(None, move |repo, _cx| async move { @@ -4277,7 +4504,7 @@ impl Repository { for (repo_path, status) in &*statuses.entries { changed_paths.remove(repo_path); - if cursor.seek_forward(&PathTarget::Path(repo_path), Bias::Left, &()) { + if cursor.seek_forward(&PathTarget::Path(repo_path), Bias::Left) { if cursor.item().is_some_and(|entry| entry.status == *status) { continue; } @@ -4290,7 +4517,7 @@ impl Repository { } let mut cursor = prev_statuses.cursor::<PathProgress>(&()); for path in changed_paths.into_iter() { - if cursor.seek_forward(&PathTarget::Path(&path), Bias::Left, &()) { + if cursor.seek_forward(&PathTarget::Path(&path), Bias::Left) { changed_path_statuses.push(Edit::Remove(PathKey(path.0))); } } @@ -4395,17 +4622,17 @@ fn serialize_blame_buffer_response(blame: Option<git::blame::Blame>) -> proto::B start_line: entry.range.start, end_line: entry.range.end, original_line_number: entry.original_line_number, - author: entry.author.clone(), - author_mail: entry.author_mail.clone(), + author: entry.author, + author_mail: entry.author_mail, author_time: entry.author_time, - author_tz: entry.author_tz.clone(), - committer: entry.committer_name.clone(), - committer_mail: entry.committer_email.clone(), + author_tz: entry.author_tz, + committer: entry.committer_name, + committer_mail: entry.committer_email, committer_time: entry.committer_time, - committer_tz: entry.committer_tz.clone(), - summary: entry.summary.clone(), - previous: entry.previous.clone(), - filename: entry.filename.clone(), + committer_tz: entry.committer_tz, + summary: entry.summary, + previous: entry.previous, + filename: entry.filename, }) .collect::<Vec<_>>(); @@ -4595,6 +4822,10 @@ async fn compute_snapshot( None => None, }; + // Used by edit prediction data collection + let remote_origin_url = backend.remote_url("origin"); + let remote_upstream_url = backend.remote_url("upstream"); + let snapshot = RepositorySnapshot { id, statuses_by_path, @@ -4603,6 +4834,8 @@ async fn compute_snapshot( branch, head_commit, merge: merge_details, + remote_origin_url, + remote_upstream_url, }; Ok((snapshot, events)) diff --git a/crates/project/src/git_store/git_traversal.rs b/crates/project/src/git_store/git_traversal.rs index 68ed03cfe9e41abf480fbe7a5bf10f84e10ce553..777042cb02cf87c127f050a88d8504dcb181678c 100644 --- a/crates/project/src/git_store/git_traversal.rs +++ b/crates/project/src/git_store/git_traversal.rs @@ -1,6 +1,6 @@ use collections::HashMap; -use git::status::GitSummary; -use std::{ops::Deref, path::Path}; +use git::{repository::RepoPath, status::GitSummary}; +use std::{collections::BTreeMap, ops::Deref, path::Path}; use sum_tree::Cursor; use text::Bias; use worktree::{Entry, PathProgress, PathTarget, Traversal}; @@ -11,7 +11,7 @@ use super::{RepositoryId, RepositorySnapshot, StatusEntry}; pub struct GitTraversal<'a> { traversal: Traversal<'a>, current_entry_summary: Option<GitSummary>, - repo_snapshots: &'a HashMap<RepositoryId, RepositorySnapshot>, + repo_root_to_snapshot: BTreeMap<&'a Path, &'a RepositorySnapshot>, repo_location: Option<(RepositoryId, Cursor<'a, StatusEntry, PathProgress<'a>>)>, } @@ -20,16 +20,46 @@ impl<'a> GitTraversal<'a> { repo_snapshots: &'a HashMap<RepositoryId, RepositorySnapshot>, traversal: Traversal<'a>, ) -> GitTraversal<'a> { + let repo_root_to_snapshot = repo_snapshots + .values() + .map(|snapshot| (&*snapshot.work_directory_abs_path, snapshot)) + .collect(); let mut this = GitTraversal { traversal, - repo_snapshots, current_entry_summary: None, repo_location: None, + repo_root_to_snapshot, }; this.synchronize_statuses(true); this } + fn repo_root_for_path(&self, path: &Path) -> Option<(&'a RepositorySnapshot, RepoPath)> { + // We might need to perform a range search multiple times, as there may be a nested repository inbetween + // the target and our path. E.g: + // /our_root_repo/ + // .git/ + // other_repo/ + // .git/ + // our_query.txt + let mut query = path.ancestors(); + while let Some(query) = query.next() { + let (_, snapshot) = self + .repo_root_to_snapshot + .range(Path::new("")..=query) + .last()?; + + let stripped = snapshot + .abs_path_to_repo_path(path) + .map(|repo_path| (*snapshot, repo_path)); + if stripped.is_some() { + return stripped; + } + } + + None + } + fn synchronize_statuses(&mut self, reset: bool) { self.current_entry_summary = None; @@ -42,15 +72,7 @@ impl<'a> GitTraversal<'a> { return; }; - let Some((repo, repo_path)) = self - .repo_snapshots - .values() - .filter_map(|repo_snapshot| { - let repo_path = repo_snapshot.abs_path_to_repo_path(&abs_path)?; - Some((repo_snapshot, repo_path)) - }) - .max_by_key(|(repo, _)| repo.work_directory_abs_path.clone()) - else { + let Some((repo, repo_path)) = self.repo_root_for_path(&abs_path) else { self.repo_location = None; return; }; @@ -72,14 +94,13 @@ impl<'a> GitTraversal<'a> { if entry.is_dir() { let mut statuses = statuses.clone(); - statuses.seek_forward(&PathTarget::Path(repo_path.as_ref()), Bias::Left, &()); - let summary = - statuses.summary(&PathTarget::Successor(repo_path.as_ref()), Bias::Left, &()); + statuses.seek_forward(&PathTarget::Path(repo_path.as_ref()), Bias::Left); + let summary = statuses.summary(&PathTarget::Successor(repo_path.as_ref()), Bias::Left); self.current_entry_summary = Some(summary); } else if entry.is_file() { // For a file entry, park the cursor on the corresponding status - if statuses.seek_forward(&PathTarget::Path(repo_path.as_ref()), Bias::Left, &()) { + if statuses.seek_forward(&PathTarget::Path(repo_path.as_ref()), Bias::Left) { // TODO: Investigate statuses.item() being None here. self.current_entry_summary = statuses.item().map(|item| item.status.into()); } else { diff --git a/crates/project/src/lsp_command.rs b/crates/project/src/lsp_command.rs index 8ed37164361b64ceab7f46837bd89cf81e2a4550..c458b6b300c34ec03d144cf297277faf4a94f5db 100644 --- a/crates/project/src/lsp_command.rs +++ b/crates/project/src/lsp_command.rs @@ -350,7 +350,7 @@ impl LspCommand for PrepareRename { } Some(lsp::PrepareRenameResponse::DefaultBehavior { .. }) => { let snapshot = buffer.snapshot(); - let (range, _) = snapshot.surrounding_word(self.position); + let (range, _) = snapshot.surrounding_word(self.position, false); let range = snapshot.anchor_after(range.start)..snapshot.anchor_before(range.end); Ok(PrepareRenameResponse::Success(range)) } @@ -2154,6 +2154,16 @@ impl LspCommand for GetHover { } } +impl GetCompletions { + pub fn can_resolve_completions(capabilities: &lsp::ServerCapabilities) -> bool { + capabilities + .completion_provider + .as_ref() + .and_then(|options| options.resolve_provider) + .unwrap_or(false) + } +} + #[async_trait(?Send)] impl LspCommand for GetCompletions { type Response = CoreCompletionResponse; @@ -2269,7 +2279,7 @@ impl LspCommand for GetCompletions { // the range based on the syntax tree. None => { if self.position != clipped_position { - log::info!("completion out of expected range"); + log::info!("completion out of expected range "); return false; } @@ -2297,7 +2307,7 @@ impl LspCommand for GetCompletions { range_for_token .get_or_insert_with(|| { let offset = self.position.to_offset(&snapshot); - let (range, kind) = snapshot.surrounding_word(offset); + let (range, kind) = snapshot.surrounding_word(offset, true); let range = if kind == Some(CharKind::Word) { range } else { @@ -2483,7 +2493,9 @@ pub(crate) fn parse_completion_text_edit( let start = snapshot.clip_point_utf16(range.start, Bias::Left); let end = snapshot.clip_point_utf16(range.end, Bias::Left); if start != range.start.0 || end != range.end.0 { - log::info!("completion out of expected range"); + log::info!( + "completion out of expected range, start: {start:?}, end: {end:?}, range: {range:?}" + ); return None; } snapshot.anchor_before(start)..snapshot.anchor_after(end) @@ -2760,6 +2772,23 @@ impl GetCodeActions { } } +impl OnTypeFormatting { + pub fn supports_on_type_formatting(trigger: &str, capabilities: &ServerCapabilities) -> bool { + let Some(on_type_formatting_options) = &capabilities.document_on_type_formatting_provider + else { + return false; + }; + on_type_formatting_options + .first_trigger_character + .contains(trigger) + || on_type_formatting_options + .more_trigger_character + .iter() + .flatten() + .any(|chars| chars.contains(trigger)) + } +} + #[async_trait(?Send)] impl LspCommand for OnTypeFormatting { type Response = Option<Transaction>; @@ -2771,20 +2800,7 @@ impl LspCommand for OnTypeFormatting { } fn check_capabilities(&self, capabilities: AdapterServerCapabilities) -> bool { - let Some(on_type_formatting_options) = &capabilities - .server_capabilities - .document_on_type_formatting_provider - else { - return false; - }; - on_type_formatting_options - .first_trigger_character - .contains(&self.trigger) - || on_type_formatting_options - .more_trigger_character - .iter() - .flatten() - .any(|chars| chars.contains(&self.trigger)) + Self::supports_on_type_formatting(&self.trigger, &capabilities.server_capabilities) } fn to_lsp( @@ -3268,6 +3284,16 @@ impl InlayHints { }) .unwrap_or(false) } + + pub fn check_capabilities(capabilities: &ServerCapabilities) -> bool { + capabilities + .inlay_hint_provider + .as_ref() + .is_some_and(|inlay_hint_provider| match inlay_hint_provider { + lsp::OneOf::Left(enabled) => *enabled, + lsp::OneOf::Right(_) => true, + }) + } } #[async_trait(?Send)] @@ -3281,17 +3307,7 @@ impl LspCommand for InlayHints { } fn check_capabilities(&self, capabilities: AdapterServerCapabilities) -> bool { - let Some(inlay_hint_provider) = &capabilities.server_capabilities.inlay_hint_provider - else { - return false; - }; - match inlay_hint_provider { - lsp::OneOf::Left(enabled) => *enabled, - lsp::OneOf::Right(inlay_hint_capabilities) => match inlay_hint_capabilities { - lsp::InlayHintServerCapabilities::Options(_) => true, - lsp::InlayHintServerCapabilities::RegistrationOptions(_) => false, - }, - } + Self::check_capabilities(&capabilities.server_capabilities) } fn to_lsp( @@ -3578,6 +3594,18 @@ impl LspCommand for GetCodeLens { } } +impl LinkedEditingRange { + pub fn check_server_capabilities(capabilities: ServerCapabilities) -> bool { + let Some(linked_editing_options) = capabilities.linked_editing_range_provider else { + return false; + }; + if let LinkedEditingRangeServerCapabilities::Simple(false) = linked_editing_options { + return false; + } + true + } +} + #[async_trait(?Send)] impl LspCommand for LinkedEditingRange { type Response = Vec<Range<Anchor>>; @@ -3589,16 +3617,7 @@ impl LspCommand for LinkedEditingRange { } fn check_capabilities(&self, capabilities: AdapterServerCapabilities) -> bool { - let Some(linked_editing_options) = &capabilities - .server_capabilities - .linked_editing_range_provider - else { - return false; - }; - if let LinkedEditingRangeServerCapabilities::Simple(false) = linked_editing_options { - return false; - } - true + Self::check_server_capabilities(capabilities.server_capabilities) } fn to_lsp( @@ -3822,7 +3841,7 @@ impl GetDocumentDiagnostics { code, code_description: match diagnostic.code_description { Some(code_description) => Some(CodeDescription { - href: lsp::Url::parse(&code_description).unwrap(), + href: Some(lsp::Url::parse(&code_description).unwrap()), }), None => None, }, @@ -3898,7 +3917,7 @@ impl GetDocumentDiagnostics { tags, code_description: diagnostic .code_description - .map(|desc| desc.href.to_string()), + .and_then(|desc| desc.href.map(|url| url.to_string())), message: diagnostic.message, data: diagnostic.data.as_ref().map(|data| data.to_string()), }) @@ -4216,8 +4235,9 @@ impl LspCommand for GetDocumentColor { server_capabilities .server_capabilities .color_provider + .as_ref() .is_some_and(|capability| match capability { - lsp::ColorProviderCapability::Simple(supported) => supported, + lsp::ColorProviderCapability::Simple(supported) => *supported, lsp::ColorProviderCapability::ColorProvider(..) => true, lsp::ColorProviderCapability::Options(..) => true, }) diff --git a/crates/project/src/lsp_store.rs b/crates/project/src/lsp_store.rs index e7afac0689755a589513a4e99e35e8fe69b5219e..b88cf42ff51639f159333e2a81f5cd768bb7ff46 100644 --- a/crates/project/src/lsp_store.rs +++ b/crates/project/src/lsp_store.rs @@ -1,4 +1,5 @@ pub mod clangd_ext; +pub mod json_language_server_ext; pub mod lsp_ext_command; pub mod rust_analyzer_ext; @@ -29,7 +30,7 @@ use clock::Global; use collections::{BTreeMap, BTreeSet, HashMap, HashSet, btree_map}; use futures::{ AsyncWriteExt, Future, FutureExt, StreamExt, - future::{Shared, join_all}, + future::{Either, Shared, join_all, pending, select}, select, select_biased, stream::FuturesUnordered, }; @@ -45,6 +46,7 @@ use language::{ DiagnosticEntry, DiagnosticSet, DiagnosticSourceKind, Diff, File as _, Language, LanguageName, LanguageRegistry, LanguageToolchainStore, LocalFile, LspAdapter, LspAdapterDelegate, Patch, PointUtf16, TextBufferSnapshot, ToOffset, ToPointUtf16, Transaction, Unclipped, + WorkspaceFoldersContent, language_settings::{ FormatOnSave, Formatter, LanguageSettings, SelectedFormatter, language_settings, }, @@ -56,12 +58,12 @@ use language::{ range_from_lsp, range_to_lsp, }; use lsp::{ - CodeActionKind, CompletionContext, DiagnosticSeverity, DiagnosticTag, - DidChangeWatchedFilesRegistrationOptions, Edit, FileOperationFilter, FileOperationPatternKind, - FileOperationRegistrationOptions, FileRename, FileSystemWatcher, LanguageServer, - LanguageServerBinary, LanguageServerBinaryOptions, LanguageServerId, LanguageServerName, - LanguageServerSelector, LspRequestFuture, MessageActionItem, MessageType, OneOf, - RenameFilesParams, SymbolKind, TextEdit, WillRenameFiles, WorkDoneProgressCancelParams, + AdapterServerCapabilities, CodeActionKind, CompletionContext, DiagnosticSeverity, + DiagnosticTag, DidChangeWatchedFilesRegistrationOptions, Edit, FileOperationFilter, + FileOperationPatternKind, FileOperationRegistrationOptions, FileRename, FileSystemWatcher, + LanguageServer, LanguageServerBinary, LanguageServerBinaryOptions, LanguageServerId, + LanguageServerName, LanguageServerSelector, LspRequestFuture, MessageActionItem, MessageType, + OneOf, RenameFilesParams, SymbolKind, TextEdit, WillRenameFiles, WorkDoneProgressCancelParams, WorkspaceFolder, notification::DidRenameFiles, }; use node_runtime::read_package_installed_version; @@ -85,13 +87,16 @@ use std::{ cmp::{Ordering, Reverse}, convert::TryInto, ffi::OsStr, + future::ready, iter, mem, ops::{ControlFlow, Range}, path::{self, Path, PathBuf}, + pin::pin, rc::Rc, sync::Arc, time::{Duration, Instant}, }; +use sum_tree::Dimensions; use text::{Anchor, BufferId, LineEnding, OffsetRangeExt}; use url::Url; use util::{ @@ -135,6 +140,20 @@ impl FormatTrigger { } } +#[derive(Debug)] +pub struct DocumentDiagnosticsUpdate<'a, D> { + pub diagnostics: D, + pub result_id: Option<String>, + pub server_id: LanguageServerId, + pub disk_based_sources: Cow<'a, [String]>, +} + +pub struct DocumentDiagnostics { + diagnostics: Vec<DiagnosticEntry<Unclipped<PointUtf16>>>, + document_abs_path: PathBuf, + version: Option<i32>, +} + pub struct LocalLspStore { weak: WeakEntity<LspStore>, worktree_store: Entity<WorktreeStore>, @@ -214,6 +233,7 @@ impl LocalLspStore { let binary = self.get_language_server_binary(adapter.clone(), delegate.clone(), true, cx); let pending_workspace_folders: Arc<Mutex<BTreeSet<Url>>> = Default::default(); + let pending_server = cx.spawn({ let adapter = adapter.clone(); let server_name = adapter.name.clone(); @@ -239,14 +259,18 @@ impl LocalLspStore { return Ok(server); } + let code_action_kinds = adapter.code_action_kinds(); lsp::LanguageServer::new( stderr_capture, server_id, server_name, binary, &root_path, - adapter.code_action_kinds(), - pending_workspace_folders, + code_action_kinds, + Some(pending_workspace_folders).filter(|_| { + adapter.adapter.workspace_folders_content() + == WorkspaceFoldersContent::SubprojectRoots + }), cx, ) } @@ -415,7 +439,7 @@ impl LocalLspStore { if settings.as_ref().is_some_and(|b| b.path.is_some()) { let settings = settings.unwrap(); - return cx.spawn(async move |_| { + return cx.background_spawn(async move { let mut env = delegate.shell_env().await; env.extend(settings.env.unwrap_or_default()); @@ -493,12 +517,16 @@ impl LocalLspStore { adapter.process_diagnostics(&mut params, server_id, buffer); } - this.merge_diagnostics( - server_id, - params, - None, + this.merge_lsp_diagnostics( DiagnosticSourceKind::Pushed, - &adapter.disk_based_diagnostic_sources, + vec![DocumentDiagnosticsUpdate { + server_id, + diagnostics: params, + result_id: None, + disk_based_sources: Cow::Borrowed( + &adapter.disk_based_diagnostic_sources, + ), + }], |_, diagnostic, cx| match diagnostic.source_kind { DiagnosticSourceKind::Other | DiagnosticSourceKind::Pushed => { adapter.retain_old_diagnostic(diagnostic, cx) @@ -572,8 +600,7 @@ impl LocalLspStore { }; let root = server.workspace_folders(); Ok(Some( - root.iter() - .cloned() + root.into_iter() .map(|uri| WorkspaceFolder { uri, name: Default::default(), @@ -613,7 +640,7 @@ impl LocalLspStore { .on_request::<lsp::request::RegisterCapability, _, _>({ let this = this.clone(); move |params, cx| { - let this = this.clone(); + let lsp_store = this.clone(); let mut cx = cx.clone(); async move { for reg in params.registrations { @@ -621,7 +648,7 @@ impl LocalLspStore { "workspace/didChangeWatchedFiles" => { if let Some(options) = reg.register_options { let options = serde_json::from_value(options)?; - this.update(&mut cx, |this, cx| { + lsp_store.update(&mut cx, |this, cx| { this.as_local_mut()?.on_lsp_did_change_watched_files( server_id, ®.id, options, cx, ); @@ -630,8 +657,9 @@ impl LocalLspStore { } } "textDocument/rangeFormatting" => { - this.read_with(&mut cx, |this, _| { - if let Some(server) = this.language_server_for_id(server_id) + lsp_store.update(&mut cx, |lsp_store, cx| { + if let Some(server) = + lsp_store.language_server_for_id(server_id) { let options = reg .register_options @@ -650,14 +678,16 @@ impl LocalLspStore { server.update_capabilities(|capabilities| { capabilities.document_range_formatting_provider = Some(provider); - }) + }); + notify_server_capabilities_updated(&server, cx); } anyhow::Ok(()) })??; } "textDocument/onTypeFormatting" => { - this.read_with(&mut cx, |this, _| { - if let Some(server) = this.language_server_for_id(server_id) + lsp_store.update(&mut cx, |lsp_store, cx| { + if let Some(server) = + lsp_store.language_server_for_id(server_id) { let options = reg .register_options @@ -674,15 +704,17 @@ impl LocalLspStore { capabilities .document_on_type_formatting_provider = Some(options); - }) + }); + notify_server_capabilities_updated(&server, cx); } } anyhow::Ok(()) })??; } "textDocument/formatting" => { - this.read_with(&mut cx, |this, _| { - if let Some(server) = this.language_server_for_id(server_id) + lsp_store.update(&mut cx, |lsp_store, cx| { + if let Some(server) = + lsp_store.language_server_for_id(server_id) { let options = reg .register_options @@ -701,7 +733,8 @@ impl LocalLspStore { server.update_capabilities(|capabilities| { capabilities.document_formatting_provider = Some(provider); - }) + }); + notify_server_capabilities_updated(&server, cx); } anyhow::Ok(()) })??; @@ -710,8 +743,9 @@ impl LocalLspStore { // Ignore payload since we notify clients of setting changes unconditionally, relying on them pulling the latest settings. } "textDocument/rename" => { - this.read_with(&mut cx, |this, _| { - if let Some(server) = this.language_server_for_id(server_id) + lsp_store.update(&mut cx, |lsp_store, cx| { + if let Some(server) = + lsp_store.language_server_for_id(server_id) { let options = reg .register_options @@ -728,7 +762,8 @@ impl LocalLspStore { server.update_capabilities(|capabilities| { capabilities.rename_provider = Some(options); - }) + }); + notify_server_capabilities_updated(&server, cx); } anyhow::Ok(()) })??; @@ -746,14 +781,15 @@ impl LocalLspStore { .on_request::<lsp::request::UnregisterCapability, _, _>({ let this = this.clone(); move |params, cx| { - let this = this.clone(); + let lsp_store = this.clone(); let mut cx = cx.clone(); async move { for unreg in params.unregisterations.iter() { match unreg.method.as_str() { "workspace/didChangeWatchedFiles" => { - this.update(&mut cx, |this, cx| { - this.as_local_mut()? + lsp_store.update(&mut cx, |lsp_store, cx| { + lsp_store + .as_local_mut()? .on_lsp_unregister_did_change_watched_files( server_id, &unreg.id, cx, ); @@ -764,44 +800,52 @@ impl LocalLspStore { // Ignore payload since we notify clients of setting changes unconditionally, relying on them pulling the latest settings. } "textDocument/rename" => { - this.read_with(&mut cx, |this, _| { - if let Some(server) = this.language_server_for_id(server_id) + lsp_store.update(&mut cx, |lsp_store, cx| { + if let Some(server) = + lsp_store.language_server_for_id(server_id) { server.update_capabilities(|capabilities| { capabilities.rename_provider = None - }) + }); + notify_server_capabilities_updated(&server, cx); } })?; } "textDocument/rangeFormatting" => { - this.read_with(&mut cx, |this, _| { - if let Some(server) = this.language_server_for_id(server_id) + lsp_store.update(&mut cx, |lsp_store, cx| { + if let Some(server) = + lsp_store.language_server_for_id(server_id) { server.update_capabilities(|capabilities| { capabilities.document_range_formatting_provider = None - }) + }); + notify_server_capabilities_updated(&server, cx); } })?; } "textDocument/onTypeFormatting" => { - this.read_with(&mut cx, |this, _| { - if let Some(server) = this.language_server_for_id(server_id) + lsp_store.update(&mut cx, |lsp_store, cx| { + if let Some(server) = + lsp_store.language_server_for_id(server_id) { server.update_capabilities(|capabilities| { capabilities.document_on_type_formatting_provider = None; - }) + }); + notify_server_capabilities_updated(&server, cx); } })?; } "textDocument/formatting" => { - this.read_with(&mut cx, |this, _| { - if let Some(server) = this.language_server_for_id(server_id) + lsp_store.update(&mut cx, |lsp_store, cx| { + if let Some(server) = + lsp_store.language_server_for_id(server_id) { server.update_capabilities(|capabilities| { capabilities.document_formatting_provider = None; - }) + }); + notify_server_capabilities_updated(&server, cx); } })?; } @@ -1032,6 +1076,7 @@ impl LocalLspStore { }) .detach(); + json_language_server_ext::register_requests(this.clone(), language_server); rust_analyzer_ext::register_notifications(this.clone(), language_server); clangd_ext::register_notifications(this, language_server, adapter); } @@ -1270,15 +1315,11 @@ impl LocalLspStore { // grouped with the previous transaction in the history // based on the transaction group interval buffer.finalize_last_transaction(); - let transaction_id = buffer + buffer .start_transaction() .context("transaction already open")?; - let transaction = buffer - .get_transaction(transaction_id) - .expect("transaction started") - .clone(); buffer.end_transaction(cx); - buffer.push_transaction(transaction, cx.background_executor().now()); + let transaction_id = buffer.push_empty_transaction(cx.background_executor().now()); buffer.finalize_last_transaction(); anyhow::Ok(transaction_id) })??; @@ -2420,36 +2461,11 @@ impl LocalLspStore { let server_id = server_node.server_id_or_init( |LaunchDisposition { server_name, - attach, path, settings, }| { - let server_id = match attach { - language::Attach::InstancePerRoot => { - // todo: handle instance per root proper. - if let Some(server_ids) = self - .language_server_ids - .get(&(worktree_id, server_name.clone())) - { - server_ids.iter().cloned().next().unwrap() - } else { - let language_name = language.name(); - let adapter = self.languages - .lsp_adapters(&language_name) - .into_iter() - .find(|adapter| &adapter.name() == server_name) - .expect("To find LSP adapter"); - let server_id = self.start_language_server( - &worktree, - delegate.clone(), - adapter, - settings, - cx, - ); - server_id - } - } - language::Attach::Shared => { + let server_id = + { let uri = Url::from_file_path( worktree.read(cx).abs_path().join(&path.path), ); @@ -2484,20 +2500,8 @@ impl LocalLspStore { } else { unreachable!("Language server ID should be available, as it's registered on demand") } - } + }; - let lsp_store = self.weak.clone(); - let server_name = server_node.name(); - let buffer_abs_path = abs_path.to_string_lossy().to_string(); - cx.defer(move |cx| { - lsp_store.update(cx, |_, cx| cx.emit(LspStoreEvent::LanguageServerUpdate { - language_server_id: server_id, - name: server_name, - message: proto::update_language_server::Variant::RegisteredForBuffer(proto::RegisteredForBuffer { - buffer_abs_path, - }) - })).ok(); - }); server_id }, )?; @@ -2533,11 +2537,13 @@ impl LocalLspStore { snapshot: initial_snapshot.clone(), }; + let mut registered = false; self.buffer_snapshots .entry(buffer_id) .or_default() .entry(server.server_id()) .or_insert_with(|| { + registered = true; server.register_buffer( uri.clone(), adapter.language_id(&language.name()), @@ -2552,15 +2558,18 @@ impl LocalLspStore { .entry(buffer_id) .or_default() .insert(server.server_id()); - cx.emit(LspStoreEvent::LanguageServerUpdate { - language_server_id: server.server_id(), - name: None, - message: proto::update_language_server::Variant::RegisteredForBuffer( - proto::RegisteredForBuffer { - buffer_abs_path: abs_path.to_string_lossy().to_string(), - }, - ), - }); + if registered { + cx.emit(LspStoreEvent::LanguageServerUpdate { + language_server_id: server.server_id(), + name: None, + message: proto::update_language_server::Variant::RegisteredForBuffer( + proto::RegisteredForBuffer { + buffer_abs_path: abs_path.to_string_lossy().to_string(), + buffer_id: buffer_id.to_proto(), + }, + ), + }); + } } } @@ -3512,6 +3521,20 @@ impl LocalLspStore { } } +fn notify_server_capabilities_updated(server: &LanguageServer, cx: &mut Context<LspStore>) { + if let Some(capabilities) = serde_json::to_string(&server.capabilities()).ok() { + cx.emit(LspStoreEvent::LanguageServerUpdate { + language_server_id: server.server_id(), + name: Some(server.name()), + message: proto::update_language_server::Variant::MetadataUpdated( + proto::ServerMetadataUpdated { + capabilities: Some(capabilities), + }, + ), + }); + } +} + #[derive(Debug)] pub struct FormattableBuffer { handle: Entity<Buffer>, @@ -3551,7 +3574,9 @@ pub struct LspStore { _maintain_buffer_languages: Task<()>, diagnostic_summaries: HashMap<WorktreeId, HashMap<Arc<Path>, HashMap<LanguageServerId, DiagnosticSummary>>>, - lsp_data: HashMap<BufferId, DocumentColorData>, + pub(super) lsp_server_capabilities: HashMap<LanguageServerId, lsp::ServerCapabilities>, + lsp_document_colors: HashMap<BufferId, DocumentColorData>, + lsp_code_lens: HashMap<BufferId, CodeLensData>, } #[derive(Debug, Default, Clone)] @@ -3561,6 +3586,7 @@ pub struct DocumentColors { } type DocumentColorTask = Shared<Task<std::result::Result<DocumentColors, Arc<anyhow::Error>>>>; +type CodeLensTask = Shared<Task<std::result::Result<Vec<CodeAction>, Arc<anyhow::Error>>>>; #[derive(Debug, Default)] struct DocumentColorData { @@ -3570,8 +3596,15 @@ struct DocumentColorData { colors_update: Option<(Global, DocumentColorTask)>, } +#[derive(Debug, Default)] +struct CodeLensData { + lens_for_version: Global, + lens: HashMap<LanguageServerId, Vec<CodeAction>>, + update: Option<(Global, CodeLensTask)>, +} + #[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum ColorFetchStrategy { +pub enum LspFetchStrategy { IgnoreCache, UseCache { known_cache_version: Option<usize> }, } @@ -3595,8 +3628,8 @@ pub enum LspStoreEvent { RefreshInlayHints, RefreshCodeLens, DiagnosticsUpdated { - language_server_id: LanguageServerId, - path: ProjectPath, + server_id: LanguageServerId, + paths: Vec<ProjectPath>, }, DiskBasedDiagnosticsStarted { language_server_id: LanguageServerId, @@ -3613,7 +3646,7 @@ pub enum LspStoreEvent { #[derive(Clone, Debug, Serialize)] pub struct LanguageServerStatus { - pub name: String, + pub name: LanguageServerName, pub pending_work: BTreeMap<String, LanguageServerProgress>, pub has_pending_diagnostic_updates: bool, progress_tokens: HashSet<String>, @@ -3656,7 +3689,6 @@ impl LspStore { client.add_entity_request_handler(Self::handle_apply_additional_edits_for_completion); client.add_entity_request_handler(Self::handle_register_buffer_with_language_servers); client.add_entity_request_handler(Self::handle_rename_project_entry); - client.add_entity_request_handler(Self::handle_language_server_id_for_name); client.add_entity_request_handler(Self::handle_pull_workspace_diagnostics); client.add_entity_request_handler(Self::handle_lsp_command::<GetCodeActions>); client.add_entity_request_handler(Self::handle_lsp_command::<GetCompletions>); @@ -3804,7 +3836,9 @@ impl LspStore { language_server_statuses: Default::default(), nonce: StdRng::from_entropy().r#gen(), diagnostic_summaries: HashMap::default(), - lsp_data: HashMap::default(), + lsp_server_capabilities: HashMap::default(), + lsp_document_colors: HashMap::default(), + lsp_code_lens: HashMap::default(), active_entry: None, _maintain_workspace_config, _maintain_buffer_languages: Self::maintain_buffer_languages(languages, cx), @@ -3819,6 +3853,9 @@ impl LspStore { request: R, cx: &mut Context<LspStore>, ) -> Task<anyhow::Result<<R as LspCommand>::Response>> { + if !self.is_capable_for_proto_request(&buffer, &request, cx) { + return Task::ready(Ok(R::Response::default())); + } let message = request.to_proto(upstream_project_id, buffer.read(cx)); cx.spawn(async move |this, cx| { let response = client.request(message).await?; @@ -3861,7 +3898,9 @@ impl LspStore { language_server_statuses: Default::default(), nonce: StdRng::from_entropy().r#gen(), diagnostic_summaries: HashMap::default(), - lsp_data: HashMap::default(), + lsp_server_capabilities: HashMap::default(), + lsp_document_colors: HashMap::default(), + lsp_code_lens: HashMap::default(), active_entry: None, toolchain_store, _maintain_workspace_config, @@ -4162,7 +4201,8 @@ impl LspStore { *refcount }; if refcount == 0 { - lsp_store.lsp_data.remove(&buffer_id); + lsp_store.lsp_document_colors.remove(&buffer_id); + lsp_store.lsp_code_lens.remove(&buffer_id); let local = lsp_store.as_local_mut().unwrap(); local.registered_buffers.remove(&buffer_id); local.buffers_opened_in_servers.remove(&buffer_id); @@ -4418,36 +4458,96 @@ impl LspStore { pub(crate) fn send_diagnostic_summaries(&self, worktree: &mut Worktree) { if let Some((client, downstream_project_id)) = self.downstream_client.clone() { - if let Some(summaries) = self.diagnostic_summaries.get(&worktree.id()) { - for (path, summaries) in summaries { - for (&server_id, summary) in summaries { - client - .send(proto::UpdateDiagnosticSummary { - project_id: downstream_project_id, - worktree_id: worktree.id().to_proto(), - summary: Some(summary.to_proto(server_id, path)), - }) - .log_err(); - } + if let Some(diangostic_summaries) = self.diagnostic_summaries.get(&worktree.id()) { + let mut summaries = + diangostic_summaries + .into_iter() + .flat_map(|(path, summaries)| { + summaries + .into_iter() + .map(|(server_id, summary)| summary.to_proto(*server_id, path)) + }); + if let Some(summary) = summaries.next() { + client + .send(proto::UpdateDiagnosticSummary { + project_id: downstream_project_id, + worktree_id: worktree.id().to_proto(), + summary: Some(summary), + more_summaries: summaries.collect(), + }) + .log_err(); } } } } - pub fn request_lsp<R: LspCommand>( + // TODO: remove MultiLspQuery: instead, the proto handler should pick appropriate server(s) + // Then, use `send_lsp_proto_request` or analogue for most of the LSP proto requests and inline this check inside + fn is_capable_for_proto_request<R>( + &self, + buffer: &Entity<Buffer>, + request: &R, + cx: &Context<Self>, + ) -> bool + where + R: LspCommand, + { + self.check_if_capable_for_proto_request( + buffer, + |capabilities| { + request.check_capabilities(AdapterServerCapabilities { + server_capabilities: capabilities.clone(), + code_action_kinds: None, + }) + }, + cx, + ) + } + + fn check_if_capable_for_proto_request<F>( + &self, + buffer: &Entity<Buffer>, + check: F, + cx: &Context<Self>, + ) -> bool + where + F: Fn(&lsp::ServerCapabilities) -> bool, + { + let Some(language) = buffer.read(cx).language().cloned() else { + return false; + }; + let relevant_language_servers = self + .languages + .lsp_adapters(&language.name()) + .into_iter() + .map(|lsp_adapter| lsp_adapter.name()) + .collect::<HashSet<_>>(); + self.language_server_statuses + .iter() + .filter_map(|(server_id, server_status)| { + relevant_language_servers + .contains(&server_status.name) + .then_some(server_id) + }) + .filter_map(|server_id| self.lsp_server_capabilities.get(&server_id)) + .any(check) + } + + pub fn request_lsp<R>( &mut self, - buffer_handle: Entity<Buffer>, + buffer: Entity<Buffer>, server: LanguageServerToQuery, request: R, cx: &mut Context<Self>, ) -> Task<Result<R::Response>> where + R: LspCommand, <R::LspRequest as lsp::request::Request>::Result: Send, <R::LspRequest as lsp::request::Request>::Params: Send, { if let Some((upstream_client, upstream_project_id)) = self.upstream_client() { return self.send_lsp_proto_request( - buffer_handle, + buffer, upstream_client, upstream_project_id, request, @@ -4455,7 +4555,7 @@ impl LspStore { ); } - let Some(language_server) = buffer_handle.update(cx, |buffer, cx| match server { + let Some(language_server) = buffer.update(cx, |buffer, cx| match server { LanguageServerToQuery::FirstCapable => self.as_local().and_then(|local| { local .language_servers_for_buffer(buffer, cx) @@ -4475,8 +4575,7 @@ impl LspStore { return Task::ready(Ok(Default::default())); }; - let buffer = buffer_handle.read(cx); - let file = File::from_dyn(buffer.file()).and_then(File::as_local); + let file = File::from_dyn(buffer.read(cx).file()).and_then(File::as_local); let Some(file) = file else { return Task::ready(Ok(Default::default())); @@ -4484,7 +4583,7 @@ impl LspStore { let lsp_params = match request.to_lsp_params_or_response( &file.abs_path(cx), - buffer, + buffer.read(cx), &language_server, cx, ) { @@ -4560,7 +4659,7 @@ impl LspStore { .response_from_lsp( response, this.upgrade().context("no app context")?, - buffer_handle, + buffer, language_server.server_id(), cx.clone(), ) @@ -4630,7 +4729,8 @@ impl LspStore { ) }) { let buffer = buffer_handle.read(cx); - if !local.registered_buffers.contains_key(&buffer.remote_id()) { + let buffer_id = buffer.remote_id(); + if !local.registered_buffers.contains_key(&buffer_id) { continue; } if let Some((file, language)) = File::from_dyn(buffer.file()) @@ -4688,35 +4788,11 @@ impl LspStore { let server_id = node.server_id_or_init( |LaunchDisposition { server_name, - attach, + path, settings, - }| match attach { - language::Attach::InstancePerRoot => { - // todo: handle instance per root proper. - if let Some(server_ids) = local - .language_server_ids - .get(&(worktree_id, server_name.clone())) - { - server_ids.iter().cloned().next().unwrap() - } else { - let adapter = local - .languages - .lsp_adapters(&language) - .into_iter() - .find(|adapter| &adapter.name() == server_name) - .expect("To find LSP adapter"); - let server_id = local.start_language_server( - &worktree, - delegate.clone(), - adapter, - settings, - cx, - ); - server_id - } - } - language::Attach::Shared => { + }| + { let uri = Url::from_file_path( worktree.read(cx).abs_path().join(&path.path), ); @@ -4745,7 +4821,6 @@ impl LspStore { } server_id } - }, ); if let Some(language_server_id) = server_id { @@ -4756,6 +4831,7 @@ impl LspStore { proto::update_language_server::Variant::RegisteredForBuffer( proto::RegisteredForBuffer { buffer_abs_path: abs_path.to_string_lossy().to_string(), + buffer_id: buffer_id.to_proto(), }, ), }); @@ -4931,19 +5007,24 @@ impl LspStore { pub fn resolve_inlay_hint( &self, - hint: InlayHint, - buffer_handle: Entity<Buffer>, + mut hint: InlayHint, + buffer: Entity<Buffer>, server_id: LanguageServerId, cx: &mut Context<Self>, ) -> Task<anyhow::Result<InlayHint>> { if let Some((upstream_client, project_id)) = self.upstream_client() { + if !self.check_if_capable_for_proto_request(&buffer, InlayHints::can_resolve_inlays, cx) + { + hint.resolve_state = ResolveState::Resolved; + return Task::ready(Ok(hint)); + } let request = proto::ResolveInlayHint { project_id, - buffer_id: buffer_handle.read(cx).remote_id().into(), + buffer_id: buffer.read(cx).remote_id().into(), language_server_id: server_id.0 as u64, hint: Some(InlayHints::project_to_proto_hint(hint.clone())), }; - cx.spawn(async move |_, _| { + cx.background_spawn(async move { let response = upstream_client .request(request) .await @@ -4955,7 +5036,7 @@ impl LspStore { } }) } else { - let Some(lang_server) = buffer_handle.update(cx, |buffer, cx| { + let Some(lang_server) = buffer.update(cx, |buffer, cx| { self.language_server_for_local_buffer(buffer, server_id, cx) .map(|(_, server)| server.clone()) }) else { @@ -4964,7 +5045,7 @@ impl LspStore { if !InlayHints::can_resolve_inlays(&lang_server.capabilities()) { return Task::ready(Ok(hint)); } - let buffer_snapshot = buffer_handle.read(cx).snapshot(); + let buffer_snapshot = buffer.read(cx).snapshot(); cx.spawn(async move |_, cx| { let resolve_task = lang_server.request::<lsp::request::InlayHintResolveRequest>( InlayHints::project_to_lsp_hint(hint, &buffer_snapshot), @@ -4975,7 +5056,7 @@ impl LspStore { .context("inlay hint resolve LSP request")?; let resolved_hint = InlayHints::lsp_to_project_hint( resolved_hint, - &buffer_handle, + &buffer, server_id, ResolveState::Resolved, false, @@ -5086,7 +5167,7 @@ impl LspStore { } } - pub(crate) fn linked_edit( + pub(crate) fn linked_edits( &mut self, buffer: &Entity<Buffer>, position: Anchor, @@ -5101,10 +5182,7 @@ impl LspStore { local .language_servers_for_buffer(buffer, cx) .filter(|(_, server)| { - server - .capabilities() - .linked_editing_range_provider - .is_some() + LinkedEditingRange::check_server_capabilities(server.capabilities()) }) .filter(|(adapter, _)| { scope @@ -5131,7 +5209,7 @@ impl LspStore { }) == Some(true) }) else { - return Task::ready(Ok(vec![])); + return Task::ready(Ok(Vec::new())); }; self.request_lsp( @@ -5150,6 +5228,15 @@ impl LspStore { cx: &mut Context<Self>, ) -> Task<Result<Option<Transaction>>> { if let Some((client, project_id)) = self.upstream_client() { + if !self.check_if_capable_for_proto_request( + &buffer, + |capabilities| { + OnTypeFormatting::supports_on_type_formatting(&trigger, capabilities) + }, + cx, + ) { + return Task::ready(Ok(None)); + } let request = proto::OnTypeFormatting { project_id, buffer_id: buffer.read(cx).remote_id().into(), @@ -5157,7 +5244,7 @@ impl LspStore { trigger, version: serialize_version(&buffer.read(cx).version()), }; - cx.spawn(async move |_, _| { + cx.background_spawn(async move { client .request(request) .await? @@ -5261,6 +5348,10 @@ impl LspStore { cx: &mut Context<Self>, ) -> Task<Result<Vec<LocationLink>>> { if let Some((upstream_client, project_id)) = self.upstream_client() { + let request = GetDefinitions { position }; + if !self.is_capable_for_proto_request(buffer_handle, &request, cx) { + return Task::ready(Ok(Vec::new())); + } let request_task = upstream_client.request(proto::MultiLspQuery { buffer_id: buffer_handle.read(cx).remote_id().into(), version: serialize_version(&buffer_handle.read(cx).version()), @@ -5269,7 +5360,7 @@ impl LspStore { proto::AllLanguageServers {}, )), request: Some(proto::multi_lsp_query::Request::GetDefinition( - GetDefinitions { position }.to_proto(project_id, buffer_handle.read(cx)), + request.to_proto(project_id, buffer_handle.read(cx)), )), }); let buffer = buffer_handle.clone(); @@ -5316,7 +5407,7 @@ impl LspStore { GetDefinitions { position }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { Ok(definitions_task .await .into_iter() @@ -5334,6 +5425,10 @@ impl LspStore { cx: &mut Context<Self>, ) -> Task<Result<Vec<LocationLink>>> { if let Some((upstream_client, project_id)) = self.upstream_client() { + let request = GetDeclarations { position }; + if !self.is_capable_for_proto_request(buffer_handle, &request, cx) { + return Task::ready(Ok(Vec::new())); + } let request_task = upstream_client.request(proto::MultiLspQuery { buffer_id: buffer_handle.read(cx).remote_id().into(), version: serialize_version(&buffer_handle.read(cx).version()), @@ -5342,7 +5437,7 @@ impl LspStore { proto::AllLanguageServers {}, )), request: Some(proto::multi_lsp_query::Request::GetDeclaration( - GetDeclarations { position }.to_proto(project_id, buffer_handle.read(cx)), + request.to_proto(project_id, buffer_handle.read(cx)), )), }); let buffer = buffer_handle.clone(); @@ -5389,7 +5484,7 @@ impl LspStore { GetDeclarations { position }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { Ok(declarations_task .await .into_iter() @@ -5402,23 +5497,27 @@ impl LspStore { pub fn type_definitions( &mut self, - buffer_handle: &Entity<Buffer>, + buffer: &Entity<Buffer>, position: PointUtf16, cx: &mut Context<Self>, ) -> Task<Result<Vec<LocationLink>>> { if let Some((upstream_client, project_id)) = self.upstream_client() { + let request = GetTypeDefinitions { position }; + if !self.is_capable_for_proto_request(&buffer, &request, cx) { + return Task::ready(Ok(Vec::new())); + } let request_task = upstream_client.request(proto::MultiLspQuery { - buffer_id: buffer_handle.read(cx).remote_id().into(), - version: serialize_version(&buffer_handle.read(cx).version()), + buffer_id: buffer.read(cx).remote_id().into(), + version: serialize_version(&buffer.read(cx).version()), project_id, strategy: Some(proto::multi_lsp_query::Strategy::All( proto::AllLanguageServers {}, )), request: Some(proto::multi_lsp_query::Request::GetTypeDefinition( - GetTypeDefinitions { position }.to_proto(project_id, buffer_handle.read(cx)), + request.to_proto(project_id, buffer.read(cx)), )), }); - let buffer = buffer_handle.clone(); + let buffer = buffer.clone(); cx.spawn(async move |weak_project, cx| { let Some(project) = weak_project.upgrade() else { return Ok(Vec::new()); @@ -5457,12 +5556,12 @@ impl LspStore { }) } else { let type_definitions_task = self.request_multiple_lsp_locally( - buffer_handle, + buffer, Some(position), GetTypeDefinitions { position }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { Ok(type_definitions_task .await .into_iter() @@ -5475,23 +5574,27 @@ impl LspStore { pub fn implementations( &mut self, - buffer_handle: &Entity<Buffer>, + buffer: &Entity<Buffer>, position: PointUtf16, cx: &mut Context<Self>, ) -> Task<Result<Vec<LocationLink>>> { if let Some((upstream_client, project_id)) = self.upstream_client() { + let request = GetImplementations { position }; + if !self.is_capable_for_proto_request(buffer, &request, cx) { + return Task::ready(Ok(Vec::new())); + } let request_task = upstream_client.request(proto::MultiLspQuery { - buffer_id: buffer_handle.read(cx).remote_id().into(), - version: serialize_version(&buffer_handle.read(cx).version()), + buffer_id: buffer.read(cx).remote_id().into(), + version: serialize_version(&buffer.read(cx).version()), project_id, strategy: Some(proto::multi_lsp_query::Strategy::All( proto::AllLanguageServers {}, )), request: Some(proto::multi_lsp_query::Request::GetImplementation( - GetImplementations { position }.to_proto(project_id, buffer_handle.read(cx)), + request.to_proto(project_id, buffer.read(cx)), )), }); - let buffer = buffer_handle.clone(); + let buffer = buffer.clone(); cx.spawn(async move |weak_project, cx| { let Some(project) = weak_project.upgrade() else { return Ok(Vec::new()); @@ -5530,12 +5633,12 @@ impl LspStore { }) } else { let implementations_task = self.request_multiple_lsp_locally( - buffer_handle, + buffer, Some(position), GetImplementations { position }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { Ok(implementations_task .await .into_iter() @@ -5548,23 +5651,27 @@ impl LspStore { pub fn references( &mut self, - buffer_handle: &Entity<Buffer>, + buffer: &Entity<Buffer>, position: PointUtf16, cx: &mut Context<Self>, ) -> Task<Result<Vec<Location>>> { if let Some((upstream_client, project_id)) = self.upstream_client() { + let request = GetReferences { position }; + if !self.is_capable_for_proto_request(&buffer, &request, cx) { + return Task::ready(Ok(Vec::new())); + } let request_task = upstream_client.request(proto::MultiLspQuery { - buffer_id: buffer_handle.read(cx).remote_id().into(), - version: serialize_version(&buffer_handle.read(cx).version()), + buffer_id: buffer.read(cx).remote_id().into(), + version: serialize_version(&buffer.read(cx).version()), project_id, strategy: Some(proto::multi_lsp_query::Strategy::All( proto::AllLanguageServers {}, )), request: Some(proto::multi_lsp_query::Request::GetReferences( - GetReferences { position }.to_proto(project_id, buffer_handle.read(cx)), + request.to_proto(project_id, buffer.read(cx)), )), }); - let buffer = buffer_handle.clone(); + let buffer = buffer.clone(); cx.spawn(async move |weak_project, cx| { let Some(project) = weak_project.upgrade() else { return Ok(Vec::new()); @@ -5603,12 +5710,12 @@ impl LspStore { }) } else { let references_task = self.request_multiple_lsp_locally( - buffer_handle, + buffer, Some(position), GetReferences { position }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { Ok(references_task .await .into_iter() @@ -5621,28 +5728,31 @@ impl LspStore { pub fn code_actions( &mut self, - buffer_handle: &Entity<Buffer>, + buffer: &Entity<Buffer>, range: Range<Anchor>, kinds: Option<Vec<CodeActionKind>>, cx: &mut Context<Self>, ) -> Task<Result<Vec<CodeAction>>> { if let Some((upstream_client, project_id)) = self.upstream_client() { + let request = GetCodeActions { + range: range.clone(), + kinds: kinds.clone(), + }; + if !self.is_capable_for_proto_request(buffer, &request, cx) { + return Task::ready(Ok(Vec::new())); + } let request_task = upstream_client.request(proto::MultiLspQuery { - buffer_id: buffer_handle.read(cx).remote_id().into(), - version: serialize_version(&buffer_handle.read(cx).version()), + buffer_id: buffer.read(cx).remote_id().into(), + version: serialize_version(&buffer.read(cx).version()), project_id, strategy: Some(proto::multi_lsp_query::Strategy::All( proto::AllLanguageServers {}, )), request: Some(proto::multi_lsp_query::Request::GetCodeActions( - GetCodeActions { - range: range.clone(), - kinds: kinds.clone(), - } - .to_proto(project_id, buffer_handle.read(cx)), + request.to_proto(project_id, buffer.read(cx)), )), }); - let buffer = buffer_handle.clone(); + let buffer = buffer.clone(); cx.spawn(async move |weak_project, cx| { let Some(project) = weak_project.upgrade() else { return Ok(Vec::new()); @@ -5684,7 +5794,7 @@ impl LspStore { }) } else { let all_actions_task = self.request_multiple_lsp_locally( - buffer_handle, + buffer, Some(range.start), GetCodeActions { range: range.clone(), @@ -5692,7 +5802,7 @@ impl LspStore { }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { Ok(all_actions_task .await .into_iter() @@ -5702,69 +5812,172 @@ impl LspStore { } } - pub fn code_lens( + pub fn code_lens_actions( &mut self, - buffer_handle: &Entity<Buffer>, + buffer: &Entity<Buffer>, cx: &mut Context<Self>, - ) -> Task<Result<Vec<CodeAction>>> { + ) -> CodeLensTask { + let version_queried_for = buffer.read(cx).version(); + let buffer_id = buffer.read(cx).remote_id(); + + if let Some(cached_data) = self.lsp_code_lens.get(&buffer_id) { + if !version_queried_for.changed_since(&cached_data.lens_for_version) { + let has_different_servers = self.as_local().is_some_and(|local| { + local + .buffers_opened_in_servers + .get(&buffer_id) + .cloned() + .unwrap_or_default() + != cached_data.lens.keys().copied().collect() + }); + if !has_different_servers { + return Task::ready(Ok(cached_data.lens.values().flatten().cloned().collect())) + .shared(); + } + } + } + + let lsp_data = self.lsp_code_lens.entry(buffer_id).or_default(); + if let Some((updating_for, running_update)) = &lsp_data.update { + if !version_queried_for.changed_since(&updating_for) { + return running_update.clone(); + } + } + let buffer = buffer.clone(); + let query_version_queried_for = version_queried_for.clone(); + let new_task = cx + .spawn(async move |lsp_store, cx| { + cx.background_executor() + .timer(Duration::from_millis(30)) + .await; + let fetched_lens = lsp_store + .update(cx, |lsp_store, cx| lsp_store.fetch_code_lens(&buffer, cx)) + .map_err(Arc::new)? + .await + .context("fetching code lens") + .map_err(Arc::new); + let fetched_lens = match fetched_lens { + Ok(fetched_lens) => fetched_lens, + Err(e) => { + lsp_store + .update(cx, |lsp_store, _| { + lsp_store.lsp_code_lens.entry(buffer_id).or_default().update = None; + }) + .ok(); + return Err(e); + } + }; + + lsp_store + .update(cx, |lsp_store, _| { + let lsp_data = lsp_store.lsp_code_lens.entry(buffer_id).or_default(); + if lsp_data.lens_for_version == query_version_queried_for { + lsp_data.lens.extend(fetched_lens.clone()); + } else if !lsp_data + .lens_for_version + .changed_since(&query_version_queried_for) + { + lsp_data.lens_for_version = query_version_queried_for; + lsp_data.lens = fetched_lens.clone(); + } + lsp_data.update = None; + lsp_data.lens.values().flatten().cloned().collect() + }) + .map_err(Arc::new) + }) + .shared(); + lsp_data.update = Some((version_queried_for, new_task.clone())); + new_task + } + + fn fetch_code_lens( + &mut self, + buffer: &Entity<Buffer>, + cx: &mut Context<Self>, + ) -> Task<Result<HashMap<LanguageServerId, Vec<CodeAction>>>> { if let Some((upstream_client, project_id)) = self.upstream_client() { + let request = GetCodeLens; + if !self.is_capable_for_proto_request(buffer, &request, cx) { + return Task::ready(Ok(HashMap::default())); + } let request_task = upstream_client.request(proto::MultiLspQuery { - buffer_id: buffer_handle.read(cx).remote_id().into(), - version: serialize_version(&buffer_handle.read(cx).version()), + buffer_id: buffer.read(cx).remote_id().into(), + version: serialize_version(&buffer.read(cx).version()), project_id, strategy: Some(proto::multi_lsp_query::Strategy::All( proto::AllLanguageServers {}, )), request: Some(proto::multi_lsp_query::Request::GetCodeLens( - GetCodeLens.to_proto(project_id, buffer_handle.read(cx)), + request.to_proto(project_id, buffer.read(cx)), )), }); - let buffer = buffer_handle.clone(); - cx.spawn(async move |weak_project, cx| { - let Some(project) = weak_project.upgrade() else { - return Ok(Vec::new()); + let buffer = buffer.clone(); + cx.spawn(async move |weak_lsp_store, cx| { + let Some(lsp_store) = weak_lsp_store.upgrade() else { + return Ok(HashMap::default()); }; let responses = request_task.await?.responses; - let code_lens = join_all( + let code_lens_actions = join_all( responses .into_iter() - .filter_map(|lsp_response| match lsp_response.response? { - proto::lsp_response::Response::GetCodeLensResponse(response) => { - Some(response) - } - unexpected => { - debug_panic!("Unexpected response: {unexpected:?}"); - None - } + .filter_map(|lsp_response| { + let response = match lsp_response.response? { + proto::lsp_response::Response::GetCodeLensResponse(response) => { + Some(response) + } + unexpected => { + debug_panic!("Unexpected response: {unexpected:?}"); + None + } + }?; + let server_id = LanguageServerId::from_proto(lsp_response.server_id); + Some((server_id, response)) }) - .map(|code_lens_response| { - GetCodeLens.response_from_proto( - code_lens_response, - project.clone(), - buffer.clone(), - cx.clone(), - ) + .map(|(server_id, code_lens_response)| { + let lsp_store = lsp_store.clone(); + let buffer = buffer.clone(); + let cx = cx.clone(); + async move { + ( + server_id, + GetCodeLens + .response_from_proto( + code_lens_response, + lsp_store, + buffer, + cx, + ) + .await, + ) + } }), ) .await; - Ok(code_lens - .into_iter() - .collect::<Result<Vec<Vec<_>>>>()? + let mut has_errors = false; + let code_lens_actions = code_lens_actions .into_iter() - .flatten() - .collect()) + .filter_map(|(server_id, code_lens)| match code_lens { + Ok(code_lens) => Some((server_id, code_lens)), + Err(e) => { + has_errors = true; + log::error!("{e:#}"); + None + } + }) + .collect::<HashMap<_, _>>(); + anyhow::ensure!( + !has_errors || !code_lens_actions.is_empty(), + "Failed to fetch code lens" + ); + Ok(code_lens_actions) }) } else { - let code_lens_task = - self.request_multiple_lsp_locally(buffer_handle, None::<usize>, GetCodeLens, cx); - cx.spawn(async move |_, _| { - Ok(code_lens_task - .await - .into_iter() - .flat_map(|(_, code_lens)| code_lens) - .collect()) - }) + let code_lens_actions_task = + self.request_multiple_lsp_locally(buffer, None::<usize>, GetCodeLens, cx); + cx.background_spawn( + async move { Ok(code_lens_actions_task.await.into_iter().collect()) }, + ) } } @@ -5779,11 +5992,15 @@ impl LspStore { let language_registry = self.languages.clone(); if let Some((upstream_client, project_id)) = self.upstream_client() { + let request = GetCompletions { position, context }; + if !self.is_capable_for_proto_request(buffer, &request, cx) { + return Task::ready(Ok(Vec::new())); + } let task = self.send_lsp_proto_request( buffer.clone(), upstream_client, project_id, - GetCompletions { position, context }, + request, cx, ); let language = buffer.read(cx).language().cloned(); @@ -5921,11 +6138,17 @@ impl LspStore { cx: &mut Context<Self>, ) -> Task<Result<bool>> { let client = self.upstream_client(); - let buffer_id = buffer.read(cx).remote_id(); let buffer_snapshot = buffer.read(cx).snapshot(); - cx.spawn(async move |this, cx| { + if !self.check_if_capable_for_proto_request( + &buffer, + GetCompletions::can_resolve_completions, + cx, + ) { + return Task::ready(Ok(false)); + } + cx.spawn(async move |lsp_store, cx| { let mut did_resolve = false; if let Some((client, project_id)) = client { for completion_index in completion_indices { @@ -5962,7 +6185,7 @@ impl LspStore { completion.source.server_id() }; if let Some(server_id) = server_id { - let server_and_adapter = this + let server_and_adapter = lsp_store .read_with(cx, |lsp_store, _| { let server = lsp_store.language_server_for_id(server_id)?; let adapter = @@ -5977,7 +6200,6 @@ impl LspStore { let resolved = Self::resolve_completion_local( server, - &buffer_snapshot, completions.clone(), completion_index, ) @@ -6010,18 +6232,11 @@ impl LspStore { async fn resolve_completion_local( server: Arc<lsp::LanguageServer>, - snapshot: &BufferSnapshot, completions: Rc<RefCell<Box<[Completion]>>>, completion_index: usize, ) -> Result<()> { let server_id = server.server_id(); - let can_resolve = server - .capabilities() - .completion_provider - .as_ref() - .and_then(|options| options.resolve_provider) - .unwrap_or(false); - if !can_resolve { + if !GetCompletions::can_resolve_completions(&server.capabilities()) { return Ok(()); } @@ -6055,26 +6270,8 @@ impl LspStore { .into_response() .context("resolve completion")?; - if let Some(text_edit) = resolved_completion.text_edit.as_ref() { - // Technically we don't have to parse the whole `text_edit`, since the only - // language server we currently use that does update `text_edit` in `completionItem/resolve` - // is `typescript-language-server` and they only update `text_edit.new_text`. - // But we should not rely on that. - let edit = parse_completion_text_edit(text_edit, snapshot); - - if let Some(mut parsed_edit) = edit { - LineEnding::normalize(&mut parsed_edit.new_text); - - let mut completions = completions.borrow_mut(); - let completion = &mut completions[completion_index]; - - completion.new_text = parsed_edit.new_text; - completion.replace_range = parsed_edit.replace_range; - if let CompletionSource::Lsp { insert_range, .. } = &mut completion.source { - *insert_range = parsed_edit.insert_range; - } - } - } + // We must not use any data such as sortText, filterText, insertText and textEdit to edit `Completion` since they are not suppose change during resolve. + // Refer: https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#textDocument_completion let mut completions = completions.borrow_mut(); let completion = &mut completions[completion_index]; @@ -6324,12 +6521,10 @@ impl LspStore { }) else { return Task::ready(Ok(None)); }; - let snapshot = buffer_handle.read(&cx).snapshot(); cx.spawn(async move |this, cx| { Self::resolve_completion_local( server.clone(), - &snapshot, completions.clone(), completion_index, ) @@ -6392,16 +6587,24 @@ impl LspStore { pub fn pull_diagnostics( &mut self, - buffer_handle: Entity<Buffer>, + buffer: Entity<Buffer>, cx: &mut Context<Self>, - ) -> Task<Result<Vec<LspPullDiagnostics>>> { - let buffer = buffer_handle.read(cx); - let buffer_id = buffer.remote_id(); + ) -> Task<Result<Option<Vec<LspPullDiagnostics>>>> { + let buffer_id = buffer.read(cx).remote_id(); if let Some((client, upstream_project_id)) = self.upstream_client() { + if !self.is_capable_for_proto_request( + &buffer, + &GetDocumentDiagnostics { + previous_result_id: None, + }, + cx, + ) { + return Task::ready(Ok(None)); + } let request_task = client.request(proto::MultiLspQuery { buffer_id: buffer_id.to_proto(), - version: serialize_version(&buffer_handle.read(cx).version()), + version: serialize_version(&buffer.read(cx).version()), project_id: upstream_project_id, strategy: Some(proto::multi_lsp_query::Strategy::All( proto::AllLanguageServers {}, @@ -6410,12 +6613,12 @@ impl LspStore { proto::GetDocumentDiagnostics { project_id: upstream_project_id, buffer_id: buffer_id.to_proto(), - version: serialize_version(&buffer_handle.read(cx).version()), + version: serialize_version(&buffer.read(cx).version()), }, )), }); cx.background_spawn(async move { - Ok(request_task + let _proto_responses = request_task .await? .responses .into_iter() @@ -6428,11 +6631,14 @@ impl LspStore { None } }) - .flat_map(GetDocumentDiagnostics::diagnostics_from_proto) - .collect()) + .collect::<Vec<_>>(); + // Proto requests cause the diagnostics to be pulled from language server(s) on the local side + // and then, buffer state updated with the diagnostics received, which will be later propagated to the client. + // Do not attempt to further process the dummy responses here. + Ok(None) }) } else { - let server_ids = buffer_handle.update(cx, |buffer, cx| { + let server_ids = buffer.update(cx, |buffer, cx| { self.language_servers_for_local_buffer(buffer, cx) .map(|(_, server)| server.server_id()) .collect::<Vec<_>>() @@ -6442,7 +6648,7 @@ impl LspStore { .map(|server_id| { let result_id = self.result_id(server_id, buffer_id, cx); self.request_lsp( - buffer_handle.clone(), + buffer.clone(), LanguageServerToQuery::Other(server_id), GetDocumentDiagnostics { previous_result_id: result_id, @@ -6457,41 +6663,43 @@ impl LspStore { for diagnostics in join_all(pull_diagnostics).await { responses.extend(diagnostics?); } - Ok(responses) + Ok(Some(responses)) }) } } pub fn inlay_hints( &mut self, - buffer_handle: Entity<Buffer>, + buffer: Entity<Buffer>, range: Range<Anchor>, cx: &mut Context<Self>, ) -> Task<anyhow::Result<Vec<InlayHint>>> { - let buffer = buffer_handle.read(cx); let range_start = range.start; let range_end = range.end; - let buffer_id = buffer.remote_id().into(); - let lsp_request = InlayHints { range }; + let buffer_id = buffer.read(cx).remote_id().into(); + let request = InlayHints { range }; if let Some((client, project_id)) = self.upstream_client() { - let request = proto::InlayHints { + if !self.is_capable_for_proto_request(&buffer, &request, cx) { + return Task::ready(Ok(Vec::new())); + } + let proto_request = proto::InlayHints { project_id, buffer_id, start: Some(serialize_anchor(&range_start)), end: Some(serialize_anchor(&range_end)), - version: serialize_version(&buffer_handle.read(cx).version()), + version: serialize_version(&buffer.read(cx).version()), }; cx.spawn(async move |project, cx| { let response = client - .request(request) + .request(proto_request) .await .context("inlay hints proto request")?; LspCommand::response_from_proto( - lsp_request, + request, response, project.upgrade().context("No project")?, - buffer_handle.clone(), + buffer.clone(), cx.clone(), ) .await @@ -6499,13 +6707,13 @@ impl LspStore { }) } else { let lsp_request_task = self.request_lsp( - buffer_handle.clone(), + buffer.clone(), LanguageServerToQuery::FirstCapable, - lsp_request, + request, cx, ); cx.spawn(async move |_, cx| { - buffer_handle + buffer .update(cx, |buffer, _| { buffer.wait_for_edits(vec![range_start.timestamp, range_end.timestamp]) })? @@ -6521,75 +6729,93 @@ impl LspStore { buffer: Entity<Buffer>, cx: &mut Context<Self>, ) -> Task<anyhow::Result<()>> { - let buffer_id = buffer.read(cx).remote_id(); let diagnostics = self.pull_diagnostics(buffer, cx); cx.spawn(async move |lsp_store, cx| { - let diagnostics = diagnostics.await.context("pulling diagnostics")?; + let Some(diagnostics) = diagnostics.await.context("pulling diagnostics")? else { + return Ok(()); + }; lsp_store.update(cx, |lsp_store, cx| { if lsp_store.as_local().is_none() { return; } - for diagnostics_set in diagnostics { - let LspPullDiagnostics::Response { - server_id, - uri, - diagnostics, - } = diagnostics_set - else { - continue; - }; - - let adapter = lsp_store.language_server_adapter_for_id(server_id); - let disk_based_sources = adapter - .as_ref() - .map(|adapter| adapter.disk_based_diagnostic_sources.as_slice()) - .unwrap_or(&[]); - match diagnostics { - PulledDiagnostics::Unchanged { result_id } => { - lsp_store - .merge_diagnostics( - server_id, - lsp::PublishDiagnosticsParams { - uri: uri.clone(), - diagnostics: Vec::new(), - version: None, - }, - Some(result_id), - DiagnosticSourceKind::Pulled, - disk_based_sources, - |_, _, _| true, - cx, - ) - .log_err(); - } - PulledDiagnostics::Changed { + let mut unchanged_buffers = HashSet::default(); + let mut changed_buffers = HashSet::default(); + let server_diagnostics_updates = diagnostics + .into_iter() + .filter_map(|diagnostics_set| match diagnostics_set { + LspPullDiagnostics::Response { + server_id, + uri, diagnostics, - result_id, - } => { - lsp_store - .merge_diagnostics( + } => Some((server_id, uri, diagnostics)), + LspPullDiagnostics::Default => None, + }) + .fold( + HashMap::default(), + |mut acc, (server_id, uri, diagnostics)| { + let (result_id, diagnostics) = match diagnostics { + PulledDiagnostics::Unchanged { result_id } => { + unchanged_buffers.insert(uri.clone()); + (Some(result_id), Vec::new()) + } + PulledDiagnostics::Changed { + result_id, + diagnostics, + } => { + changed_buffers.insert(uri.clone()); + (result_id, diagnostics) + } + }; + let disk_based_sources = Cow::Owned( + lsp_store + .language_server_adapter_for_id(server_id) + .as_ref() + .map(|adapter| adapter.disk_based_diagnostic_sources.as_slice()) + .unwrap_or(&[]) + .to_vec(), + ); + acc.entry(server_id).or_insert_with(Vec::new).push( + DocumentDiagnosticsUpdate { server_id, - lsp::PublishDiagnosticsParams { - uri: uri.clone(), + diagnostics: lsp::PublishDiagnosticsParams { + uri, diagnostics, version: None, }, result_id, - DiagnosticSourceKind::Pulled, disk_based_sources, - |buffer, old_diagnostic, _| match old_diagnostic.source_kind { - DiagnosticSourceKind::Pulled => { - buffer.remote_id() != buffer_id - } - DiagnosticSourceKind::Other - | DiagnosticSourceKind::Pushed => true, - }, - cx, - ) - .log_err(); - } - } + }, + ); + acc + }, + ); + + for diagnostic_updates in server_diagnostics_updates.into_values() { + lsp_store + .merge_lsp_diagnostics( + DiagnosticSourceKind::Pulled, + diagnostic_updates, + |buffer, old_diagnostic, cx| { + File::from_dyn(buffer.file()) + .and_then(|file| { + let abs_path = file.as_local()?.abs_path(cx); + lsp::Url::from_file_path(abs_path).ok() + }) + .is_none_or(|buffer_uri| { + unchanged_buffers.contains(&buffer_uri) + || match old_diagnostic.source_kind { + DiagnosticSourceKind::Pulled => { + !changed_buffers.contains(&buffer_uri) + } + DiagnosticSourceKind::Other + | DiagnosticSourceKind::Pushed => true, + } + }) + }, + cx, + ) + .log_err(); } }) }) @@ -6597,7 +6823,7 @@ impl LspStore { pub fn document_colors( &mut self, - fetch_strategy: ColorFetchStrategy, + fetch_strategy: LspFetchStrategy, buffer: Entity<Buffer>, cx: &mut Context<Self>, ) -> Option<DocumentColorTask> { @@ -6605,11 +6831,11 @@ impl LspStore { let buffer_id = buffer.read(cx).remote_id(); match fetch_strategy { - ColorFetchStrategy::IgnoreCache => {} - ColorFetchStrategy::UseCache { + LspFetchStrategy::IgnoreCache => {} + LspFetchStrategy::UseCache { known_cache_version, } => { - if let Some(cached_data) = self.lsp_data.get(&buffer_id) { + if let Some(cached_data) = self.lsp_document_colors.get(&buffer_id) { if !version_queried_for.changed_since(&cached_data.colors_for_version) { let has_different_servers = self.as_local().is_some_and(|local| { local @@ -6642,7 +6868,7 @@ impl LspStore { } } - let lsp_data = self.lsp_data.entry(buffer_id).or_default(); + let lsp_data = self.lsp_document_colors.entry(buffer_id).or_default(); if let Some((updating_for, running_update)) = &lsp_data.colors_update { if !version_queried_for.changed_since(&updating_for) { return Some(running_update.clone()); @@ -6656,14 +6882,14 @@ impl LspStore { .await; let fetched_colors = lsp_store .update(cx, |lsp_store, cx| { - lsp_store.fetch_document_colors_for_buffer(buffer.clone(), cx) + lsp_store.fetch_document_colors_for_buffer(&buffer, cx) })? .await .context("fetching document colors") .map_err(Arc::new); let fetched_colors = match fetched_colors { Ok(fetched_colors) => { - if fetch_strategy != ColorFetchStrategy::IgnoreCache + if fetch_strategy != LspFetchStrategy::IgnoreCache && Some(true) == buffer .update(cx, |buffer, _| { @@ -6679,7 +6905,7 @@ impl LspStore { lsp_store .update(cx, |lsp_store, _| { lsp_store - .lsp_data + .lsp_document_colors .entry(buffer_id) .or_default() .colors_update = None; @@ -6691,7 +6917,7 @@ impl LspStore { lsp_store .update(cx, |lsp_store, _| { - let lsp_data = lsp_store.lsp_data.entry(buffer_id).or_default(); + let lsp_data = lsp_store.lsp_document_colors.entry(buffer_id).or_default(); if lsp_data.colors_for_version == query_version_queried_for { lsp_data.colors.extend(fetched_colors.clone()); @@ -6725,10 +6951,15 @@ impl LspStore { fn fetch_document_colors_for_buffer( &mut self, - buffer: Entity<Buffer>, + buffer: &Entity<Buffer>, cx: &mut Context<Self>, ) -> Task<anyhow::Result<HashMap<LanguageServerId, HashSet<DocumentColor>>>> { if let Some((client, project_id)) = self.upstream_client() { + let request = GetDocumentColor {}; + if !self.is_capable_for_proto_request(buffer, &request, cx) { + return Task::ready(Ok(HashMap::default())); + } + let request_task = client.request(proto::MultiLspQuery { project_id, buffer_id: buffer.read(cx).remote_id().to_proto(), @@ -6737,9 +6968,10 @@ impl LspStore { proto::AllLanguageServers {}, )), request: Some(proto::multi_lsp_query::Request::GetDocumentColor( - GetDocumentColor {}.to_proto(project_id, buffer.read(cx)), + request.to_proto(project_id, buffer.read(cx)), )), }); + let buffer = buffer.clone(); cx.spawn(async move |project, cx| { let Some(project) = project.upgrade() else { return Ok(HashMap::default()); @@ -6764,7 +6996,7 @@ impl LspStore { } }) .map(|(server_id, color_response)| { - let response = GetDocumentColor {}.response_from_proto( + let response = request.response_from_proto( color_response, project.clone(), buffer.clone(), @@ -6785,8 +7017,8 @@ impl LspStore { }) } else { let document_colors_task = - self.request_multiple_lsp_locally(&buffer, None::<usize>, GetDocumentColor, cx); - cx.spawn(async move |_, _| { + self.request_multiple_lsp_locally(buffer, None::<usize>, GetDocumentColor, cx); + cx.background_spawn(async move { Ok(document_colors_task .await .into_iter() @@ -6811,6 +7043,10 @@ impl LspStore { let position = position.to_point_utf16(buffer.read(cx)); if let Some((client, upstream_project_id)) = self.upstream_client() { + let request = GetSignatureHelp { position }; + if !self.is_capable_for_proto_request(buffer, &request, cx) { + return Task::ready(Vec::new()); + } let request_task = client.request(proto::MultiLspQuery { buffer_id: buffer.read(cx).remote_id().into(), version: serialize_version(&buffer.read(cx).version()), @@ -6819,7 +7055,7 @@ impl LspStore { proto::AllLanguageServers {}, )), request: Some(proto::multi_lsp_query::Request::GetSignatureHelp( - GetSignatureHelp { position }.to_proto(upstream_project_id, buffer.read(cx)), + request.to_proto(upstream_project_id, buffer.read(cx)), )), }); let buffer = buffer.clone(); @@ -6865,7 +7101,7 @@ impl LspStore { GetSignatureHelp { position }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { all_actions_task .await .into_iter() @@ -6882,6 +7118,10 @@ impl LspStore { cx: &mut Context<Self>, ) -> Task<Vec<Hover>> { if let Some((client, upstream_project_id)) = self.upstream_client() { + let request = GetHover { position }; + if !self.is_capable_for_proto_request(buffer, &request, cx) { + return Task::ready(Vec::new()); + } let request_task = client.request(proto::MultiLspQuery { buffer_id: buffer.read(cx).remote_id().into(), version: serialize_version(&buffer.read(cx).version()), @@ -6890,7 +7130,7 @@ impl LspStore { proto::AllLanguageServers {}, )), request: Some(proto::multi_lsp_query::Request::GetHover( - GetHover { position }.to_proto(upstream_project_id, buffer.read(cx)), + request.to_proto(upstream_project_id, buffer.read(cx)), )), }); let buffer = buffer.clone(); @@ -6942,7 +7182,7 @@ impl LspStore { GetHover { position }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { all_actions_task .await .into_iter() @@ -7210,7 +7450,9 @@ impl LspStore { let build_incremental_change = || { buffer - .edits_since::<(PointUtf16, usize)>(previous_snapshot.snapshot.version()) + .edits_since::<Dimensions<PointUtf16, usize>>( + previous_snapshot.snapshot.version(), + ) .map(|edit| { let edit_start = edit.new.start.0; let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0); @@ -7325,21 +7567,23 @@ impl LspStore { } pub(crate) async fn refresh_workspace_configurations( - this: &WeakEntity<Self>, + lsp_store: &WeakEntity<Self>, fs: Arc<dyn Fs>, cx: &mut AsyncApp, ) { maybe!(async move { - let servers = this - .update(cx, |this, cx| { - let Some(local) = this.as_local() else { + let mut refreshed_servers = HashSet::default(); + let servers = lsp_store + .update(cx, |lsp_store, cx| { + let toolchain_store = lsp_store.toolchain_store(cx); + let Some(local) = lsp_store.as_local() else { return Vec::default(); }; local .language_server_ids .iter() .flat_map(|((worktree_id, _), server_ids)| { - let worktree = this + let worktree = lsp_store .worktree_store .read(cx) .worktree_for_id(*worktree_id, cx); @@ -7355,43 +7599,54 @@ impl LspStore { ) }); - server_ids.iter().filter_map(move |server_id| { + let fs = fs.clone(); + let toolchain_store = toolchain_store.clone(); + server_ids.iter().filter_map(|server_id| { + let delegate = delegate.clone()? as Arc<dyn LspAdapterDelegate>; let states = local.language_servers.get(server_id)?; match states { LanguageServerState::Starting { .. } => None, LanguageServerState::Running { adapter, server, .. - } => Some(( - adapter.adapter.clone(), - server.clone(), - delegate.clone()? as Arc<dyn LspAdapterDelegate>, - )), + } => { + let fs = fs.clone(); + let toolchain_store = toolchain_store.clone(); + let adapter = adapter.clone(); + let server = server.clone(); + refreshed_servers.insert(server.name()); + Some(cx.spawn(async move |_, cx| { + let settings = + LocalLspStore::workspace_configuration_for_adapter( + adapter.adapter.clone(), + fs.as_ref(), + &delegate, + toolchain_store, + cx, + ) + .await + .ok()?; + server + .notify::<lsp::notification::DidChangeConfiguration>( + &lsp::DidChangeConfigurationParams { settings }, + ) + .ok()?; + Some(()) + })) + } } - }) + }).collect::<Vec<_>>() }) .collect::<Vec<_>>() }) .ok()?; - let toolchain_store = this.update(cx, |this, cx| this.toolchain_store(cx)).ok()?; - for (adapter, server, delegate) in servers { - let settings = LocalLspStore::workspace_configuration_for_adapter( - adapter, - fs.as_ref(), - &delegate, - toolchain_store.clone(), - cx, - ) - .await - .ok()?; - - server - .notify::<lsp::notification::DidChangeConfiguration>( - &lsp::DidChangeConfigurationParams { settings }, - ) - .ok(); - } + log::info!("Refreshing workspace configurations for servers {refreshed_servers:?}"); + // TODO this asynchronous job runs concurrently with extension (de)registration and may take enough time for a certain extension + // to stop and unregister its language server wrapper. + // This is racy : an extension might have already removed all `local.language_servers` state, but here we `.clone()` and hold onto it anyway. + // This now causes errors in the logs, we should find a way to remove such servers from the processing everywhere. + let _: Vec<Option<()>> = join_all(servers).await; Some(()) }) .await; @@ -7480,16 +7735,20 @@ impl LspStore { self.downstream_client = Some((downstream_client.clone(), project_id)); for (server_id, status) in &self.language_server_statuses { - downstream_client - .send(proto::StartLanguageServer { - project_id, - server: Some(proto::LanguageServer { - id: server_id.0 as u64, - name: status.name.clone(), - worktree_id: None, - }), - }) - .log_err(); + if let Some(server) = self.language_server_for_id(*server_id) { + downstream_client + .send(proto::StartLanguageServer { + project_id, + server: Some(proto::LanguageServer { + id: server_id.to_proto(), + name: status.name.to_string(), + worktree_id: None, + }), + capabilities: serde_json::to_string(&server.capabilities()) + .expect("serializing server LSP capabilities"), + }) + .log_err(); + } } } @@ -7516,7 +7775,7 @@ impl LspStore { ( LanguageServerId(server.id as usize), LanguageServerStatus { - name: server.name, + name: LanguageServerName::from_proto(server.name), pending_work: Default::default(), has_pending_diagnostic_updates: false, progress_tokens: Default::default(), @@ -7578,87 +7837,135 @@ impl LspStore { cx: &mut Context<Self>, ) -> anyhow::Result<()> { self.merge_diagnostic_entries( - server_id, - abs_path, - result_id, - version, - diagnostics, + vec![DocumentDiagnosticsUpdate { + diagnostics: DocumentDiagnostics { + diagnostics, + document_abs_path: abs_path, + version, + }, + result_id, + server_id, + disk_based_sources: Cow::Borrowed(&[]), + }], |_, _, _| false, cx, - ) + )?; + Ok(()) } - pub fn merge_diagnostic_entries( + pub fn merge_diagnostic_entries<'a>( &mut self, - server_id: LanguageServerId, - abs_path: PathBuf, - result_id: Option<String>, - version: Option<i32>, - mut diagnostics: Vec<DiagnosticEntry<Unclipped<PointUtf16>>>, - filter: impl Fn(&Buffer, &Diagnostic, &App) -> bool + Clone, + diagnostic_updates: Vec<DocumentDiagnosticsUpdate<'a, DocumentDiagnostics>>, + merge: impl Fn(&Buffer, &Diagnostic, &App) -> bool + Clone, cx: &mut Context<Self>, ) -> anyhow::Result<()> { - let Some((worktree, relative_path)) = - self.worktree_store.read(cx).find_worktree(&abs_path, cx) - else { - log::warn!("skipping diagnostics update, no worktree found for path {abs_path:?}"); - return Ok(()); - }; + let mut diagnostics_summary = None::<proto::UpdateDiagnosticSummary>; + let mut updated_diagnostics_paths = HashMap::default(); + for mut update in diagnostic_updates { + let abs_path = &update.diagnostics.document_abs_path; + let server_id = update.server_id; + let Some((worktree, relative_path)) = + self.worktree_store.read(cx).find_worktree(abs_path, cx) + else { + log::warn!("skipping diagnostics update, no worktree found for path {abs_path:?}"); + return Ok(()); + }; - let project_path = ProjectPath { - worktree_id: worktree.read(cx).id(), - path: relative_path.into(), - }; + let worktree_id = worktree.read(cx).id(); + let project_path = ProjectPath { + worktree_id, + path: relative_path.into(), + }; - if let Some(buffer_handle) = self.buffer_store.read(cx).get_by_path(&project_path) { - let snapshot = buffer_handle.read(cx).snapshot(); - let buffer = buffer_handle.read(cx); - let reused_diagnostics = buffer - .get_diagnostics(server_id) - .into_iter() - .flat_map(|diag| { - diag.iter() - .filter(|v| filter(buffer, &v.diagnostic, cx)) - .map(|v| { - let start = Unclipped(v.range.start.to_point_utf16(&snapshot)); - let end = Unclipped(v.range.end.to_point_utf16(&snapshot)); - DiagnosticEntry { - range: start..end, - diagnostic: v.diagnostic.clone(), - } - }) - }) - .collect::<Vec<_>>(); + if let Some(buffer_handle) = self.buffer_store.read(cx).get_by_path(&project_path) { + let snapshot = buffer_handle.read(cx).snapshot(); + let buffer = buffer_handle.read(cx); + let reused_diagnostics = buffer + .get_diagnostics(server_id) + .into_iter() + .flat_map(|diag| { + diag.iter() + .filter(|v| merge(buffer, &v.diagnostic, cx)) + .map(|v| { + let start = Unclipped(v.range.start.to_point_utf16(&snapshot)); + let end = Unclipped(v.range.end.to_point_utf16(&snapshot)); + DiagnosticEntry { + range: start..end, + diagnostic: v.diagnostic.clone(), + } + }) + }) + .collect::<Vec<_>>(); - self.as_local_mut() - .context("cannot merge diagnostics on a remote LspStore")? - .update_buffer_diagnostics( - &buffer_handle, + self.as_local_mut() + .context("cannot merge diagnostics on a remote LspStore")? + .update_buffer_diagnostics( + &buffer_handle, + server_id, + update.result_id, + update.diagnostics.version, + update.diagnostics.diagnostics.clone(), + reused_diagnostics.clone(), + cx, + )?; + + update.diagnostics.diagnostics.extend(reused_diagnostics); + } + + let updated = worktree.update(cx, |worktree, cx| { + self.update_worktree_diagnostics( + worktree.id(), server_id, - result_id, - version, - diagnostics.clone(), - reused_diagnostics.clone(), + project_path.path.clone(), + update.diagnostics.diagnostics, cx, - )?; - - diagnostics.extend(reused_diagnostics); + ) + })?; + match updated { + ControlFlow::Continue(new_summary) => { + if let Some((project_id, new_summary)) = new_summary { + match &mut diagnostics_summary { + Some(diagnostics_summary) => { + diagnostics_summary + .more_summaries + .push(proto::DiagnosticSummary { + path: project_path.path.as_ref().to_proto(), + language_server_id: server_id.0 as u64, + error_count: new_summary.error_count, + warning_count: new_summary.warning_count, + }) + } + None => { + diagnostics_summary = Some(proto::UpdateDiagnosticSummary { + project_id: project_id, + worktree_id: worktree_id.to_proto(), + summary: Some(proto::DiagnosticSummary { + path: project_path.path.as_ref().to_proto(), + language_server_id: server_id.0 as u64, + error_count: new_summary.error_count, + warning_count: new_summary.warning_count, + }), + more_summaries: Vec::new(), + }) + } + } + } + updated_diagnostics_paths + .entry(server_id) + .or_insert_with(Vec::new) + .push(project_path); + } + ControlFlow::Break(()) => {} + } } - let updated = worktree.update(cx, |worktree, cx| { - self.update_worktree_diagnostics( - worktree.id(), - server_id, - project_path.path.clone(), - diagnostics, - cx, - ) - })?; - if updated { - cx.emit(LspStoreEvent::DiagnosticsUpdated { - language_server_id: server_id, - path: project_path, - }) + if let Some((diagnostics_summary, (downstream_client, _))) = + diagnostics_summary.zip(self.downstream_client.as_ref()) + { + downstream_client.send(diagnostics_summary).log_err(); + } + for (server_id, paths) in updated_diagnostics_paths { + cx.emit(LspStoreEvent::DiagnosticsUpdated { server_id, paths }); } Ok(()) } @@ -7667,10 +7974,10 @@ impl LspStore { &mut self, worktree_id: WorktreeId, server_id: LanguageServerId, - worktree_path: Arc<Path>, + path_in_worktree: Arc<Path>, diagnostics: Vec<DiagnosticEntry<Unclipped<PointUtf16>>>, _: &mut Context<Worktree>, - ) -> Result<bool> { + ) -> Result<ControlFlow<(), Option<(u64, proto::DiagnosticSummary)>>> { let local = match &mut self.mode { LspStoreMode::Local(local_lsp_store) => local_lsp_store, _ => anyhow::bail!("update_worktree_diagnostics called on remote"), @@ -7678,7 +7985,9 @@ impl LspStore { let summaries_for_tree = self.diagnostic_summaries.entry(worktree_id).or_default(); let diagnostics_for_tree = local.diagnostics.entry(worktree_id).or_default(); - let summaries_by_server_id = summaries_for_tree.entry(worktree_path.clone()).or_default(); + let summaries_by_server_id = summaries_for_tree + .entry(path_in_worktree.clone()) + .or_default(); let old_summary = summaries_by_server_id .remove(&server_id) @@ -7686,18 +7995,19 @@ impl LspStore { let new_summary = DiagnosticSummary::new(&diagnostics); if new_summary.is_empty() { - if let Some(diagnostics_by_server_id) = diagnostics_for_tree.get_mut(&worktree_path) { + if let Some(diagnostics_by_server_id) = diagnostics_for_tree.get_mut(&path_in_worktree) + { if let Ok(ix) = diagnostics_by_server_id.binary_search_by_key(&server_id, |e| e.0) { diagnostics_by_server_id.remove(ix); } if diagnostics_by_server_id.is_empty() { - diagnostics_for_tree.remove(&worktree_path); + diagnostics_for_tree.remove(&path_in_worktree); } } } else { summaries_by_server_id.insert(server_id, new_summary); let diagnostics_by_server_id = diagnostics_for_tree - .entry(worktree_path.clone()) + .entry(path_in_worktree.clone()) .or_default(); match diagnostics_by_server_id.binary_search_by_key(&server_id, |e| e.0) { Ok(ix) => { @@ -7710,23 +8020,22 @@ impl LspStore { } if !old_summary.is_empty() || !new_summary.is_empty() { - if let Some((downstream_client, project_id)) = &self.downstream_client { - downstream_client - .send(proto::UpdateDiagnosticSummary { - project_id: *project_id, - worktree_id: worktree_id.to_proto(), - summary: Some(proto::DiagnosticSummary { - path: worktree_path.to_proto(), - language_server_id: server_id.0 as u64, - error_count: new_summary.error_count as u32, - warning_count: new_summary.warning_count as u32, - }), - }) - .log_err(); + if let Some((_, project_id)) = &self.downstream_client { + Ok(ControlFlow::Continue(Some(( + *project_id, + proto::DiagnosticSummary { + path: path_in_worktree.to_proto(), + language_server_id: server_id.0 as u64, + error_count: new_summary.error_count as u32, + warning_count: new_summary.warning_count as u32, + }, + )))) + } else { + Ok(ControlFlow::Continue(None)) } + } else { + Ok(ControlFlow::Break(())) } - - Ok(!old_summary.is_empty() || !new_summary.is_empty()) } pub fn open_buffer_for_symbol( @@ -7931,7 +8240,7 @@ impl LspStore { }) .collect::<FuturesUnordered<_>>(); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { let mut responses = Vec::with_capacity(response_results.len()); while let Some((server_id, response_result)) = response_results.next().await { if let Some(response) = response_result.log_err() { @@ -8530,34 +8839,6 @@ impl LspStore { Ok(proto::Ack {}) } - async fn handle_language_server_id_for_name( - lsp_store: Entity<Self>, - envelope: TypedEnvelope<proto::LanguageServerIdForName>, - mut cx: AsyncApp, - ) -> Result<proto::LanguageServerIdForNameResponse> { - let name = &envelope.payload.name; - let buffer_id = BufferId::new(envelope.payload.buffer_id)?; - lsp_store - .update(&mut cx, |lsp_store, cx| { - let buffer = lsp_store.buffer_store.read(cx).get_existing(buffer_id)?; - let server_id = buffer.update(cx, |buffer, cx| { - lsp_store - .language_servers_for_local_buffer(buffer, cx) - .find_map(|(adapter, server)| { - if adapter.name.0.as_ref() == name { - Some(server.server_id()) - } else { - None - } - }) - }); - Ok(server_id) - })? - .map(|server_id| proto::LanguageServerIdForNameResponse { - server_id: server_id.map(|id| id.to_proto()), - }) - } - async fn handle_rename_project_entry( this: Entity<Self>, envelope: TypedEnvelope<proto::RenameProjectEntry>, @@ -8607,23 +8888,30 @@ impl LspStore { envelope: TypedEnvelope<proto::UpdateDiagnosticSummary>, mut cx: AsyncApp, ) -> Result<()> { - this.update(&mut cx, |this, cx| { + this.update(&mut cx, |lsp_store, cx| { let worktree_id = WorktreeId::from_proto(envelope.payload.worktree_id); - if let Some(message) = envelope.payload.summary { + let mut updated_diagnostics_paths = HashMap::default(); + let mut diagnostics_summary = None::<proto::UpdateDiagnosticSummary>; + for message_summary in envelope + .payload + .summary + .into_iter() + .chain(envelope.payload.more_summaries) + { let project_path = ProjectPath { worktree_id, - path: Arc::<Path>::from_proto(message.path), + path: Arc::<Path>::from_proto(message_summary.path), }; let path = project_path.path.clone(); - let server_id = LanguageServerId(message.language_server_id as usize); + let server_id = LanguageServerId(message_summary.language_server_id as usize); let summary = DiagnosticSummary { - error_count: message.error_count as usize, - warning_count: message.warning_count as usize, + error_count: message_summary.error_count as usize, + warning_count: message_summary.warning_count as usize, }; if summary.is_empty() { if let Some(worktree_summaries) = - this.diagnostic_summaries.get_mut(&worktree_id) + lsp_store.diagnostic_summaries.get_mut(&worktree_id) { if let Some(summaries) = worktree_summaries.get_mut(&path) { summaries.remove(&server_id); @@ -8633,49 +8921,84 @@ impl LspStore { } } } else { - this.diagnostic_summaries + lsp_store + .diagnostic_summaries .entry(worktree_id) .or_default() .entry(path) .or_default() .insert(server_id, summary); } - if let Some((downstream_client, project_id)) = &this.downstream_client { - downstream_client - .send(proto::UpdateDiagnosticSummary { - project_id: *project_id, - worktree_id: worktree_id.to_proto(), - summary: Some(proto::DiagnosticSummary { - path: project_path.path.as_ref().to_proto(), - language_server_id: server_id.0 as u64, - error_count: summary.error_count as u32, - warning_count: summary.warning_count as u32, - }), - }) - .log_err(); + + if let Some((_, project_id)) = &lsp_store.downstream_client { + match &mut diagnostics_summary { + Some(diagnostics_summary) => { + diagnostics_summary + .more_summaries + .push(proto::DiagnosticSummary { + path: project_path.path.as_ref().to_proto(), + language_server_id: server_id.0 as u64, + error_count: summary.error_count as u32, + warning_count: summary.warning_count as u32, + }) + } + None => { + diagnostics_summary = Some(proto::UpdateDiagnosticSummary { + project_id: *project_id, + worktree_id: worktree_id.to_proto(), + summary: Some(proto::DiagnosticSummary { + path: project_path.path.as_ref().to_proto(), + language_server_id: server_id.0 as u64, + error_count: summary.error_count as u32, + warning_count: summary.warning_count as u32, + }), + more_summaries: Vec::new(), + }) + } + } } - cx.emit(LspStoreEvent::DiagnosticsUpdated { - language_server_id: LanguageServerId(message.language_server_id as usize), - path: project_path, - }); + updated_diagnostics_paths + .entry(server_id) + .or_insert_with(Vec::new) + .push(project_path); + } + + if let Some((diagnostics_summary, (downstream_client, _))) = + diagnostics_summary.zip(lsp_store.downstream_client.as_ref()) + { + downstream_client.send(diagnostics_summary).log_err(); + } + for (server_id, paths) in updated_diagnostics_paths { + cx.emit(LspStoreEvent::DiagnosticsUpdated { server_id, paths }); } Ok(()) })? } async fn handle_start_language_server( - this: Entity<Self>, + lsp_store: Entity<Self>, envelope: TypedEnvelope<proto::StartLanguageServer>, mut cx: AsyncApp, ) -> Result<()> { let server = envelope.payload.server.context("invalid server")?; - - this.update(&mut cx, |this, cx| { + let server_capabilities = + serde_json::from_str::<lsp::ServerCapabilities>(&envelope.payload.capabilities) + .with_context(|| { + format!( + "incorrect server capabilities {}", + envelope.payload.capabilities + ) + })?; + lsp_store.update(&mut cx, |lsp_store, cx| { let server_id = LanguageServerId(server.id as usize); - this.language_server_statuses.insert( + let server_name = LanguageServerName::from_proto(server.name.clone()); + lsp_store + .lsp_server_capabilities + .insert(server_id, server_capabilities); + lsp_store.language_server_statuses.insert( server_id, LanguageServerStatus { - name: server.name.clone(), + name: server_name.clone(), pending_work: Default::default(), has_pending_diagnostic_updates: false, progress_tokens: Default::default(), @@ -8683,7 +9006,7 @@ impl LspStore { ); cx.emit(LspStoreEvent::LanguageServerAdded( server_id, - LanguageServerName(server.name.into()), + server_name, server.worktree_id.map(WorktreeId::from_proto), )); cx.notify(); @@ -8744,7 +9067,8 @@ impl LspStore { } non_lsp @ proto::update_language_server::Variant::StatusUpdate(_) - | non_lsp @ proto::update_language_server::Variant::RegisteredForBuffer(_) => { + | non_lsp @ proto::update_language_server::Variant::RegisteredForBuffer(_) + | non_lsp @ proto::update_language_server::Variant::MetadataUpdated(_) => { cx.emit(LspStoreEvent::LanguageServerUpdate { language_server_id, name: envelope @@ -9130,7 +9454,39 @@ impl LspStore { } }; - let lsp::ProgressParamsValue::WorkDone(progress) = progress.value; + match progress.value { + lsp::ProgressParamsValue::WorkDone(progress) => { + self.handle_work_done_progress( + progress, + language_server_id, + disk_based_diagnostics_progress_token, + token, + cx, + ); + } + lsp::ProgressParamsValue::WorkspaceDiagnostic(report) => { + if let Some(LanguageServerState::Running { + workspace_refresh_task: Some(workspace_refresh_task), + .. + }) = self + .as_local_mut() + .and_then(|local| local.language_servers.get_mut(&language_server_id)) + { + workspace_refresh_task.progress_tx.try_send(()).ok(); + self.apply_workspace_diagnostic_report(language_server_id, report, cx) + } + } + } + } + + fn handle_work_done_progress( + &mut self, + progress: lsp::WorkDoneProgress, + language_server_id: LanguageServerId, + disk_based_diagnostics_progress_token: Option<String>, + token: String, + cx: &mut Context<Self>, + ) { let language_server_status = if let Some(status) = self.language_server_statuses.get_mut(&language_server_id) { status @@ -10131,6 +10487,7 @@ impl LspStore { error_count: 0, warning_count: 0, }), + more_summaries: Vec::new(), }) .log_err(); } @@ -10159,7 +10516,7 @@ impl LspStore { let name = self .language_server_statuses .remove(&server_id) - .map(|status| LanguageServerName::from(status.name.as_str())) + .map(|status| status.name.clone()) .or_else(|| { if let Some(LanguageServerState::Running { adapter, .. }) = server_state.as_ref() { Some(adapter.name()) @@ -10419,52 +10776,80 @@ impl LspStore { ) } + #[cfg(any(test, feature = "test-support"))] pub fn update_diagnostics( &mut self, - language_server_id: LanguageServerId, - params: lsp::PublishDiagnosticsParams, + server_id: LanguageServerId, + diagnostics: lsp::PublishDiagnosticsParams, result_id: Option<String>, source_kind: DiagnosticSourceKind, disk_based_sources: &[String], cx: &mut Context<Self>, ) -> Result<()> { - self.merge_diagnostics( - language_server_id, - params, - result_id, + self.merge_lsp_diagnostics( source_kind, - disk_based_sources, + vec![DocumentDiagnosticsUpdate { + diagnostics, + result_id, + server_id, + disk_based_sources: Cow::Borrowed(disk_based_sources), + }], |_, _, _| false, cx, ) } - pub fn merge_diagnostics( + pub fn merge_lsp_diagnostics( &mut self, - language_server_id: LanguageServerId, - mut params: lsp::PublishDiagnosticsParams, - result_id: Option<String>, source_kind: DiagnosticSourceKind, - disk_based_sources: &[String], - filter: impl Fn(&Buffer, &Diagnostic, &App) -> bool + Clone, + lsp_diagnostics: Vec<DocumentDiagnosticsUpdate<lsp::PublishDiagnosticsParams>>, + merge: impl Fn(&Buffer, &Diagnostic, &App) -> bool + Clone, cx: &mut Context<Self>, ) -> Result<()> { anyhow::ensure!(self.mode.is_local(), "called update_diagnostics on remote"); - let abs_path = params - .uri - .to_file_path() - .map_err(|()| anyhow!("URI is not a file"))?; + let updates = lsp_diagnostics + .into_iter() + .filter_map(|update| { + let abs_path = update.diagnostics.uri.to_file_path().ok()?; + Some(DocumentDiagnosticsUpdate { + diagnostics: self.lsp_to_document_diagnostics( + abs_path, + source_kind, + update.server_id, + update.diagnostics, + &update.disk_based_sources, + ), + result_id: update.result_id, + server_id: update.server_id, + disk_based_sources: update.disk_based_sources, + }) + }) + .collect(); + self.merge_diagnostic_entries(updates, merge, cx)?; + Ok(()) + } + + fn lsp_to_document_diagnostics( + &mut self, + document_abs_path: PathBuf, + source_kind: DiagnosticSourceKind, + server_id: LanguageServerId, + mut lsp_diagnostics: lsp::PublishDiagnosticsParams, + disk_based_sources: &[String], + ) -> DocumentDiagnostics { let mut diagnostics = Vec::default(); let mut primary_diagnostic_group_ids = HashMap::default(); let mut sources_by_group_id = HashMap::default(); let mut supporting_diagnostics = HashMap::default(); - let adapter = self.language_server_adapter_for_id(language_server_id); + let adapter = self.language_server_adapter_for_id(server_id); // Ensure that primary diagnostics are always the most severe - params.diagnostics.sort_by_key(|item| item.severity); + lsp_diagnostics + .diagnostics + .sort_by_key(|item| item.severity); - for diagnostic in ¶ms.diagnostics { + for diagnostic in &lsp_diagnostics.diagnostics { let source = diagnostic.source.as_ref(); let range = range_from_lsp(diagnostic.range); let is_supporting = diagnostic @@ -10486,7 +10871,7 @@ impl LspStore { .map_or(false, |tags| tags.contains(&DiagnosticTag::UNNECESSARY)); let underline = self - .language_server_adapter_for_id(language_server_id) + .language_server_adapter_for_id(server_id) .map_or(true, |adapter| adapter.underline_diagnostic(diagnostic)); if is_supporting { @@ -10512,7 +10897,7 @@ impl LspStore { code_description: diagnostic .code_description .as_ref() - .map(|d| d.href.clone()), + .and_then(|d| d.href.clone()), severity: diagnostic.severity.unwrap_or(DiagnosticSeverity::ERROR), markdown: adapter.as_ref().and_then(|adapter| { adapter.diagnostic_message_to_markdown(&diagnostic.message) @@ -10528,7 +10913,7 @@ impl LspStore { }); if let Some(infos) = &diagnostic.related_information { for info in infos { - if info.location.uri == params.uri && !info.message.is_empty() { + if info.location.uri == lsp_diagnostics.uri && !info.message.is_empty() { let range = range_from_lsp(info.location.range); diagnostics.push(DiagnosticEntry { range, @@ -10539,7 +10924,7 @@ impl LspStore { code_description: diagnostic .code_description .as_ref() - .map(|c| c.href.clone()), + .and_then(|d| d.href.clone()), severity: DiagnosticSeverity::INFORMATION, markdown: adapter.as_ref().and_then(|adapter| { adapter.diagnostic_message_to_markdown(&info.message) @@ -10576,16 +10961,11 @@ impl LspStore { } } - self.merge_diagnostic_entries( - language_server_id, - abs_path, - result_id, - params.version, + DocumentDiagnostics { diagnostics, - filter, - cx, - )?; - Ok(()) + document_abs_path, + version: lsp_diagnostics.version, + } } fn insert_newly_running_language_server( @@ -10652,7 +11032,7 @@ impl LspStore { self.language_server_statuses.insert( server_id, LanguageServerStatus { - name: language_server.name().to_string(), + name: language_server.name(), pending_work: Default::default(), has_pending_diagnostic_updates: false, progress_tokens: Default::default(), @@ -10666,18 +11046,23 @@ impl LspStore { )); cx.emit(LspStoreEvent::RefreshInlayHints); + let server_capabilities = language_server.capabilities(); if let Some((downstream_client, project_id)) = self.downstream_client.as_ref() { downstream_client .send(proto::StartLanguageServer { project_id: *project_id, server: Some(proto::LanguageServer { - id: server_id.0 as u64, + id: server_id.to_proto(), name: language_server.name().to_string(), worktree_id: Some(key.0.to_proto()), }), + capabilities: serde_json::to_string(&server_capabilities) + .expect("serializing server LSP capabilities"), }) .log_err(); } + self.lsp_server_capabilities + .insert(server_id, server_capabilities); // Tell the language server about every open buffer in the worktree that matches the language. // Also check for buffers in worktrees that reused this server @@ -10725,10 +11110,11 @@ impl LspStore { let local = self.as_local_mut().unwrap(); - if local.registered_buffers.contains_key(&buffer.remote_id()) { + let buffer_id = buffer.remote_id(); + if local.registered_buffers.contains_key(&buffer_id) { let versions = local .buffer_snapshots - .entry(buffer.remote_id()) + .entry(buffer_id) .or_default() .entry(server_id) .and_modify(|_| { @@ -10754,10 +11140,10 @@ impl LspStore { version, initial_snapshot.text(), ); - buffer_paths_registered.push(file.abs_path(cx)); + buffer_paths_registered.push((buffer_id, file.abs_path(cx))); local .buffers_opened_in_servers - .entry(buffer.remote_id()) + .entry(buffer_id) .or_default() .insert(server_id); } @@ -10781,13 +11167,14 @@ impl LspStore { } }); - for abs_path in buffer_paths_registered { + for (buffer_id, abs_path) in buffer_paths_registered { cx.emit(LspStoreEvent::LanguageServerUpdate { language_server_id: server_id, name: Some(adapter.name()), message: proto::update_language_server::Variant::RegisteredForBuffer( proto::RegisteredForBuffer { buffer_abs_path: abs_path.to_string_lossy().to_string(), + buffer_id: buffer_id.to_proto(), }, ), }); @@ -11245,9 +11632,13 @@ impl LspStore { } fn cleanup_lsp_data(&mut self, for_server: LanguageServerId) { - for buffer_lsp_data in self.lsp_data.values_mut() { - buffer_lsp_data.colors.remove(&for_server); - buffer_lsp_data.cache_version += 1; + self.lsp_server_capabilities.remove(&for_server); + for buffer_colors in self.lsp_document_colors.values_mut() { + buffer_colors.colors.remove(&for_server); + buffer_colors.cache_version += 1; + } + for buffer_lens in self.lsp_code_lens.values_mut() { + buffer_lens.lens.remove(&for_server); } if let Some(local) = self.as_local_mut() { local.buffer_pull_diagnostics_result_ids.remove(&for_server); @@ -11291,13 +11682,13 @@ impl LspStore { pub fn pull_workspace_diagnostics(&mut self, server_id: LanguageServerId) { if let Some(LanguageServerState::Running { - workspace_refresh_task: Some((tx, _)), + workspace_refresh_task: Some(workspace_refresh_task), .. }) = self .as_local_mut() .and_then(|local| local.language_servers.get_mut(&server_id)) { - tx.try_send(()).ok(); + workspace_refresh_task.refresh_tx.try_send(()).ok(); } } @@ -11313,14 +11704,103 @@ impl LspStore { local.language_server_ids_for_buffer(buffer, cx) }) { if let Some(LanguageServerState::Running { - workspace_refresh_task: Some((tx, _)), + workspace_refresh_task: Some(workspace_refresh_task), .. }) = local.language_servers.get_mut(&server_id) { - tx.try_send(()).ok(); + workspace_refresh_task.refresh_tx.try_send(()).ok(); } } } + + fn apply_workspace_diagnostic_report( + &mut self, + server_id: LanguageServerId, + report: lsp::WorkspaceDiagnosticReportResult, + cx: &mut Context<Self>, + ) { + let workspace_diagnostics = + GetDocumentDiagnostics::deserialize_workspace_diagnostics_report(report, server_id); + let mut unchanged_buffers = HashSet::default(); + let mut changed_buffers = HashSet::default(); + let workspace_diagnostics_updates = workspace_diagnostics + .into_iter() + .filter_map( + |workspace_diagnostics| match workspace_diagnostics.diagnostics { + LspPullDiagnostics::Response { + server_id, + uri, + diagnostics, + } => Some((server_id, uri, diagnostics, workspace_diagnostics.version)), + LspPullDiagnostics::Default => None, + }, + ) + .fold( + HashMap::default(), + |mut acc, (server_id, uri, diagnostics, version)| { + let (result_id, diagnostics) = match diagnostics { + PulledDiagnostics::Unchanged { result_id } => { + unchanged_buffers.insert(uri.clone()); + (Some(result_id), Vec::new()) + } + PulledDiagnostics::Changed { + result_id, + diagnostics, + } => { + changed_buffers.insert(uri.clone()); + (result_id, diagnostics) + } + }; + let disk_based_sources = Cow::Owned( + self.language_server_adapter_for_id(server_id) + .as_ref() + .map(|adapter| adapter.disk_based_diagnostic_sources.as_slice()) + .unwrap_or(&[]) + .to_vec(), + ); + acc.entry(server_id) + .or_insert_with(Vec::new) + .push(DocumentDiagnosticsUpdate { + server_id, + diagnostics: lsp::PublishDiagnosticsParams { + uri, + diagnostics, + version, + }, + result_id, + disk_based_sources, + }); + acc + }, + ); + + for diagnostic_updates in workspace_diagnostics_updates.into_values() { + self.merge_lsp_diagnostics( + DiagnosticSourceKind::Pulled, + diagnostic_updates, + |buffer, old_diagnostic, cx| { + File::from_dyn(buffer.file()) + .and_then(|file| { + let abs_path = file.as_local()?.abs_path(cx); + lsp::Url::from_file_path(abs_path).ok() + }) + .is_none_or(|buffer_uri| { + unchanged_buffers.contains(&buffer_uri) + || match old_diagnostic.source_kind { + DiagnosticSourceKind::Pulled => { + !changed_buffers.contains(&buffer_uri) + } + DiagnosticSourceKind::Other | DiagnosticSourceKind::Pushed => { + true + } + } + }) + }, + cx, + ) + .log_err(); + } + } } fn subscribe_to_binary_statuses( @@ -11373,7 +11853,7 @@ fn subscribe_to_binary_statuses( fn lsp_workspace_diagnostics_refresh( server: Arc<LanguageServer>, cx: &mut Context<'_, LspStore>, -) -> Option<(mpsc::Sender<()>, Task<()>)> { +) -> Option<WorkspaceRefreshTask> { let identifier = match server.capabilities().diagnostic_provider? { lsp::DiagnosticServerCapabilities::Options(diagnostic_options) => { if !diagnostic_options.workspace_diagnostics { @@ -11390,19 +11870,22 @@ fn lsp_workspace_diagnostics_refresh( } }; - let (mut tx, mut rx) = mpsc::channel(1); - tx.try_send(()).ok(); + let (progress_tx, mut progress_rx) = mpsc::channel(1); + let (mut refresh_tx, mut refresh_rx) = mpsc::channel(1); + refresh_tx.try_send(()).ok(); let workspace_query_language_server = cx.spawn(async move |lsp_store, cx| { let mut attempts = 0; let max_attempts = 50; + let mut requests = 0; loop { - let Some(()) = rx.recv().await else { + let Some(()) = refresh_rx.recv().await else { return; }; 'request: loop { + requests += 1; if attempts > max_attempts { log::error!( "Failed to pull workspace diagnostics {max_attempts} times, aborting" @@ -11431,14 +11914,29 @@ fn lsp_workspace_diagnostics_refresh( return; }; + let token = format!("workspace/diagnostic-{}-{}", server.server_id(), requests); + + progress_rx.try_recv().ok(); + let timer = + LanguageServer::default_request_timer(cx.background_executor().clone()).fuse(); + let progress = pin!(progress_rx.recv().fuse()); let response_result = server - .request::<lsp::WorkspaceDiagnosticRequest>(lsp::WorkspaceDiagnosticParams { - previous_result_ids, - identifier: identifier.clone(), - work_done_progress_params: Default::default(), - partial_result_params: Default::default(), - }) + .request_with_timer::<lsp::WorkspaceDiagnosticRequest, _>( + lsp::WorkspaceDiagnosticParams { + previous_result_ids, + identifier: identifier.clone(), + work_done_progress_params: Default::default(), + partial_result_params: lsp::PartialResultParams { + partial_result_token: Some(lsp::ProgressToken::String(token)), + }, + }, + select(timer, progress).then(|either| match either { + Either::Left((message, ..)) => ready(message).left_future(), + Either::Right(..) => pending::<String>().right_future(), + }), + ) .await; + // https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#diagnostic_refresh // > If a server closes a workspace diagnostic pull request the client should re-trigger the request. match response_result { @@ -11458,72 +11956,11 @@ fn lsp_workspace_diagnostics_refresh( attempts = 0; if lsp_store .update(cx, |lsp_store, cx| { - let workspace_diagnostics = - GetDocumentDiagnostics::deserialize_workspace_diagnostics_report(pulled_diagnostics, server.server_id()); - for workspace_diagnostics in workspace_diagnostics { - let LspPullDiagnostics::Response { - server_id, - uri, - diagnostics, - } = workspace_diagnostics.diagnostics - else { - continue; - }; - - let adapter = lsp_store.language_server_adapter_for_id(server_id); - let disk_based_sources = adapter - .as_ref() - .map(|adapter| adapter.disk_based_diagnostic_sources.as_slice()) - .unwrap_or(&[]); - - match diagnostics { - PulledDiagnostics::Unchanged { result_id } => { - lsp_store - .merge_diagnostics( - server_id, - lsp::PublishDiagnosticsParams { - uri: uri.clone(), - diagnostics: Vec::new(), - version: None, - }, - Some(result_id), - DiagnosticSourceKind::Pulled, - disk_based_sources, - |_, _, _| true, - cx, - ) - .log_err(); - } - PulledDiagnostics::Changed { - diagnostics, - result_id, - } => { - lsp_store - .merge_diagnostics( - server_id, - lsp::PublishDiagnosticsParams { - uri: uri.clone(), - diagnostics, - version: workspace_diagnostics.version, - }, - result_id, - DiagnosticSourceKind::Pulled, - disk_based_sources, - |buffer, old_diagnostic, cx| match old_diagnostic.source_kind { - DiagnosticSourceKind::Pulled => { - let buffer_url = File::from_dyn(buffer.file()).map(|f| f.abs_path(cx)) - .and_then(|abs_path| file_path_to_lsp_url(&abs_path).ok()); - buffer_url.is_none_or(|buffer_url| buffer_url != uri) - }, - DiagnosticSourceKind::Other - | DiagnosticSourceKind::Pushed => true, - }, - cx, - ) - .log_err(); - } - } - } + lsp_store.apply_workspace_diagnostic_report( + server.server_id(), + pulled_diagnostics, + cx, + ) }) .is_err() { @@ -11536,7 +11973,11 @@ fn lsp_workspace_diagnostics_refresh( } }); - Some((tx, workspace_query_language_server)) + Some(WorkspaceRefreshTask { + refresh_tx, + progress_tx, + task: workspace_query_language_server, + }) } fn resolve_word_completion(snapshot: &BufferSnapshot, completion: &mut Completion) { @@ -11906,6 +12347,13 @@ impl LanguageServerLogType { } } +pub struct WorkspaceRefreshTask { + refresh_tx: mpsc::Sender<()>, + progress_tx: mpsc::Sender<()>, + #[allow(dead_code)] + task: Task<()>, +} + pub enum LanguageServerState { Starting { startup: Task<Option<Arc<LanguageServer>>>, @@ -11917,7 +12365,7 @@ pub enum LanguageServerState { adapter: Arc<CachedLspAdapter>, server: Arc<LanguageServer>, simulate_disk_based_diagnostics_completion: Option<Task<()>>, - workspace_refresh_task: Option<(mpsc::Sender<()>, Task<()>)>, + workspace_refresh_task: Option<WorkspaceRefreshTask>, }, } diff --git a/crates/project/src/lsp_store/clangd_ext.rs b/crates/project/src/lsp_store/clangd_ext.rs index 6a09bb99b4ae6e17ac000ecac3d1aef6d0d2b5ee..274b1b898086eeddf72710052397dd9963833663 100644 --- a/crates/project/src/lsp_store/clangd_ext.rs +++ b/crates/project/src/lsp_store/clangd_ext.rs @@ -1,14 +1,14 @@ -use std::sync::Arc; +use std::{borrow::Cow, sync::Arc}; use ::serde::{Deserialize, Serialize}; use gpui::WeakEntity; use language::{CachedLspAdapter, Diagnostic, DiagnosticSourceKind}; -use lsp::LanguageServer; +use lsp::{LanguageServer, LanguageServerName}; use util::ResultExt as _; -use crate::LspStore; +use crate::{LspStore, lsp_store::DocumentDiagnosticsUpdate}; -pub const CLANGD_SERVER_NAME: &str = "clangd"; +pub const CLANGD_SERVER_NAME: LanguageServerName = LanguageServerName::new_static("clangd"); const INACTIVE_REGION_MESSAGE: &str = "inactive region"; const INACTIVE_DIAGNOSTIC_SEVERITY: lsp::DiagnosticSeverity = lsp::DiagnosticSeverity::INFORMATION; @@ -34,7 +34,7 @@ pub fn is_inactive_region(diag: &Diagnostic) -> bool { && diag .source .as_ref() - .is_some_and(|v| v == CLANGD_SERVER_NAME) + .is_some_and(|v| v == &CLANGD_SERVER_NAME.0) } pub fn is_lsp_inactive_region(diag: &lsp::Diagnostic) -> bool { @@ -43,7 +43,7 @@ pub fn is_lsp_inactive_region(diag: &lsp::Diagnostic) -> bool { && diag .source .as_ref() - .is_some_and(|v| v == CLANGD_SERVER_NAME) + .is_some_and(|v| v == &CLANGD_SERVER_NAME.0) } pub fn register_notifications( @@ -51,7 +51,7 @@ pub fn register_notifications( language_server: &LanguageServer, adapter: Arc<CachedLspAdapter>, ) { - if language_server.name().0 != CLANGD_SERVER_NAME { + if language_server.name() != CLANGD_SERVER_NAME { return; } let server_id = language_server.server_id(); @@ -81,12 +81,16 @@ pub fn register_notifications( version: params.text_document.version, diagnostics, }; - this.merge_diagnostics( - server_id, - mapped_diagnostics, - None, + this.merge_lsp_diagnostics( DiagnosticSourceKind::Pushed, - &adapter.disk_based_diagnostic_sources, + vec![DocumentDiagnosticsUpdate { + server_id, + diagnostics: mapped_diagnostics, + result_id: None, + disk_based_sources: Cow::Borrowed( + &adapter.disk_based_diagnostic_sources, + ), + }], |_, diag, _| !is_inactive_region(diag), cx, ) diff --git a/crates/project/src/lsp_store/json_language_server_ext.rs b/crates/project/src/lsp_store/json_language_server_ext.rs new file mode 100644 index 0000000000000000000000000000000000000000..3eb93386a99bf40dffc5f6de75d56248936b38e3 --- /dev/null +++ b/crates/project/src/lsp_store/json_language_server_ext.rs @@ -0,0 +1,101 @@ +use anyhow::Context as _; +use collections::HashMap; +use gpui::WeakEntity; +use lsp::LanguageServer; + +use crate::LspStore; +/// https://github.com/Microsoft/vscode/blob/main/extensions/json-language-features/server/README.md#schema-content-request +/// +/// Represents a "JSON language server-specific, non-standardized, extension to the LSP" with which the vscode-json-language-server +/// can request the contents of a schema that is associated with a uri scheme it does not support. +/// In our case, we provide the uris for actions on server startup under the `zed://schemas/action/{normalize_action_name}` scheme. +/// We can then respond to this request with the schema content on demand, thereby greatly reducing the total size of the JSON we send to the server on startup +struct SchemaContentRequest {} + +impl lsp::request::Request for SchemaContentRequest { + type Params = Vec<String>; + + type Result = String; + + const METHOD: &'static str = "vscode/content"; +} + +pub fn register_requests(_lsp_store: WeakEntity<LspStore>, language_server: &LanguageServer) { + language_server + .on_request::<SchemaContentRequest, _, _>(|params, cx| { + // PERF: Use a cache (`OnceLock`?) to avoid recomputing the action schemas + let mut generator = settings::KeymapFile::action_schema_generator(); + let all_schemas = cx.update(|cx| HashMap::from_iter(cx.action_schemas(&mut generator))); + async move { + let all_schemas = all_schemas?; + let Some(uri) = params.get(0) else { + anyhow::bail!("No URI"); + }; + let normalized_action_name = uri + .strip_prefix("zed://schemas/action/") + .context("Invalid URI")?; + let action_name = denormalize_action_name(normalized_action_name); + let schema = root_schema_from_action_schema( + all_schemas + .get(action_name.as_str()) + .and_then(Option::as_ref), + &mut generator, + ) + .to_value(); + + serde_json::to_string(&schema).context("Failed to serialize schema") + } + }) + .detach(); +} + +pub fn normalize_action_name(action_name: &str) -> String { + action_name.replace("::", "__") +} + +pub fn denormalize_action_name(action_name: &str) -> String { + action_name.replace("__", "::") +} + +pub fn normalized_action_file_name(action_name: &str) -> String { + normalized_action_name_to_file_name(normalize_action_name(action_name)) +} + +pub fn normalized_action_name_to_file_name(mut normalized_action_name: String) -> String { + normalized_action_name.push_str(".json"); + normalized_action_name +} + +pub fn url_schema_for_action(action_name: &str) -> serde_json::Value { + let normalized_name = normalize_action_name(action_name); + let file_name = normalized_action_name_to_file_name(normalized_name.clone()); + serde_json::json!({ + "fileMatch": [file_name], + "url": format!("zed://schemas/action/{}", normalized_name) + }) +} + +fn root_schema_from_action_schema( + action_schema: Option<&schemars::Schema>, + generator: &mut schemars::SchemaGenerator, +) -> schemars::Schema { + let Some(action_schema) = action_schema else { + return schemars::json_schema!(false); + }; + let meta_schema = generator + .settings() + .meta_schema + .as_ref() + .expect("meta_schema should be present in schemars settings") + .to_string(); + let defs = generator.definitions(); + let mut schema = schemars::json_schema!({ + "$schema": meta_schema, + "allowTrailingCommas": true, + "$defs": defs, + }); + schema + .ensure_object() + .extend(std::mem::take(action_schema.clone().ensure_object())); + schema +} diff --git a/crates/project/src/lsp_store/rust_analyzer_ext.rs b/crates/project/src/lsp_store/rust_analyzer_ext.rs index d78715d38579c24b6aa0f5c1841c8c0298ddd9d7..6c425717a82e94985c60db8d1034d470f1aeec35 100644 --- a/crates/project/src/lsp_store/rust_analyzer_ext.rs +++ b/crates/project/src/lsp_store/rust_analyzer_ext.rs @@ -2,12 +2,12 @@ use ::serde::{Deserialize, Serialize}; use anyhow::Context as _; use gpui::{App, Entity, Task, WeakEntity}; use language::ServerHealth; -use lsp::LanguageServer; +use lsp::{LanguageServer, LanguageServerName}; use rpc::proto; use crate::{LspStore, LspStoreEvent, Project, ProjectPath, lsp_store}; -pub const RUST_ANALYZER_NAME: &str = "rust-analyzer"; +pub const RUST_ANALYZER_NAME: LanguageServerName = LanguageServerName::new_static("rust-analyzer"); pub const CARGO_DIAGNOSTICS_SOURCE_NAME: &str = "rustc"; /// Experimental: Informs the end user about the state of the server @@ -97,13 +97,9 @@ pub fn cancel_flycheck( cx.spawn(async move |cx| { let buffer = buffer.await?; - let Some(rust_analyzer_server) = project - .update(cx, |project, cx| { - buffer.update(cx, |buffer, cx| { - project.language_server_id_for_name(buffer, RUST_ANALYZER_NAME, cx) - }) - })? - .await + let Some(rust_analyzer_server) = project.read_with(cx, |project, cx| { + project.language_server_id_for_name(buffer.read(cx), &RUST_ANALYZER_NAME, cx) + })? else { return Ok(()); }; @@ -148,13 +144,9 @@ pub fn run_flycheck( cx.spawn(async move |cx| { let buffer = buffer.await?; - let Some(rust_analyzer_server) = project - .update(cx, |project, cx| { - buffer.update(cx, |buffer, cx| { - project.language_server_id_for_name(buffer, RUST_ANALYZER_NAME, cx) - }) - })? - .await + let Some(rust_analyzer_server) = project.read_with(cx, |project, cx| { + project.language_server_id_for_name(buffer.read(cx), &RUST_ANALYZER_NAME, cx) + })? else { return Ok(()); }; @@ -204,13 +196,9 @@ pub fn clear_flycheck( cx.spawn(async move |cx| { let buffer = buffer.await?; - let Some(rust_analyzer_server) = project - .update(cx, |project, cx| { - buffer.update(cx, |buffer, cx| { - project.language_server_id_for_name(buffer, RUST_ANALYZER_NAME, cx) - }) - })? - .await + let Some(rust_analyzer_server) = project.read_with(cx, |project, cx| { + project.language_server_id_for_name(buffer.read(cx), &RUST_ANALYZER_NAME, cx) + })? else { return Ok(()); }; diff --git a/crates/project/src/manifest_tree/path_trie.rs b/crates/project/src/manifest_tree/path_trie.rs index 0f7575324b040bc951db730ee97f7a08350d571f..1a0736765a43b9e1365334de95eacbe9dbf64382 100644 --- a/crates/project/src/manifest_tree/path_trie.rs +++ b/crates/project/src/manifest_tree/path_trie.rs @@ -6,7 +6,7 @@ use std::{ sync::Arc, }; -/// [RootPathTrie] is a workhorse of [super::ManifestTree]. It is responsible for determining the closest known project root for a given path. +/// [RootPathTrie] is a workhorse of [super::ManifestTree]. It is responsible for determining the closest known entry for a given path. /// It also determines how much of a given path is unexplored, thus letting callers fill in that gap if needed. /// Conceptually, it allows one to annotate Worktree entries with arbitrary extra metadata and run closest-ancestor searches. /// @@ -20,19 +20,16 @@ pub(super) struct RootPathTrie<Label> { } /// Label presence is a marker that allows to optimize searches within [RootPathTrie]; node label can be: -/// - Present; we know there's definitely a project root at this node and it is the only label of that kind on the path to the root of a worktree -/// (none of it's ancestors or descendants can contain the same present label) +/// - Present; we know there's definitely a project root at this node. /// - Known Absent - we know there's definitely no project root at this node and none of it's ancestors are Present (descendants can be present though!). -/// - Forbidden - we know there's definitely no project root at this node and none of it's ancestors or descendants can be Present. /// The distinction is there to optimize searching; when we encounter a node with unknown status, we don't need to look at it's full path /// to the root of the worktree; it's sufficient to explore only the path between last node with a KnownAbsent state and the directory of a path, since we run searches -/// from the leaf up to the root of the worktree. When any of the ancestors is forbidden, we don't need to look at the node or its ancestors. -/// When there's a present labeled node on the path to the root, we don't need to ask the adapter to run the search at all. +/// from the leaf up to the root of the worktree. /// /// In practical terms, it means that by storing label presence we don't need to do a project discovery on a given folder more than once /// (unless the node is invalidated, which can happen when FS entries are renamed/removed). /// -/// Storing project absence allows us to recognize which paths have already been scanned for a project root unsuccessfully. This way we don't need to run +/// Storing absent nodes allows us to recognize which paths have already been scanned for a project root unsuccessfully. This way we don't need to run /// such scan more than once. #[derive(Clone, Copy, Debug, PartialOrd, PartialEq, Ord, Eq)] pub(super) enum LabelPresence { @@ -237,4 +234,25 @@ mod tests { Path::new("a/") ); } + + #[test] + fn path_to_a_root_can_contain_multiple_known_nodes() { + let mut trie = RootPathTrie::<()>::new(); + trie.insert( + &TriePath::from(Path::new("a/b")), + (), + LabelPresence::Present, + ); + trie.insert(&TriePath::from(Path::new("a")), (), LabelPresence::Present); + let mut visited_paths = BTreeSet::new(); + trie.walk(&TriePath::from(Path::new("a/b/c")), &mut |path, nodes| { + assert_eq!(nodes.get(&()), Some(&LabelPresence::Present)); + if path.as_ref() != Path::new("a") && path.as_ref() != Path::new("a/b") { + panic!("Unexpected path: {}", path.as_ref().display()); + } + assert!(visited_paths.insert(path.clone())); + ControlFlow::Continue(()) + }); + assert_eq!(visited_paths.len(), 2); + } } diff --git a/crates/project/src/manifest_tree/server_tree.rs b/crates/project/src/manifest_tree/server_tree.rs index 0283f06eec0f2859f99bddb0e5be10bb8f4197fa..81cb1c450c4626bfa691c98e88d26536705dfb3d 100644 --- a/crates/project/src/manifest_tree/server_tree.rs +++ b/crates/project/src/manifest_tree/server_tree.rs @@ -13,10 +13,10 @@ use std::{ sync::{Arc, Weak}, }; -use collections::{HashMap, IndexMap}; +use collections::IndexMap; use gpui::{App, AppContext as _, Entity, Subscription}; use language::{ - Attach, CachedLspAdapter, LanguageName, LanguageRegistry, ManifestDelegate, + CachedLspAdapter, LanguageName, LanguageRegistry, ManifestDelegate, language_settings::AllLanguageSettings, }; use lsp::LanguageServerName; @@ -38,7 +38,6 @@ pub(crate) struct ServersForWorktree { pub struct LanguageServerTree { manifest_tree: Entity<ManifestTree>, pub(crate) instances: BTreeMap<WorktreeId, ServersForWorktree>, - attach_kind_cache: HashMap<LanguageServerName, Attach>, languages: Arc<LanguageRegistry>, _subscriptions: Subscription, } @@ -53,7 +52,6 @@ pub struct LanguageServerTreeNode(Weak<InnerTreeNode>); #[derive(Debug)] pub(crate) struct LaunchDisposition<'a> { pub(crate) server_name: &'a LanguageServerName, - pub(crate) attach: Attach, pub(crate) path: ProjectPath, pub(crate) settings: Arc<LspSettings>, } @@ -62,7 +60,6 @@ impl<'a> From<&'a InnerTreeNode> for LaunchDisposition<'a> { fn from(value: &'a InnerTreeNode) -> Self { LaunchDisposition { server_name: &value.name, - attach: value.attach, path: value.path.clone(), settings: value.settings.clone(), } @@ -105,7 +102,6 @@ impl From<Weak<InnerTreeNode>> for LanguageServerTreeNode { pub struct InnerTreeNode { id: OnceLock<LanguageServerId>, name: LanguageServerName, - attach: Attach, path: ProjectPath, settings: Arc<LspSettings>, } @@ -113,14 +109,12 @@ pub struct InnerTreeNode { impl InnerTreeNode { fn new( name: LanguageServerName, - attach: Attach, path: ProjectPath, settings: impl Into<Arc<LspSettings>>, ) -> Self { InnerTreeNode { id: Default::default(), name, - attach, path, settings: settings.into(), } @@ -130,8 +124,11 @@ impl InnerTreeNode { /// Determines how the list of adapters to query should be constructed. pub(crate) enum AdapterQuery<'a> { /// Search for roots of all adapters associated with a given language name. + /// Layman: Look for all project roots along the queried path that have any + /// language server associated with this language running. Language(&'a LanguageName), /// Search for roots of adapter with a given name. + /// Layman: Look for all project roots along the queried path that have this server running. Adapter(&'a LanguageServerName), } @@ -147,7 +144,7 @@ impl LanguageServerTree { }), manifest_tree, instances: Default::default(), - attach_kind_cache: Default::default(), + languages, }) } @@ -223,7 +220,6 @@ impl LanguageServerTree { .and_then(|name| roots.get(&name)) .cloned() .unwrap_or_else(|| root_path.clone()); - let attach = adapter.attach_kind(); let inner_node = self .instances @@ -237,7 +233,6 @@ impl LanguageServerTree { ( Arc::new(InnerTreeNode::new( adapter.name(), - attach, root_path.clone(), settings.clone(), )), @@ -379,7 +374,6 @@ pub(crate) struct ServerTreeRebase<'a> { impl<'tree> ServerTreeRebase<'tree> { fn new(new_tree: &'tree mut LanguageServerTree) -> Self { let old_contents = std::mem::take(&mut new_tree.instances); - new_tree.attach_kind_cache.clear(); let all_server_ids = old_contents .values() .flat_map(|nodes| { @@ -446,10 +440,7 @@ impl<'tree> ServerTreeRebase<'tree> { .get(&disposition.path.worktree_id) .and_then(|worktree_nodes| worktree_nodes.roots.get(&disposition.path.path)) .and_then(|roots| roots.get(&disposition.name)) - .filter(|(old_node, _)| { - disposition.attach == old_node.attach - && disposition.settings == old_node.settings - }) + .filter(|(old_node, _)| disposition.settings == old_node.settings) else { return Some(node); }; diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 22ec8438a2b55f458b1ac0520f15fcabaf90087c..b3a9e6fdf5e61cd8f057923cc6c76aaf08818501 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -73,11 +73,10 @@ use gpui::{ App, AppContext, AsyncApp, BorrowAppContext, Context, Entity, EventEmitter, Hsla, SharedString, Task, WeakEntity, Window, }; -use itertools::Itertools; use language::{ - Buffer, BufferEvent, Capability, CodeLabel, CursorShape, DiagnosticSourceKind, Language, - LanguageName, LanguageRegistry, PointUtf16, ToOffset, ToPointUtf16, Toolchain, ToolchainList, - Transaction, Unclipped, language_settings::InlayHintKind, proto::split_operations, + Buffer, BufferEvent, Capability, CodeLabel, CursorShape, Language, LanguageName, + LanguageRegistry, PointUtf16, ToOffset, ToPointUtf16, Toolchain, ToolchainList, Transaction, + Unclipped, language_settings::InlayHintKind, proto::split_operations, }; use lsp::{ CodeActionKind, CompletionContext, CompletionItemKind, DocumentHighlightKind, InsertTextMode, @@ -97,7 +96,7 @@ use rpc::{ }; use search::{SearchInputKind, SearchQuery, SearchResult}; use search_history::SearchHistory; -use settings::{InvalidSettingsError, Settings, SettingsLocation, SettingsStore}; +use settings::{InvalidSettingsError, Settings, SettingsLocation, SettingsSources, SettingsStore}; use smol::channel::Receiver; use snippet::Snippet; use snippet_provider::SnippetProvider; @@ -113,7 +112,7 @@ use std::{ use task_store::TaskStore; use terminals::Terminals; -use text::{Anchor, BufferId, Point}; +use text::{Anchor, BufferId, OffsetRangeExt, Point, Rope}; use toolchain_store::EmptyToolchainStore; use util::{ ResultExt as _, @@ -277,6 +276,13 @@ pub enum Event { LanguageServerAdded(LanguageServerId, LanguageServerName, Option<WorktreeId>), LanguageServerRemoved(LanguageServerId), LanguageServerLog(LanguageServerId, LanguageServerLogType, String), + // [`lsp::notification::DidOpenTextDocument`] was sent to this server using the buffer data. + // Zed's buffer-related data is updated accordingly. + LanguageServerBufferRegistered { + server_id: LanguageServerId, + buffer_id: BufferId, + buffer_abs_path: PathBuf, + }, Toast { notification_id: SharedString, message: String, @@ -299,7 +305,7 @@ pub enum Event { language_server_id: LanguageServerId, }, DiagnosticsUpdated { - path: ProjectPath, + paths: Vec<ProjectPath>, language_server_id: LanguageServerId, }, RemoteIdChanged(Option<u64>), @@ -590,7 +596,7 @@ pub(crate) struct CoreCompletion { } /// A code action provided by a language server. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub struct CodeAction { /// The id of the language server that produced this code action. pub server_id: LanguageServerId, @@ -604,7 +610,7 @@ pub struct CodeAction { } /// An action sent back by a language server. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub enum LspAction { /// An action with the full data, may have a command or may not. /// May require resolving. @@ -661,10 +667,10 @@ pub enum ResolveState { } impl InlayHint { - pub fn text(&self) -> String { + pub fn text(&self) -> Rope { match &self.label { - InlayHintLabel::String(s) => s.to_owned(), - InlayHintLabel::LabelParts(parts) => parts.iter().map(|part| &part.value).join(""), + InlayHintLabel::String(s) => Rope::from(s), + InlayHintLabel::LabelParts(parts) => parts.iter().map(|part| &*part.value).collect(), } } } @@ -942,10 +948,38 @@ pub enum PulledDiagnostics { }, } +/// Whether to disable all AI features in Zed. +/// +/// Default: false +#[derive(Copy, Clone, Debug)] +pub struct DisableAiSettings { + pub disable_ai: bool, +} + +impl settings::Settings for DisableAiSettings { + const KEY: Option<&'static str> = Some("disable_ai"); + + type FileContent = Option<bool>; + + fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> { + Ok(Self { + disable_ai: sources + .user + .or(sources.server) + .copied() + .flatten() + .unwrap_or(sources.default.ok_or_else(Self::missing_default)?), + }) + } + + fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {} +} + impl Project { pub fn init_settings(cx: &mut App) { WorktreeSettings::register(cx); ProjectSettings::register(cx); + DisableAiSettings::register(cx); } pub fn init(client: &Arc<Client>, cx: &mut App) { @@ -998,8 +1032,9 @@ impl Project { cx.subscribe(&worktree_store, Self::on_worktree_store_event) .detach(); + let weak_self = cx.weak_entity(); let context_server_store = - cx.new(|cx| ContextServerStore::new(worktree_store.clone(), cx)); + cx.new(|cx| ContextServerStore::new(worktree_store.clone(), weak_self, cx)); let environment = cx.new(|_| ProjectEnvironment::new(env)); let manifest_tree = ManifestTree::new(worktree_store.clone(), cx); @@ -1167,8 +1202,9 @@ impl Project { cx.subscribe(&worktree_store, Self::on_worktree_store_event) .detach(); + let weak_self = cx.weak_entity(); let context_server_store = - cx.new(|cx| ContextServerStore::new(worktree_store.clone(), cx)); + cx.new(|cx| ContextServerStore::new(worktree_store.clone(), weak_self, cx)); let buffer_store = cx.new(|cx| { BufferStore::remote( @@ -1360,10 +1396,7 @@ impl Project { fs: Arc<dyn Fs>, cx: AsyncApp, ) -> Result<Entity<Self>> { - client - .authenticate_and_connect(true, &cx) - .await - .into_response()?; + client.connect(true, &cx).await.into_response()?; let subscriptions = [ EntitySubscription::Project(client.subscribe_to_entity::<Self>(remote_id)?), @@ -1428,8 +1461,6 @@ impl Project { let image_store = cx.new(|cx| { ImageStore::remote(worktree_store.clone(), client.clone().into(), remote_id, cx) })?; - let context_server_store = - cx.new(|cx| ContextServerStore::new(worktree_store.clone(), cx))?; let environment = cx.new(|_| ProjectEnvironment::new(None))?; @@ -1496,6 +1527,10 @@ impl Project { let snippets = SnippetProvider::new(fs.clone(), BTreeSet::from_iter([]), cx); + let weak_self = cx.weak_entity(); + let context_server_store = + cx.new(|cx| ContextServerStore::new(worktree_store.clone(), weak_self, cx)); + let mut worktrees = Vec::new(); for worktree in response.payload.worktrees { let worktree = @@ -2860,18 +2895,17 @@ impl Project { cx: &mut Context<Self>, ) { match event { - LspStoreEvent::DiagnosticsUpdated { - language_server_id, - path, - } => cx.emit(Event::DiagnosticsUpdated { - path: path.clone(), - language_server_id: *language_server_id, - }), - LspStoreEvent::LanguageServerAdded(language_server_id, name, worktree_id) => cx.emit( - Event::LanguageServerAdded(*language_server_id, name.clone(), *worktree_id), + LspStoreEvent::DiagnosticsUpdated { server_id, paths } => { + cx.emit(Event::DiagnosticsUpdated { + paths: paths.clone(), + language_server_id: *server_id, + }) + } + LspStoreEvent::LanguageServerAdded(server_id, name, worktree_id) => cx.emit( + Event::LanguageServerAdded(*server_id, name.clone(), *worktree_id), ), - LspStoreEvent::LanguageServerRemoved(language_server_id) => { - cx.emit(Event::LanguageServerRemoved(*language_server_id)) + LspStoreEvent::LanguageServerRemoved(server_id) => { + cx.emit(Event::LanguageServerRemoved(*server_id)) } LspStoreEvent::LanguageServerLog(server_id, log_type, string) => cx.emit( Event::LanguageServerLog(*server_id, log_type.clone(), string.clone()), @@ -2902,8 +2936,8 @@ impl Project { } LspStoreEvent::LanguageServerUpdate { language_server_id, - message, name, + message, } => { if self.is_local() { self.enqueue_buffer_ordered_message( @@ -2915,6 +2949,32 @@ impl Project { ) .ok(); } + + match message { + proto::update_language_server::Variant::MetadataUpdated(update) => { + if let Some(capabilities) = update + .capabilities + .as_ref() + .and_then(|capabilities| serde_json::from_str(capabilities).ok()) + { + self.lsp_store.update(cx, |lsp_store, _| { + lsp_store + .lsp_server_capabilities + .insert(*language_server_id, capabilities); + }); + } + } + proto::update_language_server::Variant::RegisteredForBuffer(update) => { + if let Some(buffer_id) = BufferId::new(update.buffer_id).ok() { + cx.emit(Event::LanguageServerBufferRegistered { + buffer_id, + server_id: *language_server_id, + buffer_abs_path: PathBuf::from(&update.buffer_abs_path), + }); + } + } + _ => (), + } } LspStoreEvent::Notification(message) => cx.emit(Event::Toast { notification_id: "lsp".into(), @@ -3364,8 +3424,14 @@ impl Project { cx: &mut Context<Self>, ) -> Task<Result<Vec<LocationLink>>> { let position = position.to_point_utf16(buffer.read(cx)); - self.lsp_store.update(cx, |lsp_store, cx| { + let guard = self.retain_remotely_created_models(cx); + let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.definitions(buffer, position, cx) + }); + cx.background_spawn(async move { + let result = task.await; + drop(guard); + result }) } @@ -3376,8 +3442,14 @@ impl Project { cx: &mut Context<Self>, ) -> Task<Result<Vec<LocationLink>>> { let position = position.to_point_utf16(buffer.read(cx)); - self.lsp_store.update(cx, |lsp_store, cx| { + let guard = self.retain_remotely_created_models(cx); + let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.declarations(buffer, position, cx) + }); + cx.background_spawn(async move { + let result = task.await; + drop(guard); + result }) } @@ -3388,8 +3460,14 @@ impl Project { cx: &mut Context<Self>, ) -> Task<Result<Vec<LocationLink>>> { let position = position.to_point_utf16(buffer.read(cx)); - self.lsp_store.update(cx, |lsp_store, cx| { + let guard = self.retain_remotely_created_models(cx); + let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.type_definitions(buffer, position, cx) + }); + cx.background_spawn(async move { + let result = task.await; + drop(guard); + result }) } @@ -3400,8 +3478,14 @@ impl Project { cx: &mut Context<Self>, ) -> Task<Result<Vec<LocationLink>>> { let position = position.to_point_utf16(buffer.read(cx)); - self.lsp_store.update(cx, |lsp_store, cx| { + let guard = self.retain_remotely_created_models(cx); + let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.implementations(buffer, position, cx) + }); + cx.background_spawn(async move { + let result = task.await; + drop(guard); + result }) } @@ -3412,17 +3496,24 @@ impl Project { cx: &mut Context<Self>, ) -> Task<Result<Vec<Location>>> { let position = position.to_point_utf16(buffer.read(cx)); - self.lsp_store.update(cx, |lsp_store, cx| { + let guard = self.retain_remotely_created_models(cx); + let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.references(buffer, position, cx) + }); + cx.background_spawn(async move { + let result = task.await; + drop(guard); + result }) } - fn document_highlights_impl( + pub fn document_highlights<T: ToPointUtf16>( &mut self, buffer: &Entity<Buffer>, - position: PointUtf16, + position: T, cx: &mut Context<Self>, ) -> Task<Result<Vec<DocumentHighlight>>> { + let position = position.to_point_utf16(buffer.read(cx)); self.request_lsp( buffer.clone(), LanguageServerToQuery::FirstCapable, @@ -3431,16 +3522,6 @@ impl Project { ) } - pub fn document_highlights<T: ToPointUtf16>( - &mut self, - buffer: &Entity<Buffer>, - position: T, - cx: &mut Context<Self>, - ) -> Task<Result<Vec<DocumentHighlight>>> { - let position = position.to_point_utf16(buffer.read(cx)); - self.document_highlights_impl(buffer, position, cx) - } - pub fn document_symbols( &mut self, buffer: &Entity<Buffer>, @@ -3539,14 +3620,14 @@ impl Project { .update(cx, |lsp_store, cx| lsp_store.hover(buffer, position, cx)) } - pub fn linked_edit( + pub fn linked_edits( &self, buffer: &Entity<Buffer>, position: Anchor, cx: &mut Context<Self>, ) -> Task<Result<Vec<Range<Anchor>>>> { self.lsp_store.update(cx, |lsp_store, cx| { - lsp_store.linked_edit(buffer, position, cx) + lsp_store.linked_edits(buffer, position, cx) }) } @@ -3577,20 +3658,29 @@ impl Project { }) } - pub fn code_lens<T: Clone + ToOffset>( + pub fn code_lens_actions<T: Clone + ToOffset>( &mut self, - buffer_handle: &Entity<Buffer>, + buffer: &Entity<Buffer>, range: Range<T>, cx: &mut Context<Self>, ) -> Task<Result<Vec<CodeAction>>> { - let snapshot = buffer_handle.read(cx).snapshot(); - let range = snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end); + let snapshot = buffer.read(cx).snapshot(); + let range = range.clone().to_owned().to_point(&snapshot); + let range_start = snapshot.anchor_before(range.start); + let range_end = if range.start == range.end { + range_start + } else { + snapshot.anchor_after(range.end) + }; + let range = range_start..range_end; let code_lens_actions = self .lsp_store - .update(cx, |lsp_store, cx| lsp_store.code_lens(buffer_handle, cx)); + .update(cx, |lsp_store, cx| lsp_store.code_lens_actions(buffer, cx)); cx.background_spawn(async move { - let mut code_lens_actions = code_lens_actions.await?; + let mut code_lens_actions = code_lens_actions + .await + .map_err(|e| anyhow!("code lens fetch failed: {e:#}"))?; code_lens_actions.retain(|code_lens_action| { range .start @@ -3629,12 +3719,13 @@ impl Project { }) } - fn prepare_rename_impl( + pub fn prepare_rename<T: ToPointUtf16>( &mut self, buffer: Entity<Buffer>, - position: PointUtf16, + position: T, cx: &mut Context<Self>, ) -> Task<Result<PrepareRenameResponse>> { + let position = position.to_point_utf16(buffer.read(cx)); self.request_lsp( buffer, LanguageServerToQuery::FirstCapable, @@ -3642,15 +3733,6 @@ impl Project { cx, ) } - pub fn prepare_rename<T: ToPointUtf16>( - &mut self, - buffer: Entity<Buffer>, - position: T, - cx: &mut Context<Self>, - ) -> Task<Result<PrepareRenameResponse>> { - let position = position.to_point_utf16(buffer.read(cx)); - self.prepare_rename_impl(buffer, position, cx) - } pub fn perform_rename<T: ToPointUtf16>( &mut self, @@ -3746,27 +3828,6 @@ impl Project { }) } - pub fn update_diagnostics( - &mut self, - language_server_id: LanguageServerId, - source_kind: DiagnosticSourceKind, - result_id: Option<String>, - params: lsp::PublishDiagnosticsParams, - disk_based_sources: &[String], - cx: &mut Context<Self>, - ) -> Result<(), anyhow::Error> { - self.lsp_store.update(cx, |lsp_store, cx| { - lsp_store.update_diagnostics( - language_server_id, - params, - result_id, - source_kind, - disk_based_sources, - cx, - ) - }) - } - pub fn search(&mut self, query: SearchQuery, cx: &mut Context<Self>) -> Receiver<SearchResult> { let (result_tx, result_rx) = smol::channel::unbounded(); @@ -3953,7 +4014,7 @@ impl Project { let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.request_lsp(buffer_handle, server, request, cx) }); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { let result = task.await; drop(guard); result @@ -4918,63 +4979,53 @@ impl Project { } pub fn any_language_server_supports_inlay_hints(&self, buffer: &Buffer, cx: &mut App) -> bool { - self.lsp_store.update(cx, |this, cx| { - this.language_servers_for_local_buffer(buffer, cx) - .any( - |(_, server)| match server.capabilities().inlay_hint_provider { - Some(lsp::OneOf::Left(enabled)) => enabled, - Some(lsp::OneOf::Right(_)) => true, - None => false, - }, - ) + let Some(language) = buffer.language().cloned() else { + return false; + }; + self.lsp_store.update(cx, |lsp_store, _| { + let relevant_language_servers = lsp_store + .languages + .lsp_adapters(&language.name()) + .into_iter() + .map(|lsp_adapter| lsp_adapter.name()) + .collect::<HashSet<_>>(); + lsp_store + .language_server_statuses() + .filter_map(|(server_id, server_status)| { + relevant_language_servers + .contains(&server_status.name) + .then_some(server_id) + }) + .filter_map(|server_id| lsp_store.lsp_server_capabilities.get(&server_id)) + .any(InlayHints::check_capabilities) }) } pub fn language_server_id_for_name( &self, buffer: &Buffer, - name: &str, - cx: &mut App, - ) -> Task<Option<LanguageServerId>> { - if self.is_local() { - Task::ready(self.lsp_store.update(cx, |lsp_store, cx| { - lsp_store - .language_servers_for_local_buffer(buffer, cx) - .find_map(|(adapter, server)| { - if adapter.name.0 == name { - Some(server.server_id()) - } else { - None - } - }) - })) - } else if let Some(project_id) = self.remote_id() { - let request = self.client.request(proto::LanguageServerIdForName { - project_id, - buffer_id: buffer.remote_id().to_proto(), - name: name.to_string(), - }); - cx.background_spawn(async move { - let response = request.await.log_err()?; - response.server_id.map(LanguageServerId::from_proto) - }) - } else if let Some(ssh_client) = self.ssh_client.as_ref() { - let request = - ssh_client - .read(cx) - .proto_client() - .request(proto::LanguageServerIdForName { - project_id: SSH_PROJECT_ID, - buffer_id: buffer.remote_id().to_proto(), - name: name.to_string(), - }); - cx.background_spawn(async move { - let response = request.await.log_err()?; - response.server_id.map(LanguageServerId::from_proto) - }) - } else { - Task::ready(None) + name: &LanguageServerName, + cx: &App, + ) -> Option<LanguageServerId> { + let language = buffer.language()?; + let relevant_language_servers = self + .languages + .lsp_adapters(&language.name()) + .into_iter() + .map(|lsp_adapter| lsp_adapter.name()) + .collect::<HashSet<_>>(); + if !relevant_language_servers.contains(name) { + return None; } + self.language_server_statuses(cx) + .filter(|(_, server_status)| relevant_language_servers.contains(&server_status.name)) + .find_map(|(server_id, server_status)| { + if &server_status.name == name { + Some(server_id) + } else { + None + } + }) } pub fn has_language_servers_for(&self, buffer: &Buffer, cx: &mut App) -> bool { diff --git a/crates/project/src/project_settings.rs b/crates/project/src/project_settings.rs index 1c35f1652232113ed83c41fc6dee3d6b32251358..20be7fef85c79910904fe577f0691fba57424d45 100644 --- a/crates/project/src/project_settings.rs +++ b/crates/project/src/project_settings.rs @@ -326,6 +326,79 @@ impl DiagnosticSeverity { } } +/// Determines the severity of the diagnostic that should be moved to. +#[derive(PartialEq, PartialOrd, Clone, Copy, Debug, Eq, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub enum GoToDiagnosticSeverity { + /// Errors + Error = 3, + /// Warnings + Warning = 2, + /// Information + Information = 1, + /// Hints + Hint = 0, +} + +impl From<lsp::DiagnosticSeverity> for GoToDiagnosticSeverity { + fn from(severity: lsp::DiagnosticSeverity) -> Self { + match severity { + lsp::DiagnosticSeverity::ERROR => Self::Error, + lsp::DiagnosticSeverity::WARNING => Self::Warning, + lsp::DiagnosticSeverity::INFORMATION => Self::Information, + lsp::DiagnosticSeverity::HINT => Self::Hint, + _ => Self::Error, + } + } +} + +impl GoToDiagnosticSeverity { + pub fn min() -> Self { + Self::Hint + } + + pub fn max() -> Self { + Self::Error + } +} + +/// Allows filtering diagnostics that should be moved to. +#[derive(PartialEq, Clone, Copy, Debug, Deserialize, JsonSchema)] +#[serde(untagged)] +pub enum GoToDiagnosticSeverityFilter { + /// Move to diagnostics of a specific severity. + Only(GoToDiagnosticSeverity), + + /// Specify a range of severities to include. + Range { + /// Minimum severity to move to. Defaults no "error". + #[serde(default = "GoToDiagnosticSeverity::min")] + min: GoToDiagnosticSeverity, + /// Maximum severity to move to. Defaults to "hint". + #[serde(default = "GoToDiagnosticSeverity::max")] + max: GoToDiagnosticSeverity, + }, +} + +impl Default for GoToDiagnosticSeverityFilter { + fn default() -> Self { + Self::Range { + min: GoToDiagnosticSeverity::min(), + max: GoToDiagnosticSeverity::max(), + } + } +} + +impl GoToDiagnosticSeverityFilter { + pub fn matches(&self, severity: lsp::DiagnosticSeverity) -> bool { + let severity: GoToDiagnosticSeverity = severity.into(); + match self { + Self::Only(target) => *target == severity, + Self::Range { min, max } => severity >= *min && severity <= *max, + } + } +} + #[derive(Copy, Clone, Debug, Default, Serialize, Deserialize, JsonSchema)] pub struct GitSettings { /// Whether or not to show the git gutter. @@ -508,7 +581,7 @@ impl Settings for ProjectSettings { #[derive(Deserialize)] struct VsCodeContextServerCommand { - command: String, + command: PathBuf, args: Option<Vec<String>>, env: Option<HashMap<String, String>>, // note: we don't support envFile and type diff --git a/crates/project/src/project_tests.rs b/crates/project/src/project_tests.rs index 779cf95add9ad5547e13d85d87c0dcc3935ab326..cb3c9efe60584df3b2353641d4b676d85da51476 100644 --- a/crates/project/src/project_tests.rs +++ b/crates/project/src/project_tests.rs @@ -18,9 +18,10 @@ use git::{ use git2::RepositoryInitOptions; use gpui::{App, BackgroundExecutor, SemanticVersion, UpdateGlobal}; use http_client::Url; +use itertools::Itertools; use language::{ - Diagnostic, DiagnosticEntry, DiagnosticSet, DiskState, FakeLspAdapter, LanguageConfig, - LanguageMatcher, LanguageName, LineEnding, OffsetRangeExt, Point, ToPoint, + Diagnostic, DiagnosticEntry, DiagnosticSet, DiagnosticSourceKind, DiskState, FakeLspAdapter, + LanguageConfig, LanguageMatcher, LanguageName, LineEnding, OffsetRangeExt, Point, ToPoint, language_settings::{AllLanguageSettings, LanguageSettingsContent, language_settings}, tree_sitter_rust, tree_sitter_typescript, }; @@ -1100,7 +1101,7 @@ async fn test_reporting_fs_changes_to_language_servers(cx: &mut gpui::TestAppCon let fake_server = fake_servers.next().await.unwrap(); let (server_id, server_name) = lsp_store.read_with(cx, |lsp_store, _| { let (id, status) = lsp_store.language_server_statuses().next().unwrap(); - (id, LanguageServerName::from(status.name.as_str())) + (id, status.name.clone()) }); // Simulate jumping to a definition in a dependency outside of the worktree. @@ -1618,7 +1619,7 @@ async fn test_disk_based_diagnostics_progress(cx: &mut gpui::TestAppContext) { events.next().await.unwrap(), Event::DiagnosticsUpdated { language_server_id: LanguageServerId(0), - path: (worktree_id, Path::new("a.rs")).into() + paths: vec![(worktree_id, Path::new("a.rs")).into()], } ); @@ -1666,7 +1667,7 @@ async fn test_disk_based_diagnostics_progress(cx: &mut gpui::TestAppContext) { events.next().await.unwrap(), Event::DiagnosticsUpdated { language_server_id: LanguageServerId(0), - path: (worktree_id, Path::new("a.rs")).into() + paths: vec![(worktree_id, Path::new("a.rs")).into()], } ); @@ -1698,7 +1699,7 @@ async fn test_restarting_server_with_diagnostics_running(cx: &mut gpui::TestAppC name: "the-language-server", disk_based_diagnostics_sources: vec!["disk".into()], disk_based_diagnostics_progress_token: Some(progress_token.into()), - ..Default::default() + ..FakeLspAdapter::default() }, ); @@ -1710,6 +1711,7 @@ async fn test_restarting_server_with_diagnostics_running(cx: &mut gpui::TestAppC }) .await .unwrap(); + let buffer_id = buffer.read_with(cx, |buffer, _| buffer.remote_id()); // Simulate diagnostics starting to update. let fake_server = fake_servers.next().await.unwrap(); fake_server.start_progress(progress_token).await; @@ -1736,6 +1738,14 @@ async fn test_restarting_server_with_diagnostics_running(cx: &mut gpui::TestAppC ); assert_eq!(events.next().await.unwrap(), Event::RefreshInlayHints); fake_server.start_progress(progress_token).await; + assert_eq!( + events.next().await.unwrap(), + Event::LanguageServerBufferRegistered { + server_id: LanguageServerId(1), + buffer_id, + buffer_abs_path: PathBuf::from(path!("/dir/a.rs")), + } + ); assert_eq!( events.next().await.unwrap(), Event::DiskBasedDiagnosticsStarted { diff --git a/crates/project/src/search.rs b/crates/project/src/search.rs index 44732b23cd4fb7ff8044b65e55d0ea09f0c3b5c9..4f024837c8be8946c8feb00f398779b604afbbf0 100644 --- a/crates/project/src/search.rs +++ b/crates/project/src/search.rs @@ -193,6 +193,30 @@ impl SearchQuery { } pub fn from_proto(message: proto::SearchQuery) -> Result<Self> { + let files_to_include = if message.files_to_include.is_empty() { + message + .files_to_include_legacy + .split(',') + .map(str::trim) + .filter(|&glob_str| !glob_str.is_empty()) + .map(|s| s.to_string()) + .collect() + } else { + message.files_to_include + }; + + let files_to_exclude = if message.files_to_exclude.is_empty() { + message + .files_to_exclude_legacy + .split(',') + .map(str::trim) + .filter(|&glob_str| !glob_str.is_empty()) + .map(|s| s.to_string()) + .collect() + } else { + message.files_to_exclude + }; + if message.regex { Self::regex( message.query, @@ -200,8 +224,8 @@ impl SearchQuery { message.case_sensitive, message.include_ignored, false, - deserialize_path_matches(&message.files_to_include)?, - deserialize_path_matches(&message.files_to_exclude)?, + PathMatcher::new(files_to_include)?, + PathMatcher::new(files_to_exclude)?, message.match_full_paths, None, // search opened only don't need search remote ) @@ -211,8 +235,8 @@ impl SearchQuery { message.whole_word, message.case_sensitive, message.include_ignored, - deserialize_path_matches(&message.files_to_include)?, - deserialize_path_matches(&message.files_to_exclude)?, + PathMatcher::new(files_to_include)?, + PathMatcher::new(files_to_exclude)?, false, None, // search opened only don't need search remote ) @@ -236,15 +260,20 @@ impl SearchQuery { } pub fn to_proto(&self) -> proto::SearchQuery { + let files_to_include = self.files_to_include().sources().to_vec(); + let files_to_exclude = self.files_to_exclude().sources().to_vec(); proto::SearchQuery { query: self.as_str().to_string(), regex: self.is_regex(), whole_word: self.whole_word(), case_sensitive: self.case_sensitive(), include_ignored: self.include_ignored(), - files_to_include: self.files_to_include().sources().join(","), - files_to_exclude: self.files_to_exclude().sources().join(","), + files_to_include: files_to_include.clone(), + files_to_exclude: files_to_exclude.clone(), match_full_paths: self.match_full_paths(), + // Populate legacy fields for backwards compatibility + files_to_include_legacy: files_to_include.join(","), + files_to_exclude_legacy: files_to_exclude.join(","), } } @@ -520,14 +549,6 @@ impl SearchQuery { } } -pub fn deserialize_path_matches(glob_set: &str) -> anyhow::Result<PathMatcher> { - let globs = glob_set - .split(',') - .map(str::trim) - .filter(|&glob_str| !glob_str.is_empty()); - Ok(PathMatcher::new(globs)?) -} - #[cfg(test)] mod tests { use super::*; diff --git a/crates/project/src/search_history.rs b/crates/project/src/search_history.rs index 382d04f8e47c8005de28670b756c0701f12d5dbc..90b169bb0c5c83a9eb722c964bfb549ead1d5494 100644 --- a/crates/project/src/search_history.rs +++ b/crates/project/src/search_history.rs @@ -45,12 +45,6 @@ impl SearchHistory { } pub fn add(&mut self, cursor: &mut SearchHistoryCursor, search_string: String) { - if let Some(selected_ix) = cursor.selection { - if self.history.get(selected_ix) == Some(&search_string) { - return; - } - } - if self.insertion_behavior == QueryInsertionBehavior::ReplacePreviousIfContains { if let Some(previously_searched) = self.history.back_mut() { if search_string.contains(previously_searched.as_str()) { @@ -72,18 +66,12 @@ impl SearchHistory { } pub fn next(&mut self, cursor: &mut SearchHistoryCursor) -> Option<&str> { - let history_size = self.history.len(); - if history_size == 0 { - return None; - } - let selected = cursor.selection?; - if selected == history_size - 1 { - return None; - } let next_index = selected + 1; + + let next = self.history.get(next_index)?; cursor.selection = Some(next_index); - Some(&self.history[next_index]) + Some(next) } pub fn current(&self, cursor: &SearchHistoryCursor) -> Option<&str> { @@ -92,25 +80,17 @@ impl SearchHistory { .and_then(|selected_ix| self.history.get(selected_ix).map(|s| s.as_str())) } + /// Get the previous history entry using the given `SearchHistoryCursor`. + /// Uses the last element in the history when there is no cursor. pub fn previous(&mut self, cursor: &mut SearchHistoryCursor) -> Option<&str> { - let history_size = self.history.len(); - if history_size == 0 { - return None; - } - let prev_index = match cursor.selection { - Some(selected_index) => { - if selected_index == 0 { - return None; - } else { - selected_index - 1 - } - } - None => history_size - 1, + Some(index) => index.checked_sub(1)?, + None => self.history.len().checked_sub(1)?, }; + let previous = self.history.get(prev_index)?; cursor.selection = Some(prev_index); - Some(&self.history[prev_index]) + Some(previous) } } @@ -158,6 +138,14 @@ mod tests { ); assert_eq!(search_history.current(&cursor), Some("rustlang")); + // add item when it equals to current item if it's not the last one + search_history.add(&mut cursor, "php".to_string()); + search_history.previous(&mut cursor); + assert_eq!(search_history.current(&cursor), Some("rustlang")); + search_history.add(&mut cursor, "rustlang".to_string()); + assert_eq!(search_history.history.len(), 3, "Should add item"); + assert_eq!(search_history.current(&cursor), Some("rustlang")); + // push enough items to test SEARCH_HISTORY_LIMIT for i in 0..MAX_HISTORY_LEN * 2 { search_history.add(&mut cursor, format!("item{i}")); diff --git a/crates/project/src/terminals.rs b/crates/project/src/terminals.rs index 385fdf9082baaf86bcdb547841cc98d161c5c508..973d4e881191dcb21414fbd5d0f7cc85467e329c 100644 --- a/crates/project/src/terminals.rs +++ b/crates/project/src/terminals.rs @@ -16,7 +16,7 @@ use std::{ use task::{DEFAULT_REMOTE_SHELL, Shell, ShellBuilder, SpawnInTerminal}; use terminal::{ TaskState, TaskStatus, Terminal, TerminalBuilder, - terminal_settings::{self, TerminalSettings, VenvSettings}, + terminal_settings::{self, ActivateScript, TerminalSettings, VenvSettings}, }; use util::{ ResultExt, @@ -169,7 +169,7 @@ impl Project { .read(cx) .get_cli_environment() .unwrap_or_default(); - env.extend(settings.env.clone()); + env.extend(settings.env); match self.ssh_details(cx) { Some(SshDetails { @@ -213,17 +213,24 @@ impl Project { cx: &mut Context<Self>, ) -> Result<Entity<Terminal>> { let this = &mut *self; + let ssh_details = this.ssh_details(cx); let path: Option<Arc<Path>> = match &kind { TerminalKind::Shell(path) => path.as_ref().map(|path| Arc::from(path.as_ref())), TerminalKind::Task(spawn_task) => { if let Some(cwd) = &spawn_task.cwd { - Some(Arc::from(cwd.as_ref())) + if ssh_details.is_some() { + Some(Arc::from(cwd.as_ref())) + } else { + let cwd = cwd.to_string_lossy(); + let tilde_substituted = shellexpand::tilde(&cwd); + Some(Arc::from(Path::new(tilde_substituted.as_ref()))) + } } else { this.active_project_directory(cx) } } }; - let ssh_details = this.ssh_details(cx); + let is_ssh_terminal = ssh_details.is_some(); let mut settings_location = None; @@ -247,7 +254,7 @@ impl Project { .unwrap_or_default(); // Then extend it with the explicit env variables from the settings, so they take // precedence. - env.extend(settings.env.clone()); + env.extend(settings.env); let local_path = if is_ssh_terminal { None } else { path.clone() }; @@ -256,8 +263,11 @@ impl Project { let (spawn_task, shell) = match kind { TerminalKind::Shell(_) => { if let Some(python_venv_directory) = &python_venv_directory { - python_venv_activate_command = - this.python_activate_command(python_venv_directory, &settings.detect_venv); + python_venv_activate_command = this.python_activate_command( + python_venv_directory, + &settings.detect_venv, + &settings.shell, + ); } match ssh_details { @@ -510,10 +520,27 @@ impl Project { }) } + fn activate_script_kind(shell: Option<&str>) -> ActivateScript { + let shell_env = std::env::var("SHELL").ok(); + let shell_path = shell.or_else(|| shell_env.as_deref()); + let shell = std::path::Path::new(shell_path.unwrap_or("")) + .file_name() + .and_then(|name| name.to_str()) + .unwrap_or(""); + match shell { + "fish" => ActivateScript::Fish, + "tcsh" => ActivateScript::Csh, + "nu" => ActivateScript::Nushell, + "powershell" | "pwsh" => ActivateScript::PowerShell, + _ => ActivateScript::Default, + } + } + fn python_activate_command( &self, venv_base_directory: &Path, venv_settings: &VenvSettings, + shell: &Shell, ) -> Option<String> { let venv_settings = venv_settings.as_option()?; let activate_keyword = match venv_settings.activate_script { @@ -523,36 +550,62 @@ impl Project { }, terminal_settings::ActivateScript::Nushell => "overlay use", terminal_settings::ActivateScript::PowerShell => ".", + terminal_settings::ActivateScript::Pyenv => "pyenv", _ => "source", }; - let activate_script_name = match venv_settings.activate_script { - terminal_settings::ActivateScript::Default => "activate", + let script_kind = + if venv_settings.activate_script == terminal_settings::ActivateScript::Default { + match shell { + Shell::Program(program) => Self::activate_script_kind(Some(program)), + Shell::WithArguments { + program, + args: _, + title_override: _, + } => Self::activate_script_kind(Some(program)), + Shell::System => Self::activate_script_kind(None), + } + } else { + venv_settings.activate_script + }; + + let activate_script_name = match script_kind { + terminal_settings::ActivateScript::Default + | terminal_settings::ActivateScript::Pyenv => "activate", terminal_settings::ActivateScript::Csh => "activate.csh", terminal_settings::ActivateScript::Fish => "activate.fish", terminal_settings::ActivateScript::Nushell => "activate.nu", terminal_settings::ActivateScript::PowerShell => "activate.ps1", }; - let path = venv_base_directory - .join(match std::env::consts::OS { - "windows" => "Scripts", - _ => "bin", - }) - .join(activate_script_name) - .to_string_lossy() - .to_string(); - let quoted = shlex::try_quote(&path).ok()?; + let line_ending = match std::env::consts::OS { "windows" => "\r", _ => "\n", }; - smol::block_on(self.fs.metadata(path.as_ref())) - .ok() - .flatten()?; - Some(format!( - "{} {} ; clear{}", - activate_keyword, quoted, line_ending - )) + if venv_settings.venv_name.is_empty() { + let path = venv_base_directory + .join(match std::env::consts::OS { + "windows" => "Scripts", + _ => "bin", + }) + .join(activate_script_name) + .to_string_lossy() + .to_string(); + let quoted = shlex::try_quote(&path).ok()?; + smol::block_on(self.fs.metadata(path.as_ref())) + .ok() + .flatten()?; + + Some(format!( + "{} {} ; clear{}", + activate_keyword, quoted, line_ending + )) + } else { + Some(format!( + "{activate_keyword} {activate_script_name} {name}; clear{line_ending}", + name = venv_settings.venv_name + )) + } } fn activate_python_virtual_environment( @@ -616,7 +669,7 @@ pub fn wrap_for_ssh( format!("cd \"$HOME/{trimmed_path}\"; {env_changes} {to_run}") } else { - format!("cd {path}; {env_changes} {to_run}") + format!("cd \"{path}\"; {env_changes} {to_run}") } } else { format!("cd; {env_changes} {to_run}") diff --git a/crates/project_panel/Cargo.toml b/crates/project_panel/Cargo.toml index ce5fec0b138a2371667246bd25aaa6c9b9cb5185..b9d43d9873c7317c413376df5e822fd6ee3e643f 100644 --- a/crates/project_panel/Cargo.toml +++ b/crates/project_panel/Cargo.toml @@ -19,6 +19,7 @@ command_palette_hooks.workspace = true db.workspace = true editor.workspace = true file_icons.workspace = true +git_ui.workspace = true indexmap.workspace = true git.workspace = true gpui.workspace = true diff --git a/crates/project_panel/src/project_panel.rs b/crates/project_panel/src/project_panel.rs index 0ec9bac33f89c81527e520322555d6d1071273b4..45581a97c4dfd6148e80510a116d66a5e92354d7 100644 --- a/crates/project_panel/src/project_panel.rs +++ b/crates/project_panel/src/project_panel.rs @@ -16,6 +16,7 @@ use editor::{ }; use file_icons::FileIcons; use git::status::GitSummary; +use git_ui::file_diff_view::FileDiffView; use gpui::{ Action, AnyElement, App, ArcCow, AsyncWindowContext, Bounds, ClipboardItem, Context, CursorStyle, DismissEvent, Div, DragMoveEvent, Entity, EventEmitter, ExternalPaths, @@ -33,6 +34,7 @@ use project::{ Entry, EntryKind, Fs, GitEntry, GitEntryRef, GitTraversal, Project, ProjectEntryId, ProjectPath, Worktree, WorktreeId, git_store::{GitStoreEvent, git_traversal::ChildEntriesGitIter}, + project_settings::GoToDiagnosticSeverityFilter, relativize_path, }; use project_panel_settings::{ @@ -92,7 +94,7 @@ pub struct ProjectPanel { unfolded_dir_ids: HashSet<ProjectEntryId>, // Currently selected leaf entry (see auto-folding for a definition of that) in a file tree selection: Option<SelectedEntry>, - marked_entries: BTreeSet<SelectedEntry>, + marked_entries: Vec<SelectedEntry>, context_menu: Option<(Entity<ContextMenu>, Point<Pixels>, Subscription)>, edit_state: Option<EditState>, filename_editor: Entity<Editor>, @@ -107,11 +109,13 @@ pub struct ProjectPanel { hide_scrollbar_task: Option<Task<()>>, diagnostics: HashMap<(WorktreeId, PathBuf), DiagnosticSeverity>, max_width_item_index: Option<usize>, + diagnostic_summary_update: Task<()>, // We keep track of the mouse down state on entries so we don't flash the UI // in case a user clicks to open a file. mouse_down: bool, hover_expand_task: Option<Task<()>>, previous_drag_position: Option<Point<Pixels>>, + sticky_items_count: usize, } struct DragTargetEntry { @@ -186,7 +190,6 @@ struct EntryDetails { #[derive(Debug, PartialEq, Eq, Clone)] struct StickyDetails { sticky_index: usize, - is_last: bool, } /// Permanently deletes the selected file or directory. @@ -207,6 +210,24 @@ struct Trash { pub skip_prompt: bool, } +/// Selects the next entry with diagnostics. +#[derive(PartialEq, Clone, Default, Debug, Deserialize, JsonSchema, Action)] +#[action(namespace = project_panel)] +#[serde(deny_unknown_fields)] +struct SelectNextDiagnostic { + #[serde(default)] + pub severity: GoToDiagnosticSeverityFilter, +} + +/// Selects the previous entry with diagnostics. +#[derive(PartialEq, Clone, Default, Debug, Deserialize, JsonSchema, Action)] +#[action(namespace = project_panel)] +#[serde(deny_unknown_fields)] +struct SelectPrevDiagnostic { + #[serde(default)] + pub severity: GoToDiagnosticSeverityFilter, +} + actions!( project_panel, [ @@ -256,14 +277,12 @@ actions!( SelectNextGitEntry, /// Selects the previous entry with git changes. SelectPrevGitEntry, - /// Selects the next entry with diagnostics. - SelectNextDiagnostic, - /// Selects the previous entry with diagnostics. - SelectPrevDiagnostic, /// Selects the next directory. SelectNextDirectory, /// Selects the previous directory. SelectPrevDirectory, + /// Opens a diff view to compare two marked files. + CompareMarkedFiles, ] ); @@ -305,6 +324,35 @@ pub fn init(cx: &mut App) { }); } }); + + workspace.register_action(|workspace, action: &Rename, window, cx| { + workspace.open_panel::<ProjectPanel>(window, cx); + if let Some(panel) = workspace.panel::<ProjectPanel>(cx) { + panel.update(cx, |panel, cx| { + if let Some(first_marked) = panel.marked_entries.first() { + let first_marked = *first_marked; + panel.marked_entries.clear(); + panel.selection = Some(first_marked); + } + panel.rename(action, window, cx); + }); + } + }); + + workspace.register_action(|workspace, action: &Duplicate, window, cx| { + workspace.open_panel::<ProjectPanel>(window, cx); + if let Some(panel) = workspace.panel::<ProjectPanel>(cx) { + panel.update(cx, |panel, cx| { + panel.duplicate(action, window, cx); + }); + } + }); + + workspace.register_action(|workspace, action: &Delete, window, cx| { + if let Some(panel) = workspace.panel::<ProjectPanel>(cx) { + panel.update(cx, |panel, cx| panel.delete(action, window, cx)); + } + }); }) .detach(); } @@ -331,7 +379,7 @@ struct DraggedProjectEntryView { selection: SelectedEntry, details: EntryDetails, click_offset: Point<Pixels>, - selections: Arc<BTreeSet<SelectedEntry>>, + selections: Arc<[SelectedEntry]>, } struct ItemColors { @@ -342,12 +390,20 @@ struct ItemColors { focused: Hsla, } -fn get_item_color(cx: &App) -> ItemColors { +fn get_item_color(is_sticky: bool, cx: &App) -> ItemColors { let colors = cx.theme().colors(); ItemColors { - default: colors.panel_background, - hover: colors.element_hover, + default: if is_sticky { + colors.panel_overlay_background + } else { + colors.panel_background + }, + hover: if is_sticky { + colors.panel_overlay_hover + } else { + colors.element_hover + }, marked: colors.element_selected, focused: colors.panel_focused_border, drag_over: colors.drop_target_background, @@ -389,7 +445,15 @@ impl ProjectPanel { } } project::Event::ActiveEntryChanged(None) => { - this.marked_entries.clear(); + let is_active_item_file_diff_view = this + .workspace + .upgrade() + .and_then(|ws| ws.read(cx).active_item(cx)) + .map(|item| item.act_as_type(TypeId::of::<FileDiffView>(), cx).is_some()) + .unwrap_or(false); + if !is_active_item_file_diff_view { + this.marked_entries.clear(); + } } project::Event::RevealInProjectPanel(entry_id) => { if let Some(()) = this @@ -406,8 +470,16 @@ impl ProjectPanel { | project::Event::DiagnosticsUpdated { .. } => { if ProjectPanelSettings::get_global(cx).show_diagnostics != ShowDiagnostics::Off { - this.update_diagnostics(cx); - cx.notify(); + this.diagnostic_summary_update = cx.spawn(async move |this, cx| { + cx.background_executor() + .timer(Duration::from_millis(30)) + .await; + this.update(cx, |this, cx| { + this.update_diagnostics(cx); + cx.notify(); + }) + .log_err(); + }); } } project::Event::WorktreeRemoved(id) => { @@ -512,6 +584,9 @@ impl ProjectPanel { if project_panel_settings.hide_root != new_settings.hide_root { this.update_visible_entries(None, cx); } + if project_panel_settings.sticky_scroll && !new_settings.sticky_scroll { + this.sticky_items_count = 0; + } project_panel_settings = new_settings; this.update_diagnostics(cx); cx.notify(); @@ -550,10 +625,12 @@ impl ProjectPanel { .parent_entity(&cx.entity()), max_width_item_index: None, diagnostics: Default::default(), + diagnostic_summary_update: Task::ready(()), scroll_handle, mouse_down: false, hover_expand_task: None, previous_drag_position: None, + sticky_items_count: 0, }; this.update_visible_entries(None, cx); @@ -610,7 +687,7 @@ impl ProjectPanel { project_panel.update(cx, |project_panel, _| { let entry = SelectedEntry { worktree_id, entry_id }; project_panel.marked_entries.clear(); - project_panel.marked_entries.insert(entry); + project_panel.marked_entries.push(entry); project_panel.selection = Some(entry); }); if !focus_opened_item { @@ -821,6 +898,7 @@ impl ProjectPanel { let should_hide_rename = is_root && (cfg!(target_os = "windows") || (settings.hide_root && visible_worktrees_count == 1)); + let should_show_compare = !is_dir && self.file_abs_paths_to_diff(cx).is_some(); let context_menu = ContextMenu::build(window, cx, |menu, _, _| { menu.context(self.focus_handle.clone()).map(|menu| { @@ -852,6 +930,10 @@ impl ProjectPanel { .when(is_foldable, |menu| { menu.action("Fold Directory", Box::new(FoldDirectory)) }) + .when(should_show_compare, |menu| { + menu.separator() + .action("Compare marked files", Box::new(CompareMarkedFiles)) + }) .separator() .action("Cut", Box::new(Cut)) .action("Copy", Box::new(Copy)) @@ -1196,7 +1278,7 @@ impl ProjectPanel { }; self.selection = Some(selection); if window.modifiers().shift { - self.marked_entries.insert(selection); + self.marked_entries.push(selection); } self.autoscroll(cx); cx.notify(); @@ -1941,7 +2023,7 @@ impl ProjectPanel { }; self.selection = Some(selection); if window.modifiers().shift { - self.marked_entries.insert(selection); + self.marked_entries.push(selection); } self.autoscroll(cx); @@ -1955,7 +2037,7 @@ impl ProjectPanel { fn select_prev_diagnostic( &mut self, - _: &SelectPrevDiagnostic, + action: &SelectPrevDiagnostic, _: &mut Window, cx: &mut Context<Self>, ) { @@ -1974,7 +2056,8 @@ impl ProjectPanel { && entry.is_file() && self .diagnostics - .contains_key(&(worktree_id, entry.path.to_path_buf())) + .get(&(worktree_id, entry.path.to_path_buf())) + .is_some_and(|severity| action.severity.matches(*severity)) }, cx, ); @@ -1990,7 +2073,7 @@ impl ProjectPanel { fn select_next_diagnostic( &mut self, - _: &SelectNextDiagnostic, + action: &SelectNextDiagnostic, _: &mut Window, cx: &mut Context<Self>, ) { @@ -2009,7 +2092,8 @@ impl ProjectPanel { && entry.is_file() && self .diagnostics - .contains_key(&(worktree_id, entry.path.to_path_buf())) + .get(&(worktree_id, entry.path.to_path_buf())) + .is_some_and(|severity| action.severity.matches(*severity)) }, cx, ); @@ -2176,7 +2260,7 @@ impl ProjectPanel { }; self.selection = Some(selection); if window.modifiers().shift { - self.marked_entries.insert(selection); + self.marked_entries.push(selection); } self.autoscroll(cx); cx.notify(); @@ -2204,8 +2288,11 @@ impl ProjectPanel { fn autoscroll(&mut self, cx: &mut Context<Self>) { if let Some((_, _, index)) = self.selection.and_then(|s| self.index_for_selection(s)) { - self.scroll_handle - .scroll_to_item(index, ScrollStrategy::Center); + self.scroll_handle.scroll_to_item_with_offset( + index, + ScrollStrategy::Center, + self.sticky_items_count, + ); cx.notify(); } } @@ -2501,6 +2588,43 @@ impl ProjectPanel { } } + fn file_abs_paths_to_diff(&self, cx: &Context<Self>) -> Option<(PathBuf, PathBuf)> { + let mut selections_abs_path = self + .marked_entries + .iter() + .filter_map(|entry| { + let project = self.project.read(cx); + let worktree = project.worktree_for_id(entry.worktree_id, cx)?; + let entry = worktree.read(cx).entry_for_id(entry.entry_id)?; + if !entry.is_file() { + return None; + } + worktree.read(cx).absolutize(&entry.path).ok() + }) + .rev(); + + let last_path = selections_abs_path.next()?; + let previous_to_last = selections_abs_path.next()?; + Some((previous_to_last, last_path)) + } + + fn compare_marked_files( + &mut self, + _: &CompareMarkedFiles, + window: &mut Window, + cx: &mut Context<Self>, + ) { + let selected_files = self.file_abs_paths_to_diff(cx); + if let Some((file_path1, file_path2)) = selected_files { + self.workspace + .update(cx, |workspace, cx| { + FileDiffView::open(file_path1, file_path2, workspace, window, cx) + .detach_and_log_err(cx); + }) + .ok(); + } + } + fn open_system(&mut self, _: &OpenWithSystem, _: &mut Window, cx: &mut Context<Self>) { if let Some((worktree, entry)) = self.selected_entry(cx) { let abs_path = worktree.abs_path().join(&entry.path); @@ -2660,26 +2784,7 @@ impl ProjectPanel { } fn index_for_selection(&self, selection: SelectedEntry) -> Option<(usize, usize, usize)> { - let mut entry_index = 0; - let mut visible_entries_index = 0; - for (worktree_index, (worktree_id, worktree_entries, _)) in - self.visible_entries.iter().enumerate() - { - if *worktree_id == selection.worktree_id { - for entry in worktree_entries { - if entry.id == selection.entry_id { - return Some((worktree_index, entry_index, visible_entries_index)); - } else { - visible_entries_index += 1; - entry_index += 1; - } - } - break; - } else { - visible_entries_index += worktree_entries.len(); - } - } - None + self.index_for_entry(selection.entry_id, selection.worktree_id) } fn disjoint_entries(&self, cx: &App) -> BTreeSet<SelectedEntry> { @@ -3290,12 +3395,12 @@ impl ProjectPanel { entry_id: ProjectEntryId, worktree_id: WorktreeId, ) -> Option<(usize, usize, usize)> { - let mut worktree_ix = 0; let mut total_ix = 0; - for (current_worktree_id, visible_worktree_entries, _) in &self.visible_entries { + for (worktree_ix, (current_worktree_id, visible_worktree_entries, _)) in + self.visible_entries.iter().enumerate() + { if worktree_id != *current_worktree_id { total_ix += visible_worktree_entries.len(); - worktree_ix += 1; continue; } @@ -3850,7 +3955,7 @@ impl ProjectPanel { let filename_text_color = details.filename_text_color; let diagnostic_severity = details.diagnostic_severity; - let item_colors = get_item_color(cx); + let item_colors = get_item_color(is_sticky, cx); let canonical_path = details .canonical_path @@ -3862,11 +3967,9 @@ impl ProjectPanel { let depth = details.depth; let worktree_id = details.worktree_id; - let selections = Arc::new(self.marked_entries.clone()); - let dragged_selection = DraggedSelection { active_selection: selection, - marked_selections: selections, + marked_selections: Arc::from(self.marked_entries.clone()), }; let bg_color = if is_marked { @@ -3938,31 +4041,14 @@ impl ProjectPanel { } }; - let show_sticky_shadow = details.sticky.as_ref().map_or(false, |item| { - if item.is_last { - let is_scrollable = self.scroll_handle.is_scrollable(); - let is_scrolled = self.scroll_handle.offset().y < px(0.); - is_scrollable && is_scrolled - } else { - false - } - }); - let shadow_color_top = hsla(0.0, 0.0, 0.0, 0.1); - let shadow_color_bottom = hsla(0.0, 0.0, 0.0, 0.); - let sticky_shadow = div() - .absolute() - .left_0() - .bottom_neg_1p5() - .h_1p5() - .w_full() - .bg(linear_gradient( - 0., - linear_color_stop(shadow_color_top, 1.), - linear_color_stop(shadow_color_bottom, 0.), - )); + let id: ElementId = if is_sticky { + SharedString::from(format!("project_panel_sticky_item_{}", entry_id.to_usize())).into() + } else { + (entry_id.to_proto() as usize).into() + }; div() - .id(entry_id.to_proto() as usize) + .id(id.clone()) .relative() .group(GROUP_NAME) .cursor_pointer() @@ -3972,7 +4058,9 @@ impl ProjectPanel { .border_r_2() .border_color(border_color) .hover(|style| style.bg(bg_hover_color).border_color(border_hover_color)) - .when(show_sticky_shadow, |this| this.child(sticky_shadow)) + .when(is_sticky, |this| { + this.block_mouse_except_scroll() + }) .when(!is_sticky, |this| { this .when(is_highlighted && folded_directory_drag_target.is_none(), |this| this.border_color(transparent_white()).bg(item_colors.drag_over)) @@ -4052,7 +4140,7 @@ impl ProjectPanel { }); if drag_state.items().count() == 1 { this.marked_entries.clear(); - this.marked_entries.insert(drag_state.active_selection); + this.marked_entries.push(drag_state.active_selection); } this.hover_expand_task.take(); @@ -4119,89 +4207,99 @@ impl ProjectPanel { }), ) .on_click( - cx.listener(move |this, event: &gpui::ClickEvent, window, cx| { - if event.down.button == MouseButton::Right - || event.down.first_mouse + cx.listener(move |project_panel, event: &gpui::ClickEvent, window, cx| { + if event.is_right_click() || event.first_focus() || show_editor { return; } - if event.down.button == MouseButton::Left { - this.mouse_down = false; + if event.standard_click() { + project_panel.mouse_down = false; } cx.stop_propagation(); - if let Some(selection) = this.selection.filter(|_| event.modifiers().shift) { - let current_selection = this.index_for_selection(selection); + if let Some(selection) = project_panel.selection.filter(|_| event.modifiers().shift) { + let current_selection = project_panel.index_for_selection(selection); let clicked_entry = SelectedEntry { entry_id, worktree_id, }; - let target_selection = this.index_for_selection(clicked_entry); + let target_selection = project_panel.index_for_selection(clicked_entry); if let Some(((_, _, source_index), (_, _, target_index))) = current_selection.zip(target_selection) { let range_start = source_index.min(target_index); let range_end = source_index.max(target_index) + 1; - let mut new_selections = BTreeSet::new(); - this.for_each_visible_entry( + let mut new_selections = Vec::new(); + project_panel.for_each_visible_entry( range_start..range_end, window, cx, |entry_id, details, _, _| { - new_selections.insert(SelectedEntry { + new_selections.push(SelectedEntry { entry_id, worktree_id: details.worktree_id, }); }, ); - this.marked_entries = this - .marked_entries - .union(&new_selections) - .cloned() - .collect(); + for selection in &new_selections { + if !project_panel.marked_entries.contains(selection) { + project_panel.marked_entries.push(*selection); + } + } - this.selection = Some(clicked_entry); - this.marked_entries.insert(clicked_entry); + project_panel.selection = Some(clicked_entry); + if !project_panel.marked_entries.contains(&clicked_entry) { + project_panel.marked_entries.push(clicked_entry); + } } } else if event.modifiers().secondary() { - if event.down.click_count > 1 { - this.split_entry(entry_id, cx); + if event.click_count() > 1 { + project_panel.split_entry(entry_id, cx); } else { - this.selection = Some(selection); - if !this.marked_entries.insert(selection) { - this.marked_entries.remove(&selection); + project_panel.selection = Some(selection); + if let Some(position) = project_panel.marked_entries.iter().position(|e| *e == selection) { + project_panel.marked_entries.remove(position); + } else { + project_panel.marked_entries.push(selection); } } } else if kind.is_dir() { - this.marked_entries.clear(); + project_panel.marked_entries.clear(); if is_sticky { - if let Some((_, _, index)) = this.index_for_entry(entry_id, worktree_id) { - let strategy = sticky_index - .map(ScrollStrategy::ToPosition) - .unwrap_or(ScrollStrategy::Top); - this.scroll_handle.scroll_to_item(index, strategy); + if let Some((_, _, index)) = project_panel.index_for_entry(entry_id, worktree_id) { + project_panel.scroll_handle.scroll_to_item_with_offset(index, ScrollStrategy::Top, sticky_index.unwrap_or(0)); cx.notify(); + // move down by 1px so that clicked item + // don't count as sticky anymore + cx.on_next_frame(window, |_, window, cx| { + cx.on_next_frame(window, |this, _, cx| { + let mut offset = this.scroll_handle.offset(); + offset.y += px(1.); + this.scroll_handle.set_offset(offset); + cx.notify(); + }); + }); return; } } if event.modifiers().alt { - this.toggle_expand_all(entry_id, window, cx); + project_panel.toggle_expand_all(entry_id, window, cx); } else { - this.toggle_expanded(entry_id, window, cx); + project_panel.toggle_expanded(entry_id, window, cx); } } else { let preview_tabs_enabled = PreviewTabsSettings::get_global(cx).enabled; - let click_count = event.up.click_count; + let click_count = event.click_count(); let focus_opened_item = !preview_tabs_enabled || click_count > 1; let allow_preview = preview_tabs_enabled && click_count == 1; - this.open_entry(entry_id, focus_opened_item, allow_preview, cx); + project_panel.open_entry(entry_id, focus_opened_item, allow_preview, cx); } }), ) .child( - ListItem::new(entry_id.to_proto() as usize) + ListItem::new(id) .indent_level(depth) .indent_step_size(px(settings.indent_size)) .spacing(match settings.entry_spacing { @@ -4301,6 +4399,7 @@ impl ProjectPanel { .collect::<Vec<_>>(); let components_len = components.len(); + // TODO this can underflow let active_index = components_len - 1 - folded_ancestors.current_ancestor_depth; @@ -4766,12 +4865,21 @@ impl ProjectPanel { { anyhow::bail!("can't reveal an ignored entry in the project panel"); } + let is_active_item_file_diff_view = self + .workspace + .upgrade() + .and_then(|ws| ws.read(cx).active_item(cx)) + .map(|item| item.act_as_type(TypeId::of::<FileDiffView>(), cx).is_some()) + .unwrap_or(false); + if is_active_item_file_diff_view { + return Ok(()); + } let worktree_id = worktree.id(); self.expand_entry(worktree_id, entry_id, cx); self.update_visible_entries(Some((worktree_id, entry_id)), cx); self.marked_entries.clear(); - self.marked_entries.insert(SelectedEntry { + self.marked_entries.push(SelectedEntry { worktree_id, entry_id, }); @@ -4924,7 +5032,6 @@ impl ProjectPanel { .unwrap_or_default(); let sticky_details = Some(StickyDetails { sticky_index: index, - is_last: index == last_item_index, }); let details = self.details_for_entry( entry, @@ -4936,7 +5043,24 @@ impl ProjectPanel { window, cx, ); - self.render_entry(entry.id, details, window, cx).into_any() + self.render_entry(entry.id, details, window, cx) + .when(index == last_item_index, |this| { + let shadow_color_top = hsla(0.0, 0.0, 0.0, 0.1); + let shadow_color_bottom = hsla(0.0, 0.0, 0.0, 0.); + let sticky_shadow = div() + .absolute() + .left_0() + .bottom_neg_1p5() + .h_1p5() + .w_full() + .bg(linear_gradient( + 0., + linear_color_stop(shadow_color_top, 1.), + linear_color_stop(shadow_color_bottom, 0.), + )); + this.child(sticky_shadow) + }) + .into_any() }) .collect() } @@ -4970,7 +5094,16 @@ impl Render for ProjectPanel { let indent_size = ProjectPanelSettings::get_global(cx).indent_size; let show_indent_guides = ProjectPanelSettings::get_global(cx).indent_guides.show == ShowIndentGuides::Always; - let show_sticky_scroll = ProjectPanelSettings::get_global(cx).sticky_scroll; + let show_sticky_entries = { + if ProjectPanelSettings::get_global(cx).sticky_scroll { + let is_scrollable = self.scroll_handle.is_scrollable(); + let is_scrolled = self.scroll_handle.offset().y < px(0.); + is_scrollable && is_scrolled + } else { + false + } + }; + let is_local = project.is_local(); if has_worktree { @@ -5068,7 +5201,10 @@ impl Render for ProjectPanel { this.hide_scrollbar(window, cx); } })) - .on_click(cx.listener(|this, _event, _, cx| { + .on_click(cx.listener(|this, event, _, cx| { + if matches!(event, gpui::ClickEvent::Keyboard(_)) { + return; + } cx.stop_propagation(); this.selection = None; this.marked_entries.clear(); @@ -5098,6 +5234,7 @@ impl Render for ProjectPanel { .on_action(cx.listener(Self::unfold_directory)) .on_action(cx.listener(Self::fold_directory)) .on_action(cx.listener(Self::remove_from_project)) + .on_action(cx.listener(Self::compare_marked_files)) .when(!project.is_read_only(cx), |el| { el.on_action(cx.listener(Self::new_file)) .on_action(cx.listener(Self::new_directory)) @@ -5109,7 +5246,7 @@ impl Render for ProjectPanel { .on_action(cx.listener(Self::paste)) .on_action(cx.listener(Self::duplicate)) .on_click(cx.listener(|this, event: &gpui::ClickEvent, window, cx| { - if event.up.click_count > 1 { + if event.click_count() > 1 { if let Some(entry_id) = this.last_worktree_root_id { let project = this.project.read(cx); @@ -5262,7 +5399,7 @@ impl Render for ProjectPanel { }), ) }) - .when(show_sticky_scroll, |list| { + .when(show_sticky_entries, |list| { let sticky_items = ui::sticky_items( cx.entity().clone(), |this, range, window, cx| { @@ -5282,7 +5419,10 @@ impl Render for ProjectPanel { items }, |this, marker_entry, window, cx| { - this.render_sticky_entries(marker_entry, window, cx) + let sticky_entries = + this.render_sticky_entries(marker_entry, window, cx); + this.sticky_items_count = sticky_entries.len(); + sticky_entries }, ); list.with_decoration(if show_indent_guides { diff --git a/crates/project_panel/src/project_panel_tests.rs b/crates/project_panel/src/project_panel_tests.rs index 7699256bc9ba3666cfb8e602ad1e62b0f2e62749..6c62c8db930895db6a8c46fa23fdf1423e90d5a9 100644 --- a/crates/project_panel/src/project_panel_tests.rs +++ b/crates/project_panel/src/project_panel_tests.rs @@ -8,7 +8,7 @@ use settings::SettingsStore; use std::path::{Path, PathBuf}; use util::path; use workspace::{ - AppState, Pane, + AppState, ItemHandle, Pane, item::{Item, ProjectItem}, register_project_item, }; @@ -3068,7 +3068,7 @@ async fn test_multiple_marked_entries(cx: &mut gpui::TestAppContext) { panel.update(cx, |this, cx| { let drag = DraggedSelection { active_selection: this.selection.unwrap(), - marked_selections: Arc::new(this.marked_entries.clone()), + marked_selections: this.marked_entries.clone().into(), }; let target_entry = this .project @@ -5562,10 +5562,10 @@ async fn test_highlight_entry_for_selection_drag(cx: &mut gpui::TestAppContext) worktree_id, entry_id: child_file.id, }, - marked_selections: Arc::new(BTreeSet::from([SelectedEntry { + marked_selections: Arc::new([SelectedEntry { worktree_id, entry_id: child_file.id, - }])), + }]), }; let result = panel.highlight_entry_for_selection_drag(parent_dir, worktree, &dragged_selection, cx); @@ -5604,7 +5604,7 @@ async fn test_highlight_entry_for_selection_drag(cx: &mut gpui::TestAppContext) worktree_id, entry_id: child_file.id, }, - marked_selections: Arc::new(BTreeSet::from([ + marked_selections: Arc::new([ SelectedEntry { worktree_id, entry_id: child_file.id, @@ -5613,7 +5613,7 @@ async fn test_highlight_entry_for_selection_drag(cx: &mut gpui::TestAppContext) worktree_id, entry_id: sibling_file.id, }, - ])), + ]), }; let result = panel.highlight_entry_for_selection_drag(parent_dir, worktree, &dragged_selection, cx); @@ -5821,6 +5821,186 @@ async fn test_hide_root(cx: &mut gpui::TestAppContext) { } } +#[gpui::test] +async fn test_compare_selected_files(cx: &mut gpui::TestAppContext) { + init_test_with_editor(cx); + + let fs = FakeFs::new(cx.executor().clone()); + fs.insert_tree( + "/root", + json!({ + "file1.txt": "content of file1", + "file2.txt": "content of file2", + "dir1": { + "file3.txt": "content of file3" + } + }), + ) + .await; + + let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await; + let workspace = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let cx = &mut VisualTestContext::from_window(*workspace, cx); + let panel = workspace.update(cx, ProjectPanel::new).unwrap(); + + let file1_path = path!("root/file1.txt"); + let file2_path = path!("root/file2.txt"); + select_path_with_mark(&panel, file1_path, cx); + select_path_with_mark(&panel, file2_path, cx); + + panel.update_in(cx, |panel, window, cx| { + panel.compare_marked_files(&CompareMarkedFiles, window, cx); + }); + cx.executor().run_until_parked(); + + workspace + .update(cx, |workspace, _, cx| { + let active_items = workspace + .panes() + .iter() + .filter_map(|pane| pane.read(cx).active_item()) + .collect::<Vec<_>>(); + assert_eq!(active_items.len(), 1); + let diff_view = active_items + .into_iter() + .next() + .unwrap() + .downcast::<FileDiffView>() + .expect("Open item should be an FileDiffView"); + assert_eq!(diff_view.tab_content_text(0, cx), "file1.txt ↔ file2.txt"); + assert_eq!( + diff_view.tab_tooltip_text(cx).unwrap(), + format!("{} ↔ {}", file1_path, file2_path) + ); + }) + .unwrap(); + + let file1_entry_id = find_project_entry(&panel, file1_path, cx).unwrap(); + let file2_entry_id = find_project_entry(&panel, file2_path, cx).unwrap(); + let worktree_id = panel.update(cx, |panel, cx| { + panel + .project + .read(cx) + .worktrees(cx) + .next() + .unwrap() + .read(cx) + .id() + }); + + let expected_entries = [ + SelectedEntry { + worktree_id, + entry_id: file1_entry_id, + }, + SelectedEntry { + worktree_id, + entry_id: file2_entry_id, + }, + ]; + panel.update(cx, |panel, _cx| { + assert_eq!( + &panel.marked_entries, &expected_entries, + "Should keep marked entries after comparison" + ); + }); + + panel.update(cx, |panel, cx| { + panel.project.update(cx, |_, cx| { + cx.emit(project::Event::RevealInProjectPanel(file2_entry_id)) + }) + }); + + panel.update(cx, |panel, _cx| { + assert_eq!( + &panel.marked_entries, &expected_entries, + "Marked entries should persist after focusing back on the project panel" + ); + }); +} + +#[gpui::test] +async fn test_compare_files_context_menu(cx: &mut gpui::TestAppContext) { + init_test_with_editor(cx); + + let fs = FakeFs::new(cx.executor().clone()); + fs.insert_tree( + "/root", + json!({ + "file1.txt": "content of file1", + "file2.txt": "content of file2", + "dir1": {}, + "dir2": { + "file3.txt": "content of file3" + } + }), + ) + .await; + + let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await; + let workspace = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let cx = &mut VisualTestContext::from_window(*workspace, cx); + let panel = workspace.update(cx, ProjectPanel::new).unwrap(); + + // Test 1: When only one file is selected, there should be no compare option + select_path(&panel, "root/file1.txt", cx); + + let selected_files = panel.update(cx, |panel, cx| panel.file_abs_paths_to_diff(cx)); + assert_eq!( + selected_files, None, + "Should not have compare option when only one file is selected" + ); + + // Test 2: When multiple files are selected, there should be a compare option + select_path_with_mark(&panel, "root/file1.txt", cx); + select_path_with_mark(&panel, "root/file2.txt", cx); + + let selected_files = panel.update(cx, |panel, cx| panel.file_abs_paths_to_diff(cx)); + assert!( + selected_files.is_some(), + "Should have files selected for comparison" + ); + if let Some((file1, file2)) = selected_files { + assert!( + file1.to_string_lossy().ends_with("file1.txt") + && file2.to_string_lossy().ends_with("file2.txt"), + "Should have file1.txt and file2.txt as the selected files when multi-selecting" + ); + } + + // Test 3: Selecting a directory shouldn't count as a comparable file + select_path_with_mark(&panel, "root/dir1", cx); + + let selected_files = panel.update(cx, |panel, cx| panel.file_abs_paths_to_diff(cx)); + assert!( + selected_files.is_some(), + "Directory selection should not affect comparable files" + ); + if let Some((file1, file2)) = selected_files { + assert!( + file1.to_string_lossy().ends_with("file1.txt") + && file2.to_string_lossy().ends_with("file2.txt"), + "Selecting a directory should not affect the number of comparable files" + ); + } + + // Test 4: Selecting one more file + select_path_with_mark(&panel, "root/dir2/file3.txt", cx); + + let selected_files = panel.update(cx, |panel, cx| panel.file_abs_paths_to_diff(cx)); + assert!( + selected_files.is_some(), + "Directory selection should not affect comparable files" + ); + if let Some((file1, file2)) = selected_files { + assert!( + file1.to_string_lossy().ends_with("file2.txt") + && file2.to_string_lossy().ends_with("file3.txt"), + "Selecting a directory should not affect the number of comparable files" + ); + } +} + fn select_path(panel: &Entity<ProjectPanel>, path: impl AsRef<Path>, cx: &mut VisualTestContext) { let path = path.as_ref(); panel.update(cx, |panel, cx| { @@ -5855,7 +6035,7 @@ fn select_path_with_mark( entry_id, }; if !panel.marked_entries.contains(&entry) { - panel.marked_entries.insert(entry); + panel.marked_entries.push(entry); } panel.selection = Some(entry); return; diff --git a/crates/prompt_store/src/prompts.rs b/crates/prompt_store/src/prompts.rs index d737ef92464758d06c456c00b35c9c6390e1a3a9..7eb63eec5ea559432724622a7dc4ea5410cff62f 100644 --- a/crates/prompt_store/src/prompts.rs +++ b/crates/prompt_store/src/prompts.rs @@ -18,7 +18,7 @@ use util::{ResultExt, get_system_shell}; use crate::UserPromptId; -#[derive(Debug, Clone, Serialize)] +#[derive(Default, Debug, Clone, Serialize)] pub struct ProjectContext { pub worktrees: Vec<WorktreeContext>, /// Whether any worktree has a rules_file. Provided as a field because handlebars can't do this. @@ -71,14 +71,14 @@ pub struct UserRulesContext { pub contents: String, } -#[derive(Debug, Clone, Serialize)] +#[derive(Debug, Clone, Eq, PartialEq, Serialize)] pub struct WorktreeContext { pub root_name: String, pub abs_path: Arc<Path>, pub rules_file: Option<RulesFileContext>, } -#[derive(Debug, Clone, Serialize)] +#[derive(Debug, Clone, Eq, PartialEq, Serialize)] pub struct RulesFileContext { pub path_in_worktree: Arc<Path>, pub text: String, diff --git a/crates/proto/proto/app.proto b/crates/proto/proto/app.proto index 5330ee506a4179339dd0be2241bc345837bad900..353f19adb2e4cee68267bcbd0a4fd033a0ed9b58 100644 --- a/crates/proto/proto/app.proto +++ b/crates/proto/proto/app.proto @@ -79,11 +79,16 @@ message OpenServerSettings { uint64 project_id = 1; } -message GetPanicFiles { +message GetCrashFiles { } -message GetPanicFilesResponse { - repeated string file_contents = 2; +message GetCrashFilesResponse { + repeated CrashReport crashes = 1; +} + +message CrashReport { + optional string panic_contents = 1; + optional bytes minidump_contents = 2; } message Extension { diff --git a/crates/proto/proto/buffer.proto b/crates/proto/proto/buffer.proto index 09a05a50cd84381c4aaccd17a846e2eb38822392..f4dacf2fdca97bf9766c8de348a67cd18f8fb973 100644 --- a/crates/proto/proto/buffer.proto +++ b/crates/proto/proto/buffer.proto @@ -288,10 +288,12 @@ message SearchQuery { bool regex = 3; bool whole_word = 4; bool case_sensitive = 5; - string files_to_include = 6; - string files_to_exclude = 7; + repeated string files_to_include = 10; + repeated string files_to_exclude = 11; bool match_full_paths = 9; bool include_ignored = 8; + string files_to_include_legacy = 6; + string files_to_exclude_legacy = 7; } message FindSearchCandidates { diff --git a/crates/proto/proto/call.proto b/crates/proto/proto/call.proto index 5212f3b43f5e78aa86de00be00eb8a828fe8d17f..b5c882db568200b4a56f49393d93317ccf49cd12 100644 --- a/crates/proto/proto/call.proto +++ b/crates/proto/proto/call.proto @@ -71,6 +71,7 @@ message RejoinedProject { repeated WorktreeMetadata worktrees = 2; repeated Collaborator collaborators = 3; repeated LanguageServer language_servers = 4; + repeated string language_server_capabilities = 5; } message LeaveRoom {} @@ -199,6 +200,7 @@ message JoinProjectResponse { repeated WorktreeMetadata worktrees = 2; repeated Collaborator collaborators = 3; repeated LanguageServer language_servers = 4; + repeated string language_server_capabilities = 8; ChannelRole role = 6; reserved 7; } diff --git a/crates/proto/proto/debugger.proto b/crates/proto/proto/debugger.proto index 09abd4bf1c1aa73e89d77c55ade1bce21f0027d4..c6f9c9f1342336c36ab8dfd0ec70a24ff6564476 100644 --- a/crates/proto/proto/debugger.proto +++ b/crates/proto/proto/debugger.proto @@ -188,7 +188,7 @@ message DapSetVariableValueResponse { message DapPauseRequest { uint64 project_id = 1; uint64 client_id = 2; - uint64 thread_id = 3; + int64 thread_id = 3; } message DapDisconnectRequest { @@ -202,7 +202,7 @@ message DapDisconnectRequest { message DapTerminateThreadsRequest { uint64 project_id = 1; uint64 client_id = 2; - repeated uint64 thread_ids = 3; + repeated int64 thread_ids = 3; } message DapThreadsRequest { @@ -246,7 +246,7 @@ message IgnoreBreakpointState { message DapNextRequest { uint64 project_id = 1; uint64 client_id = 2; - uint64 thread_id = 3; + int64 thread_id = 3; optional bool single_thread = 4; optional SteppingGranularity granularity = 5; } @@ -254,7 +254,7 @@ message DapNextRequest { message DapStepInRequest { uint64 project_id = 1; uint64 client_id = 2; - uint64 thread_id = 3; + int64 thread_id = 3; optional uint64 target_id = 4; optional bool single_thread = 5; optional SteppingGranularity granularity = 6; @@ -263,7 +263,7 @@ message DapStepInRequest { message DapStepOutRequest { uint64 project_id = 1; uint64 client_id = 2; - uint64 thread_id = 3; + int64 thread_id = 3; optional bool single_thread = 4; optional SteppingGranularity granularity = 5; } @@ -271,7 +271,7 @@ message DapStepOutRequest { message DapStepBackRequest { uint64 project_id = 1; uint64 client_id = 2; - uint64 thread_id = 3; + int64 thread_id = 3; optional bool single_thread = 4; optional SteppingGranularity granularity = 5; } @@ -279,7 +279,7 @@ message DapStepBackRequest { message DapContinueRequest { uint64 project_id = 1; uint64 client_id = 2; - uint64 thread_id = 3; + int64 thread_id = 3; optional bool single_thread = 4; } @@ -311,7 +311,7 @@ message DapLoadedSourcesResponse { message DapStackTraceRequest { uint64 project_id = 1; uint64 client_id = 2; - uint64 thread_id = 3; + int64 thread_id = 3; optional uint64 start_frame = 4; optional uint64 stack_trace_levels = 5; } @@ -358,7 +358,7 @@ message DapVariable { } message DapThread { - uint64 id = 1; + int64 id = 1; string name = 2; } diff --git a/crates/proto/proto/git.proto b/crates/proto/proto/git.proto index 1fdef2eea6e6a52203ba2d6160860e1080b999e3..c32da9b1100ff4b534d9c450674cd1a2b967066b 100644 --- a/crates/proto/proto/git.proto +++ b/crates/proto/proto/git.proto @@ -286,6 +286,17 @@ message Unstage { repeated string paths = 4; } +message Stash { + uint64 project_id = 1; + uint64 repository_id = 2; + repeated string paths = 3; +} + +message StashPop { + uint64 project_id = 1; + uint64 repository_id = 2; +} + message Commit { uint64 project_id = 1; reserved 2; @@ -294,9 +305,11 @@ message Commit { optional string email = 5; string message = 6; optional CommitOptions options = 7; + reserved 8; message CommitOptions { bool amend = 1; + bool signoff = 2; } } @@ -409,3 +422,12 @@ message BlameBufferResponse { reserved 1 to 4; } + +message GetDefaultBranch { + uint64 project_id = 1; + uint64 repository_id = 2; +} + +message GetDefaultBranchResponse { + optional string branch = 1; +} diff --git a/crates/proto/proto/lsp.proto b/crates/proto/proto/lsp.proto index e3c2f69c0b7587580a393b343eff1c4cd932fd72..ea9647feff0cf811f0464dc4eca22059b348be6f 100644 --- a/crates/proto/proto/lsp.proto +++ b/crates/proto/proto/lsp.proto @@ -518,12 +518,14 @@ message LanguageServer { message StartLanguageServer { uint64 project_id = 1; LanguageServer server = 2; + string capabilities = 3; } message UpdateDiagnosticSummary { uint64 project_id = 1; uint64 worktree_id = 2; DiagnosticSummary summary = 3; + repeated DiagnosticSummary more_summaries = 4; } message DiagnosticSummary { @@ -545,6 +547,7 @@ message UpdateLanguageServer { LspDiskBasedDiagnosticsUpdated disk_based_diagnostics_updated = 7; StatusUpdate status_update = 9; RegisteredForBuffer registered_for_buffer = 10; + ServerMetadataUpdated metadata_updated = 11; } } @@ -597,6 +600,11 @@ enum ServerBinaryStatus { message RegisteredForBuffer { string buffer_abs_path = 1; + uint64 buffer_id = 2; +} + +message ServerMetadataUpdated { + optional string capabilities = 1; } message LanguageServerLog { @@ -811,16 +819,6 @@ message LspResponse { uint64 server_id = 7; } -message LanguageServerIdForName { - uint64 project_id = 1; - uint64 buffer_id = 2; - string name = 3; -} - -message LanguageServerIdForNameResponse { - optional uint64 server_id = 1; -} - message LspExtRunnables { uint64 project_id = 1; uint64 buffer_id = 2; diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index 31f929ec9054c19bb7cfb3a2488dff42e7493c13..bb97bd500ae03e8f5aa9b5f62ccbe2e4783d97fb 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -294,9 +294,6 @@ message Envelope { GetPathMetadata get_path_metadata = 278; GetPathMetadataResponse get_path_metadata_response = 279; - GetPanicFiles get_panic_files = 280; - GetPanicFilesResponse get_panic_files_response = 281; - CancelLanguageServerWork cancel_language_server_work = 282; LspExtOpenDocs lsp_ext_open_docs = 283; @@ -365,9 +362,6 @@ message Envelope { GetDocumentSymbols get_document_symbols = 330; GetDocumentSymbolsResponse get_document_symbols_response = 331; - LanguageServerIdForName language_server_id_for_name = 332; - LanguageServerIdForNameResponse language_server_id_for_name_response = 333; - LoadCommitDiff load_commit_diff = 334; LoadCommitDiffResponse load_commit_diff_response = 335; @@ -396,8 +390,16 @@ message Envelope { GetDocumentColor get_document_color = 353; GetDocumentColorResponse get_document_color_response = 354; GetColorPresentation get_color_presentation = 355; - GetColorPresentationResponse get_color_presentation_response = 356; // current max + GetColorPresentationResponse get_color_presentation_response = 356; + + Stash stash = 357; + StashPop stash_pop = 358; + + GetDefaultBranch get_default_branch = 359; + GetDefaultBranchResponse get_default_branch_response = 360; + GetCrashFiles get_crash_files = 361; + GetCrashFilesResponse get_crash_files_response = 362; // current max } reserved 87 to 88; @@ -418,6 +420,8 @@ message Envelope { reserved 270; reserved 247 to 254; reserved 255 to 256; + reserved 280 to 281; + reserved 332 to 333; } message Hello { diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index 918ac9e93596ce5de102f841ab95073778aab056..9edb041b4b48169cf9c2fc21401b763dd1e611b5 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -99,8 +99,8 @@ messages!( (GetHoverResponse, Background), (GetNotifications, Foreground), (GetNotificationsResponse, Foreground), - (GetPanicFiles, Background), - (GetPanicFilesResponse, Background), + (GetCrashFiles, Background), + (GetCrashFilesResponse, Background), (GetPathMetadata, Background), (GetPathMetadataResponse, Background), (GetPermalinkToLine, Foreground), @@ -121,8 +121,6 @@ messages!( (GetImplementationResponse, Background), (GetLlmToken, Background), (GetLlmTokenResponse, Background), - (LanguageServerIdForName, Background), - (LanguageServerIdForNameResponse, Background), (OpenUnstagedDiff, Foreground), (OpenUnstagedDiffResponse, Foreground), (OpenUncommittedDiff, Foreground), @@ -261,6 +259,8 @@ messages!( (Unfollow, Foreground), (UnshareProject, Foreground), (Unstage, Background), + (Stash, Background), + (StashPop, Background), (UpdateBuffer, Foreground), (UpdateBufferFile, Foreground), (UpdateChannelBuffer, Foreground), @@ -313,7 +313,9 @@ messages!( (LogToDebugConsole, Background), (GetDocumentDiagnostics, Background), (GetDocumentDiagnosticsResponse, Background), - (PullWorkspaceDiagnostics, Background) + (PullWorkspaceDiagnostics, Background), + (GetDefaultBranch, Background), + (GetDefaultBranchResponse, Background), ); request_messages!( @@ -419,13 +421,14 @@ request_messages!( (TaskContextForLocation, TaskContext), (Test, Test), (Unstage, Ack), + (Stash, Ack), + (StashPop, Ack), (UpdateBuffer, Ack), (UpdateParticipantLocation, Ack), (UpdateProject, Ack), (UpdateWorktree, Ack), (UpdateRepository, Ack), (RemoveRepository, Ack), - (LanguageServerIdForName, LanguageServerIdForNameResponse), (LspExtExpandMacro, LspExtExpandMacroResponse), (LspExtOpenDocs, LspExtOpenDocsResponse), (LspExtRunnables, LspExtRunnablesResponse), @@ -456,7 +459,7 @@ request_messages!( (ActivateToolchain, Ack), (ActiveToolchain, ActiveToolchainResponse), (GetPathMetadata, GetPathMetadataResponse), - (GetPanicFiles, GetPanicFilesResponse), + (GetCrashFiles, GetCrashFilesResponse), (CancelLanguageServerWork, Ack), (SyncExtensions, SyncExtensionsResponse), (InstallExtension, Ack), @@ -479,7 +482,8 @@ request_messages!( (GetDebugAdapterBinary, DebugAdapterBinary), (RunDebugLocators, DebugRequest), (GetDocumentDiagnostics, GetDocumentDiagnosticsResponse), - (PullWorkspaceDiagnostics, Ack) + (PullWorkspaceDiagnostics, Ack), + (GetDefaultBranch, GetDefaultBranchResponse), ); entity_messages!( @@ -549,6 +553,8 @@ entity_messages!( TaskContextForLocation, UnshareProject, Unstage, + Stash, + StashPop, UpdateBuffer, UpdateBufferFile, UpdateDiagnosticSummary, @@ -579,7 +585,6 @@ entity_messages!( OpenServerSettings, GetPermalinkToLine, LanguageServerPromptRequest, - LanguageServerIdForName, GitGetBranches, UpdateGitBranch, ListToolchains, @@ -609,7 +614,8 @@ entity_messages!( GetDebugAdapterBinary, LogToDebugConsole, GetDocumentDiagnostics, - PullWorkspaceDiagnostics + PullWorkspaceDiagnostics, + GetDefaultBranch ); entity_messages!( @@ -778,6 +784,25 @@ pub fn split_repository_update( }]) } +impl MultiLspQuery { + pub fn request_str(&self) -> &str { + match self.request { + Some(multi_lsp_query::Request::GetHover(_)) => "GetHover", + Some(multi_lsp_query::Request::GetCodeActions(_)) => "GetCodeActions", + Some(multi_lsp_query::Request::GetSignatureHelp(_)) => "GetSignatureHelp", + Some(multi_lsp_query::Request::GetCodeLens(_)) => "GetCodeLens", + Some(multi_lsp_query::Request::GetDocumentDiagnostics(_)) => "GetDocumentDiagnostics", + Some(multi_lsp_query::Request::GetDocumentColor(_)) => "GetDocumentColor", + Some(multi_lsp_query::Request::GetDefinition(_)) => "GetDefinition", + Some(multi_lsp_query::Request::GetDeclaration(_)) => "GetDeclaration", + Some(multi_lsp_query::Request::GetTypeDefinition(_)) => "GetTypeDefinition", + Some(multi_lsp_query::Request::GetImplementation(_)) => "GetImplementation", + Some(multi_lsp_query::Request::GetReferences(_)) => "GetReferences", + None => "<unknown>", + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/recent_projects/src/recent_projects.rs b/crates/recent_projects/src/recent_projects.rs index 5dbde6496de8f672473f9d92c052e06e5943ade5..2093e96caeed5bebb2ec2e833efa61574c4be0a1 100644 --- a/crates/recent_projects/src/recent_projects.rs +++ b/crates/recent_projects/src/recent_projects.rs @@ -141,6 +141,7 @@ impl Focusable for RecentProjects { impl Render for RecentProjects { fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { v_flex() + .key_context("RecentProjects") .w(rems(self.rem_width)) .child(self.picker.clone()) .on_mouse_down_out(cx.listener(|this, _, window, cx| { diff --git a/crates/recent_projects/src/remote_servers.rs b/crates/recent_projects/src/remote_servers.rs index aa5103e62ba4d28f150287d5716fa83a3b17260d..354434a7fc9318e58ff7796d061fc2386aae950f 100644 --- a/crates/recent_projects/src/remote_servers.rs +++ b/crates/recent_projects/src/remote_servers.rs @@ -953,7 +953,7 @@ impl RemoteServerProjects { ) .child(Label::new(project.paths.join(", "))) .on_click(cx.listener(move |this, e: &ClickEvent, window, cx| { - let secondary_confirm = e.down.modifiers.platform; + let secondary_confirm = e.modifiers().platform; callback(this, secondary_confirm, window, cx) })) .when(is_from_zed, |server_list_item| { @@ -963,7 +963,7 @@ impl RemoteServerProjects { .child({ let project = project.clone(); // Right-margin to offset it from the Scrollbar - IconButton::new("remove-remote-project", IconName::TrashAlt) + IconButton::new("remove-remote-project", IconName::Trash) .icon_size(IconSize::Small) .shape(IconButtonShape::Square) .size(ButtonSize::Large) diff --git a/crates/remote/src/ssh_session.rs b/crates/remote/src/ssh_session.rs index e31d3dcfd59cfba783587af45a1b16d5892554c8..4306251e44acf988ce90fcd640d8c8bed36f1ee7 100644 --- a/crates/remote/src/ssh_session.rs +++ b/crates/remote/src/ssh_session.rs @@ -1742,7 +1742,7 @@ impl SshRemoteConnection { } }); - cx.spawn(async move |_| { + cx.background_spawn(async move { let result = futures::select! { result = stdin_task.fuse() => { result.context("stdin") diff --git a/crates/remote_server/Cargo.toml b/crates/remote_server/Cargo.toml index 443c47919f14b1fe4d19daeae762d711961d1a17..c6a546f3451e901f1cff99091dbbdabd063dc337 100644 --- a/crates/remote_server/Cargo.toml +++ b/crates/remote_server/Cargo.toml @@ -67,8 +67,11 @@ watch.workspace = true worktree.workspace = true [target.'cfg(not(windows))'.dependencies] +crashes.workspace = true +crash-handler.workspace = true fork.workspace = true libc.workspace = true +minidumper.workspace = true [dev-dependencies] assistant_tool.workspace = true diff --git a/crates/remote_server/src/main.rs b/crates/remote_server/src/main.rs index 98f635d856a31ef5f5522af3f4f7607b17608811..03b0c3eda3ca4556e9e9fa0f588b68effd84a5f9 100644 --- a/crates/remote_server/src/main.rs +++ b/crates/remote_server/src/main.rs @@ -12,6 +12,10 @@ struct Cli { /// by having Zed act like netcat communicating over a Unix socket. #[arg(long, hide = true)] askpass: Option<String>, + /// Used for recording minidumps on crashes by having the server run a separate + /// process communicating over a socket. + #[arg(long, hide = true)] + crash_handler: Option<PathBuf>, /// Used for loading the environment from the project. #[arg(long, hide = true)] printenv: bool, @@ -58,6 +62,11 @@ fn main() { return; } + if let Some(socket) = &cli.crash_handler { + crashes::crash_server(socket.as_path()); + return; + } + if cli.printenv { util::shell_env::print_env(); return; diff --git a/crates/remote_server/src/unix.rs b/crates/remote_server/src/unix.rs index 84ce08ff25bfab3e7d6e97768549105e048164e1..9bb5645dc7b80ee25cbb2f75f66b4aba73cb7784 100644 --- a/crates/remote_server/src/unix.rs +++ b/crates/remote_server/src/unix.rs @@ -17,6 +17,7 @@ use node_runtime::{NodeBinaryOptions, NodeRuntime}; use paths::logs_dir; use project::project_settings::ProjectSettings; +use proto::CrashReport; use release_channel::{AppVersion, RELEASE_CHANNEL, ReleaseChannel}; use remote::proxy::ProxyLaunchError; use remote::ssh_session::ChannelClient; @@ -33,6 +34,7 @@ use smol::io::AsyncReadExt; use smol::Async; use smol::{net::unix::UnixListener, stream::StreamExt as _}; +use std::collections::HashMap; use std::ffi::OsStr; use std::ops::ControlFlow; use std::str::FromStr; @@ -109,8 +111,9 @@ fn init_logging_server(log_file_path: PathBuf) -> Result<Receiver<Vec<u8>>> { Ok(rx) } -fn init_panic_hook() { - std::panic::set_hook(Box::new(|info| { +fn init_panic_hook(session_id: String) { + std::panic::set_hook(Box::new(move |info| { + crashes::handle_panic(); let payload = info .payload() .downcast_ref::<&str>() @@ -171,9 +174,11 @@ fn init_panic_hook() { architecture: env::consts::ARCH.into(), panicked_on: Utc::now().timestamp_millis(), backtrace, - system_id: None, // Set on SSH client - installation_id: None, // Set on SSH client - session_id: "".to_string(), // Set on SSH client + system_id: None, // Set on SSH client + installation_id: None, // Set on SSH client + + // used on this end to associate panics with minidumps, but will be replaced on the SSH client + session_id: session_id.clone(), }; if let Some(panic_data_json) = serde_json::to_string(&panic_data).log_err() { @@ -194,44 +199,69 @@ fn init_panic_hook() { })); } -fn handle_panic_requests(project: &Entity<HeadlessProject>, client: &Arc<ChannelClient>) { +fn handle_crash_files_requests(project: &Entity<HeadlessProject>, client: &Arc<ChannelClient>) { let client: AnyProtoClient = client.clone().into(); client.add_request_handler( project.downgrade(), - |_, _: TypedEnvelope<proto::GetPanicFiles>, _cx| async move { + |_, _: TypedEnvelope<proto::GetCrashFiles>, _cx| async move { + let mut crashes = Vec::new(); + let mut minidumps_by_session_id = HashMap::new(); let mut children = smol::fs::read_dir(paths::logs_dir()).await?; - let mut panic_files = Vec::new(); while let Some(child) = children.next().await { let child = child?; let child_path = child.path(); - if child_path.extension() != Some(OsStr::new("panic")) { - continue; - } - let filename = if let Some(filename) = child_path.file_name() { - filename.to_string_lossy() - } else { - continue; - }; - - if !filename.starts_with("zed") { - continue; - } + let extension = child_path.extension(); + if extension == Some(OsStr::new("panic")) { + let filename = if let Some(filename) = child_path.file_name() { + filename.to_string_lossy() + } else { + continue; + }; - let file_contents = smol::fs::read_to_string(&child_path) - .await - .context("error reading panic file")?; + if !filename.starts_with("zed") { + continue; + } - panic_files.push(file_contents); + let file_contents = smol::fs::read_to_string(&child_path) + .await + .context("error reading panic file")?; + + crashes.push(proto::CrashReport { + panic_contents: Some(file_contents), + minidump_contents: None, + }); + } else if extension == Some(OsStr::new("dmp")) { + let session_id = child_path.file_stem().unwrap().to_string_lossy(); + minidumps_by_session_id + .insert(session_id.to_string(), smol::fs::read(&child_path).await?); + } // We've done what we can, delete the file - std::fs::remove_file(child_path) + smol::fs::remove_file(&child_path) + .await .context("error removing panic") .log_err(); } - anyhow::Ok(proto::GetPanicFilesResponse { - file_contents: panic_files, - }) + + for crash in &mut crashes { + let panic: telemetry_events::Panic = + serde_json::from_str(crash.panic_contents.as_ref().unwrap())?; + if let dump @ Some(_) = minidumps_by_session_id.remove(&panic.session_id) { + crash.minidump_contents = dump; + } + } + + crashes.extend( + minidumps_by_session_id + .into_values() + .map(|dmp| CrashReport { + panic_contents: None, + minidump_contents: Some(dmp), + }), + ); + + anyhow::Ok(proto::GetCrashFilesResponse { crashes }) }, ); } @@ -409,7 +439,12 @@ pub fn execute_run( ControlFlow::Continue(_) => {} } - init_panic_hook(); + let app = gpui::Application::headless(); + let id = std::process::id().to_string(); + app.background_executor() + .spawn(crashes::init(id.clone())) + .detach(); + init_panic_hook(id); let log_rx = init_logging_server(log_file)?; log::info!( "starting up. pid_file: {:?}, stdin_socket: {:?}, stdout_socket: {:?}, stderr_socket: {:?}", @@ -425,7 +460,7 @@ pub fn execute_run( let listeners = ServerListeners::new(stdin_socket, stdout_socket, stderr_socket)?; let git_hosting_provider_registry = Arc::new(GitHostingProviderRegistry::new()); - gpui::Application::headless().run(move |cx| { + app.run(move |cx| { settings::init(cx); let app_version = AppVersion::load(env!("ZED_PKG_VERSION")); release_channel::init(app_version, cx); @@ -486,7 +521,7 @@ pub fn execute_run( ) }); - handle_panic_requests(&project, &session); + handle_crash_files_requests(&project, &session); cx.background_spawn(async move { cleanup_old_binaries() }) .detach(); @@ -530,12 +565,15 @@ impl ServerPaths { pub fn execute_proxy(identifier: String, is_reconnecting: bool) -> Result<()> { init_logging_proxy(); - init_panic_hook(); - - log::info!("starting proxy process. PID: {}", std::process::id()); let server_paths = ServerPaths::new(&identifier)?; + let id = std::process::id().to_string(); + smol::spawn(crashes::init(id.clone())).detach(); + init_panic_hook(id); + + log::info!("starting proxy process. PID: {}", std::process::id()); + let server_pid = check_pid_file(&server_paths.pid_file)?; let server_running = server_pid.is_some(); if is_reconnecting { diff --git a/crates/repl/src/notebook/cell.rs b/crates/repl/src/notebook/cell.rs index 2ed68c17d13dec4236ad4416e7f950a7f61dfb8f..18851417c0b4fd8206df8b52076b2c47044a79ff 100644 --- a/crates/repl/src/notebook/cell.rs +++ b/crates/repl/src/notebook/cell.rs @@ -38,7 +38,7 @@ pub enum CellControlType { impl CellControlType { fn icon_name(&self) -> IconName { match self { - CellControlType::RunCell => IconName::Play, + CellControlType::RunCell => IconName::PlayOutlined, CellControlType::RerunCell => IconName::ArrowCircle, CellControlType::ClearCell => IconName::ListX, CellControlType::CellOptions => IconName::Ellipsis, diff --git a/crates/repl/src/notebook/notebook_ui.rs b/crates/repl/src/notebook/notebook_ui.rs index d14f458fa9d4fcaf8b6cdd50bf276c36fa2ef0b6..2efa51e0cc0f63510231109d3eafd6090e208222 100644 --- a/crates/repl/src/notebook/notebook_ui.rs +++ b/crates/repl/src/notebook/notebook_ui.rs @@ -126,29 +126,7 @@ impl NotebookEditor { let cell_count = cell_order.len(); let this = cx.entity(); - let cell_list = ListState::new( - cell_count, - gpui::ListAlignment::Top, - px(1000.), - move |ix, window, cx| { - notebook_handle - .upgrade() - .and_then(|notebook_handle| { - notebook_handle.update(cx, |notebook, cx| { - notebook - .cell_order - .get(ix) - .and_then(|cell_id| notebook.cell_map.get(cell_id)) - .map(|cell| { - notebook - .render_cell(ix, cell, window, cx) - .into_any_element() - }) - }) - }) - .unwrap_or_else(|| div().into_any()) - }, - ); + let cell_list = ListState::new(cell_count, gpui::ListAlignment::Top, px(1000.)); Self { project, @@ -343,7 +321,7 @@ impl NotebookEditor { .child( Self::render_notebook_control( "run-all-cells", - IconName::Play, + IconName::PlayOutlined, window, cx, ) @@ -544,7 +522,19 @@ impl Render for NotebookEditor { .flex_1() .size_full() .overflow_y_scroll() - .child(list(self.cell_list.clone()).size_full()), + .child(list( + self.cell_list.clone(), + cx.processor(|this, ix, window, cx| { + this.cell_order + .get(ix) + .and_then(|cell_id| this.cell_map.get(cell_id)) + .map(|cell| { + this.render_cell(ix, cell, window, cx).into_any_element() + }) + .unwrap_or_else(|| div().into_any()) + }), + )) + .size_full(), ) .child(self.render_notebook_controls(window, cx)) } diff --git a/crates/repl/src/session.rs b/crates/repl/src/session.rs index 18d41f3eae97ce4288d95e1e0eabb57d4b47adec..729a6161350652a90fcf9687593a2f115481a945 100644 --- a/crates/repl/src/session.rs +++ b/crates/repl/src/session.rs @@ -90,7 +90,6 @@ impl EditorBlock { style: BlockStyle::Sticky, render: Self::create_output_area_renderer(execution_view.clone(), on_close.clone()), priority: 0, - render_in_minimap: false, }; let block_id = editor.insert_blocks([block], None, cx)[0]; diff --git a/crates/reqwest_client/src/reqwest_client.rs b/crates/reqwest_client/src/reqwest_client.rs index daff20ac4ad244a7491bc8f6a248d6df3e7e99f5..6461a0ae17d288a9fe282cda39ea1af9ba297e21 100644 --- a/crates/reqwest_client/src/reqwest_client.rs +++ b/crates/reqwest_client/src/reqwest_client.rs @@ -4,14 +4,13 @@ use std::{any::type_name, borrow::Cow, mem, pin::Pin, task::Poll, time::Duration use anyhow::anyhow; use bytes::{BufMut, Bytes, BytesMut}; -use futures::{AsyncRead, TryStreamExt as _}; +use futures::{AsyncRead, FutureExt as _, TryStreamExt as _}; use http_client::{RedirectPolicy, Url, http}; use regex::Regex; use reqwest::{ header::{HeaderMap, HeaderValue}, redirect, }; -use smol::future::FutureExt; const DEFAULT_CAPACITY: usize = 4096; static RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new(); @@ -20,6 +19,7 @@ static REDACT_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"key=[^&]+") pub struct ReqwestClient { client: reqwest::Client, proxy: Option<Url>, + user_agent: Option<HeaderValue>, handle: tokio::runtime::Handle, } @@ -44,9 +44,11 @@ impl ReqwestClient { Ok(client.into()) } - pub fn proxy_and_user_agent(proxy: Option<Url>, agent: &str) -> anyhow::Result<Self> { + pub fn proxy_and_user_agent(proxy: Option<Url>, user_agent: &str) -> anyhow::Result<Self> { + let user_agent = HeaderValue::from_str(user_agent)?; + let mut map = HeaderMap::new(); - map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?); + map.insert(http::header::USER_AGENT, user_agent.clone()); let mut client = Self::builder().default_headers(map); let client_has_proxy; @@ -73,6 +75,7 @@ impl ReqwestClient { .build()?; let mut client: ReqwestClient = client.into(); client.proxy = client_has_proxy.then_some(proxy).flatten(); + client.user_agent = Some(user_agent); Ok(client) } } @@ -96,6 +99,7 @@ impl From<reqwest::Client> for ReqwestClient { client, handle, proxy: None, + user_agent: None, } } } @@ -216,6 +220,10 @@ impl http_client::HttpClient for ReqwestClient { type_name::<Self>() } + fn user_agent(&self) -> Option<&HeaderValue> { + self.user_agent.as_ref() + } + fn send( &self, req: http::Request<http_client::AsyncBody>, @@ -265,6 +273,26 @@ impl http_client::HttpClient for ReqwestClient { } .boxed() } + + fn send_multipart_form<'a>( + &'a self, + url: &str, + form: reqwest::multipart::Form, + ) -> futures::future::BoxFuture<'a, anyhow::Result<http_client::Response<http_client::AsyncBody>>> + { + let response = self.client.post(url).multipart(form).send(); + self.handle + .spawn(async move { + let response = response.await?; + let mut builder = http::response::Builder::new().status(response.status()); + for (k, v) in response.headers() { + builder = builder.header(k, v) + } + Ok(builder.body(response.bytes().await?.into())?) + }) + .map(|e| e?) + .boxed() + } } #[cfg(test)] diff --git a/crates/rope/src/rope.rs b/crates/rope/src/rope.rs index 535b863b7d7b1e66b8621b2da02c8f8d9c7f3912..d8ed3bfac86bdbb360f6a161242aa874a0fd51af 100644 --- a/crates/rope/src/rope.rs +++ b/crates/rope/src/rope.rs @@ -12,7 +12,7 @@ use std::{ ops::{self, AddAssign, Range}, str, }; -use sum_tree::{Bias, Dimension, SumTree}; +use sum_tree::{Bias, Dimension, Dimensions, SumTree}; pub use chunk::ChunkSlice; pub use offset_utf16::OffsetUtf16; @@ -41,9 +41,9 @@ impl Rope { self.push_chunk(chunk.as_slice()); let mut chunks = rope.chunks.cursor::<()>(&()); - chunks.next(&()); - chunks.next(&()); - self.chunks.append(chunks.suffix(&()), &()); + chunks.next(); + chunks.next(); + self.chunks.append(chunks.suffix(), &()); self.check_invariants(); return; } @@ -282,8 +282,8 @@ impl Rope { if offset >= self.summary().len { return self.summary().len_utf16; } - let mut cursor = self.chunks.cursor::<(usize, OffsetUtf16)>(&()); - cursor.seek(&offset, Bias::Left, &()); + let mut cursor = self.chunks.cursor::<Dimensions<usize, OffsetUtf16>>(&()); + cursor.seek(&offset, Bias::Left); let overshoot = offset - cursor.start().0; cursor.start().1 + cursor.item().map_or(Default::default(), |chunk| { @@ -295,8 +295,8 @@ impl Rope { if offset >= self.summary().len_utf16 { return self.summary().len; } - let mut cursor = self.chunks.cursor::<(OffsetUtf16, usize)>(&()); - cursor.seek(&offset, Bias::Left, &()); + let mut cursor = self.chunks.cursor::<Dimensions<OffsetUtf16, usize>>(&()); + cursor.seek(&offset, Bias::Left); let overshoot = offset - cursor.start().0; cursor.start().1 + cursor.item().map_or(Default::default(), |chunk| { @@ -308,8 +308,8 @@ impl Rope { if offset >= self.summary().len { return self.summary().lines; } - let mut cursor = self.chunks.cursor::<(usize, Point)>(&()); - cursor.seek(&offset, Bias::Left, &()); + let mut cursor = self.chunks.cursor::<Dimensions<usize, Point>>(&()); + cursor.seek(&offset, Bias::Left); let overshoot = offset - cursor.start().0; cursor.start().1 + cursor.item().map_or(Point::zero(), |chunk| { @@ -321,8 +321,8 @@ impl Rope { if offset >= self.summary().len { return self.summary().lines_utf16(); } - let mut cursor = self.chunks.cursor::<(usize, PointUtf16)>(&()); - cursor.seek(&offset, Bias::Left, &()); + let mut cursor = self.chunks.cursor::<Dimensions<usize, PointUtf16>>(&()); + cursor.seek(&offset, Bias::Left); let overshoot = offset - cursor.start().0; cursor.start().1 + cursor.item().map_or(PointUtf16::zero(), |chunk| { @@ -334,8 +334,8 @@ impl Rope { if point >= self.summary().lines { return self.summary().lines_utf16(); } - let mut cursor = self.chunks.cursor::<(Point, PointUtf16)>(&()); - cursor.seek(&point, Bias::Left, &()); + let mut cursor = self.chunks.cursor::<Dimensions<Point, PointUtf16>>(&()); + cursor.seek(&point, Bias::Left); let overshoot = point - cursor.start().0; cursor.start().1 + cursor.item().map_or(PointUtf16::zero(), |chunk| { @@ -347,8 +347,8 @@ impl Rope { if point >= self.summary().lines { return self.summary().len; } - let mut cursor = self.chunks.cursor::<(Point, usize)>(&()); - cursor.seek(&point, Bias::Left, &()); + let mut cursor = self.chunks.cursor::<Dimensions<Point, usize>>(&()); + cursor.seek(&point, Bias::Left); let overshoot = point - cursor.start().0; cursor.start().1 + cursor @@ -368,8 +368,8 @@ impl Rope { if point >= self.summary().lines_utf16() { return self.summary().len; } - let mut cursor = self.chunks.cursor::<(PointUtf16, usize)>(&()); - cursor.seek(&point, Bias::Left, &()); + let mut cursor = self.chunks.cursor::<Dimensions<PointUtf16, usize>>(&()); + cursor.seek(&point, Bias::Left); let overshoot = point - cursor.start().0; cursor.start().1 + cursor.item().map_or(0, |chunk| { @@ -381,8 +381,8 @@ impl Rope { if point.0 >= self.summary().lines_utf16() { return self.summary().lines; } - let mut cursor = self.chunks.cursor::<(PointUtf16, Point)>(&()); - cursor.seek(&point.0, Bias::Left, &()); + let mut cursor = self.chunks.cursor::<Dimensions<PointUtf16, Point>>(&()); + cursor.seek(&point.0, Bias::Left); let overshoot = Unclipped(point.0 - cursor.start().0); cursor.start().1 + cursor.item().map_or(Point::zero(), |chunk| { @@ -392,7 +392,7 @@ impl Rope { pub fn clip_offset(&self, mut offset: usize, bias: Bias) -> usize { let mut cursor = self.chunks.cursor::<usize>(&()); - cursor.seek(&offset, Bias::Left, &()); + cursor.seek(&offset, Bias::Left); if let Some(chunk) = cursor.item() { let mut ix = offset - cursor.start(); while !chunk.text.is_char_boundary(ix) { @@ -415,7 +415,7 @@ impl Rope { pub fn clip_offset_utf16(&self, offset: OffsetUtf16, bias: Bias) -> OffsetUtf16 { let mut cursor = self.chunks.cursor::<OffsetUtf16>(&()); - cursor.seek(&offset, Bias::Right, &()); + cursor.seek(&offset, Bias::Right); if let Some(chunk) = cursor.item() { let overshoot = offset - cursor.start(); *cursor.start() + chunk.as_slice().clip_offset_utf16(overshoot, bias) @@ -426,7 +426,7 @@ impl Rope { pub fn clip_point(&self, point: Point, bias: Bias) -> Point { let mut cursor = self.chunks.cursor::<Point>(&()); - cursor.seek(&point, Bias::Right, &()); + cursor.seek(&point, Bias::Right); if let Some(chunk) = cursor.item() { let overshoot = point - cursor.start(); *cursor.start() + chunk.as_slice().clip_point(overshoot, bias) @@ -437,7 +437,7 @@ impl Rope { pub fn clip_point_utf16(&self, point: Unclipped<PointUtf16>, bias: Bias) -> PointUtf16 { let mut cursor = self.chunks.cursor::<PointUtf16>(&()); - cursor.seek(&point.0, Bias::Right, &()); + cursor.seek(&point.0, Bias::Right); if let Some(chunk) = cursor.item() { let overshoot = Unclipped(point.0 - cursor.start()); *cursor.start() + chunk.as_slice().clip_point_utf16(overshoot, bias) @@ -450,10 +450,6 @@ impl Rope { self.clip_point(Point::new(row, u32::MAX), Bias::Left) .column } - - pub fn ptr_eq(&self, other: &Self) -> bool { - self.chunks.ptr_eq(&other.chunks) - } } impl<'a> From<&'a str> for Rope { @@ -475,11 +471,19 @@ impl<'a> FromIterator<&'a str> for Rope { } impl From<String> for Rope { + #[inline(always)] fn from(text: String) -> Self { Rope::from(text.as_str()) } } +impl From<&String> for Rope { + #[inline(always)] + fn from(text: &String) -> Self { + Rope::from(text.as_str()) + } +} + impl fmt::Display for Rope { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { for chunk in self.chunks() { @@ -514,7 +518,7 @@ pub struct Cursor<'a> { impl<'a> Cursor<'a> { pub fn new(rope: &'a Rope, offset: usize) -> Self { let mut chunks = rope.chunks.cursor(&()); - chunks.seek(&offset, Bias::Right, &()); + chunks.seek(&offset, Bias::Right); Self { rope, chunks, @@ -525,7 +529,7 @@ impl<'a> Cursor<'a> { pub fn seek_forward(&mut self, end_offset: usize) { debug_assert!(end_offset >= self.offset); - self.chunks.seek_forward(&end_offset, Bias::Right, &()); + self.chunks.seek_forward(&end_offset, Bias::Right); self.offset = end_offset; } @@ -540,14 +544,14 @@ impl<'a> Cursor<'a> { let mut slice = Rope::new(); if let Some(start_chunk) = self.chunks.item() { let start_ix = self.offset - self.chunks.start(); - let end_ix = cmp::min(end_offset, self.chunks.end(&())) - self.chunks.start(); + let end_ix = cmp::min(end_offset, self.chunks.end()) - self.chunks.start(); slice.push_chunk(start_chunk.slice(start_ix..end_ix)); } - if end_offset > self.chunks.end(&()) { - self.chunks.next(&()); + if end_offset > self.chunks.end() { + self.chunks.next(); slice.append(Rope { - chunks: self.chunks.slice(&end_offset, Bias::Right, &()), + chunks: self.chunks.slice(&end_offset, Bias::Right), }); if let Some(end_chunk) = self.chunks.item() { let end_ix = end_offset - self.chunks.start(); @@ -565,13 +569,13 @@ impl<'a> Cursor<'a> { let mut summary = D::zero(&()); if let Some(start_chunk) = self.chunks.item() { let start_ix = self.offset - self.chunks.start(); - let end_ix = cmp::min(end_offset, self.chunks.end(&())) - self.chunks.start(); + let end_ix = cmp::min(end_offset, self.chunks.end()) - self.chunks.start(); summary.add_assign(&D::from_chunk(start_chunk.slice(start_ix..end_ix))); } - if end_offset > self.chunks.end(&()) { - self.chunks.next(&()); - summary.add_assign(&self.chunks.summary(&end_offset, Bias::Right, &())); + if end_offset > self.chunks.end() { + self.chunks.next(); + summary.add_assign(&self.chunks.summary(&end_offset, Bias::Right)); if let Some(end_chunk) = self.chunks.item() { let end_ix = end_offset - self.chunks.start(); summary.add_assign(&D::from_chunk(end_chunk.slice(0..end_ix))); @@ -603,10 +607,10 @@ impl<'a> Chunks<'a> { pub fn new(rope: &'a Rope, range: Range<usize>, reversed: bool) -> Self { let mut chunks = rope.chunks.cursor(&()); let offset = if reversed { - chunks.seek(&range.end, Bias::Left, &()); + chunks.seek(&range.end, Bias::Left); range.end } else { - chunks.seek(&range.start, Bias::Right, &()); + chunks.seek(&range.start, Bias::Right); range.start }; Self { @@ -642,10 +646,10 @@ impl<'a> Chunks<'a> { Bias::Right }; - if offset >= self.chunks.end(&()) { - self.chunks.seek_forward(&offset, bias, &()); + if offset >= self.chunks.end() { + self.chunks.seek_forward(&offset, bias); } else { - self.chunks.seek(&offset, bias, &()); + self.chunks.seek(&offset, bias); } self.offset = offset; @@ -674,25 +678,25 @@ impl<'a> Chunks<'a> { found = self.offset <= self.range.end; } else { self.chunks - .search_forward(|summary| summary.text.lines.row > 0, &()); + .search_forward(|summary| summary.text.lines.row > 0); self.offset = *self.chunks.start(); if let Some(newline_ix) = self.peek().and_then(|chunk| chunk.find('\n')) { self.offset += newline_ix + 1; found = self.offset <= self.range.end; } else { - self.offset = self.chunks.end(&()); + self.offset = self.chunks.end(); } } - if self.offset == self.chunks.end(&()) { + if self.offset == self.chunks.end() { self.next(); } } if self.offset > self.range.end { self.offset = cmp::min(self.offset, self.range.end); - self.chunks.seek(&self.offset, Bias::Right, &()); + self.chunks.seek(&self.offset, Bias::Right); } found @@ -711,7 +715,7 @@ impl<'a> Chunks<'a> { let initial_offset = self.offset; if self.offset == *self.chunks.start() { - self.chunks.prev(&()); + self.chunks.prev(); } if let Some(chunk) = self.chunks.item() { @@ -729,14 +733,14 @@ impl<'a> Chunks<'a> { } self.chunks - .search_backward(|summary| summary.text.lines.row > 0, &()); + .search_backward(|summary| summary.text.lines.row > 0); self.offset = *self.chunks.start(); if let Some(chunk) = self.chunks.item() { if let Some(newline_ix) = chunk.text.rfind('\n') { self.offset += newline_ix + 1; if self.offset_is_valid() { - if self.offset == self.chunks.end(&()) { - self.chunks.next(&()); + if self.offset == self.chunks.end() { + self.chunks.next(); } return true; @@ -746,7 +750,7 @@ impl<'a> Chunks<'a> { if !self.offset_is_valid() || self.chunks.item().is_none() { self.offset = self.range.start; - self.chunks.seek(&self.offset, Bias::Right, &()); + self.chunks.seek(&self.offset, Bias::Right); } self.offset < initial_offset && self.offset == 0 @@ -765,7 +769,7 @@ impl<'a> Chunks<'a> { slice_start..slice_end } else { let slice_start = self.offset - chunk_start; - let slice_end = cmp::min(self.chunks.end(&()), self.range.end) - chunk_start; + let slice_end = cmp::min(self.chunks.end(), self.range.end) - chunk_start; slice_start..slice_end }; @@ -825,12 +829,12 @@ impl<'a> Iterator for Chunks<'a> { if self.reversed { self.offset -= chunk.len(); if self.offset <= *self.chunks.start() { - self.chunks.prev(&()); + self.chunks.prev(); } } else { self.offset += chunk.len(); - if self.offset >= self.chunks.end(&()) { - self.chunks.next(&()); + if self.offset >= self.chunks.end() { + self.chunks.next(); } } @@ -848,9 +852,9 @@ impl<'a> Bytes<'a> { pub fn new(rope: &'a Rope, range: Range<usize>, reversed: bool) -> Self { let mut chunks = rope.chunks.cursor(&()); if reversed { - chunks.seek(&range.end, Bias::Left, &()); + chunks.seek(&range.end, Bias::Left); } else { - chunks.seek(&range.start, Bias::Right, &()); + chunks.seek(&range.start, Bias::Right); } Self { chunks, @@ -861,7 +865,7 @@ impl<'a> Bytes<'a> { pub fn peek(&self) -> Option<&'a [u8]> { let chunk = self.chunks.item()?; - if self.reversed && self.range.start >= self.chunks.end(&()) { + if self.reversed && self.range.start >= self.chunks.end() { return None; } let chunk_start = *self.chunks.start(); @@ -881,9 +885,9 @@ impl<'a> Iterator for Bytes<'a> { let result = self.peek(); if result.is_some() { if self.reversed { - self.chunks.prev(&()); + self.chunks.prev(); } else { - self.chunks.next(&()); + self.chunks.next(); } } result @@ -905,9 +909,9 @@ impl io::Read for Bytes<'_> { if len == chunk.len() { if self.reversed { - self.chunks.prev(&()); + self.chunks.prev(); } else { - self.chunks.next(&()); + self.chunks.next(); } } Ok(len) @@ -1172,16 +1176,17 @@ pub trait TextDimension: fn add_assign(&mut self, other: &Self); } -impl<D1: TextDimension, D2: TextDimension> TextDimension for (D1, D2) { +impl<D1: TextDimension, D2: TextDimension> TextDimension for Dimensions<D1, D2, ()> { fn from_text_summary(summary: &TextSummary) -> Self { - ( + Dimensions( D1::from_text_summary(summary), D2::from_text_summary(summary), + (), ) } fn from_chunk(chunk: ChunkSlice) -> Self { - (D1::from_chunk(chunk), D2::from_chunk(chunk)) + Dimensions(D1::from_chunk(chunk), D2::from_chunk(chunk), ()) } fn add_assign(&mut self, other: &Self) { diff --git a/crates/rules_library/src/rules_library.rs b/crates/rules_library/src/rules_library.rs index f871416f391d844d324ee3a11d9c41465ea0dccd..ebec96dd7b6298720653768cfbe1b4515cd376fd 100644 --- a/crates/rules_library/src/rules_library.rs +++ b/crates/rules_library/src/rules_library.rs @@ -319,7 +319,7 @@ impl PickerDelegate for RulePickerDelegate { }) .into_any() } else { - IconButton::new("delete-rule", IconName::TrashAlt) + IconButton::new("delete-rule", IconName::Trash) .icon_color(Color::Muted) .icon_size(IconSize::Small) .shape(IconButtonShape::Square) @@ -611,7 +611,7 @@ impl RulesLibrary { this.update_in(cx, |this, window, cx| match rule { Ok(rule) => { let title_editor = cx.new(|cx| { - let mut editor = Editor::auto_width(window, cx); + let mut editor = Editor::single_line(window, cx); editor.set_placeholder_text("Untitled", cx); editor.set_text(rule_metadata.title.unwrap_or_default(), window, cx); if prompt_id.is_built_in() { @@ -1101,7 +1101,7 @@ impl RulesLibrary { inlay_hints_style: editor::make_inlay_hints_style( cx, ), - inline_completion_styles: + edit_prediction_styles: editor::make_suggestion_styles(cx), ..EditorStyle::default() }, @@ -1163,7 +1163,7 @@ impl RulesLibrary { }) .into_any() } else { - IconButton::new("delete-rule", IconName::TrashAlt) + IconButton::new("delete-rule", IconName::Trash) .icon_size(IconSize::Small) .tooltip(move |window, cx| { Tooltip::for_action( diff --git a/crates/search/src/buffer_search.rs b/crates/search/src/buffer_search.rs index c2590ec9b04df03434a9434ebbd44af9c6ebb698..5d77a95027a6cae193ec293ce7e9254a0fa0363c 100644 --- a/crates/search/src/buffer_search.rs +++ b/crates/search/src/buffer_search.rs @@ -228,16 +228,17 @@ impl Render for BufferSearchBar { if in_replace { key_context.add("in_replace"); } - let editor_border = if self.query_error.is_some() { + let query_border = if self.query_error.is_some() { Color::Error.color(cx) } else { cx.theme().colors().border }; + let replacement_border = cx.theme().colors().border; let container_width = window.viewport_size().width; let input_width = SearchInputWidth::calc_width(container_width); - let input_base_styles = || { + let input_base_styles = |border_color| { h_flex() .min_w_32() .w(input_width) @@ -246,7 +247,7 @@ impl Render for BufferSearchBar { .pr_1() .py_1() .border_1() - .border_color(editor_border) + .border_color(border_color) .rounded_lg() }; @@ -256,7 +257,7 @@ impl Render for BufferSearchBar { el.child(Label::new("Find in results").color(Color::Hint)) }) .child( - input_base_styles() + input_base_styles(query_border) .id("editor-scroll") .track_scroll(&self.editor_scroll_handle) .child(self.render_text_input(&self.query_editor, color_override, cx)) @@ -430,11 +431,13 @@ impl Render for BufferSearchBar { let replace_line = should_show_replace_input.then(|| { h_flex() .gap_2() - .child(input_base_styles().child(self.render_text_input( - &self.replacement_editor, - None, - cx, - ))) + .child( + input_base_styles(replacement_border).child(self.render_text_input( + &self.replacement_editor, + None, + cx, + )), + ) .child( h_flex() .min_w_64() @@ -700,7 +703,11 @@ impl BufferSearchBar { window: &mut Window, cx: &mut Context<Self>, ) -> Self { - let query_editor = cx.new(|cx| Editor::single_line(window, cx)); + let query_editor = cx.new(|cx| { + let mut editor = Editor::single_line(window, cx); + editor.set_use_autoclose(false); + editor + }); cx.subscribe_in(&query_editor, window, Self::on_query_editor_event) .detach(); let replacement_editor = cx.new(|cx| Editor::single_line(window, cx)); @@ -771,6 +778,7 @@ impl BufferSearchBar { pub fn dismiss(&mut self, _: &Dismiss, window: &mut Window, cx: &mut Context<Self>) { self.dismissed = true; + self.query_error = None; for searchable_item in self.searchable_items_with_matches.keys() { if let Some(searchable_item) = WeakSearchableItemHandle::upgrade(searchable_item.as_ref(), cx) diff --git a/crates/search/src/project_search.rs b/crates/search/src/project_search.rs index 57ca5e56b9447f8552abac55c6d79a5f6e8326a1..15c1099aec1ecb0e6e6873f465474dad706a3657 100644 --- a/crates/search/src/project_search.rs +++ b/crates/search/src/project_search.rs @@ -195,6 +195,7 @@ pub struct ProjectSearch { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] enum InputPanel { Query, + Replacement, Exclude, Include, } @@ -354,8 +355,9 @@ impl ProjectSearch { while let Some(new_ranges) = new_ranges.next().await { project_search - .update(cx, |project_search, _| { + .update(cx, |project_search, cx| { project_search.match_ranges.extend(new_ranges); + cx.notify(); }) .ok()?; } @@ -1962,7 +1964,7 @@ impl Render for ProjectSearchBar { MultipleInputs, } - let input_base_styles = |base_style: BaseStyle| { + let input_base_styles = |base_style: BaseStyle, panel: InputPanel| { h_flex() .min_w_32() .map(|div| match base_style { @@ -1974,11 +1976,11 @@ impl Render for ProjectSearchBar { .pr_1() .py_1() .border_1() - .border_color(search.border_color_for(InputPanel::Query, cx)) + .border_color(search.border_color_for(panel, cx)) .rounded_lg() }; - let query_column = input_base_styles(BaseStyle::SingleInput) + let query_column = input_base_styles(BaseStyle::SingleInput, InputPanel::Query) .on_action(cx.listener(|this, action, window, cx| this.confirm(action, window, cx))) .on_action(cx.listener(|this, action, window, cx| { this.previous_history_query(action, window, cx) @@ -2167,7 +2169,7 @@ impl Render for ProjectSearchBar { .child(h_flex().min_w_64().child(mode_column).child(matches_column)); let replace_line = search.replace_enabled.then(|| { - let replace_column = input_base_styles(BaseStyle::SingleInput) + let replace_column = input_base_styles(BaseStyle::SingleInput, InputPanel::Replacement) .child(self.render_text_input(&search.replacement_editor, cx)); let focus_handle = search.replacement_editor.read(cx).focus_handle(cx); @@ -2241,7 +2243,7 @@ impl Render for ProjectSearchBar { .gap_2() .w(input_width) .child( - input_base_styles(BaseStyle::MultipleInputs) + input_base_styles(BaseStyle::MultipleInputs, InputPanel::Include) .on_action(cx.listener(|this, action, window, cx| { this.previous_history_query(action, window, cx) })) @@ -2251,7 +2253,7 @@ impl Render for ProjectSearchBar { .child(self.render_text_input(&search.included_files_editor, cx)), ) .child( - input_base_styles(BaseStyle::MultipleInputs) + input_base_styles(BaseStyle::MultipleInputs, InputPanel::Exclude) .on_action(cx.listener(|this, action, window, cx| { this.previous_history_query(action, window, cx) })) diff --git a/crates/search/src/search_status_button.rs b/crates/search/src/search_status_button.rs index fcdf36041f282376716aac3bde78baf8a667a68e..ff2ee1641d07a68c52e88a9686e90b2f3f40c4c5 100644 --- a/crates/search/src/search_status_button.rs +++ b/crates/search/src/search_status_button.rs @@ -1,9 +1,6 @@ use editor::EditorSettings; use settings::Settings as _; -use ui::{ - ButtonCommon, ButtonLike, Clickable, Color, Context, Icon, IconName, IconSize, ParentElement, - Render, Styled, Tooltip, Window, h_flex, -}; +use ui::{ButtonCommon, Clickable, Context, Render, Tooltip, Window, prelude::*}; use workspace::{ItemHandle, StatusItemView}; pub struct SearchButton; @@ -16,18 +13,15 @@ impl SearchButton { impl Render for SearchButton { fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl ui::IntoElement { - let button = h_flex().gap_2(); + let button = div(); + if !EditorSettings::get_global(cx).search.button { - return button; + return button.w_0().invisible(); } button.child( - ButtonLike::new("project-search-indicator") - .child( - Icon::new(IconName::MagnifyingGlass) - .size(IconSize::Small) - .color(Color::Default), - ) + IconButton::new("project-search-indicator", IconName::MagnifyingGlass) + .icon_size(IconSize::Small) .tooltip(|window, cx| { Tooltip::for_action( "Project Search", diff --git a/crates/semantic_index/src/project_index_debug_view.rs b/crates/semantic_index/src/project_index_debug_view.rs index 1b0d87fca0d3f0bd676741f3142927ada47180f3..8d6a49c45caf336c03fe0a2b62ecbca9e079fc65 100644 --- a/crates/semantic_index/src/project_index_debug_view.rs +++ b/crates/semantic_index/src/project_index_debug_view.rs @@ -115,21 +115,9 @@ impl ProjectIndexDebugView { .collect::<Vec<_>>(); this.update(cx, |this, cx| { - let view = cx.entity().downgrade(); this.selected_path = Some(PathState { path: file_path, - list_state: ListState::new( - chunks.len(), - gpui::ListAlignment::Top, - px(100.), - move |ix, _, cx| { - if let Some(view) = view.upgrade() { - view.update(cx, |view, cx| view.render_chunk(ix, cx)) - } else { - div().into_any() - } - }, - ), + list_state: ListState::new(chunks.len(), gpui::ListAlignment::Top, px(100.)), chunks, }); cx.notify(); @@ -219,7 +207,13 @@ impl Render for ProjectIndexDebugView { cx.notify(); })), ) - .child(list(selected_path.list_state.clone()).size_full()) + .child( + list( + selected_path.list_state.clone(), + cx.processor(|this, ix, _, cx| this.render_chunk(ix, cx)), + ) + .size_full(), + ) .size_full() .into_any_element() } else { diff --git a/crates/settings/src/keymap_file.rs b/crates/settings/src/keymap_file.rs index 19bc58ea2342dae25be0bebfcab600771594989c..7802671fecdcafe26a22057b8484ddfcbe7556fd 100644 --- a/crates/settings/src/keymap_file.rs +++ b/crates/settings/src/keymap_file.rs @@ -10,6 +10,7 @@ use serde::Deserialize; use serde_json::{Value, json}; use std::borrow::Cow; use std::{any::TypeId, fmt::Write, rc::Rc, sync::Arc, sync::LazyLock}; +use util::ResultExt as _; use util::{ asset_str, markdown::{MarkdownEscaped, MarkdownInlineCode, MarkdownString}, @@ -607,14 +608,31 @@ impl KeymapFile { mut keymap_contents: String, tab_size: usize, ) -> Result<String> { - // if trying to replace a keybinding that is not user-defined, treat it as an add operation match operation { + // if trying to replace a keybinding that is not user-defined, treat it as an add operation KeybindUpdateOperation::Replace { target_keybind_source: target_source, source, - .. + target, } if target_source != KeybindSource::User => { - operation = KeybindUpdateOperation::Add(source); + operation = KeybindUpdateOperation::Add { + source, + from: Some(target), + }; + } + // if trying to remove a keybinding that is not user-defined, treat it as creating a binding + // that binds it to `zed::NoAction` + KeybindUpdateOperation::Remove { + target, + target_keybind_source, + } if target_keybind_source != KeybindSource::User => { + let mut source = target.clone(); + source.action_name = gpui::NoAction.name(); + source.action_arguments.take(); + operation = KeybindUpdateOperation::Add { + source, + from: Some(target), + }; } _ => {} } @@ -623,49 +641,48 @@ impl KeymapFile { // We don't want to modify the file if it's invalid. let keymap = Self::parse(&keymap_contents).context("Failed to parse keymap")?; + if let KeybindUpdateOperation::Remove { target, .. } = operation { + let target_action_value = target + .action_value() + .context("Failed to generate target action JSON value")?; + let Some((index, keystrokes_str)) = + find_binding(&keymap, &target, &target_action_value) + else { + anyhow::bail!("Failed to find keybinding to remove"); + }; + let is_only_binding = keymap.0[index] + .bindings + .as_ref() + .map_or(true, |bindings| bindings.len() == 1); + let key_path: &[&str] = if is_only_binding { + &[] + } else { + &["bindings", keystrokes_str] + }; + let (replace_range, replace_value) = replace_top_level_array_value_in_json_text( + &keymap_contents, + key_path, + None, + None, + index, + tab_size, + ) + .context("Failed to remove keybinding")?; + keymap_contents.replace_range(replace_range, &replace_value); + return Ok(keymap_contents); + } + if let KeybindUpdateOperation::Replace { source, target, .. } = operation { - let mut found_index = None; let target_action_value = target .action_value() .context("Failed to generate target action JSON value")?; let source_action_value = source .action_value() .context("Failed to generate source action JSON value")?; - 'sections: for (index, section) in keymap.sections().enumerate() { - if section.context != target.context.unwrap_or("") { - continue; - } - if section.use_key_equivalents != target.use_key_equivalents { - continue; - } - let Some(bindings) = §ion.bindings else { - continue; - }; - for (keystrokes, action) in bindings { - let Ok(keystrokes) = keystrokes - .split_whitespace() - .map(Keystroke::parse) - .collect::<Result<Vec<_>, _>>() - else { - continue; - }; - if keystrokes.len() != target.keystrokes.len() - || !keystrokes - .iter() - .zip(target.keystrokes) - .all(|(a, b)| a.should_match(b)) - { - continue; - } - if action.0 != target_action_value { - continue; - } - found_index = Some(index); - break 'sections; - } - } - if let Some(index) = found_index { + if let Some((index, keystrokes_str)) = + find_binding(&keymap, &target, &target_action_value) + { if target.context == source.context { // if we are only changing the keybinding (common case) // not the context, etc. Then just update the binding in place @@ -673,7 +690,7 @@ impl KeymapFile { let (replace_range, replace_value) = replace_top_level_array_value_in_json_text( &keymap_contents, - &["bindings", &target.keystrokes_unparsed()], + &["bindings", keystrokes_str], Some(&source_action_value), Some(&source.keystrokes_unparsed()), index, @@ -695,7 +712,7 @@ impl KeymapFile { let (replace_range, replace_value) = replace_top_level_array_value_in_json_text( &keymap_contents, - &["bindings", &target.keystrokes_unparsed()], + &["bindings", keystrokes_str], Some(&source_action_value), Some(&source.keystrokes_unparsed()), index, @@ -725,7 +742,7 @@ impl KeymapFile { let (replace_range, replace_value) = replace_top_level_array_value_in_json_text( &keymap_contents, - &["bindings", &target.keystrokes_unparsed()], + &["bindings", keystrokes_str], None, None, index, @@ -733,7 +750,10 @@ impl KeymapFile { ) .context("Failed to replace keybinding")?; keymap_contents.replace_range(replace_range, &replace_value); - operation = KeybindUpdateOperation::Add(source); + operation = KeybindUpdateOperation::Add { + source, + from: Some(target), + }; } } else { log::warn!( @@ -743,16 +763,28 @@ impl KeymapFile { source.keystrokes, source_action_value, ); - operation = KeybindUpdateOperation::Add(source); + operation = KeybindUpdateOperation::Add { + source, + from: Some(target), + }; } } - if let KeybindUpdateOperation::Add(keybinding) = operation { + if let KeybindUpdateOperation::Add { + source: keybinding, + from, + } = operation + { let mut value = serde_json::Map::with_capacity(4); if let Some(context) = keybinding.context { value.insert("context".to_string(), context.into()); } - if keybinding.use_key_equivalents { + let use_key_equivalents = from.and_then(|from| { + let action_value = from.action_value().context("Failed to serialize action value. `use_key_equivalents` on new keybinding may be incorrect.").log_err()?; + let (index, _) = find_binding(&keymap, &from, &action_value)?; + Some(keymap.0[index].use_key_equivalents) + }).unwrap_or(false); + if use_key_equivalents { value.insert("use_key_equivalents".to_string(), true.into()); } @@ -771,9 +803,51 @@ impl KeymapFile { keymap_contents.replace_range(replace_range, &replace_value); } return Ok(keymap_contents); + + fn find_binding<'a, 'b>( + keymap: &'b KeymapFile, + target: &KeybindUpdateTarget<'a>, + target_action_value: &Value, + ) -> Option<(usize, &'b str)> { + let target_context_parsed = + KeyBindingContextPredicate::parse(target.context.unwrap_or("")).ok(); + for (index, section) in keymap.sections().enumerate() { + let section_context_parsed = + KeyBindingContextPredicate::parse(§ion.context).ok(); + if section_context_parsed != target_context_parsed { + continue; + } + let Some(bindings) = §ion.bindings else { + continue; + }; + for (keystrokes_str, action) in bindings { + let Ok(keystrokes) = keystrokes_str + .split_whitespace() + .map(Keystroke::parse) + .collect::<Result<Vec<_>, _>>() + else { + continue; + }; + if keystrokes.len() != target.keystrokes.len() + || !keystrokes + .iter() + .zip(target.keystrokes) + .all(|(a, b)| a.should_match(b)) + { + continue; + } + if &action.0 != target_action_value { + continue; + } + return Some((index, &keystrokes_str)); + } + } + None + } } } +#[derive(Clone)] pub enum KeybindUpdateOperation<'a> { Replace { /// Describes the keybind to create @@ -782,25 +856,82 @@ pub enum KeybindUpdateOperation<'a> { target: KeybindUpdateTarget<'a>, target_keybind_source: KeybindSource, }, - Add(KeybindUpdateTarget<'a>), + Add { + source: KeybindUpdateTarget<'a>, + from: Option<KeybindUpdateTarget<'a>>, + }, + Remove { + target: KeybindUpdateTarget<'a>, + target_keybind_source: KeybindSource, + }, +} + +impl KeybindUpdateOperation<'_> { + pub fn generate_telemetry( + &self, + ) -> ( + // The keybind that is created + String, + // The keybinding that was removed + String, + // The source of the keybinding + String, + ) { + let (new_binding, removed_binding, source) = match &self { + KeybindUpdateOperation::Replace { + source, + target, + target_keybind_source, + } => (Some(source), Some(target), Some(*target_keybind_source)), + KeybindUpdateOperation::Add { source, .. } => (Some(source), None, None), + KeybindUpdateOperation::Remove { + target, + target_keybind_source, + } => (None, Some(target), Some(*target_keybind_source)), + }; + + let new_binding = new_binding + .map(KeybindUpdateTarget::telemetry_string) + .unwrap_or("null".to_owned()); + let removed_binding = removed_binding + .map(KeybindUpdateTarget::telemetry_string) + .unwrap_or("null".to_owned()); + + let source = source + .as_ref() + .map(KeybindSource::name) + .map(ToOwned::to_owned) + .unwrap_or("null".to_owned()); + + (new_binding, removed_binding, source) + } } +impl<'a> KeybindUpdateOperation<'a> { + pub fn add(source: KeybindUpdateTarget<'a>) -> Self { + Self::Add { source, from: None } + } +} + +#[derive(Debug, Clone)] pub struct KeybindUpdateTarget<'a> { pub context: Option<&'a str>, pub keystrokes: &'a [Keystroke], pub action_name: &'a str, - pub use_key_equivalents: bool, - pub input: Option<&'a str>, + pub action_arguments: Option<&'a str>, } impl<'a> KeybindUpdateTarget<'a> { fn action_value(&self) -> Result<Value> { + if self.action_name == gpui::NoAction.name() { + return Ok(Value::Null); + } let action_name: Value = self.action_name.into(); - let value = match self.input { - Some(input) => { - let input = serde_json::from_str::<Value>(input) - .context("Failed to parse action input as JSON")?; - serde_json::json!([action_name, input]) + let value = match self.action_arguments { + Some(args) => { + let args = serde_json::from_str::<Value>(args) + .context("Failed to parse action arguments as JSON")?; + serde_json::json!([action_name, args]) } None => action_name, }; @@ -816,21 +947,33 @@ impl<'a> KeybindUpdateTarget<'a> { keystrokes.pop(); keystrokes } + + fn telemetry_string(&self) -> String { + format!( + "action_name: {}, context: {}, action_arguments: {}, keystrokes: {}", + self.action_name, + self.context.unwrap_or("global"), + self.action_arguments.unwrap_or("none"), + self.keystrokes_unparsed() + ) + } } -#[derive(Clone, Copy, PartialEq, Eq)] +#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord)] pub enum KeybindSource { User, - Default, - Base, Vim, + Base, + #[default] + Default, + Unknown, } impl KeybindSource { - const BASE: KeyBindingMetaIndex = KeyBindingMetaIndex(0); - const DEFAULT: KeyBindingMetaIndex = KeyBindingMetaIndex(1); - const VIM: KeyBindingMetaIndex = KeyBindingMetaIndex(2); - const USER: KeyBindingMetaIndex = KeyBindingMetaIndex(3); + const BASE: KeyBindingMetaIndex = KeyBindingMetaIndex(KeybindSource::Base as u32); + const DEFAULT: KeyBindingMetaIndex = KeyBindingMetaIndex(KeybindSource::Default as u32); + const VIM: KeyBindingMetaIndex = KeyBindingMetaIndex(KeybindSource::Vim as u32); + const USER: KeyBindingMetaIndex = KeyBindingMetaIndex(KeybindSource::User as u32); pub fn name(&self) -> &'static str { match self { @@ -838,6 +981,7 @@ impl KeybindSource { KeybindSource::Default => "Default", KeybindSource::Base => "Base", KeybindSource::Vim => "Vim", + KeybindSource::Unknown => "Unknown", } } @@ -847,6 +991,7 @@ impl KeybindSource { KeybindSource::Default => Self::DEFAULT, KeybindSource::Base => Self::BASE, KeybindSource::Vim => Self::VIM, + KeybindSource::Unknown => KeyBindingMetaIndex(*self as u32), } } @@ -856,7 +1001,7 @@ impl KeybindSource { Self::BASE => KeybindSource::Base, Self::DEFAULT => KeybindSource::Default, Self::VIM => KeybindSource::Vim, - _ => unreachable!(), + _ => KeybindSource::Unknown, } } } @@ -869,12 +1014,13 @@ impl From<KeyBindingMetaIndex> for KeybindSource { impl From<KeybindSource> for KeyBindingMetaIndex { fn from(source: KeybindSource) -> Self { - return source.meta(); + source.meta() } } #[cfg(test)] mod tests { + use gpui::Keystroke; use unindent::Unindent; use crate::{ @@ -897,38 +1043,36 @@ mod tests { KeymapFile::parse(json).unwrap(); } + #[track_caller] + fn check_keymap_update( + input: impl ToString, + operation: KeybindUpdateOperation, + expected: impl ToString, + ) { + let result = KeymapFile::update_keybinding(operation, input.to_string(), 4) + .expect("Update succeeded"); + pretty_assertions::assert_eq!(expected.to_string(), result); + } + + #[track_caller] + fn parse_keystrokes(keystrokes: &str) -> Vec<Keystroke> { + return keystrokes + .split(' ') + .map(|s| Keystroke::parse(s).expect("Keystrokes valid")) + .collect(); + } + #[test] fn keymap_update() { - use gpui::Keystroke; - zlog::init_test(); - #[track_caller] - fn check_keymap_update( - input: impl ToString, - operation: KeybindUpdateOperation, - expected: impl ToString, - ) { - let result = KeymapFile::update_keybinding(operation, input.to_string(), 4) - .expect("Update succeeded"); - pretty_assertions::assert_eq!(expected.to_string(), result); - } - - #[track_caller] - fn parse_keystrokes(keystrokes: &str) -> Vec<Keystroke> { - return keystrokes - .split(' ') - .map(|s| Keystroke::parse(s).expect("Keystrokes valid")) - .collect(); - } check_keymap_update( "[]", - KeybindUpdateOperation::Add(KeybindUpdateTarget { + KeybindUpdateOperation::add(KeybindUpdateTarget { keystrokes: &parse_keystrokes("ctrl-a"), action_name: "zed::SomeAction", context: None, - use_key_equivalents: false, - input: None, + action_arguments: None, }), r#"[ { @@ -949,12 +1093,11 @@ mod tests { } ]"# .unindent(), - KeybindUpdateOperation::Add(KeybindUpdateTarget { + KeybindUpdateOperation::add(KeybindUpdateTarget { keystrokes: &parse_keystrokes("ctrl-b"), action_name: "zed::SomeOtherAction", context: None, - use_key_equivalents: false, - input: None, + action_arguments: None, }), r#"[ { @@ -980,12 +1123,11 @@ mod tests { } ]"# .unindent(), - KeybindUpdateOperation::Add(KeybindUpdateTarget { + KeybindUpdateOperation::add(KeybindUpdateTarget { keystrokes: &parse_keystrokes("ctrl-b"), action_name: "zed::SomeOtherAction", context: None, - use_key_equivalents: false, - input: Some(r#"{"foo": "bar"}"#), + action_arguments: Some(r#"{"foo": "bar"}"#), }), r#"[ { @@ -1016,12 +1158,11 @@ mod tests { } ]"# .unindent(), - KeybindUpdateOperation::Add(KeybindUpdateTarget { + KeybindUpdateOperation::add(KeybindUpdateTarget { keystrokes: &parse_keystrokes("ctrl-b"), action_name: "zed::SomeOtherAction", context: Some("Zed > Editor && some_condition = true"), - use_key_equivalents: true, - input: Some(r#"{"foo": "bar"}"#), + action_arguments: Some(r#"{"foo": "bar"}"#), }), r#"[ { @@ -1031,7 +1172,6 @@ mod tests { }, { "context": "Zed > Editor && some_condition = true", - "use_key_equivalents": true, "bindings": { "ctrl-b": [ "zed::SomeOtherAction", @@ -1059,15 +1199,13 @@ mod tests { keystrokes: &parse_keystrokes("ctrl-a"), action_name: "zed::SomeAction", context: None, - use_key_equivalents: false, - input: None, + action_arguments: None, }, source: KeybindUpdateTarget { keystrokes: &parse_keystrokes("ctrl-b"), action_name: "zed::SomeOtherAction", context: None, - use_key_equivalents: false, - input: Some(r#"{"foo": "bar"}"#), + action_arguments: Some(r#"{"foo": "bar"}"#), }, target_keybind_source: KeybindSource::Base, }, @@ -1105,15 +1243,13 @@ mod tests { keystrokes: &parse_keystrokes("a"), action_name: "zed::SomeAction", context: None, - use_key_equivalents: false, - input: None, + action_arguments: None, }, source: KeybindUpdateTarget { keystrokes: &parse_keystrokes("ctrl-b"), action_name: "zed::SomeOtherAction", context: None, - use_key_equivalents: false, - input: Some(r#"{"foo": "bar"}"#), + action_arguments: Some(r#"{"foo": "bar"}"#), }, target_keybind_source: KeybindSource::User, }, @@ -1146,15 +1282,13 @@ mod tests { keystrokes: &parse_keystrokes("ctrl-a"), action_name: "zed::SomeNonexistentAction", context: None, - use_key_equivalents: false, - input: None, + action_arguments: None, }, source: KeybindUpdateTarget { keystrokes: &parse_keystrokes("ctrl-b"), action_name: "zed::SomeOtherAction", context: None, - use_key_equivalents: false, - input: None, + action_arguments: None, }, target_keybind_source: KeybindSource::User, }, @@ -1189,15 +1323,13 @@ mod tests { keystrokes: &parse_keystrokes("ctrl-a"), action_name: "zed::SomeAction", context: None, - use_key_equivalents: false, - input: None, + action_arguments: None, }, source: KeybindUpdateTarget { keystrokes: &parse_keystrokes("ctrl-b"), action_name: "zed::SomeOtherAction", context: None, - use_key_equivalents: false, - input: Some(r#"{"foo": "bar"}"#), + action_arguments: Some(r#"{"foo": "bar"}"#), }, target_keybind_source: KeybindSource::User, }, @@ -1234,15 +1366,13 @@ mod tests { keystrokes: &parse_keystrokes("a"), action_name: "foo::bar", context: Some("SomeContext"), - use_key_equivalents: false, - input: None, + action_arguments: None, }, source: KeybindUpdateTarget { keystrokes: &parse_keystrokes("c"), action_name: "foo::baz", context: Some("SomeOtherContext"), - use_key_equivalents: false, - input: None, + action_arguments: None, }, target_keybind_source: KeybindSource::User, }, @@ -1278,15 +1408,13 @@ mod tests { keystrokes: &parse_keystrokes("a"), action_name: "foo::bar", context: Some("SomeContext"), - use_key_equivalents: false, - input: None, + action_arguments: None, }, source: KeybindUpdateTarget { keystrokes: &parse_keystrokes("c"), action_name: "foo::baz", context: Some("SomeOtherContext"), - use_key_equivalents: false, - input: None, + action_arguments: None, }, target_keybind_source: KeybindSource::User, }, @@ -1300,5 +1428,239 @@ mod tests { ]"# .unindent(), ); + + check_keymap_update( + r#"[ + { + "context": "SomeContext", + "bindings": { + "a": "foo::bar", + "c": "foo::baz", + } + }, + ]"# + .unindent(), + KeybindUpdateOperation::Remove { + target: KeybindUpdateTarget { + context: Some("SomeContext"), + keystrokes: &parse_keystrokes("a"), + action_name: "foo::bar", + action_arguments: None, + }, + target_keybind_source: KeybindSource::User, + }, + r#"[ + { + "context": "SomeContext", + "bindings": { + "c": "foo::baz", + } + }, + ]"# + .unindent(), + ); + + check_keymap_update( + r#"[ + { + "context": "SomeContext", + "bindings": { + "a": ["foo::bar", true], + "c": "foo::baz", + } + }, + ]"# + .unindent(), + KeybindUpdateOperation::Remove { + target: KeybindUpdateTarget { + context: Some("SomeContext"), + keystrokes: &parse_keystrokes("a"), + action_name: "foo::bar", + action_arguments: Some("true"), + }, + target_keybind_source: KeybindSource::User, + }, + r#"[ + { + "context": "SomeContext", + "bindings": { + "c": "foo::baz", + } + }, + ]"# + .unindent(), + ); + + check_keymap_update( + r#"[ + { + "context": "SomeContext", + "bindings": { + "b": "foo::baz", + } + }, + { + "context": "SomeContext", + "bindings": { + "a": ["foo::bar", true], + } + }, + { + "context": "SomeContext", + "bindings": { + "c": "foo::baz", + } + }, + ]"# + .unindent(), + KeybindUpdateOperation::Remove { + target: KeybindUpdateTarget { + context: Some("SomeContext"), + keystrokes: &parse_keystrokes("a"), + action_name: "foo::bar", + action_arguments: Some("true"), + }, + target_keybind_source: KeybindSource::User, + }, + r#"[ + { + "context": "SomeContext", + "bindings": { + "b": "foo::baz", + } + }, + { + "context": "SomeContext", + "bindings": { + "c": "foo::baz", + } + }, + ]"# + .unindent(), + ); + check_keymap_update( + r#"[ + { + "context": "SomeOtherContext", + "use_key_equivalents": true, + "bindings": { + "b": "foo::bar", + } + }, + ]"# + .unindent(), + KeybindUpdateOperation::Add { + source: KeybindUpdateTarget { + context: Some("SomeContext"), + keystrokes: &parse_keystrokes("a"), + action_name: "foo::baz", + action_arguments: Some("true"), + }, + from: Some(KeybindUpdateTarget { + context: Some("SomeOtherContext"), + keystrokes: &parse_keystrokes("b"), + action_name: "foo::bar", + action_arguments: None, + }), + }, + r#"[ + { + "context": "SomeOtherContext", + "use_key_equivalents": true, + "bindings": { + "b": "foo::bar", + } + }, + { + "context": "SomeContext", + "use_key_equivalents": true, + "bindings": { + "a": [ + "foo::baz", + true + ] + } + } + ]"# + .unindent(), + ); + + check_keymap_update( + r#"[ + { + "context": "SomeOtherContext", + "use_key_equivalents": true, + "bindings": { + "b": "foo::bar", + } + }, + ]"# + .unindent(), + KeybindUpdateOperation::Remove { + target: KeybindUpdateTarget { + context: Some("SomeContext"), + keystrokes: &parse_keystrokes("a"), + action_name: "foo::baz", + action_arguments: Some("true"), + }, + target_keybind_source: KeybindSource::Default, + }, + r#"[ + { + "context": "SomeOtherContext", + "use_key_equivalents": true, + "bindings": { + "b": "foo::bar", + } + }, + { + "context": "SomeContext", + "bindings": { + "a": null + } + } + ]"# + .unindent(), + ); + } + + #[test] + fn test_keymap_remove() { + zlog::init_test(); + + check_keymap_update( + r#" + [ + { + "context": "Editor", + "bindings": { + "cmd-k cmd-u": "editor::ConvertToUpperCase", + "cmd-k cmd-l": "editor::ConvertToLowerCase", + "cmd-[": "pane::GoBack", + } + }, + ] + "#, + KeybindUpdateOperation::Remove { + target: KeybindUpdateTarget { + context: Some("Editor"), + keystrokes: &parse_keystrokes("cmd-k cmd-l"), + action_name: "editor::ConvertToLowerCase", + action_arguments: None, + }, + target_keybind_source: KeybindSource::User, + }, + r#" + [ + { + "context": "Editor", + "bindings": { + "cmd-k cmd-u": "editor::ConvertToUpperCase", + "cmd-[": "pane::GoBack", + } + }, + ] + "#, + ); } } diff --git a/crates/settings/src/settings.rs b/crates/settings/src/settings.rs index 4e6bd94d92bc09009b2775b7b43f9962341c5f94..afd4ea08907654c9a4fc5edfbfb90bd2f87a3285 100644 --- a/crates/settings/src/settings.rs +++ b/crates/settings/src/settings.rs @@ -7,7 +7,7 @@ mod settings_json; mod settings_store; mod vscode_import; -use gpui::App; +use gpui::{App, Global}; use rust_embed::RustEmbed; use std::{borrow::Cow, fmt, str}; use util::asset_str; @@ -27,6 +27,11 @@ pub use settings_store::{ }; pub use vscode_import::{VsCodeSettings, VsCodeSettingsSource}; +#[derive(Clone, Debug, PartialEq)] +pub struct ActiveSettingsProfileName(pub String); + +impl Global for ActiveSettingsProfileName {} + #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, PartialOrd, Ord)] pub struct WorktreeId(usize); @@ -74,6 +79,7 @@ pub fn init(cx: &mut App) { .unwrap(); cx.set_global(settings); BaseKeymap::register(cx); + SettingsStore::observe_active_settings_profile_name(cx).detach(); } pub fn default_settings() -> Cow<'static, str> { diff --git a/crates/settings/src/settings_json.rs b/crates/settings/src/settings_json.rs index f569a187699b764bbac43cca8c3799ab043c373b..e6683857e778e0c9cd052fc9f72407ee5d7787be 100644 --- a/crates/settings/src/settings_json.rs +++ b/crates/settings/src/settings_json.rs @@ -190,6 +190,7 @@ fn replace_value_in_json_text( } } + let mut removed_comma = false; // Look backward for a preceding comma first let preceding_text = text.get(0..removal_start).unwrap_or(""); if let Some(comma_pos) = preceding_text.rfind(',') { @@ -197,10 +198,12 @@ fn replace_value_in_json_text( let between_comma_and_key = text.get(comma_pos + 1..removal_start).unwrap_or(""); if between_comma_and_key.trim().is_empty() { removal_start = comma_pos; + removed_comma = true; } } - - if let Some(remaining_text) = text.get(existing_value_range.end..) { + if let Some(remaining_text) = text.get(existing_value_range.end..) + && !removed_comma + { let mut chars = remaining_text.char_indices(); while let Some((offset, ch)) = chars.next() { if ch == ',' { @@ -353,29 +356,58 @@ pub fn replace_top_level_array_value_in_json_text( let range = cursor.node().range(); let indent_width = range.start_point.column; let offset = range.start_byte; - let value_str = &text[range.start_byte..range.end_byte]; + let text_range = range.start_byte..range.end_byte; + let value_str = &text[text_range.clone()]; let needs_indent = range.start_point.row > 0; - let (mut replace_range, mut replace_value) = - replace_value_in_json_text(value_str, key_path, tab_size, new_value, replace_key); + if new_value.is_none() && key_path.is_empty() { + let mut remove_range = text_range.clone(); + if index == 0 { + while cursor.goto_next_sibling() + && (cursor.node().is_extra() || cursor.node().is_missing()) + {} + if cursor.node().kind() == "," { + remove_range.end = cursor.node().range().end_byte; + } + if let Some(next_newline) = &text[remove_range.end + 1..].find('\n') { + if text[remove_range.end + 1..remove_range.end + next_newline] + .chars() + .all(|c| c.is_ascii_whitespace()) + { + remove_range.end = remove_range.end + next_newline; + } + } + } else { + while cursor.goto_previous_sibling() + && (cursor.node().is_extra() || cursor.node().is_missing()) + {} + if cursor.node().kind() == "," { + remove_range.start = cursor.node().range().start_byte; + } + } + return Ok((remove_range, String::new())); + } else { + let (mut replace_range, mut replace_value) = + replace_value_in_json_text(value_str, key_path, tab_size, new_value, replace_key); - replace_range.start += offset; - replace_range.end += offset; + replace_range.start += offset; + replace_range.end += offset; - if needs_indent { - let increased_indent = format!("\n{space:width$}", space = ' ', width = indent_width); - replace_value = replace_value.replace('\n', &increased_indent); - // replace_value.push('\n'); - } else { - while let Some(idx) = replace_value.find("\n ") { - replace_value.remove(idx + 1); - } - while let Some(idx) = replace_value.find("\n") { - replace_value.replace_range(idx..idx + 1, " "); + if needs_indent { + let increased_indent = format!("\n{space:width$}", space = ' ', width = indent_width); + replace_value = replace_value.replace('\n', &increased_indent); + // replace_value.push('\n'); + } else { + while let Some(idx) = replace_value.find("\n ") { + replace_value.remove(idx + 1); + } + while let Some(idx) = replace_value.find("\n") { + replace_value.replace_range(idx..idx + 1, " "); + } } - } - return Ok((replace_range, replace_value)); + return Ok((replace_range, replace_value)); + } } pub fn append_top_level_array_value_in_json_text( @@ -408,17 +440,19 @@ pub fn append_top_level_array_value_in_json_text( ); debug_assert_eq!(cursor.node().kind(), "]"); let close_bracket_start = cursor.node().start_byte(); - cursor.goto_previous_sibling(); - while (cursor.node().is_extra() || cursor.node().is_missing()) && cursor.goto_previous_sibling() - { - } + while cursor.goto_previous_sibling() + && (cursor.node().is_extra() || cursor.node().is_missing()) + && !cursor.node().is_error() + {} let mut comma_range = None; let mut prev_item_range = None; - if cursor.node().kind() == "," { + if cursor.node().kind() == "," || is_error_of_kind(&mut cursor, ",") { comma_range = Some(cursor.node().byte_range()); - while cursor.goto_previous_sibling() && cursor.node().is_extra() {} + while cursor.goto_previous_sibling() + && (cursor.node().is_extra() || cursor.node().is_missing()) + {} debug_assert_ne!(cursor.node().kind(), "["); prev_item_range = Some(cursor.node().range()); @@ -485,6 +519,17 @@ pub fn append_top_level_array_value_in_json_text( replace_value.push('\n'); } return Ok((replace_range, replace_value)); + + fn is_error_of_kind(cursor: &mut tree_sitter::TreeCursor<'_>, kind: &str) -> bool { + if cursor.node().kind() != "ERROR" { + return false; + } + + let descendant_index = cursor.descendant_index(); + let res = cursor.goto_first_child() && cursor.node().kind() == kind; + cursor.goto_descendant(descendant_index); + return res; + } } pub fn to_pretty_json( @@ -1005,14 +1050,14 @@ mod tests { input: impl ToString, index: usize, key_path: &[&str], - value: Value, + value: Option<Value>, expected: impl ToString, ) { let input = input.to_string(); let result = replace_top_level_array_value_in_json_text( &input, key_path, - Some(&value), + value.as_ref(), None, index, 4, @@ -1023,10 +1068,10 @@ mod tests { pretty_assertions::assert_eq!(expected.to_string(), result_str); } - check_array_replace(r#"[1, 3, 3]"#, 1, &[], json!(2), r#"[1, 2, 3]"#); - check_array_replace(r#"[1, 3, 3]"#, 2, &[], json!(2), r#"[1, 3, 2]"#); - check_array_replace(r#"[1, 3, 3,]"#, 3, &[], json!(2), r#"[1, 3, 3, 2]"#); - check_array_replace(r#"[1, 3, 3,]"#, 100, &[], json!(2), r#"[1, 3, 3, 2]"#); + check_array_replace(r#"[1, 3, 3]"#, 1, &[], Some(json!(2)), r#"[1, 2, 3]"#); + check_array_replace(r#"[1, 3, 3]"#, 2, &[], Some(json!(2)), r#"[1, 3, 2]"#); + check_array_replace(r#"[1, 3, 3,]"#, 3, &[], Some(json!(2)), r#"[1, 3, 3, 2]"#); + check_array_replace(r#"[1, 3, 3,]"#, 100, &[], Some(json!(2)), r#"[1, 3, 3, 2]"#); check_array_replace( r#"[ 1, @@ -1036,7 +1081,7 @@ mod tests { .unindent(), 1, &[], - json!({"foo": "bar", "baz": "qux"}), + Some(json!({"foo": "bar", "baz": "qux"})), r#"[ 1, { @@ -1051,7 +1096,7 @@ mod tests { r#"[1, 3, 3,]"#, 1, &[], - json!({"foo": "bar", "baz": "qux"}), + Some(json!({"foo": "bar", "baz": "qux"})), r#"[1, { "foo": "bar", "baz": "qux" }, 3,]"#, ); @@ -1059,7 +1104,7 @@ mod tests { r#"[1, { "foo": "bar", "baz": "qux" }, 3,]"#, 1, &["baz"], - json!({"qux": "quz"}), + Some(json!({"qux": "quz"})), r#"[1, { "foo": "bar", "baz": { "qux": "quz" } }, 3,]"#, ); @@ -1074,7 +1119,7 @@ mod tests { ]"#, 1, &["baz"], - json!({"qux": "quz"}), + Some(json!({"qux": "quz"})), r#"[ 1, { @@ -1100,7 +1145,7 @@ mod tests { ]"#, 1, &["baz"], - json!("qux"), + Some(json!("qux")), r#"[ 1, { @@ -1127,7 +1172,7 @@ mod tests { ]"#, 1, &["baz"], - json!("qux"), + Some(json!("qux")), r#"[ 1, { @@ -1151,7 +1196,7 @@ mod tests { ]"#, 2, &[], - json!("replaced"), + Some(json!("replaced")), r#"[ 1, // This is element 2 @@ -1169,7 +1214,7 @@ mod tests { .unindent(), 0, &[], - json!("first"), + Some(json!("first")), r#"[ // Empty array with comment "first" @@ -1180,7 +1225,7 @@ mod tests { r#"[]"#.unindent(), 0, &[], - json!("first"), + Some(json!("first")), r#"[ "first" ]"# @@ -1197,7 +1242,7 @@ mod tests { ]"#, 0, &[], - json!({"new": "object"}), + Some(json!({"new": "object"})), r#"[ // Leading comment // Another leading comment @@ -1217,7 +1262,7 @@ mod tests { ]"#, 1, &[], - json!("deep"), + Some(json!("deep")), r#"[ 1, "deep", @@ -1230,7 +1275,7 @@ mod tests { r#"[1,2, 3, 4]"#, 2, &[], - json!("spaced"), + Some(json!("spaced")), r#"[1,2, "spaced", 4]"#, ); @@ -1243,7 +1288,7 @@ mod tests { ]"#, 1, &[], - json!(["a", "b", "c", "d"]), + Some(json!(["a", "b", "c", "d"])), r#"[ [1, 2, 3], [ @@ -1268,7 +1313,7 @@ mod tests { ]"#, 0, &[], - json!("updated"), + Some(json!("updated")), r#"[ /* * This is a @@ -1284,7 +1329,7 @@ mod tests { r#"[true, false, true]"#, 1, &[], - json!(null), + Some(json!(null)), r#"[true, null, true]"#, ); @@ -1293,7 +1338,7 @@ mod tests { r#"[42]"#, 0, &[], - json!({"answer": 42}), + Some(json!({"answer": 42})), r#"[{ "answer": 42 }]"#, ); @@ -1307,7 +1352,7 @@ mod tests { .unindent(), 10, &[], - json!(123), + Some(json!(123)), r#"[ // Comment 1 // Comment 2 @@ -1316,6 +1361,54 @@ mod tests { ]"# .unindent(), ); + + check_array_replace( + r#"[ + { + "key": "value" + }, + { + "key": "value2" + } + ]"# + .unindent(), + 0, + &[], + None, + r#"[ + { + "key": "value2" + } + ]"# + .unindent(), + ); + + check_array_replace( + r#"[ + { + "key": "value" + }, + { + "key": "value2" + }, + { + "key": "value3" + }, + ]"# + .unindent(), + 1, + &[], + None, + r#"[ + { + "key": "value" + }, + { + "key": "value3" + }, + ]"# + .unindent(), + ); } #[test] diff --git a/crates/settings/src/settings_store.rs b/crates/settings/src/settings_store.rs index 0d23385a682fbf8fc3b8eec97c98748b5664d480..bfdafbffe8e4daf276f76f98a2ab7c535f4e1212 100644 --- a/crates/settings/src/settings_store.rs +++ b/crates/settings/src/settings_store.rs @@ -2,7 +2,11 @@ use anyhow::{Context as _, Result}; use collections::{BTreeMap, HashMap, btree_map, hash_map}; use ec4rs::{ConfigParser, PropertiesSource, Section}; use fs::Fs; -use futures::{FutureExt, StreamExt, channel::mpsc, future::LocalBoxFuture}; +use futures::{ + FutureExt, StreamExt, + channel::{mpsc, oneshot}, + future::LocalBoxFuture, +}; use gpui::{App, AsyncApp, BorrowAppContext, Global, Task, UpdateGlobal}; use paths::{EDITORCONFIG_NAME, local_settings_file_relative_path, task_file_name}; @@ -12,6 +16,7 @@ use serde_json::{Value, json}; use smallvec::SmallVec; use std::{ any::{Any, TypeId, type_name}, + env, fmt::Debug, ops::Range, path::{Path, PathBuf}, @@ -26,8 +31,8 @@ use util::{ pub type EditorconfigProperties = ec4rs::Properties; use crate::{ - ParameterizedJsonSchema, SettingsJsonSchemaParams, VsCodeSettings, WorktreeId, - parse_json_with_comments, update_value_in_json_text, + ActiveSettingsProfileName, ParameterizedJsonSchema, SettingsJsonSchemaParams, VsCodeSettings, + WorktreeId, parse_json_with_comments, update_value_in_json_text, }; /// A value that can be defined as a user setting. @@ -122,6 +127,10 @@ pub struct SettingsSources<'a, T> { pub user: Option<&'a T>, /// The user settings for the current release channel. pub release_channel: Option<&'a T>, + /// The user settings for the current operating system. + pub operating_system: Option<&'a T>, + /// The settings associated with an enabled settings profile + pub profile: Option<&'a T>, /// The server's settings. pub server: Option<&'a T>, /// The project settings, ordered from least specific to most specific. @@ -141,6 +150,8 @@ impl<'a, T: Serialize> SettingsSources<'a, T> { .chain(self.extensions) .chain(self.user) .chain(self.release_channel) + .chain(self.operating_system) + .chain(self.profile) .chain(self.server) .chain(self.project.iter().copied()) } @@ -282,6 +293,14 @@ impl SettingsStore { } } + pub fn observe_active_settings_profile_name(cx: &mut App) -> gpui::Subscription { + cx.observe_global::<ActiveSettingsProfileName>(|cx| { + Self::update_global(cx, |store, cx| { + store.recompute_values(None, cx).log_err(); + }); + }) + } + pub fn update<C, R>(cx: &mut C, f: impl FnOnce(&mut Self, &mut C) -> R) -> R where C: BorrowAppContext, @@ -321,6 +340,22 @@ impl SettingsStore { .log_err(); } + let mut os_settings_value = None; + if let Some(os_settings) = &self.raw_user_settings.get(env::consts::OS) { + os_settings_value = setting_value.deserialize_setting(os_settings).log_err(); + } + + let mut profile_value = None; + if let Some(active_profile) = cx.try_global::<ActiveSettingsProfileName>() { + if let Some(profiles) = self.raw_user_settings.get("profiles") { + if let Some(profile_settings) = profiles.get(&active_profile.0) { + profile_value = setting_value + .deserialize_setting(profile_settings) + .log_err(); + } + } + } + let server_value = self .raw_server_settings .as_ref() @@ -340,6 +375,8 @@ impl SettingsStore { extensions: extension_value.as_ref(), user: user_value.as_ref(), release_channel: release_channel_value.as_ref(), + operating_system: os_settings_value.as_ref(), + profile: profile_value.as_ref(), server: server_value.as_ref(), project: &[], }, @@ -402,6 +439,16 @@ impl SettingsStore { &self.raw_user_settings } + /// Get the configured settings profile names. + pub fn configured_settings_profiles(&self) -> impl Iterator<Item = &str> { + self.raw_user_settings + .get("profiles") + .and_then(|v| v.as_object()) + .into_iter() + .flat_map(|obj| obj.keys()) + .map(|s| s.as_str()) + } + /// Access the raw JSON value of the global settings. pub fn raw_global_settings(&self) -> Option<&Value> { self.raw_global_settings.as_ref() @@ -498,41 +545,64 @@ impl SettingsStore { .ok(); } - pub fn import_vscode_settings(&self, fs: Arc<dyn Fs>, vscode_settings: VsCodeSettings) { + pub fn import_vscode_settings( + &self, + fs: Arc<dyn Fs>, + vscode_settings: VsCodeSettings, + ) -> oneshot::Receiver<Result<()>> { + let (tx, rx) = oneshot::channel::<Result<()>>(); self.setting_file_updates_tx .unbounded_send(Box::new(move |cx: AsyncApp| { async move { - let old_text = Self::load_settings(&fs).await?; - let new_text = cx.read_global(|store: &SettingsStore, _cx| { - store.get_vscode_edits(old_text, &vscode_settings) - })?; - let settings_path = paths::settings_file().as_path(); - if fs.is_file(settings_path).await { - let resolved_path = - fs.canonicalize(settings_path).await.with_context(|| { - format!("Failed to canonicalize settings path {:?}", settings_path) - })?; + let res = async move { + let old_text = Self::load_settings(&fs).await?; + let new_text = cx.read_global(|store: &SettingsStore, _cx| { + store.get_vscode_edits(old_text, &vscode_settings) + })?; + let settings_path = paths::settings_file().as_path(); + if fs.is_file(settings_path).await { + let resolved_path = + fs.canonicalize(settings_path).await.with_context(|| { + format!( + "Failed to canonicalize settings path {:?}", + settings_path + ) + })?; + + fs.atomic_write(resolved_path.clone(), new_text) + .await + .with_context(|| { + format!("Failed to write settings to file {:?}", resolved_path) + })?; + } else { + fs.atomic_write(settings_path.to_path_buf(), new_text) + .await + .with_context(|| { + format!("Failed to write settings to file {:?}", settings_path) + })?; + } - fs.atomic_write(resolved_path.clone(), new_text) - .await - .with_context(|| { - format!("Failed to write settings to file {:?}", resolved_path) - })?; - } else { - fs.atomic_write(settings_path.to_path_buf(), new_text) - .await - .with_context(|| { - format!("Failed to write settings to file {:?}", settings_path) - })?; + anyhow::Ok(()) } + .await; - anyhow::Ok(()) + let new_res = match &res { + Ok(_) => anyhow::Ok(()), + Err(e) => Err(anyhow::anyhow!("Failed to write settings to file {:?}", e)), + }; + + _ = tx.send(new_res); + res } .boxed_local() })) .ok(); + + rx } +} +impl SettingsStore { /// Updates the value of a setting in a JSON file, returning the new text /// for that JSON file. pub fn new_text_for_update<T: Settings>( @@ -1001,18 +1071,18 @@ impl SettingsStore { const ZED_SETTINGS: &str = "ZedSettings"; let zed_settings_ref = add_new_subschema(&mut generator, ZED_SETTINGS, combined_schema); - // add `ZedReleaseStageSettings` which is the same as `ZedSettings` except that unknown - // fields are rejected. - let mut zed_release_stage_settings = zed_settings_ref.clone(); - zed_release_stage_settings.insert("unevaluatedProperties".to_string(), false.into()); - let zed_release_stage_settings_ref = add_new_subschema( + // add `ZedSettingsOverride` which is the same as `ZedSettings` except that unknown + // fields are rejected. This is used for release stage settings and profiles. + let mut zed_settings_override = zed_settings_ref.clone(); + zed_settings_override.insert("unevaluatedProperties".to_string(), false.into()); + let zed_settings_override_ref = add_new_subschema( &mut generator, - "ZedReleaseStageSettings", - zed_release_stage_settings.to_value(), + "ZedSettingsOverride", + zed_settings_override.to_value(), ); // Remove `"additionalProperties": false` added by `DefaultDenyUnknownFields` so that - // unknown fields can be handled by the root schema and `ZedReleaseStageSettings`. + // unknown fields can be handled by the root schema and `ZedSettingsOverride`. let mut definitions = generator.take_definitions(true); definitions .get_mut(ZED_SETTINGS) @@ -1032,15 +1102,23 @@ impl SettingsStore { "$schema": meta_schema, "title": "Zed Settings", "unevaluatedProperties": false, - // ZedSettings + settings overrides for each release stage + // ZedSettings + settings overrides for each release stage / OS / profiles "allOf": [ zed_settings_ref, { "properties": { - "dev": zed_release_stage_settings_ref, - "nightly": zed_release_stage_settings_ref, - "stable": zed_release_stage_settings_ref, - "preview": zed_release_stage_settings_ref, + "dev": zed_settings_override_ref, + "nightly": zed_settings_override_ref, + "stable": zed_settings_override_ref, + "preview": zed_settings_override_ref, + "linux": zed_settings_override_ref, + "macos": zed_settings_override_ref, + "windows": zed_settings_override_ref, + "profiles": { + "type": "object", + "description": "Configures any number of settings profiles.", + "additionalProperties": zed_settings_override_ref + } } } ], @@ -1099,6 +1177,23 @@ impl SettingsStore { } } + let mut os_settings = None; + if let Some(settings) = &self.raw_user_settings.get(env::consts::OS) { + if let Some(settings) = setting_value.deserialize_setting(settings).log_err() { + os_settings = Some(settings); + } + } + + let mut profile_settings = None; + if let Some(active_profile) = cx.try_global::<ActiveSettingsProfileName>() { + if let Some(profiles) = self.raw_user_settings.get("profiles") { + if let Some(profile_json) = profiles.get(&active_profile.0) { + profile_settings = + setting_value.deserialize_setting(profile_json).log_err(); + } + } + } + // If the global settings file changed, reload the global value for the field. if changed_local_path.is_none() { if let Some(value) = setting_value @@ -1109,6 +1204,8 @@ impl SettingsStore { extensions: extension_settings.as_ref(), user: user_settings.as_ref(), release_channel: release_channel_settings.as_ref(), + operating_system: os_settings.as_ref(), + profile: profile_settings.as_ref(), server: server_settings.as_ref(), project: &[], }, @@ -1161,6 +1258,8 @@ impl SettingsStore { extensions: extension_settings.as_ref(), user: user_settings.as_ref(), release_channel: release_channel_settings.as_ref(), + operating_system: os_settings.as_ref(), + profile: profile_settings.as_ref(), server: server_settings.as_ref(), project: &project_settings_stack.iter().collect::<Vec<_>>(), }, @@ -1286,6 +1385,12 @@ impl<T: Settings> AnySettingValue for SettingValue<T> { release_channel: values .release_channel .map(|value| value.0.downcast_ref::<T::FileContent>().unwrap()), + operating_system: values + .operating_system + .map(|value| value.0.downcast_ref::<T::FileContent>().unwrap()), + profile: values + .profile + .map(|value| value.0.downcast_ref::<T::FileContent>().unwrap()), server: values .server .map(|value| value.0.downcast_ref::<T::FileContent>().unwrap()), diff --git a/crates/settings_profile_selector/Cargo.toml b/crates/settings_profile_selector/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..189272e54be02ac46840838f6874be64d1e06321 --- /dev/null +++ b/crates/settings_profile_selector/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "settings_profile_selector" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/settings_profile_selector.rs" +doctest = false + +[dependencies] +fuzzy.workspace = true +gpui.workspace = true +picker.workspace = true +settings.workspace = true +ui.workspace = true +workspace-hack.workspace = true +workspace.workspace = true +zed_actions.workspace = true + +[dev-dependencies] +client = { workspace = true, features = ["test-support"] } +editor = { workspace = true, features = ["test-support"] } +gpui = { workspace = true, features = ["test-support"] } +language = { workspace = true, features = ["test-support"] } +menu.workspace = true +project = { workspace = true, features = ["test-support"] } +serde_json.workspace = true +settings = { workspace = true, features = ["test-support"] } +theme = { workspace = true, features = ["test-support"] } +workspace = { workspace = true, features = ["test-support"] } diff --git a/crates/settings_profile_selector/LICENSE-GPL b/crates/settings_profile_selector/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/settings_profile_selector/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/settings_profile_selector/src/settings_profile_selector.rs b/crates/settings_profile_selector/src/settings_profile_selector.rs new file mode 100644 index 0000000000000000000000000000000000000000..8a34c120512aa82d71c023a0c16308e6a4e1c271 --- /dev/null +++ b/crates/settings_profile_selector/src/settings_profile_selector.rs @@ -0,0 +1,581 @@ +use fuzzy::{StringMatch, StringMatchCandidate, match_strings}; +use gpui::{ + App, Context, DismissEvent, Entity, EventEmitter, Focusable, Render, Task, WeakEntity, Window, +}; +use picker::{Picker, PickerDelegate}; +use settings::{ActiveSettingsProfileName, SettingsStore}; +use ui::{HighlightedLabel, ListItem, ListItemSpacing, prelude::*}; +use workspace::{ModalView, Workspace}; + +pub fn init(cx: &mut App) { + cx.on_action(|_: &zed_actions::settings_profile_selector::Toggle, cx| { + workspace::with_active_or_new_workspace(cx, |workspace, window, cx| { + toggle_settings_profile_selector(workspace, window, cx); + }); + }); +} + +fn toggle_settings_profile_selector( + workspace: &mut Workspace, + window: &mut Window, + cx: &mut Context<Workspace>, +) { + workspace.toggle_modal(window, cx, |window, cx| { + let delegate = SettingsProfileSelectorDelegate::new(cx.entity().downgrade(), window, cx); + SettingsProfileSelector::new(delegate, window, cx) + }); +} + +pub struct SettingsProfileSelector { + picker: Entity<Picker<SettingsProfileSelectorDelegate>>, +} + +impl ModalView for SettingsProfileSelector {} + +impl EventEmitter<DismissEvent> for SettingsProfileSelector {} + +impl Focusable for SettingsProfileSelector { + fn focus_handle(&self, cx: &App) -> gpui::FocusHandle { + self.picker.focus_handle(cx) + } +} + +impl Render for SettingsProfileSelector { + fn render(&mut self, _: &mut Window, _: &mut Context<Self>) -> impl IntoElement { + v_flex().w(rems(22.)).child(self.picker.clone()) + } +} + +impl SettingsProfileSelector { + pub fn new( + delegate: SettingsProfileSelectorDelegate, + window: &mut Window, + cx: &mut Context<Self>, + ) -> Self { + let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx)); + Self { picker } + } +} + +pub struct SettingsProfileSelectorDelegate { + matches: Vec<StringMatch>, + profile_names: Vec<Option<String>>, + original_profile_name: Option<String>, + selected_profile_name: Option<String>, + selected_index: usize, + selection_completed: bool, + selector: WeakEntity<SettingsProfileSelector>, +} + +impl SettingsProfileSelectorDelegate { + fn new( + selector: WeakEntity<SettingsProfileSelector>, + _: &mut Window, + cx: &mut Context<SettingsProfileSelector>, + ) -> Self { + let settings_store = cx.global::<SettingsStore>(); + let mut profile_names: Vec<Option<String>> = settings_store + .configured_settings_profiles() + .map(|s| Some(s.to_string())) + .collect(); + profile_names.insert(0, None); + + let matches = profile_names + .iter() + .enumerate() + .map(|(ix, profile_name)| StringMatch { + candidate_id: ix, + score: 0.0, + positions: Default::default(), + string: display_name(profile_name), + }) + .collect(); + + let profile_name = cx + .try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()); + + let mut this = Self { + matches, + profile_names, + original_profile_name: profile_name.clone(), + selected_profile_name: None, + selected_index: 0, + selection_completed: false, + selector, + }; + + if let Some(profile_name) = profile_name { + this.select_if_matching(&profile_name); + } + + this + } + + fn select_if_matching(&mut self, profile_name: &str) { + self.selected_index = self + .matches + .iter() + .position(|mat| mat.string == profile_name) + .unwrap_or(self.selected_index); + } + + fn set_selected_profile( + &self, + cx: &mut Context<Picker<SettingsProfileSelectorDelegate>>, + ) -> Option<String> { + let mat = self.matches.get(self.selected_index)?; + let profile_name = self.profile_names.get(mat.candidate_id)?; + return Self::update_active_profile_name_global(profile_name.clone(), cx); + } + + fn update_active_profile_name_global( + profile_name: Option<String>, + cx: &mut Context<Picker<SettingsProfileSelectorDelegate>>, + ) -> Option<String> { + if let Some(profile_name) = profile_name { + cx.set_global(ActiveSettingsProfileName(profile_name.clone())); + return Some(profile_name.clone()); + } + + if cx.has_global::<ActiveSettingsProfileName>() { + cx.remove_global::<ActiveSettingsProfileName>(); + } + + None + } +} + +impl PickerDelegate for SettingsProfileSelectorDelegate { + type ListItem = ListItem; + + fn placeholder_text(&self, _: &mut Window, _: &mut App) -> std::sync::Arc<str> { + "Select a settings profile...".into() + } + + fn match_count(&self) -> usize { + self.matches.len() + } + + fn selected_index(&self) -> usize { + self.selected_index + } + + fn set_selected_index( + &mut self, + ix: usize, + _: &mut Window, + cx: &mut Context<Picker<SettingsProfileSelectorDelegate>>, + ) { + self.selected_index = ix; + self.selected_profile_name = self.set_selected_profile(cx); + } + + fn update_matches( + &mut self, + query: String, + window: &mut Window, + cx: &mut Context<Picker<SettingsProfileSelectorDelegate>>, + ) -> Task<()> { + let background = cx.background_executor().clone(); + let candidates = self + .profile_names + .iter() + .enumerate() + .map(|(id, profile_name)| StringMatchCandidate::new(id, &display_name(profile_name))) + .collect::<Vec<_>>(); + + cx.spawn_in(window, async move |this, cx| { + let matches = if query.is_empty() { + candidates + .into_iter() + .enumerate() + .map(|(index, candidate)| StringMatch { + candidate_id: index, + string: candidate.string, + positions: Vec::new(), + score: 0.0, + }) + .collect() + } else { + match_strings( + &candidates, + &query, + false, + true, + 100, + &Default::default(), + background, + ) + .await + }; + + this.update_in(cx, |this, _, cx| { + this.delegate.matches = matches; + this.delegate.selected_index = this + .delegate + .selected_index + .min(this.delegate.matches.len().saturating_sub(1)); + this.delegate.selected_profile_name = this.delegate.set_selected_profile(cx); + }) + .ok(); + }) + } + + fn confirm( + &mut self, + _: bool, + _: &mut Window, + cx: &mut Context<Picker<SettingsProfileSelectorDelegate>>, + ) { + self.selection_completed = true; + self.selector + .update(cx, |_, cx| { + cx.emit(DismissEvent); + }) + .ok(); + } + + fn dismissed( + &mut self, + _: &mut Window, + cx: &mut Context<Picker<SettingsProfileSelectorDelegate>>, + ) { + if !self.selection_completed { + SettingsProfileSelectorDelegate::update_active_profile_name_global( + self.original_profile_name.clone(), + cx, + ); + } + self.selector.update(cx, |_, cx| cx.emit(DismissEvent)).ok(); + } + + fn render_match( + &self, + ix: usize, + selected: bool, + _: &mut Window, + _: &mut Context<Picker<Self>>, + ) -> Option<Self::ListItem> { + let mat = &self.matches[ix]; + let profile_name = &self.profile_names[mat.candidate_id]; + + Some( + ListItem::new(ix) + .inset(true) + .spacing(ListItemSpacing::Sparse) + .toggle_state(selected) + .child(HighlightedLabel::new( + display_name(profile_name), + mat.positions.clone(), + )), + ) + } +} + +fn display_name(profile_name: &Option<String>) -> String { + profile_name.clone().unwrap_or("Disabled".into()) +} + +#[cfg(test)] +mod tests { + use super::*; + use client; + use editor; + use gpui::{TestAppContext, UpdateGlobal, VisualTestContext}; + use language; + use menu::{Cancel, Confirm, SelectNext, SelectPrevious}; + use project::{FakeFs, Project}; + use serde_json::json; + use settings::Settings; + use theme::{self, ThemeSettings}; + use workspace::{self, AppState}; + use zed_actions::settings_profile_selector; + + async fn init_test( + profiles_json: serde_json::Value, + cx: &mut TestAppContext, + ) -> (Entity<Workspace>, &mut VisualTestContext) { + cx.update(|cx| { + let state = AppState::test(cx); + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + settings::init(cx); + theme::init(theme::LoadThemes::JustBase, cx); + ThemeSettings::register(cx); + client::init_settings(cx); + language::init(cx); + super::init(cx); + editor::init(cx); + workspace::init_settings(cx); + Project::init_settings(cx); + state + }); + + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + let settings_json = json!({ + "buffer_font_size": 10.0, + "profiles": profiles_json, + }); + + store + .set_user_settings(&settings_json.to_string(), cx) + .unwrap(); + }); + }); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, ["/test".as_ref()], cx).await; + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + + cx.update(|_, cx| { + assert!(!cx.has_global::<ActiveSettingsProfileName>()); + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 10.0); + }); + + (workspace, cx) + } + + #[track_caller] + fn active_settings_profile_picker( + workspace: &Entity<Workspace>, + cx: &mut VisualTestContext, + ) -> Entity<Picker<SettingsProfileSelectorDelegate>> { + workspace.update(cx, |workspace, cx| { + workspace + .active_modal::<SettingsProfileSelector>(cx) + .expect("settings profile selector is not open") + .read(cx) + .picker + .clone() + }) + } + + #[gpui::test] + async fn test_settings_profile_selector_state(cx: &mut TestAppContext) { + let classroom_and_streaming_profile_name = "Classroom / Streaming".to_string(); + let demo_videos_profile_name = "Demo Videos".to_string(); + + let profiles_json = json!({ + classroom_and_streaming_profile_name.clone(): { + "buffer_font_size": 20.0, + }, + demo_videos_profile_name.clone(): { + "buffer_font_size": 15.0 + } + }); + let (workspace, cx) = init_test(profiles_json.clone(), cx).await; + + cx.dispatch_action(settings_profile_selector::Toggle); + let picker = active_settings_profile_picker(&workspace, cx); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.matches.len(), 3); + assert_eq!(picker.delegate.matches[0].string, display_name(&None)); + assert_eq!( + picker.delegate.matches[1].string, + classroom_and_streaming_profile_name + ); + assert_eq!(picker.delegate.matches[2].string, demo_videos_profile_name); + assert_eq!(picker.delegate.matches.get(3), None); + + assert_eq!(picker.delegate.selected_index, 0); + assert_eq!(picker.delegate.selected_profile_name, None); + + assert_eq!(cx.try_global::<ActiveSettingsProfileName>(), None); + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 10.0); + }); + + cx.dispatch_action(Confirm); + + cx.update(|_, cx| { + assert_eq!(cx.try_global::<ActiveSettingsProfileName>(), None); + }); + + cx.dispatch_action(settings_profile_selector::Toggle); + let picker = active_settings_profile_picker(&workspace, cx); + cx.dispatch_action(SelectNext); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 1); + assert_eq!( + picker.delegate.selected_profile_name, + Some(classroom_and_streaming_profile_name.clone()) + ); + + assert_eq!( + cx.try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()), + Some(classroom_and_streaming_profile_name.clone()) + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 20.0); + }); + + cx.dispatch_action(Cancel); + + cx.update(|_, cx| { + assert_eq!(cx.try_global::<ActiveSettingsProfileName>(), None); + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 10.0); + }); + + cx.dispatch_action(settings_profile_selector::Toggle); + let picker = active_settings_profile_picker(&workspace, cx); + + cx.dispatch_action(SelectNext); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 1); + assert_eq!( + picker.delegate.selected_profile_name, + Some(classroom_and_streaming_profile_name.clone()) + ); + + assert_eq!( + cx.try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()), + Some(classroom_and_streaming_profile_name.clone()) + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 20.0); + }); + + cx.dispatch_action(SelectNext); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 2); + assert_eq!( + picker.delegate.selected_profile_name, + Some(demo_videos_profile_name.clone()) + ); + + assert_eq!( + cx.try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()), + Some(demo_videos_profile_name.clone()) + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 15.0); + }); + + cx.dispatch_action(Confirm); + + cx.update(|_, cx| { + assert_eq!( + cx.try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()), + Some(demo_videos_profile_name.clone()) + ); + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 15.0); + }); + + cx.dispatch_action(settings_profile_selector::Toggle); + let picker = active_settings_profile_picker(&workspace, cx); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 2); + assert_eq!( + picker.delegate.selected_profile_name, + Some(demo_videos_profile_name.clone()) + ); + + assert_eq!( + cx.try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()), + Some(demo_videos_profile_name.clone()) + ); + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 15.0); + }); + + cx.dispatch_action(SelectPrevious); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 1); + assert_eq!( + picker.delegate.selected_profile_name, + Some(classroom_and_streaming_profile_name.clone()) + ); + + assert_eq!( + cx.try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()), + Some(classroom_and_streaming_profile_name.clone()) + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 20.0); + }); + + cx.dispatch_action(Cancel); + + cx.update(|_, cx| { + assert_eq!( + cx.try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()), + Some(demo_videos_profile_name.clone()) + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 15.0); + }); + + cx.dispatch_action(settings_profile_selector::Toggle); + let picker = active_settings_profile_picker(&workspace, cx); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 2); + assert_eq!( + picker.delegate.selected_profile_name, + Some(demo_videos_profile_name.clone()) + ); + + assert_eq!( + cx.try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()), + Some(demo_videos_profile_name) + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 15.0); + }); + + cx.dispatch_action(SelectPrevious); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 1); + assert_eq!( + picker.delegate.selected_profile_name, + Some(classroom_and_streaming_profile_name.clone()) + ); + + assert_eq!( + cx.try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()), + Some(classroom_and_streaming_profile_name) + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 20.0); + }); + + cx.dispatch_action(SelectPrevious); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 0); + assert_eq!(picker.delegate.selected_profile_name, None); + + assert_eq!( + cx.try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()), + None + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 10.0); + }); + + cx.dispatch_action(Confirm); + + cx.update(|_, cx| { + assert_eq!(cx.try_global::<ActiveSettingsProfileName>(), None); + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 10.0); + }); + } +} diff --git a/crates/settings_ui/Cargo.toml b/crates/settings_ui/Cargo.toml index 7af240bd7419610ab7267439bee993ddfb194c5f..a4c47081c60fc6a749a753607c61c94b643a7e00 100644 --- a/crates/settings_ui/Cargo.toml +++ b/crates/settings_ui/Cargo.toml @@ -23,19 +23,31 @@ feature_flags.workspace = true fs.workspace = true fuzzy.workspace = true gpui.workspace = true +itertools.workspace = true language.workspace = true log.workspace = true menu.workspace = true +notifications.workspace = true paths.workspace = true project.workspace = true -schemars.workspace = true search.workspace = true serde.workspace = true +serde_json.workspace = true settings.workspace = true +telemetry.workspace = true +tempfile.workspace = true theme.workspace = true tree-sitter-json.workspace = true tree-sitter-rust.workspace = true ui.workspace = true +ui_input.workspace = true util.workspace = true workspace-hack.workspace = true workspace.workspace = true + +[dev-dependencies] +db = {"workspace"= true, "features" = ["test-support"]} +fs = { workspace = true, features = ["test-support"] } +gpui = { workspace = true, features = ["test-support"] } +project = { workspace = true, features = ["test-support"] } +workspace = { workspace = true, features = ["test-support"] } diff --git a/crates/settings_ui/src/keybindings.rs b/crates/settings_ui/src/keybindings.rs index 1f6c0ba8c777869ede42cb9336a727e558eb2dc0..599bb0b18f523f64647ed5dba6113daa0a04b04d 100644 --- a/crates/settings_ui/src/keybindings.rs +++ b/crates/settings_ui/src/keybindings.rs @@ -1,34 +1,43 @@ use std::{ - ops::{Not, Range}, + cmp::{self}, + ops::{Not as _, Range}, sync::Arc, + time::Duration, }; use anyhow::{Context as _, anyhow}; use collections::{HashMap, HashSet}; use editor::{CompletionProvider, Editor, EditorEvent}; -use feature_flags::FeatureFlagViewExt; use fs::Fs; use fuzzy::{StringMatch, StringMatchCandidate}; use gpui::{ - AppContext as _, AsyncApp, ClickEvent, Context, DismissEvent, Entity, EventEmitter, - FocusHandle, Focusable, Global, KeyContext, Keystroke, ModifiersChangedEvent, ScrollStrategy, - StyledText, Subscription, WeakEntity, actions, div, + Action, AppContext as _, AsyncApp, Axis, ClickEvent, Context, DismissEvent, Entity, + EventEmitter, FocusHandle, Focusable, Global, IsZero, KeyContext, Keystroke, MouseButton, + Point, ScrollStrategy, ScrollWheelEvent, Stateful, StyledText, Subscription, Task, + TextStyleRefinement, WeakEntity, actions, anchored, deferred, div, }; use language::{Language, LanguageConfig, ToOffset as _}; -use settings::{BaseKeymap, KeybindSource, KeymapFile, SettingsAssets}; - -use util::ResultExt; - +use notifications::status_toast::{StatusToast, ToastIcon}; +use project::Project; +use settings::{BaseKeymap, KeybindSource, KeymapFile, Settings as _, SettingsAssets}; use ui::{ - ActiveTheme as _, App, Banner, BorrowAppContext, ContextMenu, ParentElement as _, Render, - SharedString, Styled as _, Tooltip, Window, prelude::*, right_click_menu, + ActiveTheme as _, App, Banner, BorrowAppContext, ContextMenu, IconButtonShape, Indicator, + Modal, ModalFooter, ModalHeader, ParentElement as _, Render, Section, SharedString, + Styled as _, Tooltip, Window, prelude::*, +}; +use ui_input::SingleLineInput; +use util::ResultExt; +use workspace::{ + Item, ModalView, SerializableItem, Workspace, notifications::NotifyTaskExt as _, + register_serializable_item, }; -use workspace::{Item, ModalView, SerializableItem, Workspace, register_serializable_item}; use crate::{ - SettingsUiFeatureFlag, keybindings::persistence::KEYBINDING_EDITORS, - ui_components::table::{Table, TableInteractionState}, + ui_components::{ + keystroke_input::{ClearKeystrokes, KeystrokeInput, StartRecording, StopRecording}, + table::{ColumnWidths, ResizeBehavior, Table, TableInteractionState}, + }, }; const NO_ACTION_ARGUMENTS_TEXT: SharedString = SharedString::new_static("<no arguments>"); @@ -41,7 +50,6 @@ actions!( ] ); -const KEYMAP_EDITOR_NAMESPACE: &'static str = "keymap_editor"; actions!( keymap_editor, [ @@ -49,10 +57,20 @@ actions!( EditBinding, /// Creates a new key binding for the selected action. CreateBinding, + /// Deletes the selected key binding. + DeleteBinding, /// Copies the action name to clipboard. CopyAction, /// Copies the context predicate to clipboard. - CopyContext + CopyContext, + /// Toggles Conflict Filtering + ToggleConflictFilter, + /// Toggle Keystroke search + ToggleKeystrokeSearch, + /// Toggles exact matching for keystroke search + ToggleExactKeystrokeMatching, + /// Shows matching keystrokes for the currently selected binding + ShowMatchingKeybinds ] ); @@ -62,58 +80,32 @@ pub fn init(cx: &mut App) { cx.on_action(|_: &OpenKeymapEditor, cx| { workspace::with_active_or_new_workspace(cx, move |workspace, window, cx| { - let existing = workspace - .active_pane() - .read(cx) - .items() - .find_map(|item| item.downcast::<KeymapEditor>()); - - if let Some(existing) = existing { - workspace.activate_item(&existing, true, true, window, cx); - } else { - let keymap_editor = - cx.new(|cx| KeymapEditor::new(workspace.weak_handle(), window, cx)); - workspace.add_item_to_active_pane(Box::new(keymap_editor), None, true, window, cx); - } - }); + workspace + .with_local_workspace(window, cx, |workspace, window, cx| { + let existing = workspace + .active_pane() + .read(cx) + .items() + .find_map(|item| item.downcast::<KeymapEditor>()); + + if let Some(existing) = existing { + workspace.activate_item(&existing, true, true, window, cx); + } else { + let keymap_editor = + cx.new(|cx| KeymapEditor::new(workspace.weak_handle(), window, cx)); + workspace.add_item_to_active_pane( + Box::new(keymap_editor), + None, + true, + window, + cx, + ); + } + }) + .detach(); + }) }); - cx.observe_new(|_workspace: &mut Workspace, window, cx| { - let Some(window) = window else { return }; - - let keymap_ui_actions = [std::any::TypeId::of::<OpenKeymapEditor>()]; - - command_palette_hooks::CommandPaletteFilter::update_global(cx, |filter, _cx| { - filter.hide_action_types(&keymap_ui_actions); - filter.hide_namespace(KEYMAP_EDITOR_NAMESPACE); - }); - - cx.observe_flag::<SettingsUiFeatureFlag, _>( - window, - move |is_enabled, _workspace, _, cx| { - if is_enabled { - command_palette_hooks::CommandPaletteFilter::update_global( - cx, - |filter, _cx| { - filter.show_action_types(keymap_ui_actions.iter()); - filter.show_namespace(KEYMAP_EDITOR_NAMESPACE); - }, - ); - } else { - command_palette_hooks::CommandPaletteFilter::update_global( - cx, - |filter, _cx| { - filter.hide_action_types(&keymap_ui_actions); - filter.hide_namespace(KEYMAP_EDITOR_NAMESPACE); - }, - ); - } - }, - ) - .detach(); - }) - .detach(); - register_serializable_item::<KeymapEditor>(cx); } @@ -138,6 +130,31 @@ impl KeymapEventChannel { } #[derive(Default, PartialEq)] +enum SearchMode { + #[default] + Normal, + KeyStroke { + exact_match: bool, + }, +} + +impl SearchMode { + fn invert(&self) -> Self { + match self { + SearchMode::Normal => SearchMode::KeyStroke { exact_match: false }, + SearchMode::KeyStroke { .. } => SearchMode::Normal, + } + } + + fn exact_match(&self) -> bool { + match self { + SearchMode::Normal => false, + SearchMode::KeyStroke { exact_match } => *exact_match, + } + } +} + +#[derive(Default, PartialEq, Copy, Clone)] enum FilterState { #[default] All, @@ -153,59 +170,147 @@ impl FilterState { } } -type ActionMapping = (SharedString, Option<SharedString>); +#[derive(Debug, Default, PartialEq, Eq, Clone, Hash)] +struct ActionMapping { + keystrokes: Vec<Keystroke>, + context: Option<SharedString>, +} + +#[derive(Debug)] +struct KeybindConflict { + first_conflict_index: usize, + remaining_conflict_amount: usize, +} + +impl KeybindConflict { + fn from_iter<'a>(mut indices: impl Iterator<Item = &'a ConflictOrigin>) -> Option<Self> { + indices.next().map(|origin| Self { + first_conflict_index: origin.index, + remaining_conflict_amount: indices.count(), + }) + } +} + +#[derive(Clone, Copy, PartialEq)] +struct ConflictOrigin { + override_source: KeybindSource, + overridden_source: Option<KeybindSource>, + index: usize, +} + +impl ConflictOrigin { + fn new(source: KeybindSource, index: usize) -> Self { + Self { + override_source: source, + index, + overridden_source: None, + } + } + + fn with_overridden_source(self, source: KeybindSource) -> Self { + Self { + overridden_source: Some(source), + ..self + } + } + + fn get_conflict_with(&self, other: &Self) -> Option<Self> { + if self.override_source == KeybindSource::User + && other.override_source == KeybindSource::User + { + Some( + Self::new(KeybindSource::User, other.index) + .with_overridden_source(self.override_source), + ) + } else if self.override_source > other.override_source { + Some(other.with_overridden_source(self.override_source)) + } else { + None + } + } + + fn is_user_keybind_conflict(&self) -> bool { + self.override_source == KeybindSource::User + && self.overridden_source == Some(KeybindSource::User) + } +} #[derive(Default)] struct ConflictState { - conflicts: Vec<usize>, - action_keybind_mapping: HashMap<ActionMapping, Vec<usize>>, + conflicts: Vec<Option<ConflictOrigin>>, + keybind_mapping: HashMap<ActionMapping, Vec<ConflictOrigin>>, + has_user_conflicts: bool, } impl ConflictState { - fn new(key_bindings: &Vec<ProcessedKeybinding>) -> Self { - let mut action_keybind_mapping: HashMap<_, Vec<usize>> = HashMap::default(); + fn new(key_bindings: &[ProcessedBinding]) -> Self { + let mut action_keybind_mapping: HashMap<_, Vec<ConflictOrigin>> = HashMap::default(); - key_bindings + let mut largest_index = 0; + for (index, binding) in key_bindings .iter() .enumerate() - .filter(|(_, binding)| !binding.keystroke_text.is_empty()) - .for_each(|(index, binding)| { - action_keybind_mapping - .entry(binding.get_action_mapping()) - .or_default() - .push(index); - }); + .flat_map(|(index, binding)| Some(index).zip(binding.keybind_information())) + { + action_keybind_mapping + .entry(binding.get_action_mapping()) + .or_default() + .push(ConflictOrigin::new(binding.source, index)); + largest_index = index; + } + + let mut conflicts = vec![None; largest_index + 1]; + let mut has_user_conflicts = false; + + for indices in action_keybind_mapping.values_mut() { + indices.sort_unstable_by_key(|origin| origin.override_source); + let Some((fst, snd)) = indices.get(0).zip(indices.get(1)) else { + continue; + }; + + for origin in indices.iter() { + conflicts[origin.index] = + origin.get_conflict_with(if origin == fst { &snd } else { &fst }) + } + + has_user_conflicts |= fst.override_source == KeybindSource::User + && snd.override_source == KeybindSource::User; + } Self { - conflicts: action_keybind_mapping - .values() - .filter(|indices| indices.len() > 1) - .flatten() - .copied() - .collect(), - action_keybind_mapping, + conflicts, + keybind_mapping: action_keybind_mapping, + has_user_conflicts, } } fn conflicting_indices_for_mapping( &self, - action_mapping: ActionMapping, - keybind_idx: usize, - ) -> Option<Vec<usize>> { - self.action_keybind_mapping - .get(&action_mapping) + action_mapping: &ActionMapping, + keybind_idx: Option<usize>, + ) -> Option<KeybindConflict> { + self.keybind_mapping + .get(action_mapping) .and_then(|indices| { - let mut indices = indices.iter().filter(|&idx| *idx != keybind_idx).peekable(); - indices.peek().is_some().then(|| indices.copied().collect()) + KeybindConflict::from_iter( + indices + .iter() + .filter(|&conflict| Some(conflict.index) != keybind_idx), + ) }) } - fn has_conflict(&self, candidate_idx: &usize) -> bool { - self.conflicts.contains(candidate_idx) + fn conflict_for_idx(&self, idx: usize) -> Option<ConflictOrigin> { + self.conflicts.get(idx).copied().flatten() + } + + fn has_user_conflict(&self, candidate_idx: usize) -> bool { + self.conflict_for_idx(candidate_idx) + .is_some_and(|conflict| conflict.is_user_keybind_conflict()) } - fn any_conflicts(&self) -> bool { - !self.conflicts.is_empty() + fn any_user_binding_conflicts(&self) -> bool { + self.has_user_conflicts } } @@ -213,33 +318,83 @@ struct KeymapEditor { workspace: WeakEntity<Workspace>, focus_handle: FocusHandle, _keymap_subscription: Subscription, - keybindings: Vec<ProcessedKeybinding>, + keybindings: Vec<ProcessedBinding>, keybinding_conflict_state: ConflictState, filter_state: FilterState, + search_mode: SearchMode, + search_query_debounce: Option<Task<()>>, // corresponds 1 to 1 with keybindings string_match_candidates: Arc<Vec<StringMatchCandidate>>, matches: Vec<StringMatch>, table_interaction_state: Entity<TableInteractionState>, filter_editor: Entity<Editor>, + keystroke_editor: Entity<KeystrokeInput>, selected_index: Option<usize>, + context_menu: Option<(Entity<ContextMenu>, Point<Pixels>, Subscription)>, + previous_edit: Option<PreviousEdit>, + humanized_action_names: HumanizedActionNameCache, + current_widths: Entity<ColumnWidths<6>>, + show_hover_menus: bool, + /// In order for the JSON LSP to run in the actions arguments editor, we + /// require a backing file In order to avoid issues (primarily log spam) + /// with drop order between the buffer, file, worktree, etc, we create a + /// temporary directory for these backing files in the keymap editor struct + /// instead of here. This has the added benefit of only having to create a + /// worktree and directory once, although the perf improvement is negligible. + action_args_temp_dir_worktree: Option<Entity<project::Worktree>>, + action_args_temp_dir: Option<tempfile::TempDir>, +} + +enum PreviousEdit { + /// When deleting, we want to maintain the same scroll position + ScrollBarOffset(Point<Pixels>), + /// When editing or creating, because the new keybinding could be in a different position in the sort order + /// we store metadata about the new binding (either the modified version or newly created one) + /// and upon reload, we search for this binding in the list of keybindings, and if we find the one that matches + /// this metadata, we set the selected index to it and scroll to it, + /// and if we don't find it, we scroll to 0 and don't set a selected index + Keybinding { + action_mapping: ActionMapping, + action_name: &'static str, + /// The scrollbar position to fallback to if we don't find the keybinding during a refresh + /// this can happen if there's a filter applied to the search and the keybinding modification + /// filters the binding from the search results + fallback: Point<Pixels>, + }, } impl EventEmitter<()> for KeymapEditor {} impl Focusable for KeymapEditor { fn focus_handle(&self, cx: &App) -> gpui::FocusHandle { - return self.filter_editor.focus_handle(cx); + if self.selected_index.is_some() { + self.focus_handle.clone() + } else { + self.filter_editor.focus_handle(cx) + } } } +/// Helper function to check if two keystroke sequences match exactly +fn keystrokes_match_exactly(keystrokes1: &[Keystroke], keystrokes2: &[Keystroke]) -> bool { + keystrokes1.len() == keystrokes2.len() + && keystrokes1 + .iter() + .zip(keystrokes2) + .all(|(k1, k2)| k1.key == k2.key && k1.modifiers == k2.modifiers) +} impl KeymapEditor { fn new(workspace: WeakEntity<Workspace>, window: &mut Window, cx: &mut Context<Self>) -> Self { - let focus_handle = cx.focus_handle(); - let _keymap_subscription = - cx.observe_global::<KeymapEventChannel>(Self::update_keybindings); + cx.observe_global_in::<KeymapEventChannel>(window, Self::on_keymap_changed); let table_interaction_state = TableInteractionState::new(window, cx); + let keystroke_editor = cx.new(|cx| { + let mut keystroke_editor = KeystrokeInput::new(None, window, cx); + keystroke_editor.set_search(true); + keystroke_editor + }); + let filter_editor = cx.new(|cx| { let mut editor = Editor::single_line(window, cx); editor.set_placeholder_text("Filter action names…", cx); @@ -251,7 +406,34 @@ impl KeymapEditor { return; } - this.update_matches(cx); + this.on_query_changed(cx); + }) + .detach(); + + cx.subscribe(&keystroke_editor, |this, _, _, cx| { + if matches!(this.search_mode, SearchMode::Normal) { + return; + } + + this.on_query_changed(cx); + }) + .detach(); + + cx.spawn({ + let workspace = workspace.clone(); + async move |this, cx| { + let temp_dir = tempfile::tempdir_in(paths::temp_dir())?; + let worktree = workspace + .update(cx, |ws, cx| { + ws.project() + .update(cx, |p, cx| p.create_worktree(temp_dir.path(), false, cx)) + })? + .await?; + this.update(cx, |this, _| { + this.action_args_temp_dir = Some(temp_dir); + this.action_args_temp_dir_worktree = Some(worktree); + }) + } }) .detach(); @@ -260,44 +442,96 @@ impl KeymapEditor { keybindings: vec![], keybinding_conflict_state: ConflictState::default(), filter_state: FilterState::default(), + search_mode: SearchMode::default(), string_match_candidates: Arc::new(vec![]), matches: vec![], - focus_handle: focus_handle.clone(), + focus_handle: cx.focus_handle(), _keymap_subscription, table_interaction_state, filter_editor, + keystroke_editor, selected_index: None, + context_menu: None, + previous_edit: None, + search_query_debounce: None, + humanized_action_names: HumanizedActionNameCache::new(cx), + show_hover_menus: true, + action_args_temp_dir: None, + action_args_temp_dir_worktree: None, + current_widths: cx.new(|cx| ColumnWidths::new(cx)), }; - this.update_keybindings(cx); + this.on_keymap_changed(window, cx); this } - fn current_query(&self, cx: &mut Context<Self>) -> String { + fn current_action_query(&self, cx: &App) -> String { self.filter_editor.read(cx).text(cx) } - fn update_matches(&self, cx: &mut Context<Self>) { - let query = self.current_query(cx); + fn current_keystroke_query(&self, cx: &App) -> Vec<Keystroke> { + match self.search_mode { + SearchMode::KeyStroke { .. } => self + .keystroke_editor + .read(cx) + .keystrokes() + .iter() + .cloned() + .collect(), + SearchMode::Normal => Default::default(), + } + } - cx.spawn(async move |this, cx| Self::process_query(this, query, cx).await) - .detach(); + fn on_query_changed(&mut self, cx: &mut Context<Self>) { + let action_query = self.current_action_query(cx); + let keystroke_query = self.current_keystroke_query(cx); + let exact_match = self.search_mode.exact_match(); + + let timer = cx.background_executor().timer(Duration::from_secs(1)); + self.search_query_debounce = Some(cx.background_spawn({ + let action_query = action_query.clone(); + let keystroke_query = keystroke_query.clone(); + async move { + timer.await; + + let keystroke_query = keystroke_query + .into_iter() + .map(|keystroke| keystroke.unparse()) + .collect::<Vec<String>>() + .join(" "); + + telemetry::event!( + "Keystroke Search Completed", + action_query = action_query, + keystroke_query = keystroke_query, + keystroke_exact_match = exact_match + ) + } + })); + cx.spawn(async move |this, cx| { + Self::update_matches(this.clone(), action_query, keystroke_query, cx).await?; + this.update(cx, |this, cx| { + this.scroll_to_item(0, ScrollStrategy::Top, cx) + }) + }) + .detach(); } - async fn process_query( + async fn update_matches( this: WeakEntity<Self>, - query: String, + action_query: String, + keystroke_query: Vec<Keystroke>, cx: &mut AsyncApp, ) -> anyhow::Result<()> { - let query = command_palette::normalize_action_query(&query); + let action_query = command_palette::normalize_action_query(&action_query); let (string_match_candidates, keybind_count) = this.read_with(cx, |this, _| { (this.string_match_candidates.clone(), this.keybindings.len()) })?; let executor = cx.background_executor().clone(); let mut matches = fuzzy::match_strings( &string_match_candidates, - &query, + &action_query, true, true, keybind_count, @@ -310,41 +544,89 @@ impl KeymapEditor { FilterState::Conflicts => { matches.retain(|candidate| { this.keybinding_conflict_state - .has_conflict(&candidate.candidate_id) + .has_user_conflict(candidate.candidate_id) }); } FilterState::All => {} } - if query.is_empty() { - // apply default sort - // sorts by source precedence, and alphabetically by action name within each source - matches.sort_by_key(|match_item| { - let keybind = &this.keybindings[match_item.candidate_id]; - let source = keybind.source.as_ref().map(|s| s.0); - use KeybindSource::*; - let source_precedence = match source { - Some(User) => 0, - Some(Vim) => 1, - Some(Base) => 2, - Some(Default) => 3, - None => 4, - }; - return (source_precedence, keybind.action_name.as_ref()); + match this.search_mode { + SearchMode::KeyStroke { exact_match } => { + matches.retain(|item| { + this.keybindings[item.candidate_id] + .keystrokes() + .is_some_and(|keystrokes| { + if exact_match { + keystrokes_match_exactly(&keystroke_query, keystrokes) + } else if keystroke_query.len() > keystrokes.len() { + return false; + } else { + for keystroke_offset in 0..keystrokes.len() { + let mut found_count = 0; + let mut query_cursor = 0; + let mut keystroke_cursor = keystroke_offset; + while query_cursor < keystroke_query.len() + && keystroke_cursor < keystrokes.len() + { + let query = &keystroke_query[query_cursor]; + let keystroke = &keystrokes[keystroke_cursor]; + let matches = + query.modifiers.is_subset_of(&keystroke.modifiers) + && ((query.key.is_empty() + || query.key == keystroke.key) + && query + .key_char + .as_ref() + .map_or(true, |q_kc| { + q_kc == &keystroke.key + })); + if matches { + found_count += 1; + query_cursor += 1; + } + keystroke_cursor += 1; + } + + if found_count == keystroke_query.len() { + return true; + } + } + return false; + } + }) + }); + } + SearchMode::Normal => {} + } + + if action_query.is_empty() { + matches.sort_by(|item1, item2| { + let binding1 = &this.keybindings[item1.candidate_id]; + let binding2 = &this.keybindings[item2.candidate_id]; + + binding1.cmp(binding2) }); } this.selected_index.take(); - this.scroll_to_item(0, ScrollStrategy::Top, cx); this.matches = matches; + cx.notify(); }) } + fn get_conflict(&self, row_index: usize) -> Option<ConflictOrigin> { + self.matches.get(row_index).and_then(|candidate| { + self.keybinding_conflict_state + .conflict_for_idx(candidate.candidate_id) + }) + } + fn process_bindings( json_language: Arc<Language>, - rust_language: Arc<Language>, + zed_keybind_context_language: Arc<Language>, + humanized_action_names: &HumanizedActionNameCache, cx: &mut App, - ) -> (Vec<ProcessedKeybinding>, Vec<StringMatchCandidate>) { + ) -> (Vec<ProcessedBinding>, Vec<StringMatchCandidate>) { let key_bindings_ptr = cx.key_bindings(); let lock = key_bindings_ptr.borrow(); let key_bindings = lock.bindings(); @@ -352,91 +634,98 @@ impl KeymapEditor { HashSet::from_iter(cx.all_action_names().into_iter().copied()); let action_documentation = cx.action_documentation(); let mut generator = KeymapFile::action_schema_generator(); - let action_schema = HashMap::from_iter( + let actions_with_schemas = HashSet::from_iter( cx.action_schemas(&mut generator) .into_iter() - .filter_map(|(name, schema)| schema.map(|schema| (name, schema))), + .filter_map(|(name, schema)| schema.is_some().then_some(name)), ); let mut processed_bindings = Vec::new(); let mut string_match_candidates = Vec::new(); for key_binding in key_bindings { - let source = key_binding.meta().map(settings::KeybindSource::from_meta); + let source = key_binding + .meta() + .map(KeybindSource::from_meta) + .unwrap_or(KeybindSource::Unknown); let keystroke_text = ui::text_for_keystrokes(key_binding.keystrokes(), cx); - let ui_key_binding = Some( - ui::KeyBinding::new_from_gpui(key_binding.clone(), cx) - .vim_mode(source == Some(settings::KeybindSource::Vim)), - ); + let ui_key_binding = ui::KeyBinding::new_from_gpui(key_binding.clone(), cx) + .vim_mode(source == KeybindSource::Vim); let context = key_binding .predicate() .map(|predicate| { - KeybindContextString::Local(predicate.to_string().into(), rust_language.clone()) + KeybindContextString::Local( + predicate.to_string().into(), + zed_keybind_context_language.clone(), + ) }) .unwrap_or(KeybindContextString::Global); - let source = source.map(|source| (source, source.name().into())); - let action_name = key_binding.action().name(); unmapped_action_names.remove(&action_name); - let action_input = key_binding + + let action_arguments = key_binding .action_input() - .map(|input| SyntaxHighlightedText::new(input, json_language.clone())); - let action_docs = action_documentation.get(action_name).copied(); + .map(|arguments| SyntaxHighlightedText::new(arguments, json_language.clone())); + let action_information = ActionInformation::new( + action_name, + action_arguments, + &actions_with_schemas, + &action_documentation, + &humanized_action_names, + ); let index = processed_bindings.len(); - let string_match_candidate = StringMatchCandidate::new(index, &action_name); - processed_bindings.push(ProcessedKeybinding { - keystroke_text: keystroke_text.into(), + let string_match_candidate = + StringMatchCandidate::new(index, &action_information.humanized_name); + processed_bindings.push(ProcessedBinding::new_mapped( + keystroke_text, ui_key_binding, - action_name: action_name.into(), - action_input, - action_docs, - action_schema: action_schema.get(action_name).cloned(), - context: Some(context), + context, source, - }); + action_information, + )); string_match_candidates.push(string_match_candidate); } - let empty = SharedString::new_static(""); for action_name in unmapped_action_names.into_iter() { let index = processed_bindings.len(); - let string_match_candidate = StringMatchCandidate::new(index, &action_name); - processed_bindings.push(ProcessedKeybinding { - keystroke_text: empty.clone(), - ui_key_binding: None, - action_name: action_name.into(), - action_input: None, - action_docs: action_documentation.get(action_name).copied(), - action_schema: action_schema.get(action_name).cloned(), - context: None, - source: None, - }); + let action_information = ActionInformation::new( + action_name, + None, + &actions_with_schemas, + &action_documentation, + &humanized_action_names, + ); + let string_match_candidate = + StringMatchCandidate::new(index, &action_information.humanized_name); + + processed_bindings.push(ProcessedBinding::Unmapped(action_information)); string_match_candidates.push(string_match_candidate); } (processed_bindings, string_match_candidates) } - fn update_keybindings(&mut self, cx: &mut Context<KeymapEditor>) { + fn on_keymap_changed(&mut self, window: &mut Window, cx: &mut Context<KeymapEditor>) { let workspace = self.workspace.clone(); - cx.spawn(async move |this, cx| { + cx.spawn_in(window, async move |this, cx| { let json_language = load_json_language(workspace.clone(), cx).await; - let rust_language = load_rust_language(workspace.clone(), cx).await; + let zed_keybind_context_language = + load_keybind_context_language(workspace.clone(), cx).await; - let query = this.update(cx, |this, cx| { - let (key_bindings, string_match_candidates) = - Self::process_bindings(json_language, rust_language, cx); + let (action_query, keystroke_query) = this.update(cx, |this, cx| { + let (key_bindings, string_match_candidates) = Self::process_bindings( + json_language, + zed_keybind_context_language, + &this.humanized_action_names, + cx, + ); this.keybinding_conflict_state = ConflictState::new(&key_bindings); - if !this.keybinding_conflict_state.any_conflicts() { - this.filter_state = FilterState::All; - } - this.keybindings = key_bindings; this.string_match_candidates = Arc::new(string_match_candidates); this.matches = this @@ -450,15 +739,63 @@ impl KeymapEditor { string: candidate.string.clone(), }) .collect(); - this.current_query(cx) + ( + this.current_action_query(cx), + this.current_keystroke_query(cx), + ) })?; // calls cx.notify - Self::process_query(this, query, cx).await + Self::update_matches(this.clone(), action_query, keystroke_query, cx).await?; + this.update_in(cx, |this, window, cx| { + if let Some(previous_edit) = this.previous_edit.take() { + match previous_edit { + // should remove scroll from process_query + PreviousEdit::ScrollBarOffset(offset) => { + this.table_interaction_state.update(cx, |table, _| { + table.set_scrollbar_offset(Axis::Vertical, offset) + }) + // set selected index and scroll + } + PreviousEdit::Keybinding { + action_mapping, + action_name, + fallback, + } => { + let scroll_position = + this.matches.iter().enumerate().find_map(|(index, item)| { + let binding = &this.keybindings[item.candidate_id]; + if binding.get_action_mapping().is_some_and(|binding_mapping| { + binding_mapping == action_mapping + }) && binding.action().name == action_name + { + Some(index) + } else { + None + } + }); + + if let Some(scroll_position) = scroll_position { + this.select_index( + scroll_position, + Some(ScrollStrategy::Top), + window, + cx, + ); + } else { + this.table_interaction_state.update(cx, |table, _| { + table.set_scrollbar_offset(Axis::Vertical, fallback) + }); + } + cx.notify(); + } + } + } + }) }) .detach_and_log_err(cx); } - fn dispatch_context(&self, _window: &Window, _cx: &Context<Self>) -> KeyContext { + fn key_context(&self) -> KeyContext { let mut dispatch_context = KeyContext::new_with_defaults(); dispatch_context.add("KeymapEditor"); dispatch_context.add("menu"); @@ -493,26 +830,220 @@ impl KeymapEditor { self.selected_index.take(); } - fn selected_keybind_idx(&self) -> Option<usize> { + fn selected_keybind_index(&self) -> Option<usize> { self.selected_index .and_then(|match_index| self.matches.get(match_index)) .map(|r#match| r#match.candidate_id) } - fn selected_binding(&self) -> Option<&ProcessedKeybinding> { - self.selected_keybind_idx() + fn selected_keybind_and_index(&self) -> Option<(&ProcessedBinding, usize)> { + self.selected_keybind_index() + .map(|keybind_index| (&self.keybindings[keybind_index], keybind_index)) + } + + fn selected_binding(&self) -> Option<&ProcessedBinding> { + self.selected_keybind_index() .and_then(|keybind_index| self.keybindings.get(keybind_index)) } + fn select_index( + &mut self, + index: usize, + scroll: Option<ScrollStrategy>, + window: &mut Window, + cx: &mut Context<Self>, + ) { + if self.selected_index != Some(index) { + self.selected_index = Some(index); + if let Some(scroll_strategy) = scroll { + self.scroll_to_item(index, scroll_strategy, cx); + } + window.focus(&self.focus_handle); + cx.notify(); + } + } + + fn create_context_menu( + &mut self, + position: Point<Pixels>, + window: &mut Window, + cx: &mut Context<Self>, + ) { + self.context_menu = self.selected_binding().map(|selected_binding| { + let selected_binding_has_no_context = selected_binding + .context() + .and_then(KeybindContextString::local) + .is_none(); + + let selected_binding_is_unbound = selected_binding.is_unbound(); + + let context_menu = ContextMenu::build(window, cx, |menu, _window, _cx| { + menu.context(self.focus_handle.clone()) + .action_disabled_when( + selected_binding_is_unbound, + "Edit", + Box::new(EditBinding), + ) + .action("Create", Box::new(CreateBinding)) + .action_disabled_when( + selected_binding_is_unbound, + "Delete", + Box::new(DeleteBinding), + ) + .separator() + .action("Copy Action", Box::new(CopyAction)) + .action_disabled_when( + selected_binding_has_no_context, + "Copy Context", + Box::new(CopyContext), + ) + .separator() + .action_disabled_when( + selected_binding_has_no_context, + "Show Matching Keybindings", + Box::new(ShowMatchingKeybinds), + ) + }); + + let context_menu_handle = context_menu.focus_handle(cx); + window.defer(cx, move |window, _cx| window.focus(&context_menu_handle)); + let subscription = cx.subscribe_in( + &context_menu, + window, + |this, _, _: &DismissEvent, window, cx| { + this.dismiss_context_menu(window, cx); + }, + ); + (context_menu, position, subscription) + }); + + cx.notify(); + } + + fn dismiss_context_menu(&mut self, window: &mut Window, cx: &mut Context<Self>) { + self.context_menu.take(); + window.focus(&self.focus_handle); + cx.notify(); + } + + fn context_menu_deployed(&self) -> bool { + self.context_menu.is_some() + } + + fn create_row_button( + &self, + index: usize, + conflict: Option<ConflictOrigin>, + cx: &mut Context<Self>, + ) -> IconButton { + if self.filter_state != FilterState::Conflicts + && let Some(conflict) = conflict + { + if conflict.is_user_keybind_conflict() { + base_button_style(index, IconName::Warning) + .icon_color(Color::Warning) + .tooltip(|window, cx| { + Tooltip::with_meta( + "View conflicts", + Some(&ToggleConflictFilter), + "Use alt+click to show all conflicts", + window, + cx, + ) + }) + .on_click(cx.listener(move |this, click: &ClickEvent, window, cx| { + if click.modifiers().alt { + this.set_filter_state(FilterState::Conflicts, cx); + } else { + this.select_index(index, None, window, cx); + this.open_edit_keybinding_modal(false, window, cx); + cx.stop_propagation(); + } + })) + } else if self.search_mode.exact_match() { + base_button_style(index, IconName::Info) + .tooltip(|window, cx| { + Tooltip::with_meta( + "Edit this binding", + Some(&ShowMatchingKeybinds), + "This binding is overridden by other bindings.", + window, + cx, + ) + }) + .on_click(cx.listener(move |this, _: &ClickEvent, window, cx| { + this.select_index(index, None, window, cx); + this.open_edit_keybinding_modal(false, window, cx); + cx.stop_propagation(); + })) + } else { + base_button_style(index, IconName::Info) + .tooltip(|window, cx| { + Tooltip::with_meta( + "Show matching keybinds", + Some(&ShowMatchingKeybinds), + "This binding is overridden by other bindings.\nUse alt+click to edit this binding", + window, + cx, + ) + }) + .on_click(cx.listener(move |this, click: &ClickEvent, window, cx| { + if click.modifiers().alt { + this.select_index(index, None, window, cx); + this.open_edit_keybinding_modal(false, window, cx); + cx.stop_propagation(); + } else { + this.show_matching_keystrokes(&Default::default(), window, cx); + } + })) + } + } else { + base_button_style(index, IconName::Pencil) + .visible_on_hover(if self.selected_index == Some(index) { + "".into() + } else if self.show_hover_menus { + row_group_id(index) + } else { + "never-show".into() + }) + .when( + self.show_hover_menus && !self.context_menu_deployed(), + |this| this.tooltip(Tooltip::for_action_title("Edit Keybinding", &EditBinding)), + ) + .on_click(cx.listener(move |this, _, window, cx| { + this.select_index(index, None, window, cx); + this.open_edit_keybinding_modal(false, window, cx); + cx.stop_propagation(); + })) + } + } + + fn render_no_matches_hint(&self, _window: &mut Window, _cx: &App) -> AnyElement { + let hint = match (self.filter_state, &self.search_mode) { + (FilterState::Conflicts, _) => { + if self.keybinding_conflict_state.any_user_binding_conflicts() { + "No conflicting keybinds found that match the provided query" + } else { + "No conflicting keybinds found" + } + } + (FilterState::All, SearchMode::KeyStroke { .. }) => { + "No keybinds found matching the entered keystrokes" + } + (FilterState::All, SearchMode::Normal) => "No matches found for the provided query", + }; + + Label::new(hint).color(Color::Muted).into_any_element() + } + fn select_next(&mut self, _: &menu::SelectNext, window: &mut Window, cx: &mut Context<Self>) { + self.show_hover_menus = false; if let Some(selected) = self.selected_index { let selected = selected + 1; if selected >= self.matches.len() { self.select_last(&Default::default(), window, cx); } else { - self.selected_index = Some(selected); - self.scroll_to_item(selected, ScrollStrategy::Center, cx); - cx.notify(); + self.select_index(selected, Some(ScrollStrategy::Center), window, cx); } } else { self.select_first(&Default::default(), window, cx); @@ -525,6 +1056,7 @@ impl KeymapEditor { window: &mut Window, cx: &mut Context<Self>, ) { + self.show_hover_menus = false; if let Some(selected) = self.selected_index { if selected == 0 { return; @@ -535,54 +1067,64 @@ impl KeymapEditor { if selected >= self.matches.len() { self.select_last(&Default::default(), window, cx); } else { - self.selected_index = Some(selected); - self.scroll_to_item(selected, ScrollStrategy::Center, cx); - cx.notify(); + self.select_index(selected, Some(ScrollStrategy::Center), window, cx); } } else { self.select_last(&Default::default(), window, cx); } } - fn select_first( - &mut self, - _: &menu::SelectFirst, - _window: &mut Window, - cx: &mut Context<Self>, - ) { + fn select_first(&mut self, _: &menu::SelectFirst, window: &mut Window, cx: &mut Context<Self>) { + self.show_hover_menus = false; if self.matches.get(0).is_some() { - self.selected_index = Some(0); - self.scroll_to_item(0, ScrollStrategy::Center, cx); - cx.notify(); + self.select_index(0, Some(ScrollStrategy::Center), window, cx); } } - fn select_last(&mut self, _: &menu::SelectLast, _window: &mut Window, cx: &mut Context<Self>) { + fn select_last(&mut self, _: &menu::SelectLast, window: &mut Window, cx: &mut Context<Self>) { + self.show_hover_menus = false; if self.matches.last().is_some() { let index = self.matches.len() - 1; - self.selected_index = Some(index); - self.scroll_to_item(index, ScrollStrategy::Center, cx); - cx.notify(); + self.select_index(index, Some(ScrollStrategy::Center), window, cx); } } - fn confirm(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) { - self.open_edit_keybinding_modal(false, window, cx); - } - fn open_edit_keybinding_modal( &mut self, create: bool, window: &mut Window, cx: &mut Context<Self>, ) { - let Some((keybind_idx, keybind)) = self - .selected_keybind_idx() - .zip(self.selected_binding().cloned()) - else { + self.show_hover_menus = false; + let Some((keybind, keybind_index)) = self.selected_keybind_and_index() else { return; }; + let keybind = keybind.clone(); let keymap_editor = cx.entity(); + + let keystroke = keybind.keystroke_text().cloned().unwrap_or_default(); + let arguments = keybind + .action() + .arguments + .as_ref() + .map(|arguments| arguments.text.clone()); + let context = keybind + .context() + .map(|context| context.local_str().unwrap_or("global")); + let action = keybind.action().name; + let source = keybind.keybind_source().map(|source| source.name()); + + telemetry::event!( + "Edit Keybinding Modal Opened", + keystroke = keystroke, + action = action, + source = source, + context = context, + arguments = arguments, + ); + + let temp_dir = self.action_args_temp_dir.as_ref().map(|dir| dir.path()); + self.workspace .update(cx, |workspace, cx| { let fs = workspace.app_state().fs.clone(); @@ -591,8 +1133,9 @@ impl KeymapEditor { let modal = KeybindingEditorModal::new( create, keybind, - keybind_idx, + keybind_index, keymap_editor, + temp_dir, workspace_weak, fs, window, @@ -613,6 +1156,27 @@ impl KeymapEditor { self.open_edit_keybinding_modal(true, window, cx); } + fn delete_binding(&mut self, _: &DeleteBinding, window: &mut Window, cx: &mut Context<Self>) { + let Some(to_remove) = self.selected_binding().cloned() else { + return; + }; + + let std::result::Result::Ok(fs) = self + .workspace + .read_with(cx, |workspace, _| workspace.app_state().fs.clone()) + else { + return; + }; + let tab_size = cx.global::<settings::SettingsStore>().json_tab_size(); + self.previous_edit = Some(PreviousEdit::ScrollBarOffset( + self.table_interaction_state + .read(cx) + .get_scrollbar_offset(Axis::Vertical), + )); + cx.spawn(async move |_, _| remove_keybinding(to_remove, &fs, tab_size).await) + .detach_and_notify_err(window, cx); + } + fn copy_context_to_clipboard( &mut self, _: &CopyContext, @@ -621,12 +1185,14 @@ impl KeymapEditor { ) { let context = self .selected_binding() - .and_then(|binding| binding.context.as_ref()) + .and_then(|binding| binding.context()) .and_then(KeybindContextString::local_str) .map(|context| context.to_string()); let Some(context) = context else { return; }; + + telemetry::event!("Keybinding Context Copied", context = context.clone()); cx.write_to_clipboard(gpui::ClipboardItem::new_string(context.clone())); } @@ -638,36 +1204,246 @@ impl KeymapEditor { ) { let action = self .selected_binding() - .map(|binding| binding.action_name.to_string()); + .map(|binding| binding.action().name.to_string()); let Some(action) = action else { return; }; + + telemetry::event!("Keybinding Action Copied", action = action.clone()); cx.write_to_clipboard(gpui::ClipboardItem::new_string(action.clone())); } -} - -#[derive(Clone)] -struct ProcessedKeybinding { - keystroke_text: SharedString, - ui_key_binding: Option<ui::KeyBinding>, - action_name: SharedString, - action_input: Option<SyntaxHighlightedText>, - action_docs: Option<&'static str>, - action_schema: Option<schemars::Schema>, - context: Option<KeybindContextString>, - source: Option<(KeybindSource, SharedString)>, -} -impl ProcessedKeybinding { - fn get_action_mapping(&self) -> ActionMapping { - ( - self.keystroke_text.clone(), - self.context - .as_ref() - .and_then(|context| context.local()) - .cloned(), + fn toggle_conflict_filter( + &mut self, + _: &ToggleConflictFilter, + _: &mut Window, + cx: &mut Context<Self>, + ) { + self.set_filter_state(self.filter_state.invert(), cx); + } + + fn set_filter_state(&mut self, filter_state: FilterState, cx: &mut Context<Self>) { + if self.filter_state != filter_state { + self.filter_state = filter_state; + self.on_query_changed(cx); + } + } + + fn toggle_keystroke_search( + &mut self, + _: &ToggleKeystrokeSearch, + window: &mut Window, + cx: &mut Context<Self>, + ) { + self.search_mode = self.search_mode.invert(); + self.on_query_changed(cx); + + match self.search_mode { + SearchMode::KeyStroke { .. } => { + self.keystroke_editor.update(cx, |editor, cx| { + editor.start_recording(&StartRecording, window, cx); + }); + } + SearchMode::Normal => { + self.keystroke_editor.update(cx, |editor, cx| { + editor.stop_recording(&StopRecording, window, cx); + editor.clear_keystrokes(&ClearKeystrokes, window, cx); + }); + window.focus(&self.filter_editor.focus_handle(cx)); + } + } + } + + fn toggle_exact_keystroke_matching( + &mut self, + _: &ToggleExactKeystrokeMatching, + _: &mut Window, + cx: &mut Context<Self>, + ) { + let SearchMode::KeyStroke { exact_match } = &mut self.search_mode else { + return; + }; + + *exact_match = !(*exact_match); + self.on_query_changed(cx); + } + + fn show_matching_keystrokes( + &mut self, + _: &ShowMatchingKeybinds, + _: &mut Window, + cx: &mut Context<Self>, + ) { + let Some(selected_binding) = self.selected_binding() else { + return; + }; + + let keystrokes = selected_binding + .keystrokes() + .map(Vec::from) + .unwrap_or_default(); + + self.filter_state = FilterState::All; + self.search_mode = SearchMode::KeyStroke { exact_match: true }; + + self.keystroke_editor.update(cx, |editor, cx| { + editor.set_keystrokes(keystrokes, cx); + }); + } +} + +struct HumanizedActionNameCache { + cache: HashMap<&'static str, SharedString>, +} + +impl HumanizedActionNameCache { + fn new(cx: &App) -> Self { + let cache = HashMap::from_iter(cx.all_action_names().into_iter().map(|&action_name| { + ( + action_name, + command_palette::humanize_action_name(action_name).into(), + ) + })); + Self { cache } + } + + fn get(&self, action_name: &'static str) -> SharedString { + match self.cache.get(action_name) { + Some(name) => name.clone(), + None => action_name.into(), + } + } +} + +#[derive(Clone)] +struct KeybindInformation { + keystroke_text: SharedString, + ui_binding: ui::KeyBinding, + context: KeybindContextString, + source: KeybindSource, +} + +impl KeybindInformation { + fn get_action_mapping(&self) -> ActionMapping { + ActionMapping { + keystrokes: self.ui_binding.keystrokes.clone(), + context: self.context.local().cloned(), + } + } +} + +#[derive(Clone)] +struct ActionInformation { + name: &'static str, + humanized_name: SharedString, + arguments: Option<SyntaxHighlightedText>, + documentation: Option<&'static str>, + has_schema: bool, +} + +impl ActionInformation { + fn new( + action_name: &'static str, + action_arguments: Option<SyntaxHighlightedText>, + actions_with_schemas: &HashSet<&'static str>, + action_documentation: &HashMap<&'static str, &'static str>, + action_name_cache: &HumanizedActionNameCache, + ) -> Self { + Self { + humanized_name: action_name_cache.get(action_name), + has_schema: actions_with_schemas.contains(action_name), + arguments: action_arguments, + documentation: action_documentation.get(action_name).copied(), + name: action_name, + } + } +} + +#[derive(Clone)] +enum ProcessedBinding { + Mapped(KeybindInformation, ActionInformation), + Unmapped(ActionInformation), +} + +impl ProcessedBinding { + fn new_mapped( + keystroke_text: impl Into<SharedString>, + ui_key_binding: ui::KeyBinding, + context: KeybindContextString, + source: KeybindSource, + action_information: ActionInformation, + ) -> Self { + Self::Mapped( + KeybindInformation { + keystroke_text: keystroke_text.into(), + ui_binding: ui_key_binding, + context, + source, + }, + action_information, ) } + + fn is_unbound(&self) -> bool { + matches!(self, Self::Unmapped(_)) + } + + fn get_action_mapping(&self) -> Option<ActionMapping> { + self.keybind_information() + .map(|keybind| keybind.get_action_mapping()) + } + + fn keystrokes(&self) -> Option<&[Keystroke]> { + self.ui_key_binding() + .map(|binding| binding.keystrokes.as_slice()) + } + + fn keybind_information(&self) -> Option<&KeybindInformation> { + match self { + Self::Mapped(keybind_information, _) => Some(keybind_information), + Self::Unmapped(_) => None, + } + } + + fn keybind_source(&self) -> Option<KeybindSource> { + self.keybind_information().map(|keybind| keybind.source) + } + + fn context(&self) -> Option<&KeybindContextString> { + self.keybind_information().map(|keybind| &keybind.context) + } + + fn ui_key_binding(&self) -> Option<&ui::KeyBinding> { + self.keybind_information() + .map(|keybind| &keybind.ui_binding) + } + + fn keystroke_text(&self) -> Option<&SharedString> { + self.keybind_information() + .map(|binding| &binding.keystroke_text) + } + + fn action(&self) -> &ActionInformation { + match self { + Self::Mapped(_, action) | Self::Unmapped(action) => action, + } + } + + fn cmp(&self, other: &Self) -> cmp::Ordering { + match (self, other) { + (Self::Mapped(keybind1, action1), Self::Mapped(keybind2, action2)) => { + match keybind1.source.cmp(&keybind2.source) { + cmp::Ordering::Equal => action1.humanized_name.cmp(&action2.humanized_name), + ordering => ordering, + } + } + (Self::Mapped(_, _), Self::Unmapped(_)) => cmp::Ordering::Less, + (Self::Unmapped(_), Self::Mapped(_, _)) => cmp::Ordering::Greater, + (Self::Unmapped(action1), Self::Unmapped(action2)) => { + action1.humanized_name.cmp(&action2.humanized_name) + } + } + } } #[derive(Clone, Debug, IntoElement, PartialEq, Eq, Hash)] @@ -724,102 +1500,278 @@ impl Item for KeymapEditor { } impl Render for KeymapEditor { - fn render(&mut self, window: &mut Window, cx: &mut ui::Context<Self>) -> impl ui::IntoElement { + fn render(&mut self, _window: &mut Window, cx: &mut ui::Context<Self>) -> impl ui::IntoElement { let row_count = self.matches.len(); let theme = cx.theme(); + let focus_handle = &self.focus_handle; v_flex() .id("keymap-editor") - .track_focus(&self.focus_handle) - .key_context(self.dispatch_context(window, cx)) + .track_focus(focus_handle) + .key_context(self.key_context()) .on_action(cx.listener(Self::select_next)) .on_action(cx.listener(Self::select_previous)) .on_action(cx.listener(Self::select_first)) .on_action(cx.listener(Self::select_last)) .on_action(cx.listener(Self::focus_search)) - .on_action(cx.listener(Self::confirm)) .on_action(cx.listener(Self::edit_binding)) .on_action(cx.listener(Self::create_binding)) + .on_action(cx.listener(Self::delete_binding)) .on_action(cx.listener(Self::copy_action_to_clipboard)) .on_action(cx.listener(Self::copy_context_to_clipboard)) + .on_action(cx.listener(Self::toggle_conflict_filter)) + .on_action(cx.listener(Self::toggle_keystroke_search)) + .on_action(cx.listener(Self::toggle_exact_keystroke_matching)) + .on_action(cx.listener(Self::show_matching_keystrokes)) + .on_mouse_move(cx.listener(|this, _, _window, _cx| { + this.show_hover_menus = true; + })) .size_full() .p_2() .gap_1() .bg(theme.colors().editor_background) .child( - h_flex() - .key_context({ - let mut context = KeyContext::new_with_defaults(); - context.add("BufferSearchBar"); - context - }) - .h_8() - .pl_2() - .pr_1() - .py_1() - .border_1() - .border_color(theme.colors().border) - .rounded_lg() - .child(self.filter_editor.clone()) - .when(self.keybinding_conflict_state.any_conflicts(), |this| { - this.child( - IconButton::new("KeymapEditorConflictIcon", IconName::Warning) - .tooltip(Tooltip::text(match self.filter_state { - FilterState::All => "Show conflicts", - FilterState::Conflicts => "Hide conflicts", - })) - .selected_icon_color(Color::Error) - .toggle_state(matches!(self.filter_state, FilterState::Conflicts)) - .on_click(cx.listener(|this, _, _, cx| { - this.filter_state = this.filter_state.invert(); - this.update_matches(cx); - })), - ) - }), + v_flex() + .gap_2() + .child( + h_flex() + .gap_2() + .child( + div() + .key_context({ + let mut context = KeyContext::new_with_defaults(); + context.add("BufferSearchBar"); + context + }) + .size_full() + .h_8() + .pl_2() + .pr_1() + .py_1() + .border_1() + .border_color(theme.colors().border) + .rounded_lg() + .child(self.filter_editor.clone()), + ) + .child( + IconButton::new( + "KeymapEditorToggleFiltersIcon", + IconName::Keyboard, + ) + .shape(ui::IconButtonShape::Square) + .tooltip({ + let focus_handle = focus_handle.clone(); + + move |window, cx| { + Tooltip::for_action_in( + "Search by Keystroke", + &ToggleKeystrokeSearch, + &focus_handle.clone(), + window, + cx, + ) + } + }) + .toggle_state(matches!( + self.search_mode, + SearchMode::KeyStroke { .. } + )) + .on_click(|_, window, cx| { + window.dispatch_action(ToggleKeystrokeSearch.boxed_clone(), cx); + }), + ) + .child( + IconButton::new("KeymapEditorConflictIcon", IconName::Warning) + .shape(ui::IconButtonShape::Square) + .when( + self.keybinding_conflict_state.any_user_binding_conflicts(), + |this| { + this.indicator(Indicator::dot().color(Color::Warning)) + }, + ) + .tooltip({ + let filter_state = self.filter_state; + let focus_handle = focus_handle.clone(); + + move |window, cx| { + Tooltip::for_action_in( + match filter_state { + FilterState::All => "Show Conflicts", + FilterState::Conflicts => "Hide Conflicts", + }, + &ToggleConflictFilter, + &focus_handle.clone(), + window, + cx, + ) + } + }) + .selected_icon_color(Color::Warning) + .toggle_state(matches!( + self.filter_state, + FilterState::Conflicts + )) + .on_click(|_, window, cx| { + window.dispatch_action( + ToggleConflictFilter.boxed_clone(), + cx, + ); + }), + ), + ) + .when_some( + match self.search_mode { + SearchMode::Normal => None, + SearchMode::KeyStroke { exact_match } => Some(exact_match), + }, + |this, exact_match| { + this.child( + h_flex() + .map(|this| { + if self + .keybinding_conflict_state + .any_user_binding_conflicts() + { + this.pr(rems_from_px(54.)) + } else { + this.pr_7() + } + }) + .gap_2() + .child(self.keystroke_editor.clone()) + .child( + IconButton::new( + "keystrokes-exact-match", + IconName::CaseSensitive, + ) + .tooltip({ + let keystroke_focus_handle = + self.keystroke_editor.read(cx).focus_handle(cx); + + move |window, cx| { + Tooltip::for_action_in( + "Toggle Exact Match Mode", + &ToggleExactKeystrokeMatching, + &keystroke_focus_handle, + window, + cx, + ) + } + }) + .shape(IconButtonShape::Square) + .toggle_state(exact_match) + .on_click( + cx.listener(|_, _, window, cx| { + window.dispatch_action( + ToggleExactKeystrokeMatching.boxed_clone(), + cx, + ); + }), + ), + ), + ) + }, + ), ) .child( Table::new() .interactable(&self.table_interaction_state) .striped() - .column_widths([rems(16.), rems(16.), rems(16.), rems(32.), rems(8.)]) - .header(["Action", "Arguments", "Keystrokes", "Context", "Source"]) + .empty_table_callback({ + let this = cx.entity(); + move |window, cx| this.read(cx).render_no_matches_hint(window, cx) + }) + .column_widths([ + DefiniteLength::Absolute(AbsoluteLength::Pixels(px(36.))), + DefiniteLength::Fraction(0.25), + DefiniteLength::Fraction(0.20), + DefiniteLength::Fraction(0.14), + DefiniteLength::Fraction(0.45), + DefiniteLength::Fraction(0.08), + ]) + .resizable_columns( + [ + ResizeBehavior::None, + ResizeBehavior::Resizable, + ResizeBehavior::Resizable, + ResizeBehavior::Resizable, + ResizeBehavior::Resizable, + ResizeBehavior::Resizable, // this column doesn't matter + ], + &self.current_widths, + cx, + ) + .header(["", "Action", "Arguments", "Keystrokes", "Context", "Source"]) .uniform_list( "keymap-editor-table", row_count, cx.processor(move |this, range: Range<usize>, _window, cx| { + let context_menu_deployed = this.context_menu_deployed(); range .filter_map(|index| { let candidate_id = this.matches.get(index)?.candidate_id; let binding = &this.keybindings[candidate_id]; + let action_name = binding.action().name; + let conflict = this.get_conflict(index); + let is_overridden = conflict.is_some_and(|conflict| { + !conflict.is_user_keybind_conflict() + }); + + let icon = this.create_row_button(index, conflict, cx); let action = div() - .child(binding.action_name.clone()) .id(("keymap action", index)) - .tooltip({ - let action_name = binding.action_name.clone(); - let action_docs = binding.action_docs; - move |_, cx| { - let action_tooltip = Tooltip::new( - command_palette::humanize_action_name( - &action_name, - ), - ); - let action_tooltip = match action_docs { - Some(docs) => action_tooltip.meta(docs), - None => action_tooltip, - }; - cx.new(|_| action_tooltip).into() + .child({ + if action_name != gpui::NoAction.name() { + binding + .action() + .humanized_name + .clone() + .into_any_element() + } else { + const NULL: SharedString = + SharedString::new_static("<null>"); + muted_styled_text(NULL.clone(), cx) + .into_any_element() } }) + .when( + !context_menu_deployed + && this.show_hover_menus + && !is_overridden, + |this| { + this.tooltip({ + let action_name = binding.action().name; + let action_docs = + binding.action().documentation; + move |_, cx| { + let action_tooltip = + Tooltip::new(action_name); + let action_tooltip = match action_docs { + Some(docs) => action_tooltip.meta(docs), + None => action_tooltip, + }; + cx.new(|_| action_tooltip).into() + } + }) + }, + ) .into_any_element(); - let keystrokes = binding.ui_key_binding.clone().map_or( - binding.keystroke_text.clone().into_any_element(), + + let keystrokes = binding.ui_key_binding().cloned().map_or( + binding + .keystroke_text() + .cloned() + .unwrap_or_default() + .into_any_element(), IntoElement::into_any_element, ); - let action_input = match binding.action_input.clone() { - Some(input) => input.into_any_element(), + + let action_arguments = match binding.action().arguments.clone() + { + Some(arguments) => arguments.into_any_element(), None => { - if binding.action_schema.is_some() { + if binding.action().has_schema { muted_styled_text(NO_ACTION_ARGUMENTS_TEXT, cx) .into_any_element() } else { @@ -827,65 +1779,167 @@ impl Render for KeymapEditor { } } }; - let context = binding - .context - .clone() - .map_or(gpui::Empty.into_any_element(), |context| { - context.into_any_element() - }); + + let context = binding.context().cloned().map_or( + gpui::Empty.into_any_element(), + |context| { + let is_local = context.local().is_some(); + + div() + .id(("keymap context", index)) + .child(context.clone()) + .when( + is_local + && !context_menu_deployed + && !is_overridden + && this.show_hover_menus, + |this| { + this.tooltip(Tooltip::element({ + move |_, _| { + context.clone().into_any_element() + } + })) + }, + ) + .into_any_element() + }, + ); + let source = binding - .source - .clone() - .map(|(_source, name)| name) + .keybind_source() + .map(|source| source.name()) .unwrap_or_default() .into_any_element(); - Some([action, action_input, keystrokes, context, source]) + + Some([ + icon.into_any_element(), + action, + action_arguments, + keystrokes, + context, + source, + ]) }) .collect() }), ) - .map_row( - cx.processor(|this, (row_index, row): (usize, Div), _window, cx| { - let is_conflict = this - .matches - .get(row_index) - .map(|candidate| candidate.candidate_id) - .is_some_and(|id| this.keybinding_conflict_state.has_conflict(&id)); + .map_row(cx.processor( + |this, (row_index, row): (usize, Stateful<Div>), _window, cx| { + let conflict = this.get_conflict(row_index); let is_selected = this.selected_index == Some(row_index); - let row = row - .id(("keymap-table-row", row_index)) - .on_click(cx.listener( - move |this, event: &ClickEvent, window, cx| { - this.selected_index = Some(row_index); - if event.up.click_count == 2 { - this.open_edit_keybinding_modal(false, window, cx); - } - }, - )) + let row_id = row_group_id(row_index); + + div() + .id(("keymap-row-wrapper", row_index)) + .child( + row.id(row_id.clone()) + .on_any_mouse_down(cx.listener( + move |this, + mouse_down_event: &gpui::MouseDownEvent, + window, + cx| { + match mouse_down_event.button { + MouseButton::Right => { + this.select_index( + row_index, None, window, cx, + ); + this.create_context_menu( + mouse_down_event.position, + window, + cx, + ); + } + _ => {} + } + }, + )) + .on_click(cx.listener( + move |this, event: &ClickEvent, window, cx| { + this.select_index(row_index, None, window, cx); + if event.click_count() == 2 { + this.open_edit_keybinding_modal( + false, window, cx, + ); + } + }, + )) + .group(row_id) + .when( + conflict.is_some_and(|conflict| { + !conflict.is_user_keybind_conflict() + }), + |row| { + const OVERRIDDEN_OPACITY: f32 = 0.5; + row.opacity(OVERRIDDEN_OPACITY) + }, + ) + .when_some( + conflict.filter(|conflict| { + !this.context_menu_deployed() && + !conflict.is_user_keybind_conflict() + }), + |row, conflict| { + let overriding_binding = this.keybindings.get(conflict.index); + let context = overriding_binding.and_then(|binding| { + match conflict.override_source { + KeybindSource::User => Some("your keymap"), + KeybindSource::Vim => Some("the vim keymap"), + KeybindSource::Base => Some("your base keymap"), + _ => { + log::error!("Unexpected override from the {} keymap", conflict.override_source.name()); + None + } + }.map(|source| format!("This keybinding is overridden by the '{}' binding from {}.", binding.action().humanized_name, source)) + }).unwrap_or_else(|| "This binding is overridden.".to_string()); + + row.tooltip(Tooltip::text(context))}, + ), + ) .border_2() - .when(is_conflict, |row| { - row.bg(cx.theme().status().error_background) - }) + .when( + conflict.is_some_and(|conflict| { + conflict.is_user_keybind_conflict() + }), + |row| row.bg(cx.theme().status().error_background), + ) .when(is_selected, |row| { row.border_color(cx.theme().colors().panel_focused_border) - }); - - right_click_menu(("keymap-table-row-menu", row_index)) - .trigger(move |_, _, _| row) - .menu({ - let this = cx.weak_entity(); - move |window, cx| { - build_keybind_context_menu(&this, row_index, window, cx) - } }) .into_any_element() }), ), ) + .on_scroll_wheel(cx.listener(|this, event: &ScrollWheelEvent, _, cx| { + // This ensures that the menu is not dismissed in cases where scroll events + // with a delta of zero are emitted + if !event.delta.pixel_delta(px(1.)).y.is_zero() { + this.context_menu.take(); + cx.notify(); + } + })) + .children(self.context_menu.as_ref().map(|(menu, position, _)| { + deferred( + anchored() + .position(*position) + .anchor(gpui::Corner::TopLeft) + .child(menu.clone()), + ) + .with_priority(1) + })) } } +fn row_group_id(row_index: usize) -> SharedString { + SharedString::new(format!("keymap-table-row-{}", row_index)) +} + +fn base_button_style(row_index: usize, icon: IconName) -> IconButton { + IconButton::new(("keymap-icon", row_index), icon) + .shape(IconButtonShape::Square) + .size(ButtonSize::Compact) +} + #[derive(Debug, Clone, IntoElement)] struct SyntaxHighlightedText { text: SharedString, @@ -934,46 +1988,44 @@ impl RenderOnce for SyntaxHighlightedText { runs.push(text_style.to_run(text.len() - offset)); } - return StyledText::new(text).with_runs(runs); + StyledText::new(text).with_runs(runs) } } #[derive(PartialEq)] -enum InputError { - Warning(SharedString), - Error(SharedString), +struct InputError { + severity: ui::Severity, + content: SharedString, } impl InputError { fn warning(message: impl Into<SharedString>) -> Self { - Self::Warning(message.into()) - } - - fn error(message: impl Into<SharedString>) -> Self { - Self::Error(message.into()) - } - - fn content(&self) -> &SharedString { - match self { - InputError::Warning(content) | InputError::Error(content) => content, + Self { + severity: ui::Severity::Warning, + content: message.into(), } } - fn is_warning(&self) -> bool { - matches!(self, InputError::Warning(_)) + fn error(message: anyhow::Error) -> Self { + Self { + severity: ui::Severity::Error, + content: message.to_string().into(), + } } } struct KeybindingEditorModal { creating: bool, - editing_keybind: ProcessedKeybinding, + editing_keybind: ProcessedBinding, editing_keybind_idx: usize, keybind_editor: Entity<KeystrokeInput>, - context_editor: Entity<Editor>, - input_editor: Option<Entity<Editor>>, + context_editor: Entity<SingleLineInput>, + action_arguments_editor: Option<Entity<ActionArgumentsEditor>>, fs: Arc<dyn Fs>, error: Option<InputError>, keymap_editor: Entity<KeymapEditor>, + workspace: WeakEntity<Workspace>, + focus_state: KeybindingEditorModalFocusState, } impl ModalView for KeybindingEditorModal {} @@ -989,36 +2041,47 @@ impl Focusable for KeybindingEditorModal { impl KeybindingEditorModal { pub fn new( create: bool, - editing_keybind: ProcessedKeybinding, + editing_keybind: ProcessedBinding, editing_keybind_idx: usize, keymap_editor: Entity<KeymapEditor>, + action_args_temp_dir: Option<&std::path::Path>, workspace: WeakEntity<Workspace>, fs: Arc<dyn Fs>, window: &mut Window, cx: &mut App, ) -> Self { - let keybind_editor = cx.new(|cx| KeystrokeInput::new(window, cx)); + let keybind_editor = cx + .new(|cx| KeystrokeInput::new(editing_keybind.keystrokes().map(Vec::from), window, cx)); - let context_editor = cx.new(|cx| { - let mut editor = Editor::single_line(window, cx); + let context_editor: Entity<SingleLineInput> = cx.new(|cx| { + let input = SingleLineInput::new(window, cx, "Keybinding Context") + .label("Edit Context") + .label_size(LabelSize::Default); if let Some(context) = editing_keybind - .context - .as_ref() + .context() .and_then(KeybindContextString::local) { - editor.set_text(context.clone(), window, cx); - } else { - editor.set_placeholder_text("Keybinding context", cx); + input.editor().update(cx, |editor, cx| { + editor.set_text(context.clone(), window, cx); + }); } - cx.spawn(async |editor, cx| { + let editor_entity = input.editor().clone(); + let workspace = workspace.clone(); + cx.spawn(async move |_input_handle, cx| { let contexts = cx .background_spawn(async { collect_contexts_from_assets() }) .await; - editor - .update(cx, |editor, _cx| { + let language = load_keybind_context_language(workspace, cx).await; + editor_entity + .update(cx, |editor, cx| { + if let Some(buffer) = editor.buffer().read(cx).as_singleton() { + buffer.update(cx, |buffer, cx| { + buffer.set_language(Some(language), cx); + }); + } editor.set_completion_provider(Some(std::rc::Rc::new( KeyContextCompletionProvider { contexts }, ))); @@ -1027,35 +2090,35 @@ impl KeybindingEditorModal { }) .detach_and_log_err(cx); - editor + input }); - let input_editor = editing_keybind.action_schema.clone().map(|_schema| { + let action_arguments_editor = editing_keybind.action().has_schema.then(|| { + let arguments = editing_keybind + .action() + .arguments + .as_ref() + .map(|args| args.text.clone()); cx.new(|cx| { - let mut editor = Editor::auto_height_unbounded(1, window, cx); - if let Some(input) = editing_keybind.action_input.clone() { - editor.set_text(input.text, window, cx); - } else { - // TODO: default value from schema? - editor.set_placeholder_text("Action input", cx); - } - cx.spawn(async |editor, cx| { - let json_language = load_json_language(workspace, cx).await; - editor - .update(cx, |editor, cx| { - if let Some(buffer) = editor.buffer().read(cx).as_singleton() { - buffer.update(cx, |buffer, cx| { - buffer.set_language(Some(json_language), cx) - }); - } - }) - .context("Failed to load JSON language for editing keybinding action input") - }) - .detach_and_log_err(cx); - editor + ActionArgumentsEditor::new( + editing_keybind.action().name, + arguments, + action_args_temp_dir, + workspace.clone(), + window, + cx, + ) }) }); + let focus_state = KeybindingEditorModalFocusState::new( + keybind_editor.focus_handle(cx), + action_arguments_editor + .as_ref() + .map(|args_editor| args_editor.focus_handle(cx)), + context_editor.focus_handle(cx), + ); + Self { creating: create, editing_keybind, @@ -1063,18 +2126,18 @@ impl KeybindingEditorModal { fs, keybind_editor, context_editor, - input_editor, + action_arguments_editor, error: None, keymap_editor, + workspace, + focus_state, } } fn set_error(&mut self, error: InputError, cx: &mut Context<Self>) -> bool { - if self - .error - .as_ref() - .is_some_and(|old_error| old_error.is_warning() && *old_error == error) - { + if self.error.as_ref().is_some_and(|old_error| { + old_error.severity == ui::Severity::Warning && *old_error == error + }) { false } else { self.error = Some(error); @@ -1083,62 +2146,97 @@ impl KeybindingEditorModal { } } - fn save(&mut self, cx: &mut Context<Self>) { - let existing_keybind = self.editing_keybind.clone(); - let fs = self.fs.clone(); + fn validate_action_arguments(&self, cx: &App) -> anyhow::Result<Option<String>> { + let action_arguments = self + .action_arguments_editor + .as_ref() + .map(|editor| editor.read(cx).editor.read(cx).text(cx)); + + let value = action_arguments + .as_ref() + .map(|args| { + serde_json::from_str(args).context("Failed to parse action arguments as JSON") + }) + .transpose()?; + + cx.build_action(&self.editing_keybind.action().name, value) + .context("Failed to validate action arguments")?; + Ok(action_arguments) + } + + fn validate_keystrokes(&self, cx: &App) -> anyhow::Result<Vec<Keystroke>> { let new_keystrokes = self .keybind_editor .read_with(cx, |editor, _| editor.keystrokes().to_vec()); - if new_keystrokes.is_empty() { - self.set_error(InputError::error("Keystrokes cannot be empty"), cx); - return; - } - let tab_size = cx.global::<settings::SettingsStore>().json_tab_size(); + anyhow::ensure!(!new_keystrokes.is_empty(), "Keystrokes cannot be empty"); + Ok(new_keystrokes) + } + + fn validate_context(&self, cx: &App) -> anyhow::Result<Option<String>> { let new_context = self .context_editor - .read_with(cx, |editor, cx| editor.text(cx)); - let new_context = new_context.is_empty().not().then_some(new_context); - let new_context_err = new_context.as_deref().and_then(|context| { - gpui::KeyBindingContextPredicate::parse(context) - .context("Failed to parse key context") - .err() - }); - if let Some(err) = new_context_err { - // TODO: store and display as separate error - // TODO: also, should be validating on keystroke - self.set_error(InputError::error(err.to_string()), cx); - return; - } + .read_with(cx, |input, cx| input.editor().read(cx).text(cx)); + let Some(context) = new_context.is_empty().not().then_some(new_context) else { + return Ok(None); + }; + gpui::KeyBindingContextPredicate::parse(&context).context("Failed to parse key context")?; - let action_mapping: ActionMapping = ( - ui::text_for_keystrokes(&new_keystrokes, cx).into(), - new_context - .as_ref() - .map(Into::into) - .or_else(|| existing_keybind.get_action_mapping().1), - ); + Ok(Some(context)) + } + + fn save_or_display_error(&mut self, cx: &mut Context<Self>) { + self.save(cx).map_err(|err| self.set_error(err, cx)).ok(); + } - if let Some(conflicting_indices) = self + fn save(&mut self, cx: &mut Context<Self>) -> Result<(), InputError> { + let existing_keybind = self.editing_keybind.clone(); + let fs = self.fs.clone(); + let tab_size = cx.global::<settings::SettingsStore>().json_tab_size(); + + let new_keystrokes = self + .validate_keystrokes(cx) + .map_err(InputError::error)? + .into_iter() + .map(remove_key_char) + .collect::<Vec<_>>(); + + let new_context = self.validate_context(cx).map_err(InputError::error)?; + let new_action_args = self + .validate_action_arguments(cx) + .map_err(InputError::error)?; + + let action_mapping = ActionMapping { + keystrokes: new_keystrokes, + context: new_context.map(SharedString::from), + }; + + let conflicting_indices = self .keymap_editor .read(cx) .keybinding_conflict_state - .conflicting_indices_for_mapping(action_mapping, self.editing_keybind_idx) + .conflicting_indices_for_mapping( + &action_mapping, + self.creating.not().then_some(self.editing_keybind_idx), + ); + + conflicting_indices.map(|KeybindConflict { + first_conflict_index, + remaining_conflict_amount, + }| { - let first_conflicting_index = conflicting_indices[0]; let conflicting_action_name = self .keymap_editor .read(cx) .keybindings - .get(first_conflicting_index) - .map(|keybind| keybind.action_name.clone()); + .get(first_conflict_index) + .map(|keybind| keybind.action().name); let warning_message = match conflicting_action_name { Some(name) => { - let confliction_action_amount = conflicting_indices.len() - 1; - if confliction_action_amount > 0 { + if remaining_conflict_amount > 0 { format!( "Your keybind would conflict with the \"{}\" action and {} other bindings", - name, confliction_action_amount + name, remaining_conflict_amount ) } else { format!("Your keybind would conflict with the \"{}\" action", name) @@ -1147,135 +2245,572 @@ impl KeybindingEditorModal { None => { log::info!( "Could not find action in keybindings with index {}", - first_conflicting_index + first_conflict_index ); "Your keybind would conflict with other actions".to_string() } }; - if self.set_error(InputError::warning(warning_message), cx) { - return; + let warning = InputError::warning(warning_message); + if self.error.as_ref().is_some_and(|old_error| *old_error == warning) { + Ok(()) + } else { + Err(warning) } - } + }).unwrap_or(Ok(()))?; let create = self.creating; + let status_toast = StatusToast::new( + format!( + "Saved edits to the {} action.", + &self.editing_keybind.action().humanized_name + ), + cx, + move |this, _cx| { + this.icon(ToastIcon::new(IconName::Check).color(Color::Success)) + .dismiss_button(true) + // .action("Undo", f) todo: wire the undo functionality + }, + ); + + self.workspace + .update(cx, |workspace, cx| { + workspace.toggle_status_toast(status_toast, cx); + }) + .log_err(); + cx.spawn(async move |this, cx| { + let action_name = existing_keybind.action().name; + if let Err(err) = save_keybinding_update( create, existing_keybind, - &new_keystrokes, - new_context.as_deref(), + &action_mapping, + new_action_args.as_deref(), &fs, tab_size, ) .await { this.update(cx, |this, cx| { - this.set_error(InputError::error(err.to_string()), cx); + this.set_error(InputError::error(err), cx); }) .log_err(); } else { - this.update(cx, |_this, cx| { + this.update(cx, |this, cx| { + this.keymap_editor.update(cx, |keymap, cx| { + keymap.previous_edit = Some(PreviousEdit::Keybinding { + action_mapping, + action_name, + fallback: keymap + .table_interaction_state + .read(cx) + .get_scrollbar_offset(Axis::Vertical), + }) + }); cx.emit(DismissEvent); }) .ok(); } }) .detach(); + + Ok(()) + } + + fn key_context(&self) -> KeyContext { + let mut key_context = KeyContext::new_with_defaults(); + key_context.add("KeybindEditorModal"); + key_context + } + + fn focus_next(&mut self, _: &menu::SelectNext, window: &mut Window, cx: &mut Context<Self>) { + self.focus_state.focus_next(window, cx); + } + + fn focus_prev( + &mut self, + _: &menu::SelectPrevious, + window: &mut Window, + cx: &mut Context<Self>, + ) { + self.focus_state.focus_previous(window, cx); + } + + fn confirm(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context<Self>) { + self.save_or_display_error(cx); + } + + fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) { + cx.emit(DismissEvent); + } + + fn get_matching_bindings_count(&self, cx: &Context<Self>) -> usize { + let current_keystrokes = self.keybind_editor.read(cx).keystrokes().to_vec(); + + if current_keystrokes.is_empty() { + return 0; + } + + self.keymap_editor + .read(cx) + .keybindings + .iter() + .enumerate() + .filter(|(idx, binding)| { + // Don't count the binding we're currently editing + if !self.creating && *idx == self.editing_keybind_idx { + return false; + } + + binding + .keystrokes() + .map(|keystrokes| keystrokes_match_exactly(keystrokes, ¤t_keystrokes)) + .unwrap_or(false) + }) + .count() + } + + fn show_matching_bindings(&mut self, _window: &mut Window, cx: &mut Context<Self>) { + let keystrokes = self.keybind_editor.read(cx).keystrokes().to_vec(); + + // Dismiss the modal + cx.emit(DismissEvent); + + // Update the keymap editor to show matching keystrokes + self.keymap_editor.update(cx, |editor, cx| { + editor.filter_state = FilterState::All; + editor.search_mode = SearchMode::KeyStroke { exact_match: true }; + editor.keystroke_editor.update(cx, |keystroke_editor, cx| { + keystroke_editor.set_keystrokes(keystrokes, cx); + }); + }); + } +} + +fn remove_key_char(Keystroke { modifiers, key, .. }: Keystroke) -> Keystroke { + Keystroke { + modifiers, + key, + ..Default::default() } } impl Render for KeybindingEditorModal { fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { let theme = cx.theme().colors(); - let input_base = || { - div() - .w_full() - .py_2() - .px_3() - .min_h_8() - .rounded_md() - .bg(theme.editor_background) - .border_1() - .border_color(theme.border_variant) - }; + let matching_bindings_count = self.get_matching_bindings_count(cx); v_flex() .w(rems(34.)) .elevation_3(cx) + .key_context(self.key_context()) + .on_action(cx.listener(Self::focus_next)) + .on_action(cx.listener(Self::focus_prev)) + .on_action(cx.listener(Self::confirm)) + .on_action(cx.listener(Self::cancel)) .child( - v_flex() - .p_3() - .child(Label::new("Edit Keystroke")) - .child( - Label::new("Input the desired keystroke for the selected action.") - .color(Color::Muted) - .mb_2(), - ) - .child(self.keybind_editor.clone()), - ) - .when_some(self.input_editor.clone(), |this, editor| { - this.child( - v_flex() - .p_3() - .pt_0() - .child(Label::new("Edit Input")) - .child( - Label::new("Input the desired input to the binding.") - .color(Color::Muted) - .mb_2(), - ) - .child(input_base().child(editor)), - ) - }) - .child( - v_flex() - .p_3() - .pt_0() - .child(Label::new("Edit Context")) - .child( - Label::new("Input the desired context for the binding.") - .color(Color::Muted) - .mb_2(), + Modal::new("keybinding_editor_modal", None) + .header( + ModalHeader::new().child( + v_flex() + .w_full() + .pb_1p5() + .mb_1() + .gap_0p5() + .border_b_1() + .border_color(theme.border_variant) + .child(Label::new( + self.editing_keybind.action().humanized_name.clone(), + )) + .when_some( + self.editing_keybind.action().documentation, + |this, docs| { + this.child( + Label::new(docs) + .size(LabelSize::Small) + .color(Color::Muted), + ) + }, + ), + ), ) - .child(input_base().child(self.context_editor.clone())), - ) - .when_some(self.error.as_ref(), |this, error| { - this.child( - div().p_2().child( - Banner::new() - .map(|banner| match error { - InputError::Error(_) => banner.severity(ui::Severity::Error), - InputError::Warning(_) => banner.severity(ui::Severity::Warning), - }) - // For some reason, the div overflows its container to the - // right. The padding accounts for that. - .child(div().size_full().pr_2().child(Label::new(error.content()))), - ), - ) - }) - .child( - h_flex() - .p_2() - .w_full() - .gap_1() - .justify_end() - .border_t_1() - .border_color(theme.border_variant) - .child( - Button::new("cancel", "Cancel") - .on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))), + .section( + Section::new().child( + v_flex() + .gap_2p5() + .child( + v_flex() + .gap_1() + .child(Label::new("Edit Keystroke")) + .child(self.keybind_editor.clone()) + .child(h_flex().gap_px().when( + matching_bindings_count > 0, + |this| { + let label = format!( + "There {} {} {} with the same keystrokes.", + if matching_bindings_count == 1 { + "is" + } else { + "are" + }, + matching_bindings_count, + if matching_bindings_count == 1 { + "binding" + } else { + "bindings" + } + ); + + this.child( + Label::new(label) + .size(LabelSize::Small) + .color(Color::Muted), + ) + .child( + Button::new("show_matching", "View") + .label_size(LabelSize::Small) + .icon(IconName::ArrowUpRight) + .icon_color(Color::Muted) + .icon_size(IconSize::XSmall) + .on_click(cx.listener( + |this, _, window, cx| { + this.show_matching_bindings( + window, cx, + ); + }, + )), + ) + }, + )), + ) + .when_some(self.action_arguments_editor.clone(), |this, editor| { + this.child( + v_flex() + .gap_1() + .child(Label::new("Edit Arguments")) + .child(editor), + ) + }) + .child(self.context_editor.clone()) + .when_some(self.error.as_ref(), |this, error| { + this.child( + Banner::new() + .severity(error.severity) + .child(Label::new(error.content.clone())), + ) + }), + ), ) - .child( - Button::new("save-btn", "Save").on_click( - cx.listener(|this, _event, _window, cx| Self::save(this, cx)), + .footer( + ModalFooter::new().end_slot( + h_flex() + .gap_1() + .child( + Button::new("cancel", "Cancel") + .on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))), + ) + .child(Button::new("save-btn", "Save").on_click(cx.listener( + |this, _event, _window, cx| { + this.save_or_display_error(cx); + }, + ))), ), ), ) } } +struct KeybindingEditorModalFocusState { + handles: Vec<FocusHandle>, +} + +impl KeybindingEditorModalFocusState { + fn new( + keystrokes: FocusHandle, + action_input: Option<FocusHandle>, + context: FocusHandle, + ) -> Self { + Self { + handles: Vec::from_iter( + [Some(keystrokes), action_input, Some(context)] + .into_iter() + .flatten(), + ), + } + } + + fn focused_index(&self, window: &Window, cx: &App) -> Option<i32> { + self.handles + .iter() + .position(|handle| handle.contains_focused(window, cx)) + .map(|i| i as i32) + } + + fn focus_index(&self, mut index: i32, window: &mut Window) { + if index < 0 { + index = self.handles.len() as i32 - 1; + } + if index >= self.handles.len() as i32 { + index = 0; + } + window.focus(&self.handles[index as usize]); + } + + fn focus_next(&self, window: &mut Window, cx: &App) { + let index_to_focus = if let Some(index) = self.focused_index(window, cx) { + index + 1 + } else { + 0 + }; + self.focus_index(index_to_focus, window); + } + + fn focus_previous(&self, window: &mut Window, cx: &App) { + let index_to_focus = if let Some(index) = self.focused_index(window, cx) { + index - 1 + } else { + self.handles.len() as i32 - 1 + }; + self.focus_index(index_to_focus, window); + } +} + +struct ActionArgumentsEditor { + editor: Entity<Editor>, + focus_handle: FocusHandle, + is_loading: bool, + /// See documentation in `KeymapEditor` for why a temp dir is needed. + /// This field exists because the keymap editor temp dir creation may fail, + /// and rather than implement a complicated retry mechanism, we simply + /// fallback to trying to create a temporary directory in this editor on + /// demand. Of note is that the TempDir struct will remove the directory + /// when dropped. + backup_temp_dir: Option<tempfile::TempDir>, +} + +impl Focusable for ActionArgumentsEditor { + fn focus_handle(&self, _cx: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl ActionArgumentsEditor { + fn new( + action_name: &'static str, + arguments: Option<SharedString>, + temp_dir: Option<&std::path::Path>, + workspace: WeakEntity<Workspace>, + window: &mut Window, + cx: &mut Context<Self>, + ) -> Self { + let focus_handle = cx.focus_handle(); + cx.on_focus_in(&focus_handle, window, |this, window, cx| { + this.editor.focus_handle(cx).focus(window); + }) + .detach(); + let editor = cx.new(|cx| { + let mut editor = Editor::auto_height_unbounded(1, window, cx); + Self::set_editor_text(&mut editor, arguments.clone(), window, cx); + editor.set_read_only(true); + editor + }); + + let temp_dir = temp_dir.map(|path| path.to_owned()); + cx.spawn_in(window, async move |this, cx| { + let result = async { + let (project, fs) = workspace.read_with(cx, |workspace, _cx| { + ( + workspace.project().downgrade(), + workspace.app_state().fs.clone(), + ) + })?; + + let file_name = + project::lsp_store::json_language_server_ext::normalized_action_file_name( + action_name, + ); + + let (buffer, backup_temp_dir) = + Self::create_temp_buffer(temp_dir, file_name.clone(), project.clone(), fs, cx) + .await + .context(concat!( + "Failed to create temporary buffer for action arguments. ", + "Auto-complete will not work" + ))?; + + let editor = cx.new_window_entity(|window, cx| { + let multi_buffer = cx.new(|cx| editor::MultiBuffer::singleton(buffer, cx)); + let mut editor = Editor::new( + editor::EditorMode::Full { + scale_ui_elements_with_buffer_font_size: true, + show_active_line_background: false, + sized_by_content: true, + }, + multi_buffer, + project.upgrade(), + window, + cx, + ); + editor.set_searchable(false); + editor.disable_scrollbars_and_minimap(window, cx); + editor.set_show_edit_predictions(Some(false), window, cx); + editor.set_show_gutter(false, cx); + Self::set_editor_text(&mut editor, arguments, window, cx); + editor + })?; + + this.update_in(cx, |this, window, cx| { + if this.editor.focus_handle(cx).is_focused(window) { + editor.focus_handle(cx).focus(window); + } + this.editor = editor; + this.backup_temp_dir = backup_temp_dir; + this.is_loading = false; + })?; + + anyhow::Ok(()) + } + .await; + if result.is_err() { + let json_language = load_json_language(workspace.clone(), cx).await; + this.update(cx, |this, cx| { + this.editor.update(cx, |editor, cx| { + if let Some(buffer) = editor.buffer().read(cx).as_singleton() { + buffer.update(cx, |buffer, cx| { + buffer.set_language(Some(json_language.clone()), cx) + }); + } + }) + // .context("Failed to load JSON language for editing keybinding action arguments input") + }) + .ok(); + this.update(cx, |this, _cx| { + this.is_loading = false; + }) + .ok(); + } + return result; + }) + .detach_and_log_err(cx); + Self { + editor, + focus_handle, + is_loading: true, + backup_temp_dir: None, + } + } + + fn set_editor_text( + editor: &mut Editor, + arguments: Option<SharedString>, + window: &mut Window, + cx: &mut Context<Editor>, + ) { + if let Some(arguments) = arguments { + editor.set_text(arguments, window, cx); + } else { + // TODO: default value from schema? + editor.set_placeholder_text("Action Arguments", cx); + } + } + + async fn create_temp_buffer( + temp_dir: Option<std::path::PathBuf>, + file_name: String, + project: WeakEntity<Project>, + fs: Arc<dyn Fs>, + cx: &mut AsyncApp, + ) -> anyhow::Result<(Entity<language::Buffer>, Option<tempfile::TempDir>)> { + let (temp_file_path, temp_dir) = { + let file_name = file_name.clone(); + async move { + let temp_dir_backup = match temp_dir.as_ref() { + Some(_) => None, + None => { + let temp_dir = paths::temp_dir(); + let sub_temp_dir = tempfile::Builder::new() + .tempdir_in(temp_dir) + .context("Failed to create temporary directory")?; + Some(sub_temp_dir) + } + }; + let dir_path = temp_dir.as_deref().unwrap_or_else(|| { + temp_dir_backup + .as_ref() + .expect("created backup tempdir") + .path() + }); + let path = dir_path.join(file_name); + fs.create_file( + &path, + fs::CreateOptions { + ignore_if_exists: true, + overwrite: true, + }, + ) + .await + .context("Failed to create temporary file")?; + anyhow::Ok((path, temp_dir_backup)) + } + } + .await + .context("Failed to create backing file")?; + + project + .update(cx, |project, cx| { + project.open_local_buffer(temp_file_path, cx) + })? + .await + .context("Failed to create buffer") + .map(|buffer| (buffer, temp_dir)) + } +} + +impl Render for ActionArgumentsEditor { + fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { + let background_color; + let border_color; + let text_style = { + let colors = cx.theme().colors(); + let settings = theme::ThemeSettings::get_global(cx); + background_color = colors.editor_background; + border_color = if self.is_loading { + colors.border_disabled + } else { + colors.border_variant + }; + TextStyleRefinement { + font_size: Some(rems(0.875).into()), + font_weight: Some(settings.buffer_font.weight), + line_height: Some(relative(1.2)), + font_style: Some(gpui::FontStyle::Normal), + color: self.is_loading.then_some(colors.text_disabled), + ..Default::default() + } + }; + + self.editor + .update(cx, |editor, _| editor.set_text_style_refinement(text_style)); + + return v_flex().w_full().child( + h_flex() + .min_h_8() + .min_w_48() + .px_2() + .py_1p5() + .flex_grow() + .rounded_lg() + .bg(background_color) + .border_1() + .border_color(border_color) + .track_focus(&self.focus_handle) + .child(self.editor.clone()), + ); + } +} + struct KeyContextCompletionProvider { contexts: Vec<SharedString>, } @@ -1365,25 +2900,31 @@ async fn load_json_language(workspace: WeakEntity<Workspace>, cx: &mut AsyncApp) }); } -async fn load_rust_language(workspace: WeakEntity<Workspace>, cx: &mut AsyncApp) -> Arc<Language> { - let rust_language_task = workspace +async fn load_keybind_context_language( + workspace: WeakEntity<Workspace>, + cx: &mut AsyncApp, +) -> Arc<Language> { + let language_task = workspace .read_with(cx, |workspace, cx| { workspace .project() .read(cx) .languages() - .language_for_name("Rust") + .language_for_name("Zed Keybind Context") }) - .context("Failed to load Rust language") + .context("Failed to load Zed Keybind Context language") .log_err(); - let rust_language = match rust_language_task { - Some(task) => task.await.context("Failed to load Rust language").log_err(), + let language = match language_task { + Some(task) => task + .await + .context("Failed to load Zed Keybind Context language") + .log_err(), None => None, }; - return rust_language.unwrap_or_else(|| { + return language.unwrap_or_else(|| { Arc::new(Language::new( LanguageConfig { - name: "Rust".into(), + name: "Zed Keybind Context".into(), ..Default::default() }, Some(tree_sitter_rust::LANGUAGE.into()), @@ -1393,9 +2934,9 @@ async fn load_rust_language(workspace: WeakEntity<Workspace>, cx: &mut AsyncApp) async fn save_keybinding_update( create: bool, - existing: ProcessedKeybinding, - new_keystrokes: &[Keystroke], - new_context: Option<&str>, + existing: ProcessedBinding, + action_mapping: &ActionMapping, + new_args: Option<&str>, fs: &Arc<dyn Fs>, tab_size: usize, ) -> anyhow::Result<()> { @@ -1403,281 +2944,106 @@ async fn save_keybinding_update( .await .context("Failed to load keymap file")?; - let existing_keystrokes = existing - .ui_key_binding + let existing_keystrokes = existing.keystrokes().unwrap_or_default(); + let existing_context = existing.context().and_then(KeybindContextString::local_str); + let existing_args = existing + .action() + .arguments .as_ref() - .map(|keybinding| keybinding.keystrokes.as_slice()) - .unwrap_or_default(); + .map(|args| args.text.as_ref()); - let existing_context = existing - .context - .as_ref() - .and_then(KeybindContextString::local_str); + let target = settings::KeybindUpdateTarget { + context: existing_context, + keystrokes: existing_keystrokes, + action_name: &existing.action().name, + action_arguments: existing_args, + }; - let input = existing - .action_input - .as_ref() - .map(|input| input.text.as_ref()); + let source = settings::KeybindUpdateTarget { + context: action_mapping.context.as_ref().map(|a| &***a), + keystrokes: &action_mapping.keystrokes, + action_name: &existing.action().name, + action_arguments: new_args, + }; let operation = if !create { settings::KeybindUpdateOperation::Replace { - target: settings::KeybindUpdateTarget { - context: existing_context, - keystrokes: existing_keystrokes, - action_name: &existing.action_name, - use_key_equivalents: false, - input, - }, - target_keybind_source: existing - .source - .map(|(source, _name)| source) - .unwrap_or(KeybindSource::User), - source: settings::KeybindUpdateTarget { - context: new_context, - keystrokes: new_keystrokes, - action_name: &existing.action_name, - use_key_equivalents: false, - input, - }, + target, + target_keybind_source: existing.keybind_source().unwrap_or(KeybindSource::User), + source, } } else { - settings::KeybindUpdateOperation::Add(settings::KeybindUpdateTarget { - context: new_context, - keystrokes: new_keystrokes, - action_name: &existing.action_name, - use_key_equivalents: false, - input, - }) + settings::KeybindUpdateOperation::Add { + source, + from: Some(target), + } }; + + let (new_keybinding, removed_keybinding, source) = operation.generate_telemetry(); + let updated_keymap_contents = settings::KeymapFile::update_keybinding(operation, keymap_contents, tab_size) .context("Failed to update keybinding")?; - fs.atomic_write(paths::keymap_file().clone(), updated_keymap_contents) - .await - .context("Failed to write keymap file")?; + fs.write( + paths::keymap_file().as_path(), + updated_keymap_contents.as_bytes(), + ) + .await + .context("Failed to write keymap file")?; + + telemetry::event!( + "Keybinding Updated", + new_keybinding = new_keybinding, + removed_keybinding = removed_keybinding, + source = source + ); Ok(()) } -struct KeystrokeInput { - keystrokes: Vec<Keystroke>, - focus_handle: FocusHandle, - intercept_subscription: Option<Subscription>, - _focus_subscriptions: [Subscription; 2], -} - -impl KeystrokeInput { - fn new(window: &mut Window, cx: &mut Context<Self>) -> Self { - let focus_handle = cx.focus_handle(); - let _focus_subscriptions = [ - cx.on_focus_in(&focus_handle, window, Self::on_focus_in), - cx.on_focus_out(&focus_handle, window, Self::on_focus_out), - ]; - Self { - keystrokes: Vec::new(), - focus_handle, - intercept_subscription: None, - _focus_subscriptions, - } - } - - fn on_modifiers_changed( - &mut self, - event: &ModifiersChangedEvent, - _window: &mut Window, - cx: &mut Context<Self>, - ) { - if let Some(last) = self.keystrokes.last_mut() - && last.key.is_empty() - { - if !event.modifiers.modified() { - self.keystrokes.pop(); - } else { - last.modifiers = event.modifiers; - } - } else { - self.keystrokes.push(Keystroke { - modifiers: event.modifiers, - key: "".to_string(), - key_char: None, - }); - } - cx.stop_propagation(); - cx.notify(); - } - - fn handle_keystroke(&mut self, keystroke: &Keystroke, cx: &mut Context<Self>) { - if let Some(last) = self.keystrokes.last_mut() - && last.key.is_empty() - { - *last = keystroke.clone(); - } else if Some(keystroke) != self.keystrokes.last() { - self.keystrokes.push(keystroke.clone()); - } - cx.stop_propagation(); - cx.notify(); - } - - fn on_key_up( - &mut self, - event: &gpui::KeyUpEvent, - _window: &mut Window, - cx: &mut Context<Self>, - ) { - if let Some(last) = self.keystrokes.last_mut() - && !last.key.is_empty() - && last.modifiers == event.keystroke.modifiers - { - self.keystrokes.push(Keystroke { - modifiers: event.keystroke.modifiers, - key: "".to_string(), - key_char: None, - }); - } - cx.stop_propagation(); - cx.notify(); - } - - fn on_focus_in(&mut self, _window: &mut Window, cx: &mut Context<Self>) { - if self.intercept_subscription.is_none() { - let listener = cx.listener(|this, event: &gpui::KeystrokeEvent, _window, cx| { - this.handle_keystroke(&event.keystroke, cx); - }); - self.intercept_subscription = Some(cx.intercept_keystrokes(listener)) - } - } - - fn on_focus_out( - &mut self, - _event: gpui::FocusOutEvent, - _window: &mut Window, - _cx: &mut Context<Self>, - ) { - self.intercept_subscription.take(); - } - - fn keystrokes(&self) -> &[Keystroke] { - if self - .keystrokes - .last() - .map_or(false, |last| last.key.is_empty()) - { - return &self.keystrokes[..self.keystrokes.len() - 1]; - } - return &self.keystrokes; - } -} - -impl Focusable for KeystrokeInput { - fn focus_handle(&self, _cx: &App) -> FocusHandle { - self.focus_handle.clone() - } -} - -impl Render for KeystrokeInput { - fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { - let colors = cx.theme().colors(); - let is_focused = self.focus_handle.is_focused(window); - - return h_flex() - .id("keybinding_input") - .track_focus(&self.focus_handle) - .on_modifiers_changed(cx.listener(Self::on_modifiers_changed)) - .on_key_up(cx.listener(Self::on_key_up)) - .focus(|mut style| { - style.border_color = Some(colors.border_focused); - style - }) - .py_2() - .px_3() - .gap_2() - .min_h_8() - .w_full() - .flex_1() - .justify_between() - .rounded_md() - .overflow_hidden() - .bg(colors.editor_background) - .border_1() - .border_color(colors.border_variant) - .child( - h_flex() - .w_full() - .min_w_0() - .justify_center() - .flex_wrap() - .gap(ui::DynamicSpacing::Base04.rems(cx)) - .children(self.keystrokes.iter().map(|keystroke| { - h_flex().children(ui::render_keystroke( - keystroke, - None, - Some(rems(0.875).into()), - ui::PlatformStyle::platform(), - false, - )) - })), - ) - .child( - h_flex() - .gap_0p5() - .flex_none() - .child( - IconButton::new("backspace-btn", IconName::Delete) - .tooltip(Tooltip::text("Delete Keystroke")) - .when(!is_focused, |this| this.icon_color(Color::Muted)) - .on_click(cx.listener(|this, _event, _window, cx| { - this.keystrokes.pop(); - cx.notify(); - })), - ) - .child( - IconButton::new("clear-btn", IconName::Eraser) - .tooltip(Tooltip::text("Clear Keystrokes")) - .when(!is_focused, |this| this.icon_color(Color::Muted)) - .on_click(cx.listener(|this, _event, _window, cx| { - this.keystrokes.clear(); - cx.notify(); - })), - ), - ); - } -} - -fn build_keybind_context_menu( - this: &WeakEntity<KeymapEditor>, - item_idx: usize, - window: &mut Window, - cx: &mut App, -) -> Entity<ContextMenu> { - ContextMenu::build(window, cx, |menu, _window, cx| { - let selected_binding = this - .update(cx, |this, _cx| { - this.selected_index = Some(item_idx); - this.selected_binding().cloned() - }) - .ok() - .flatten(); +async fn remove_keybinding( + existing: ProcessedBinding, + fs: &Arc<dyn Fs>, + tab_size: usize, +) -> anyhow::Result<()> { + let Some(keystrokes) = existing.keystrokes() else { + anyhow::bail!("Cannot remove a keybinding that does not exist"); + }; + let keymap_contents = settings::KeymapFile::load_keymap_file(fs) + .await + .context("Failed to load keymap file")?; - let Some(selected_binding) = selected_binding else { - return menu; - }; + let operation = settings::KeybindUpdateOperation::Remove { + target: settings::KeybindUpdateTarget { + context: existing.context().and_then(KeybindContextString::local_str), + keystrokes, + action_name: &existing.action().name, + action_arguments: existing + .action() + .arguments + .as_ref() + .map(|arguments| arguments.text.as_ref()), + }, + target_keybind_source: existing.keybind_source().unwrap_or(KeybindSource::User), + }; - let selected_binding_has_no_context = selected_binding - .context - .as_ref() - .and_then(KeybindContextString::local) - .is_none(); - - let selected_binding_is_unbound = selected_binding.ui_key_binding.is_none(); - - menu.action_disabled_when(selected_binding_is_unbound, "Edit", Box::new(EditBinding)) - .action("Create", Box::new(CreateBinding)) - .action("Copy action", Box::new(CopyAction)) - .action_disabled_when( - selected_binding_has_no_context, - "Copy Context", - Box::new(CopyContext), - ) - }) + let (new_keybinding, removed_keybinding, source) = operation.generate_telemetry(); + let updated_keymap_contents = + settings::KeymapFile::update_keybinding(operation, keymap_contents, tab_size) + .context("Failed to update keybinding")?; + fs.write( + paths::keymap_file().as_path(), + updated_keymap_contents.as_bytes(), + ) + .await + .context("Failed to write keymap file")?; + + telemetry::event!( + "Keybinding Removed", + new_keybinding = new_keybinding, + removed_keybinding = removed_keybinding, + source = source + ); + Ok(()) } fn collect_contexts_from_assets() -> Vec<SharedString> { @@ -1720,7 +3086,7 @@ fn collect_contexts_from_assets() -> Vec<SharedString> { contexts.insert(ident_a); contexts.insert(ident_b); } - gpui::KeyBindingContextPredicate::Child(ctx_a, ctx_b) => { + gpui::KeyBindingContextPredicate::Descendant(ctx_a, ctx_b) => { queue.push(*ctx_a); queue.push(*ctx_b); } diff --git a/crates/settings_ui/src/settings_ui.rs b/crates/settings_ui/src/settings_ui.rs index 2f0abb478933c215048b64b5fa7399981f7beebe..3022cc714268f641b7b6f30021b5e86d6072b7b6 100644 --- a/crates/settings_ui/src/settings_ui.rs +++ b/crates/settings_ui/src/settings_ui.rs @@ -1,20 +1,12 @@ mod appearance_settings_controls; use std::any::TypeId; -use std::sync::Arc; use command_palette_hooks::CommandPaletteFilter; use editor::EditorSettingsControls; use feature_flags::{FeatureFlag, FeatureFlagViewExt}; -use fs::Fs; -use gpui::{ - Action, App, AsyncWindowContext, Entity, EventEmitter, FocusHandle, Focusable, Task, actions, -}; -use schemars::JsonSchema; -use serde::Deserialize; -use settings::{SettingsStore, VsCodeSettingsSource}; +use gpui::{App, Entity, EventEmitter, FocusHandle, Focusable, actions}; use ui::prelude::*; -use util::truncate_and_remove_front; use workspace::item::{Item, ItemEvent}; use workspace::{Workspace, with_active_or_new_workspace}; @@ -29,23 +21,6 @@ impl FeatureFlag for SettingsUiFeatureFlag { const NAME: &'static str = "settings-ui"; } -/// Imports settings from Visual Studio Code. -#[derive(Copy, Clone, Debug, Default, PartialEq, Deserialize, JsonSchema, Action)] -#[action(namespace = zed)] -#[serde(deny_unknown_fields)] -pub struct ImportVsCodeSettings { - #[serde(default)] - pub skip_prompt: bool, -} - -/// Imports settings from Cursor editor. -#[derive(Copy, Clone, Debug, Default, PartialEq, Deserialize, JsonSchema, Action)] -#[action(namespace = zed)] -#[serde(deny_unknown_fields)] -pub struct ImportCursorSettings { - #[serde(default)] - pub skip_prompt: bool, -} actions!( zed, [ @@ -72,45 +47,11 @@ pub fn init(cx: &mut App) { }); }); - cx.observe_new(|workspace: &mut Workspace, window, cx| { + cx.observe_new(|_workspace: &mut Workspace, window, cx| { let Some(window) = window else { return; }; - workspace.register_action(|_workspace, action: &ImportVsCodeSettings, window, cx| { - let fs = <dyn Fs>::global(cx); - let action = *action; - - window - .spawn(cx, async move |cx: &mut AsyncWindowContext| { - handle_import_vscode_settings( - VsCodeSettingsSource::VsCode, - action.skip_prompt, - fs, - cx, - ) - .await - }) - .detach(); - }); - - workspace.register_action(|_workspace, action: &ImportCursorSettings, window, cx| { - let fs = <dyn Fs>::global(cx); - let action = *action; - - window - .spawn(cx, async move |cx: &mut AsyncWindowContext| { - handle_import_vscode_settings( - VsCodeSettingsSource::Cursor, - action.skip_prompt, - fs, - cx, - ) - .await - }) - .detach(); - }); - let settings_ui_actions = [TypeId::of::<OpenSettingsEditor>()]; CommandPaletteFilter::update_global(cx, |filter, _cx| { @@ -138,57 +79,6 @@ pub fn init(cx: &mut App) { keybindings::init(cx); } -async fn handle_import_vscode_settings( - source: VsCodeSettingsSource, - skip_prompt: bool, - fs: Arc<dyn Fs>, - cx: &mut AsyncWindowContext, -) { - let vscode_settings = - match settings::VsCodeSettings::load_user_settings(source, fs.clone()).await { - Ok(vscode_settings) => vscode_settings, - Err(err) => { - log::error!("{err}"); - let _ = cx.prompt( - gpui::PromptLevel::Info, - &format!("Could not find or load a {source} settings file"), - None, - &["Ok"], - ); - return; - } - }; - - let prompt = if skip_prompt { - Task::ready(Some(0)) - } else { - let prompt = cx.prompt( - gpui::PromptLevel::Warning, - &format!( - "Importing {} settings may overwrite your existing settings. \ - Will import settings from {}", - vscode_settings.source, - truncate_and_remove_front(&vscode_settings.path.to_string_lossy(), 128), - ), - None, - &["Ok", "Cancel"], - ); - cx.spawn(async move |_| prompt.await.ok()) - }; - if prompt.await != Some(0) { - return; - } - - cx.update(|_, cx| { - let source = vscode_settings.source; - let path = vscode_settings.path.clone(); - cx.global::<SettingsStore>() - .import_vscode_settings(fs, vscode_settings); - log::info!("Imported {source} settings from {}", path.display()); - }) - .ok(); -} - pub struct SettingsPage { focus_handle: FocusHandle, } diff --git a/crates/settings_ui/src/ui_components/keystroke_input.rs b/crates/settings_ui/src/ui_components/keystroke_input.rs new file mode 100644 index 0000000000000000000000000000000000000000..ee5c4036eae585f3702ecd9b590f22a367e49e0f --- /dev/null +++ b/crates/settings_ui/src/ui_components/keystroke_input.rs @@ -0,0 +1,1388 @@ +use gpui::{ + Animation, AnimationExt, Context, EventEmitter, FocusHandle, Focusable, FontWeight, KeyContext, + Keystroke, Modifiers, ModifiersChangedEvent, Subscription, Task, actions, +}; +use ui::{ + ActiveTheme as _, Color, IconButton, IconButtonShape, IconName, IconSize, Label, LabelSize, + ParentElement as _, Render, Styled as _, Tooltip, Window, prelude::*, +}; + +actions!( + keystroke_input, + [ + /// Starts recording keystrokes + StartRecording, + /// Stops recording keystrokes + StopRecording, + /// Clears the recorded keystrokes + ClearKeystrokes, + ] +); + +const KEY_CONTEXT_VALUE: &'static str = "KeystrokeInput"; + +const CLOSE_KEYSTROKE_CAPTURE_END_TIMEOUT: std::time::Duration = + std::time::Duration::from_millis(300); + +enum CloseKeystrokeResult { + Partial, + Close, + None, +} + +impl PartialEq for CloseKeystrokeResult { + fn eq(&self, other: &Self) -> bool { + matches!( + (self, other), + (CloseKeystrokeResult::Partial, CloseKeystrokeResult::Partial) + | (CloseKeystrokeResult::Close, CloseKeystrokeResult::Close) + | (CloseKeystrokeResult::None, CloseKeystrokeResult::None) + ) + } +} + +pub struct KeystrokeInput { + keystrokes: Vec<Keystroke>, + placeholder_keystrokes: Option<Vec<Keystroke>>, + outer_focus_handle: FocusHandle, + inner_focus_handle: FocusHandle, + intercept_subscription: Option<Subscription>, + _focus_subscriptions: [Subscription; 2], + search: bool, + /// The sequence of close keystrokes being typed + close_keystrokes: Option<Vec<Keystroke>>, + close_keystrokes_start: Option<usize>, + previous_modifiers: Modifiers, + /// In order to support inputting keystrokes that end with a prefix of the + /// close keybind keystrokes, we clear the close keystroke capture info + /// on a timeout after a close keystroke is pressed + /// + /// e.g. if close binding is `esc esc esc` and user wants to search for + /// `ctrl-g esc`, after entering the `ctrl-g esc`, hitting `esc` twice would + /// stop recording because of the sequence of three escapes making it + /// impossible to search for anything ending in `esc` + clear_close_keystrokes_timer: Option<Task<()>>, + #[cfg(test)] + recording: bool, +} + +impl KeystrokeInput { + const KEYSTROKE_COUNT_MAX: usize = 3; + + pub fn new( + placeholder_keystrokes: Option<Vec<Keystroke>>, + window: &mut Window, + cx: &mut Context<Self>, + ) -> Self { + let outer_focus_handle = cx.focus_handle(); + let inner_focus_handle = cx.focus_handle(); + let _focus_subscriptions = [ + cx.on_focus_in(&inner_focus_handle, window, Self::on_inner_focus_in), + cx.on_focus_out(&inner_focus_handle, window, Self::on_inner_focus_out), + ]; + Self { + keystrokes: Vec::new(), + placeholder_keystrokes, + inner_focus_handle, + outer_focus_handle, + intercept_subscription: None, + _focus_subscriptions, + search: false, + close_keystrokes: None, + close_keystrokes_start: None, + previous_modifiers: Modifiers::default(), + clear_close_keystrokes_timer: None, + #[cfg(test)] + recording: false, + } + } + + pub fn set_keystrokes(&mut self, keystrokes: Vec<Keystroke>, cx: &mut Context<Self>) { + self.keystrokes = keystrokes; + self.keystrokes_changed(cx); + } + + pub fn set_search(&mut self, search: bool) { + self.search = search; + } + + pub fn keystrokes(&self) -> &[Keystroke] { + if let Some(placeholders) = self.placeholder_keystrokes.as_ref() + && self.keystrokes.is_empty() + { + return placeholders; + } + if !self.search + && self + .keystrokes + .last() + .map_or(false, |last| last.key.is_empty()) + { + return &self.keystrokes[..self.keystrokes.len() - 1]; + } + return &self.keystrokes; + } + + fn dummy(modifiers: Modifiers) -> Keystroke { + return Keystroke { + modifiers, + key: "".to_string(), + key_char: None, + }; + } + + fn keystrokes_changed(&self, cx: &mut Context<Self>) { + cx.emit(()); + cx.notify(); + } + + fn key_context() -> KeyContext { + let mut key_context = KeyContext::default(); + key_context.add(KEY_CONTEXT_VALUE); + key_context + } + + fn determine_stop_recording_binding(window: &mut Window) -> Option<gpui::KeyBinding> { + if cfg!(test) { + Some(gpui::KeyBinding::new( + "escape escape escape", + StopRecording, + Some(KEY_CONTEXT_VALUE), + )) + } else { + window.highest_precedence_binding_for_action_in_context( + &StopRecording, + Self::key_context(), + ) + } + } + + fn upsert_close_keystrokes_start(&mut self, start: usize, cx: &mut Context<Self>) { + if self.close_keystrokes_start.is_some() { + return; + } + self.close_keystrokes_start = Some(start); + self.update_clear_close_keystrokes_timer(cx); + } + + fn update_clear_close_keystrokes_timer(&mut self, cx: &mut Context<Self>) { + self.clear_close_keystrokes_timer = Some(cx.spawn(async |this, cx| { + cx.background_executor() + .timer(CLOSE_KEYSTROKE_CAPTURE_END_TIMEOUT) + .await; + this.update(cx, |this, _cx| { + this.end_close_keystrokes_capture(); + }) + .ok(); + })); + } + + /// Interrupt the capture of close keystrokes, but do not clear the close keystrokes + /// from the input + fn end_close_keystrokes_capture(&mut self) -> Option<usize> { + self.close_keystrokes.take(); + self.clear_close_keystrokes_timer.take(); + return self.close_keystrokes_start.take(); + } + + fn handle_possible_close_keystroke( + &mut self, + keystroke: &Keystroke, + window: &mut Window, + cx: &mut Context<Self>, + ) -> CloseKeystrokeResult { + let Some(keybind_for_close_action) = Self::determine_stop_recording_binding(window) else { + log::trace!("No keybinding to stop recording keystrokes in keystroke input"); + self.end_close_keystrokes_capture(); + return CloseKeystrokeResult::None; + }; + let action_keystrokes = keybind_for_close_action.keystrokes(); + + if let Some(mut close_keystrokes) = self.close_keystrokes.take() { + let mut index = 0; + + while index < action_keystrokes.len() && index < close_keystrokes.len() { + if !close_keystrokes[index].should_match(&action_keystrokes[index]) { + break; + } + index += 1; + } + if index == close_keystrokes.len() { + if index >= action_keystrokes.len() { + self.end_close_keystrokes_capture(); + return CloseKeystrokeResult::None; + } + if keystroke.should_match(&action_keystrokes[index]) { + close_keystrokes.push(keystroke.clone()); + if close_keystrokes.len() == action_keystrokes.len() { + return CloseKeystrokeResult::Close; + } else { + self.close_keystrokes = Some(close_keystrokes); + self.update_clear_close_keystrokes_timer(cx); + return CloseKeystrokeResult::Partial; + } + } else { + self.end_close_keystrokes_capture(); + return CloseKeystrokeResult::None; + } + } + } else if let Some(first_action_keystroke) = action_keystrokes.first() + && keystroke.should_match(first_action_keystroke) + { + self.close_keystrokes = Some(vec![keystroke.clone()]); + return CloseKeystrokeResult::Partial; + } + self.end_close_keystrokes_capture(); + return CloseKeystrokeResult::None; + } + + fn on_modifiers_changed( + &mut self, + event: &ModifiersChangedEvent, + window: &mut Window, + cx: &mut Context<Self>, + ) { + cx.stop_propagation(); + let keystrokes_len = self.keystrokes.len(); + + if self.previous_modifiers.modified() + && event.modifiers.is_subset_of(&self.previous_modifiers) + { + self.previous_modifiers &= event.modifiers; + return; + } + self.keystrokes_changed(cx); + + if let Some(last) = self.keystrokes.last_mut() + && last.key.is_empty() + && keystrokes_len <= Self::KEYSTROKE_COUNT_MAX + { + if !self.search && !event.modifiers.modified() { + self.keystrokes.pop(); + return; + } + if self.search { + if self.previous_modifiers.modified() { + last.modifiers |= event.modifiers; + } else { + self.keystrokes.push(Self::dummy(event.modifiers)); + } + self.previous_modifiers |= event.modifiers; + } else { + last.modifiers = event.modifiers; + return; + } + } else if keystrokes_len < Self::KEYSTROKE_COUNT_MAX { + self.keystrokes.push(Self::dummy(event.modifiers)); + if self.search { + self.previous_modifiers |= event.modifiers; + } + } + if keystrokes_len >= Self::KEYSTROKE_COUNT_MAX { + self.clear_keystrokes(&ClearKeystrokes, window, cx); + } + } + + fn handle_keystroke( + &mut self, + keystroke: &Keystroke, + window: &mut Window, + cx: &mut Context<Self>, + ) { + cx.stop_propagation(); + + let close_keystroke_result = self.handle_possible_close_keystroke(keystroke, window, cx); + if close_keystroke_result == CloseKeystrokeResult::Close { + self.stop_recording(&StopRecording, window, cx); + return; + } + + let mut keystroke = keystroke.clone(); + if let Some(last) = self.keystrokes.last() + && last.key.is_empty() + && (!self.search || self.previous_modifiers.modified()) + { + let key = keystroke.key.clone(); + keystroke = last.clone(); + keystroke.key = key; + self.keystrokes.pop(); + } + + if close_keystroke_result == CloseKeystrokeResult::Partial { + self.upsert_close_keystrokes_start(self.keystrokes.len(), cx); + if self.keystrokes.len() >= Self::KEYSTROKE_COUNT_MAX { + return; + } + } + + if self.keystrokes.len() >= Self::KEYSTROKE_COUNT_MAX { + self.clear_keystrokes(&ClearKeystrokes, window, cx); + return; + } + + self.keystrokes.push(keystroke.clone()); + self.keystrokes_changed(cx); + + if self.search { + self.previous_modifiers = keystroke.modifiers; + return; + } + if self.keystrokes.len() < Self::KEYSTROKE_COUNT_MAX && keystroke.modifiers.modified() { + self.keystrokes.push(Self::dummy(keystroke.modifiers)); + } + } + + fn on_inner_focus_in(&mut self, _window: &mut Window, cx: &mut Context<Self>) { + if self.intercept_subscription.is_none() { + let listener = cx.listener(|this, event: &gpui::KeystrokeEvent, window, cx| { + this.handle_keystroke(&event.keystroke, window, cx); + }); + self.intercept_subscription = Some(cx.intercept_keystrokes(listener)) + } + } + + fn on_inner_focus_out( + &mut self, + _event: gpui::FocusOutEvent, + _window: &mut Window, + cx: &mut Context<Self>, + ) { + self.intercept_subscription.take(); + cx.notify(); + } + + fn render_keystrokes(&self, is_recording: bool) -> impl Iterator<Item = Div> { + let keystrokes = if let Some(placeholders) = self.placeholder_keystrokes.as_ref() + && self.keystrokes.is_empty() + { + if is_recording { + &[] + } else { + placeholders.as_slice() + } + } else { + &self.keystrokes + }; + keystrokes.iter().map(move |keystroke| { + h_flex().children(ui::render_keystroke( + keystroke, + Some(Color::Default), + Some(rems(0.875).into()), + ui::PlatformStyle::platform(), + false, + )) + }) + } + + pub fn start_recording( + &mut self, + _: &StartRecording, + window: &mut Window, + cx: &mut Context<Self>, + ) { + window.focus(&self.inner_focus_handle); + self.clear_keystrokes(&ClearKeystrokes, window, cx); + self.previous_modifiers = window.modifiers(); + #[cfg(test)] + { + self.recording = true; + } + cx.stop_propagation(); + } + + pub fn stop_recording( + &mut self, + _: &StopRecording, + window: &mut Window, + cx: &mut Context<Self>, + ) { + if !self.is_recording(window) { + return; + } + window.focus(&self.outer_focus_handle); + if let Some(close_keystrokes_start) = self.close_keystrokes_start.take() + && close_keystrokes_start < self.keystrokes.len() + { + self.keystrokes.drain(close_keystrokes_start..); + self.keystrokes_changed(cx); + } + self.end_close_keystrokes_capture(); + #[cfg(test)] + { + self.recording = false; + } + cx.notify(); + } + + pub fn clear_keystrokes( + &mut self, + _: &ClearKeystrokes, + _window: &mut Window, + cx: &mut Context<Self>, + ) { + self.keystrokes.clear(); + self.keystrokes_changed(cx); + self.end_close_keystrokes_capture(); + } + + fn is_recording(&self, window: &Window) -> bool { + #[cfg(test)] + { + if true { + // in tests, we just need a simple bool that is toggled on start and stop recording + return self.recording; + } + } + // however, in the real world, checking if the inner focus handle is focused + // is a much more reliable check, as the intercept keystroke handlers are installed + // on focus of the inner focus handle, thereby ensuring our recording state does + // not get de-synced + return self.inner_focus_handle.is_focused(window); + } +} + +impl EventEmitter<()> for KeystrokeInput {} + +impl Focusable for KeystrokeInput { + fn focus_handle(&self, _cx: &gpui::App) -> FocusHandle { + self.outer_focus_handle.clone() + } +} + +impl Render for KeystrokeInput { + fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { + let colors = cx.theme().colors(); + let is_focused = self.outer_focus_handle.contains_focused(window, cx); + let is_recording = self.is_recording(window); + + let horizontal_padding = rems_from_px(64.); + + let recording_bg_color = colors + .editor_background + .blend(colors.text_accent.opacity(0.1)); + + let recording_pulse = |color: Color| { + Icon::new(IconName::Circle) + .size(IconSize::Small) + .color(Color::Error) + .with_animation( + "recording-pulse", + Animation::new(std::time::Duration::from_secs(2)) + .repeat() + .with_easing(gpui::pulsating_between(0.4, 0.8)), + { + let color = color.color(cx); + move |this, delta| this.color(Color::Custom(color.opacity(delta))) + }, + ) + }; + + let recording_indicator = h_flex() + .h_4() + .pr_1() + .gap_0p5() + .border_1() + .border_color(colors.border) + .bg(colors + .editor_background + .blend(colors.text_accent.opacity(0.1))) + .rounded_sm() + .child(recording_pulse(Color::Error)) + .child( + Label::new("REC") + .size(LabelSize::XSmall) + .weight(FontWeight::SEMIBOLD) + .color(Color::Error), + ); + + let search_indicator = h_flex() + .h_4() + .pr_1() + .gap_0p5() + .border_1() + .border_color(colors.border) + .bg(colors + .editor_background + .blend(colors.text_accent.opacity(0.1))) + .rounded_sm() + .child(recording_pulse(Color::Accent)) + .child( + Label::new("SEARCH") + .size(LabelSize::XSmall) + .weight(FontWeight::SEMIBOLD) + .color(Color::Accent), + ); + + let record_icon = if self.search { + IconName::MagnifyingGlass + } else { + IconName::PlayFilled + }; + + h_flex() + .id("keystroke-input") + .track_focus(&self.outer_focus_handle) + .py_2() + .px_3() + .gap_2() + .min_h_10() + .w_full() + .flex_1() + .justify_between() + .rounded_sm() + .overflow_hidden() + .map(|this| { + if is_recording { + this.bg(recording_bg_color) + } else { + this.bg(colors.editor_background) + } + }) + .border_1() + .border_color(colors.border_variant) + .when(is_focused, |parent| { + parent.border_color(colors.border_focused) + }) + .key_context(Self::key_context()) + .on_action(cx.listener(Self::start_recording)) + .on_action(cx.listener(Self::clear_keystrokes)) + .child( + h_flex() + .w(horizontal_padding) + .gap_0p5() + .justify_start() + .flex_none() + .when(is_recording, |this| { + this.map(|this| { + if self.search { + this.child(search_indicator) + } else { + this.child(recording_indicator) + } + }) + }), + ) + .child( + h_flex() + .id("keystroke-input-inner") + .track_focus(&self.inner_focus_handle) + .on_modifiers_changed(cx.listener(Self::on_modifiers_changed)) + .size_full() + .when(!self.search, |this| { + this.focus(|mut style| { + style.border_color = Some(colors.border_focused); + style + }) + }) + .w_full() + .min_w_0() + .justify_center() + .flex_wrap() + .gap(ui::DynamicSpacing::Base04.rems(cx)) + .children(self.render_keystrokes(is_recording)), + ) + .child( + h_flex() + .w(horizontal_padding) + .gap_0p5() + .justify_end() + .flex_none() + .map(|this| { + if is_recording { + this.child( + IconButton::new("stop-record-btn", IconName::StopFilled) + .shape(IconButtonShape::Square) + .map(|this| { + this.tooltip(Tooltip::for_action_title( + if self.search { + "Stop Searching" + } else { + "Stop Recording" + }, + &StopRecording, + )) + }) + .icon_color(Color::Error) + .on_click(cx.listener(|this, _event, window, cx| { + this.stop_recording(&StopRecording, window, cx); + })), + ) + } else { + this.child( + IconButton::new("record-btn", record_icon) + .shape(IconButtonShape::Square) + .map(|this| { + this.tooltip(Tooltip::for_action_title( + if self.search { + "Start Searching" + } else { + "Start Recording" + }, + &StartRecording, + )) + }) + .when(!is_focused, |this| this.icon_color(Color::Muted)) + .on_click(cx.listener(|this, _event, window, cx| { + this.start_recording(&StartRecording, window, cx); + })), + ) + } + }) + .child( + IconButton::new("clear-btn", IconName::Delete) + .shape(IconButtonShape::Square) + .tooltip(Tooltip::for_action_title( + "Clear Keystrokes", + &ClearKeystrokes, + )) + .when(!is_recording || !is_focused, |this| { + this.icon_color(Color::Muted) + }) + .on_click(cx.listener(|this, _event, window, cx| { + this.clear_keystrokes(&ClearKeystrokes, window, cx); + })), + ), + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use fs::FakeFs; + use gpui::{Entity, TestAppContext, VisualTestContext}; + use itertools::Itertools as _; + use project::Project; + use settings::SettingsStore; + use workspace::Workspace; + + pub struct KeystrokeInputTestHelper { + input: Entity<KeystrokeInput>, + current_modifiers: Modifiers, + cx: VisualTestContext, + } + + impl KeystrokeInputTestHelper { + /// Creates a new test helper with default settings + pub fn new(mut cx: VisualTestContext) -> Self { + let input = cx.new_window_entity(|window, cx| KeystrokeInput::new(None, window, cx)); + + let mut helper = Self { + input, + current_modifiers: Modifiers::default(), + cx, + }; + + helper.start_recording(); + helper + } + + /// Sets search mode on the input + pub fn with_search_mode(&mut self, search: bool) -> &mut Self { + self.input.update(&mut self.cx, |input, _| { + input.set_search(search); + }); + self + } + + /// Sends a keystroke event based on string description + /// Examples: "a", "ctrl-a", "cmd-shift-z", "escape" + #[track_caller] + pub fn send_keystroke(&mut self, keystroke_input: &str) -> &mut Self { + self.expect_is_recording(true); + let keystroke_str = if keystroke_input.ends_with('-') { + format!("{}_", keystroke_input) + } else { + keystroke_input.to_string() + }; + + let mut keystroke = Keystroke::parse(&keystroke_str) + .unwrap_or_else(|_| panic!("Invalid keystroke: {}", keystroke_input)); + + // Remove the dummy key if we added it for modifier-only keystrokes + if keystroke_input.ends_with('-') && keystroke_str.ends_with("_") { + keystroke.key = "".to_string(); + } + + // Combine current modifiers with keystroke modifiers + keystroke.modifiers |= self.current_modifiers; + + self.update_input(|input, window, cx| { + input.handle_keystroke(&keystroke, window, cx); + }); + + // Don't update current_modifiers for keystrokes with actual keys + if keystroke.key.is_empty() { + self.current_modifiers = keystroke.modifiers; + } + self + } + + /// Sends a modifier change event based on string description + /// Examples: "+ctrl", "-ctrl", "+cmd+shift", "-all" + #[track_caller] + pub fn send_modifiers(&mut self, modifiers: &str) -> &mut Self { + self.expect_is_recording(true); + let new_modifiers = if modifiers == "-all" { + Modifiers::default() + } else { + self.parse_modifier_change(modifiers) + }; + + let event = ModifiersChangedEvent { + modifiers: new_modifiers, + capslock: gpui::Capslock::default(), + }; + + self.update_input(|input, window, cx| { + input.on_modifiers_changed(&event, window, cx); + }); + + self.current_modifiers = new_modifiers; + self + } + + /// Sends multiple events in sequence + /// Each event string is either a keystroke or modifier change + #[track_caller] + pub fn send_events(&mut self, events: &[&str]) -> &mut Self { + self.expect_is_recording(true); + for event in events { + if event.starts_with('+') || event.starts_with('-') { + self.send_modifiers(event); + } else { + self.send_keystroke(event); + } + } + self + } + + #[track_caller] + fn expect_keystrokes_equal(actual: &[Keystroke], expected: &[&str]) { + let expected_keystrokes: Result<Vec<Keystroke>, _> = expected + .iter() + .map(|s| { + let keystroke_str = if s.ends_with('-') { + format!("{}_", s) + } else { + s.to_string() + }; + + let mut keystroke = Keystroke::parse(&keystroke_str)?; + + // Remove the dummy key if we added it for modifier-only keystrokes + if s.ends_with('-') && keystroke_str.ends_with("_") { + keystroke.key = "".to_string(); + } + + Ok(keystroke) + }) + .collect(); + + let expected_keystrokes = expected_keystrokes + .unwrap_or_else(|e: anyhow::Error| panic!("Invalid expected keystroke: {}", e)); + + assert_eq!( + actual.len(), + expected_keystrokes.len(), + "Keystroke count mismatch. Expected: {:?}, Actual: {:?}", + expected_keystrokes + .iter() + .map(|k| k.unparse()) + .collect::<Vec<_>>(), + actual.iter().map(|k| k.unparse()).collect::<Vec<_>>() + ); + + for (i, (actual, expected)) in actual.iter().zip(expected_keystrokes.iter()).enumerate() + { + assert_eq!( + actual.unparse(), + expected.unparse(), + "Keystroke {} mismatch. Expected: '{}', Actual: '{}'", + i, + expected.unparse(), + actual.unparse() + ); + } + } + + /// Verifies that the keystrokes match the expected strings + #[track_caller] + pub fn expect_keystrokes(&mut self, expected: &[&str]) -> &mut Self { + let actual = self + .input + .read_with(&mut self.cx, |input, _| input.keystrokes.clone()); + Self::expect_keystrokes_equal(&actual, expected); + self + } + + #[track_caller] + pub fn expect_close_keystrokes(&mut self, expected: &[&str]) -> &mut Self { + let actual = self + .input + .read_with(&mut self.cx, |input, _| input.close_keystrokes.clone()) + .unwrap_or_default(); + Self::expect_keystrokes_equal(&actual, expected); + self + } + + /// Verifies that there are no keystrokes + #[track_caller] + pub fn expect_empty(&mut self) -> &mut Self { + self.expect_keystrokes(&[]) + } + + /// Starts recording keystrokes + #[track_caller] + pub fn start_recording(&mut self) -> &mut Self { + self.expect_is_recording(false); + self.input.update_in(&mut self.cx, |input, window, cx| { + input.start_recording(&StartRecording, window, cx); + }); + self + } + + /// Stops recording keystrokes + pub fn stop_recording(&mut self) -> &mut Self { + self.expect_is_recording(true); + self.input.update_in(&mut self.cx, |input, window, cx| { + input.stop_recording(&StopRecording, window, cx); + }); + self + } + + /// Clears all keystrokes + #[track_caller] + pub fn clear_keystrokes(&mut self) -> &mut Self { + let change_tracker = KeystrokeUpdateTracker::new(self.input.clone(), &mut self.cx); + self.input.update_in(&mut self.cx, |input, window, cx| { + input.clear_keystrokes(&ClearKeystrokes, window, cx); + }); + KeystrokeUpdateTracker::finish(change_tracker, &self.cx); + self.current_modifiers = Default::default(); + self + } + + /// Verifies the recording state + #[track_caller] + pub fn expect_is_recording(&mut self, expected: bool) -> &mut Self { + let actual = self + .input + .update_in(&mut self.cx, |input, window, _| input.is_recording(window)); + assert_eq!( + actual, expected, + "Recording state mismatch. Expected: {}, Actual: {}", + expected, actual + ); + self + } + + pub async fn wait_for_close_keystroke_capture_end(&mut self) -> &mut Self { + let task = self.input.update_in(&mut self.cx, |input, _, _| { + input.clear_close_keystrokes_timer.take() + }); + let task = task.expect("No close keystroke capture end timer task"); + self.cx + .executor() + .advance_clock(CLOSE_KEYSTROKE_CAPTURE_END_TIMEOUT); + task.await; + self + } + + /// Parses modifier change strings like "+ctrl", "-shift", "+cmd+alt" + #[track_caller] + fn parse_modifier_change(&self, modifiers_str: &str) -> Modifiers { + let mut modifiers = self.current_modifiers; + + assert!(!modifiers_str.is_empty(), "Empty modifier string"); + + let value; + let split_char; + let remaining; + if let Some(to_add) = modifiers_str.strip_prefix('+') { + value = true; + split_char = '+'; + remaining = to_add; + } else { + let to_remove = modifiers_str + .strip_prefix('-') + .expect("Modifier string must start with '+' or '-'"); + value = false; + split_char = '-'; + remaining = to_remove; + } + + for modifier in remaining.split(split_char) { + match modifier { + "ctrl" | "control" => modifiers.control = value, + "alt" | "option" => modifiers.alt = value, + "shift" => modifiers.shift = value, + "cmd" | "command" | "platform" => modifiers.platform = value, + "fn" | "function" => modifiers.function = value, + _ => panic!("Unknown modifier: {}", modifier), + } + } + + modifiers + } + + #[track_caller] + fn update_input<R>( + &mut self, + cb: impl FnOnce(&mut KeystrokeInput, &mut Window, &mut Context<KeystrokeInput>) -> R, + ) -> R { + let change_tracker = KeystrokeUpdateTracker::new(self.input.clone(), &mut self.cx); + let result = self.input.update_in(&mut self.cx, cb); + KeystrokeUpdateTracker::finish(change_tracker, &self.cx); + return result; + } + } + + struct KeystrokeUpdateTracker { + initial_keystrokes: Vec<Keystroke>, + _subscription: Subscription, + input: Entity<KeystrokeInput>, + received_keystrokes_updated: bool, + } + + impl KeystrokeUpdateTracker { + fn new(input: Entity<KeystrokeInput>, cx: &mut VisualTestContext) -> Entity<Self> { + cx.new(|cx| Self { + initial_keystrokes: input.read_with(cx, |input, _| input.keystrokes.clone()), + _subscription: cx.subscribe(&input, |this: &mut Self, _, _, _| { + this.received_keystrokes_updated = true; + }), + input, + received_keystrokes_updated: false, + }) + } + #[track_caller] + fn finish(this: Entity<Self>, cx: &VisualTestContext) { + let (received_keystrokes_updated, initial_keystrokes_str, updated_keystrokes_str) = + this.read_with(cx, |this, cx| { + let updated_keystrokes = this + .input + .read_with(cx, |input, _| input.keystrokes.clone()); + let initial_keystrokes_str = keystrokes_str(&this.initial_keystrokes); + let updated_keystrokes_str = keystrokes_str(&updated_keystrokes); + ( + this.received_keystrokes_updated, + initial_keystrokes_str, + updated_keystrokes_str, + ) + }); + if received_keystrokes_updated { + assert_ne!( + initial_keystrokes_str, updated_keystrokes_str, + "Received keystrokes_updated event, expected different keystrokes" + ); + } else { + assert_eq!( + initial_keystrokes_str, updated_keystrokes_str, + "Received no keystrokes_updated event, expected same keystrokes" + ); + } + + fn keystrokes_str(ks: &[Keystroke]) -> String { + ks.iter().map(|ks| ks.unparse()).join(" ") + } + } + } + + async fn init_test(cx: &mut TestAppContext) -> KeystrokeInputTestHelper { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + theme::init(theme::LoadThemes::JustBase, cx); + language::init(cx); + project::Project::init_settings(cx); + workspace::init_settings(cx); + }); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [], cx).await; + let workspace = + cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let cx = VisualTestContext::from_window(*workspace, cx); + KeystrokeInputTestHelper::new(cx) + } + + #[gpui::test] + async fn test_basic_keystroke_input(cx: &mut TestAppContext) { + init_test(cx) + .await + .send_keystroke("a") + .clear_keystrokes() + .expect_empty(); + } + + #[gpui::test] + async fn test_modifier_handling(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl", "a", "-ctrl"]) + .expect_keystrokes(&["ctrl-a"]); + } + + #[gpui::test] + async fn test_multiple_modifiers(cx: &mut TestAppContext) { + init_test(cx) + .await + .send_keystroke("cmd-shift-z") + .expect_keystrokes(&["cmd-shift-z", "cmd-shift-"]); + } + + #[gpui::test] + async fn test_search_mode_behavior(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+cmd", "shift-f", "-cmd"]) + // In search mode, when completing a modifier-only keystroke with a key, + // only the original modifiers are preserved, not the keystroke's modifiers + .expect_keystrokes(&["cmd-f"]); + } + + #[gpui::test] + async fn test_keystroke_limit(cx: &mut TestAppContext) { + init_test(cx) + .await + .send_keystroke("a") + .send_keystroke("b") + .send_keystroke("c") + .expect_keystrokes(&["a", "b", "c"]) // At max limit + .send_keystroke("d") + .expect_empty(); // Should clear when exceeding limit + } + + #[gpui::test] + async fn test_modifier_release_all(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl+shift", "a", "-all"]) + .expect_keystrokes(&["ctrl-shift-a"]); + } + + #[gpui::test] + async fn test_search_new_modifiers_not_added_until_all_released(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl+shift", "a", "-ctrl"]) + .expect_keystrokes(&["ctrl-shift-a"]) + .send_events(&["+ctrl"]) + .expect_keystrokes(&["ctrl-shift-a", "ctrl-shift-"]); + } + + #[gpui::test] + async fn test_previous_modifiers_no_effect_when_not_search(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(false) + .send_events(&["+ctrl+shift", "a", "-all"]) + .expect_keystrokes(&["ctrl-shift-a"]); + } + + #[gpui::test] + async fn test_keystroke_limit_overflow_non_search_mode(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(false) + .send_events(&["a", "b", "c", "d"]) // 4 keystrokes, exceeds limit of 3 + .expect_empty(); // Should clear when exceeding limit + } + + #[gpui::test] + async fn test_complex_modifier_sequences(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl", "+shift", "+alt", "a", "-ctrl", "-shift", "-alt"]) + .expect_keystrokes(&["ctrl-shift-alt-a"]); + } + + #[gpui::test] + async fn test_modifier_only_keystrokes_search_mode(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl", "+shift", "-ctrl", "-shift"]) + .expect_keystrokes(&["ctrl-shift-"]); // Modifier-only sequences create modifier-only keystrokes + } + + #[gpui::test] + async fn test_modifier_only_keystrokes_non_search_mode(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(false) + .send_events(&["+ctrl", "+shift", "-ctrl", "-shift"]) + .expect_empty(); // Modifier-only sequences get filtered in non-search mode + } + + #[gpui::test] + async fn test_rapid_modifier_changes(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl", "-ctrl", "+shift", "-shift", "+alt", "a", "-alt"]) + .expect_keystrokes(&["ctrl-", "shift-", "alt-a"]); + } + + #[gpui::test] + async fn test_clear_keystrokes_search_mode(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl", "a", "-ctrl", "b"]) + .expect_keystrokes(&["ctrl-a", "b"]) + .clear_keystrokes() + .expect_empty(); + } + + #[gpui::test] + async fn test_non_search_mode_modifier_key_sequence(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(false) + .send_events(&["+ctrl", "a"]) + .expect_keystrokes(&["ctrl-a", "ctrl-"]) + .send_events(&["-ctrl"]) + .expect_keystrokes(&["ctrl-a"]); // Non-search mode filters trailing empty keystrokes + } + + #[gpui::test] + async fn test_all_modifiers_at_once(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl+shift+alt+cmd", "a", "-all"]) + .expect_keystrokes(&["ctrl-shift-alt-cmd-a"]); + } + + #[gpui::test] + async fn test_keystrokes_at_exact_limit(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["a", "b", "c"]) // exactly 3 keystrokes (at limit) + .expect_keystrokes(&["a", "b", "c"]) + .send_events(&["d"]) // should clear when exceeding + .expect_empty(); + } + + #[gpui::test] + async fn test_function_modifier_key(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+fn", "f1", "-fn"]) + .expect_keystrokes(&["fn-f1"]); + } + + #[gpui::test] + async fn test_start_stop_recording(cx: &mut TestAppContext) { + init_test(cx) + .await + .send_events(&["a", "b"]) + .expect_keystrokes(&["a", "b"]) // start_recording clears existing keystrokes + .stop_recording() + .expect_is_recording(false) + .start_recording() + .send_events(&["c"]) + .expect_keystrokes(&["c"]); + } + + #[gpui::test] + async fn test_modifier_sequence_with_interruption(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl", "+shift", "a", "-shift", "b", "-ctrl"]) + .expect_keystrokes(&["ctrl-shift-a", "ctrl-b"]); + } + + #[gpui::test] + async fn test_empty_key_sequence_search_mode(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&[]) // No events at all + .expect_empty(); + } + + #[gpui::test] + async fn test_modifier_sequence_completion_search_mode(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl", "+shift", "-shift", "a", "-ctrl"]) + .expect_keystrokes(&["ctrl-shift-a"]); + } + + #[gpui::test] + async fn test_triple_escape_stops_recording_search_mode(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["a", "escape", "escape", "escape"]) + .expect_keystrokes(&["a"]) // Triple escape removes final escape, stops recording + .expect_is_recording(false); + } + + #[gpui::test] + async fn test_triple_escape_stops_recording_non_search_mode(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(false) + .send_events(&["a", "escape", "escape", "escape"]) + .expect_keystrokes(&["a"]); // Triple escape stops recording but only removes final escape + } + + #[gpui::test] + async fn test_triple_escape_at_keystroke_limit(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["a", "b", "c", "escape", "escape", "escape"]) // 6 keystrokes total, exceeds limit + .expect_keystrokes(&["a", "b", "c"]); // Triple escape stops recording and removes escapes, leaves original keystrokes + } + + #[gpui::test] + async fn test_interrupted_escape_sequence(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["escape", "escape", "a", "escape"]) // Partial escape sequence interrupted by 'a' + .expect_keystrokes(&["escape", "escape", "a"]); // Escape sequence interrupted by 'a', no close triggered + } + + #[gpui::test] + async fn test_interrupted_escape_sequence_within_limit(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["escape", "escape", "a"]) // Partial escape sequence interrupted by 'a' (3 keystrokes, at limit) + .expect_keystrokes(&["escape", "escape", "a"]); // Should not trigger close, interruption resets escape detection + } + + #[gpui::test] + async fn test_partial_escape_sequence_no_close(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["escape", "escape"]) // Only 2 escapes, not enough to close + .expect_keystrokes(&["escape", "escape"]) + .expect_is_recording(true); // Should remain in keystrokes, no close triggered + } + + #[gpui::test] + async fn test_recording_state_after_triple_escape(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["a", "escape", "escape", "escape"]) + .expect_keystrokes(&["a"]) // Triple escape stops recording, removes final escape + .expect_is_recording(false); + } + + #[gpui::test] + async fn test_triple_escape_mixed_with_other_keystrokes(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["a", "escape", "b", "escape", "escape"]) // Mixed sequence, should not trigger close + .expect_keystrokes(&["a", "escape", "b"]); // No complete triple escape sequence, stays at limit + } + + #[gpui::test] + async fn test_triple_escape_only(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["escape", "escape", "escape"]) // Pure triple escape sequence + .expect_empty(); + } + + #[gpui::test] + async fn test_end_close_keystroke_capture(cx: &mut TestAppContext) { + init_test(cx) + .await + .send_events(&["+ctrl", "g", "-ctrl", "escape"]) + .expect_keystrokes(&["ctrl-g", "escape"]) + .wait_for_close_keystroke_capture_end() + .await + .send_events(&["escape", "escape"]) + .expect_keystrokes(&["ctrl-g", "escape", "escape"]) + .expect_close_keystrokes(&["escape", "escape"]) + .send_keystroke("escape") + .expect_keystrokes(&["ctrl-g", "escape"]); + } + + #[gpui::test] + async fn test_search_previous_modifiers_are_sticky(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl+alt", "-ctrl", "j"]) + .expect_keystrokes(&["ctrl-alt-j"]); + } + + #[gpui::test] + async fn test_previous_modifiers_can_be_entered_separately(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl", "-ctrl"]) + .expect_keystrokes(&["ctrl-"]) + .send_events(&["+alt", "-alt"]) + .expect_keystrokes(&["ctrl-", "alt-"]); + } + + #[gpui::test] + async fn test_previous_modifiers_reset_on_key(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl+alt", "-ctrl", "+shift"]) + .expect_keystrokes(&["ctrl-shift-alt-"]) + .send_keystroke("j") + .expect_keystrokes(&["ctrl-shift-alt-j"]) + .send_keystroke("i") + .expect_keystrokes(&["ctrl-shift-alt-j", "shift-alt-i"]) + .send_events(&["-shift-alt", "+cmd"]) + .expect_keystrokes(&["ctrl-shift-alt-j", "shift-alt-i", "cmd-"]); + } + + #[gpui::test] + async fn test_previous_modifiers_reset_on_release_all(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl+alt", "-ctrl", "+shift"]) + .expect_keystrokes(&["ctrl-shift-alt-"]) + .send_events(&["-all", "j"]) + .expect_keystrokes(&["ctrl-shift-alt-", "j"]); + } + + #[gpui::test] + async fn test_search_repeat_modifiers(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl", "-ctrl", "+alt", "-alt", "+shift", "-shift"]) + .expect_keystrokes(&["ctrl-", "alt-", "shift-"]) + .send_events(&["+cmd"]) + .expect_empty(); + } + + #[gpui::test] + async fn test_not_search_repeat_modifiers(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(false) + .send_events(&["+ctrl", "-ctrl", "+alt", "-alt", "+shift", "-shift"]) + .expect_empty(); + } +} diff --git a/crates/settings_ui/src/ui_components/mod.rs b/crates/settings_ui/src/ui_components/mod.rs index 13971b0a5df8e3b188de1df94faab3df94aa86da..5d6463a61a21afd5208b75af0362f6f7956f5e56 100644 --- a/crates/settings_ui/src/ui_components/mod.rs +++ b/crates/settings_ui/src/ui_components/mod.rs @@ -1 +1,2 @@ +pub mod keystroke_input; pub mod table; diff --git a/crates/settings_ui/src/ui_components/table.rs b/crates/settings_ui/src/ui_components/table.rs index c3b70d7d4f166ff3b34cd2b52146e4dc7408badc..2b3e815f369a96235c8628935df433737d58b0ce 100644 --- a/crates/settings_ui/src/ui_components/table.rs +++ b/crates/settings_ui/src/ui_components/table.rs @@ -2,19 +2,26 @@ use std::{ops::Range, rc::Rc, time::Duration}; use editor::{EditorSettings, ShowScrollbar, scroll::ScrollbarAutoHide}; use gpui::{ - AppContext, Axis, Context, Entity, FocusHandle, Length, ListHorizontalSizingBehavior, - ListSizingBehavior, MouseButton, Task, UniformListScrollHandle, WeakEntity, transparent_black, - uniform_list, + AbsoluteLength, AppContext, Axis, Context, DefiniteLength, DragMoveEvent, Entity, EntityId, + FocusHandle, Length, ListHorizontalSizingBehavior, ListSizingBehavior, MouseButton, Point, + Stateful, Task, UniformListScrollHandle, WeakEntity, transparent_black, uniform_list, }; + +use itertools::intersperse_with; use settings::Settings as _; use ui::{ ActiveTheme as _, AnyElement, App, Button, ButtonCommon as _, ButtonStyle, Color, Component, ComponentScope, Div, ElementId, FixedWidth as _, FluentBuilder as _, Indicator, - InteractiveElement as _, IntoElement, ParentElement, Pixels, RegisterComponent, RenderOnce, - Scrollbar, ScrollbarState, StatefulInteractiveElement as _, Styled, StyledExt as _, + InteractiveElement, IntoElement, ParentElement, Pixels, RegisterComponent, RenderOnce, + Scrollbar, ScrollbarState, SharedString, StatefulInteractiveElement, Styled, StyledExt as _, StyledTypography, Window, div, example_group_with_title, h_flex, px, single_example, v_flex, }; +const RESIZE_COLUMN_WIDTH: f32 = 8.0; + +#[derive(Debug)] +struct DraggedColumn(usize); + struct UniformListData<const COLS: usize> { render_item_fn: Box<dyn Fn(Range<usize>, &mut Window, &mut App) -> Vec<[AnyElement; COLS]>>, element_id: ElementId, @@ -40,6 +47,10 @@ impl<const COLS: usize> TableContents<COLS> { TableContents::UniformList(data) => data.row_count, } } + + fn is_empty(&self) -> bool { + self.len() == 0 + } } pub struct TableInteractionState { @@ -90,6 +101,28 @@ impl TableInteractionState { }) } + pub fn get_scrollbar_offset(&self, axis: Axis) -> Point<Pixels> { + match axis { + Axis::Vertical => self.vertical_scrollbar.state.scroll_handle().offset(), + Axis::Horizontal => self.horizontal_scrollbar.state.scroll_handle().offset(), + } + } + + pub fn set_scrollbar_offset(&self, axis: Axis, offset: Point<Pixels>) { + match axis { + Axis::Vertical => self + .vertical_scrollbar + .state + .scroll_handle() + .set_offset(offset), + Axis::Horizontal => self + .horizontal_scrollbar + .state + .scroll_handle() + .set_offset(offset), + } + } + fn update_scrollbar_visibility(&mut self, cx: &mut Context<Self>) { let show_setting = EditorSettings::get_global(cx).scrollbar.show; @@ -165,6 +198,89 @@ impl TableInteractionState { } } + fn render_resize_handles<const COLS: usize>( + &self, + column_widths: &[Length; COLS], + resizable_columns: &[ResizeBehavior; COLS], + initial_sizes: [DefiniteLength; COLS], + columns: Option<Entity<ColumnWidths<COLS>>>, + window: &mut Window, + cx: &mut App, + ) -> AnyElement { + let spacers = column_widths + .iter() + .map(|width| base_cell_style(Some(*width)).into_any_element()); + + let mut column_ix = 0; + let resizable_columns_slice = *resizable_columns; + let mut resizable_columns = resizable_columns.into_iter(); + + let dividers = intersperse_with(spacers, || { + window.with_id(column_ix, |window| { + let mut resize_divider = div() + // This is required because this is evaluated at a different time than the use_state call above + .id(column_ix) + .relative() + .top_0() + .w_px() + .h_full() + .bg(cx.theme().colors().border.opacity(0.8)); + + let mut resize_handle = div() + .id("column-resize-handle") + .absolute() + .left_neg_0p5() + .w(px(RESIZE_COLUMN_WIDTH)) + .h_full(); + + if resizable_columns + .next() + .is_some_and(ResizeBehavior::is_resizable) + { + let hovered = window.use_state(cx, |_window, _cx| false); + + resize_divider = resize_divider.when(*hovered.read(cx), |div| { + div.bg(cx.theme().colors().border_focused) + }); + + resize_handle = resize_handle + .on_hover(move |&was_hovered, _, cx| hovered.write(cx, was_hovered)) + .cursor_col_resize() + .when_some(columns.clone(), |this, columns| { + this.on_click(move |event, window, cx| { + if event.click_count() >= 2 { + columns.update(cx, |columns, _| { + columns.on_double_click( + column_ix, + &initial_sizes, + &resizable_columns_slice, + window, + ); + }) + } + + cx.stop_propagation(); + }) + }) + .on_drag(DraggedColumn(column_ix), |_, _offset, _window, cx| { + cx.new(|_cx| gpui::Empty) + }) + } + + column_ix += 1; + resize_divider.child(resize_handle).into_any_element() + }) + }); + + h_flex() + .id("resize-handles") + .absolute() + .inset_0() + .w_full() + .children(dividers) + .into_any_element() + } + fn render_vertical_scrollbar_track( this: &Entity<Self>, parent: Div, @@ -343,6 +459,307 @@ impl TableInteractionState { } } +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum ResizeBehavior { + None, + Resizable, + MinSize(f32), +} + +impl ResizeBehavior { + pub fn is_resizable(&self) -> bool { + *self != ResizeBehavior::None + } + + pub fn min_size(&self) -> Option<f32> { + match self { + ResizeBehavior::None => None, + ResizeBehavior::Resizable => Some(0.05), + ResizeBehavior::MinSize(min_size) => Some(*min_size), + } + } +} + +pub struct ColumnWidths<const COLS: usize> { + widths: [DefiniteLength; COLS], + visible_widths: [DefiniteLength; COLS], + cached_bounds_width: Pixels, + initialized: bool, +} + +impl<const COLS: usize> ColumnWidths<COLS> { + pub fn new(_: &mut App) -> Self { + Self { + widths: [DefiniteLength::default(); COLS], + visible_widths: [DefiniteLength::default(); COLS], + cached_bounds_width: Default::default(), + initialized: false, + } + } + + fn get_fraction(length: &DefiniteLength, bounds_width: Pixels, rem_size: Pixels) -> f32 { + match length { + DefiniteLength::Absolute(AbsoluteLength::Pixels(pixels)) => *pixels / bounds_width, + DefiniteLength::Absolute(AbsoluteLength::Rems(rems_width)) => { + rems_width.to_pixels(rem_size) / bounds_width + } + DefiniteLength::Fraction(fraction) => *fraction, + } + } + + fn on_double_click( + &mut self, + double_click_position: usize, + initial_sizes: &[DefiniteLength; COLS], + resize_behavior: &[ResizeBehavior; COLS], + window: &mut Window, + ) { + let bounds_width = self.cached_bounds_width; + let rem_size = window.rem_size(); + let initial_sizes = + initial_sizes.map(|length| Self::get_fraction(&length, bounds_width, rem_size)); + let widths = self + .widths + .map(|length| Self::get_fraction(&length, bounds_width, rem_size)); + + let updated_widths = Self::reset_to_initial_size( + double_click_position, + widths, + initial_sizes, + resize_behavior, + ); + self.widths = updated_widths.map(DefiniteLength::Fraction); + self.visible_widths = self.widths; + } + + fn reset_to_initial_size( + col_idx: usize, + mut widths: [f32; COLS], + initial_sizes: [f32; COLS], + resize_behavior: &[ResizeBehavior; COLS], + ) -> [f32; COLS] { + // RESET: + // Part 1: + // Figure out if we should shrink/grow the selected column + // Get diff which represents the change in column we want to make initial size delta curr_size = diff + // + // Part 2: We need to decide which side column we should move and where + // + // If we want to grow our column we should check the left/right columns diff to see what side + // has a greater delta than their initial size. Likewise, if we shrink our column we should check + // the left/right column diffs to see what side has the smallest delta. + // + // Part 3: resize + // + // col_idx represents the column handle to the right of an active column + // + // If growing and right has the greater delta { + // shift col_idx to the right + // } else if growing and left has the greater delta { + // shift col_idx - 1 to the left + // } else if shrinking and the right has the greater delta { + // shift + // } { + // + // } + // } + // + // if we need to shrink, then if the right + // + + // DRAGGING + // we get diff which represents the change in the _drag handle_ position + // -diff => dragging left -> + // grow the column to the right of the handle as much as we can shrink columns to the left of the handle + // +diff => dragging right -> growing handles column + // grow the column to the left of the handle as much as we can shrink columns to the right of the handle + // + + let diff = initial_sizes[col_idx] - widths[col_idx]; + + let left_diff = + initial_sizes[..col_idx].iter().sum::<f32>() - widths[..col_idx].iter().sum::<f32>(); + let right_diff = initial_sizes[col_idx + 1..].iter().sum::<f32>() + - widths[col_idx + 1..].iter().sum::<f32>(); + + let go_left_first = if diff < 0.0 { + left_diff > right_diff + } else { + left_diff < right_diff + }; + + if !go_left_first { + let diff_remaining = + Self::propagate_resize_diff(diff, col_idx, &mut widths, resize_behavior, 1); + + if diff_remaining != 0.0 && col_idx > 0 { + Self::propagate_resize_diff( + diff_remaining, + col_idx, + &mut widths, + resize_behavior, + -1, + ); + } + } else { + let diff_remaining = + Self::propagate_resize_diff(diff, col_idx, &mut widths, resize_behavior, -1); + + if diff_remaining != 0.0 { + Self::propagate_resize_diff( + diff_remaining, + col_idx, + &mut widths, + resize_behavior, + 1, + ); + } + } + + widths + } + + fn on_drag_move( + &mut self, + drag_event: &DragMoveEvent<DraggedColumn>, + resize_behavior: &[ResizeBehavior; COLS], + window: &mut Window, + cx: &mut Context<Self>, + ) { + let drag_position = drag_event.event.position; + let bounds = drag_event.bounds; + + let mut col_position = 0.0; + let rem_size = window.rem_size(); + let bounds_width = bounds.right() - bounds.left(); + let col_idx = drag_event.drag(cx).0; + + let column_handle_width = Self::get_fraction( + &DefiniteLength::Absolute(AbsoluteLength::Pixels(px(RESIZE_COLUMN_WIDTH))), + bounds_width, + rem_size, + ); + + let mut widths = self + .widths + .map(|length| Self::get_fraction(&length, bounds_width, rem_size)); + + for length in widths[0..=col_idx].iter() { + col_position += length + column_handle_width; + } + + let mut total_length_ratio = col_position; + for length in widths[col_idx + 1..].iter() { + total_length_ratio += length; + } + total_length_ratio += (COLS - 1 - col_idx) as f32 * column_handle_width; + + let drag_fraction = (drag_position.x - bounds.left()) / bounds_width; + let drag_fraction = drag_fraction * total_length_ratio; + let diff = drag_fraction - col_position - column_handle_width / 2.0; + + Self::drag_column_handle(diff, col_idx, &mut widths, resize_behavior); + + self.visible_widths = widths.map(DefiniteLength::Fraction); + } + + fn drag_column_handle( + diff: f32, + col_idx: usize, + widths: &mut [f32; COLS], + resize_behavior: &[ResizeBehavior; COLS], + ) { + // if diff > 0.0 then go right + if diff > 0.0 { + Self::propagate_resize_diff(diff, col_idx, widths, resize_behavior, 1); + } else { + Self::propagate_resize_diff(-diff, col_idx + 1, widths, resize_behavior, -1); + } + } + + fn propagate_resize_diff( + diff: f32, + col_idx: usize, + widths: &mut [f32; COLS], + resize_behavior: &[ResizeBehavior; COLS], + direction: i8, + ) -> f32 { + let mut diff_remaining = diff; + if resize_behavior[col_idx].min_size().is_none() { + return diff; + } + + let step_right; + let step_left; + if direction < 0 { + step_right = 0; + step_left = 1; + } else { + step_right = 1; + step_left = 0; + } + if col_idx == 0 && direction < 0 { + return diff; + } + let mut curr_column = col_idx + step_right - step_left; + + while diff_remaining != 0.0 && curr_column < COLS { + let Some(min_size) = resize_behavior[curr_column].min_size() else { + if curr_column == 0 { + break; + } + curr_column -= step_left; + curr_column += step_right; + continue; + }; + + let curr_width = widths[curr_column] - diff_remaining; + widths[curr_column] = curr_width; + + if min_size > curr_width { + diff_remaining = min_size - curr_width; + widths[curr_column] = min_size; + } else { + diff_remaining = 0.0; + break; + } + if curr_column == 0 { + break; + } + curr_column -= step_left; + curr_column += step_right; + } + widths[col_idx] = widths[col_idx] + (diff - diff_remaining); + + return diff_remaining; + } +} + +pub struct TableWidths<const COLS: usize> { + initial: [DefiniteLength; COLS], + current: Option<Entity<ColumnWidths<COLS>>>, + resizable: [ResizeBehavior; COLS], +} + +impl<const COLS: usize> TableWidths<COLS> { + pub fn new(widths: [impl Into<DefiniteLength>; COLS]) -> Self { + let widths = widths.map(Into::into); + + TableWidths { + initial: widths, + current: None, + resizable: [ResizeBehavior::None; COLS], + } + } + + fn lengths(&self, cx: &App) -> [Length; COLS] { + self.current + .as_ref() + .map(|entity| entity.read(cx).visible_widths.map(Length::Definite)) + .unwrap_or(self.initial.map(Length::Definite)) + } +} + /// A table component #[derive(RegisterComponent, IntoElement)] pub struct Table<const COLS: usize = 3> { @@ -351,21 +768,23 @@ pub struct Table<const COLS: usize = 3> { headers: Option<[AnyElement; COLS]>, rows: TableContents<COLS>, interaction_state: Option<WeakEntity<TableInteractionState>>, - column_widths: Option<[Length; COLS]>, - map_row: Option<Rc<dyn Fn((usize, Div), &mut Window, &mut App) -> AnyElement>>, + col_widths: Option<TableWidths<COLS>>, + map_row: Option<Rc<dyn Fn((usize, Stateful<Div>), &mut Window, &mut App) -> AnyElement>>, + empty_table_callback: Option<Rc<dyn Fn(&mut Window, &mut App) -> AnyElement>>, } impl<const COLS: usize> Table<COLS> { /// number of headers provided. pub fn new() -> Self { - Table { + Self { striped: false, width: None, headers: None, rows: TableContents::Vec(Vec::new()), interaction_state: None, - column_widths: None, map_row: None, + empty_table_callback: None, + col_widths: None, } } @@ -426,32 +845,68 @@ impl<const COLS: usize> Table<COLS> { self } - pub fn column_widths(mut self, widths: [impl Into<Length>; COLS]) -> Self { - self.column_widths = Some(widths.map(Into::into)); + pub fn column_widths(mut self, widths: [impl Into<DefiniteLength>; COLS]) -> Self { + if self.col_widths.is_none() { + self.col_widths = Some(TableWidths::new(widths)); + } + self + } + + pub fn resizable_columns( + mut self, + resizable: [ResizeBehavior; COLS], + column_widths: &Entity<ColumnWidths<COLS>>, + cx: &mut App, + ) -> Self { + if let Some(table_widths) = self.col_widths.as_mut() { + table_widths.resizable = resizable; + let column_widths = table_widths + .current + .get_or_insert_with(|| column_widths.clone()); + + column_widths.update(cx, |widths, _| { + if !widths.initialized { + widths.initialized = true; + widths.widths = table_widths.initial; + widths.visible_widths = widths.widths; + } + }) + } self } pub fn map_row( mut self, - callback: impl Fn((usize, Div), &mut Window, &mut App) -> AnyElement + 'static, + callback: impl Fn((usize, Stateful<Div>), &mut Window, &mut App) -> AnyElement + 'static, ) -> Self { self.map_row = Some(Rc::new(callback)); self } + + /// Provide a callback that is invoked when the table is rendered without any rows + pub fn empty_table_callback( + mut self, + callback: impl Fn(&mut Window, &mut App) -> AnyElement + 'static, + ) -> Self { + self.empty_table_callback = Some(Rc::new(callback)); + self + } } -fn base_cell_style(width: Option<Length>, cx: &App) -> Div { +fn base_cell_style(width: Option<Length>) -> Div { div() .px_1p5() .when_some(width, |this, width| this.w(width)) .when(width.is_none(), |this| this.flex_1()) - .justify_start() - .text_ui(cx) .whitespace_nowrap() .text_ellipsis() .overflow_hidden() } +fn base_cell_style_text(width: Option<Length>, cx: &App) -> Div { + base_cell_style(width).text_ui(cx) +} + pub fn render_row<const COLS: usize>( row_index: usize, items: [impl IntoElement; COLS], @@ -470,43 +925,56 @@ pub fn render_row<const COLS: usize>( .column_widths .map_or([None; COLS], |widths| widths.map(Some)); - let row = div().w_full().child( - h_flex() - .id("table_row") - .w_full() - .justify_between() - .px_1p5() - .py_1() - .when_some(bg, |row, bg| row.bg(bg)) - .when(!is_striped, |row| { - row.border_b_1() - .border_color(transparent_black()) - .when(!is_last, |row| row.border_color(cx.theme().colors().border)) - }) - .children( - items - .map(IntoElement::into_any_element) - .into_iter() - .zip(column_widths) - .map(|(cell, width)| base_cell_style(width, cx).child(cell)), - ), + let mut row = h_flex() + .h_full() + .id(("table_row", row_index)) + .w_full() + .justify_between() + .when_some(bg, |row, bg| row.bg(bg)) + .when(!is_striped, |row| { + row.border_b_1() + .border_color(transparent_black()) + .when(!is_last, |row| row.border_color(cx.theme().colors().border)) + }); + + row = row.children( + items + .map(IntoElement::into_any_element) + .into_iter() + .zip(column_widths) + .map(|(cell, width)| base_cell_style_text(width, cx).px_1().py_0p5().child(cell)), ); - if let Some(map_row) = table_context.map_row { + let row = if let Some(map_row) = table_context.map_row { map_row((row_index, row), window, cx) } else { row.into_any_element() - } + }; + + div().size_full().child(row).into_any_element() } pub fn render_header<const COLS: usize>( headers: [impl IntoElement; COLS], table_context: TableRenderContext<COLS>, + columns_widths: Option<( + WeakEntity<ColumnWidths<COLS>>, + [ResizeBehavior; COLS], + [DefiniteLength; COLS], + )>, + entity_id: Option<EntityId>, cx: &mut App, ) -> impl IntoElement { let column_widths = table_context .column_widths .map_or([None; COLS], |widths| widths.map(Some)); + + let element_id = entity_id + .map(|entity| entity.to_string()) + .unwrap_or_default(); + + let shared_element_id: SharedString = format!("table-{}", element_id).into(); + div() .flex() .flex_row() @@ -516,12 +984,39 @@ pub fn render_header<const COLS: usize>( .p_2() .border_b_1() .border_color(cx.theme().colors().border) - .children( - headers - .into_iter() - .zip(column_widths) - .map(|(h, width)| base_cell_style(width, cx).child(h)), - ) + .children(headers.into_iter().enumerate().zip(column_widths).map( + |((header_idx, h), width)| { + base_cell_style_text(width, cx) + .child(h) + .id(ElementId::NamedInteger( + shared_element_id.clone(), + header_idx as u64, + )) + .when_some( + columns_widths.as_ref().cloned(), + |this, (column_widths, resizables, initial_sizes)| { + if resizables[header_idx].is_resizable() { + this.on_click(move |event, window, cx| { + if event.click_count() > 1 { + column_widths + .update(cx, |column, _| { + column.on_double_click( + header_idx, + &initial_sizes, + &resizables, + window, + ); + }) + .ok(); + } + }) + } else { + this + } + }, + ) + }, + )) } #[derive(Clone)] @@ -529,15 +1024,15 @@ pub struct TableRenderContext<const COLS: usize> { pub striped: bool, pub total_row_count: usize, pub column_widths: Option<[Length; COLS]>, - pub map_row: Option<Rc<dyn Fn((usize, Div), &mut Window, &mut App) -> AnyElement>>, + pub map_row: Option<Rc<dyn Fn((usize, Stateful<Div>), &mut Window, &mut App) -> AnyElement>>, } impl<const COLS: usize> TableRenderContext<COLS> { - fn new(table: &Table<COLS>) -> Self { + fn new(table: &Table<COLS>, cx: &App) -> Self { Self { striped: table.striped, total_row_count: table.rows.len(), - column_widths: table.column_widths, + column_widths: table.col_widths.as_ref().map(|widths| widths.lengths(cx)), map_row: table.map_row.clone(), } } @@ -545,8 +1040,19 @@ impl<const COLS: usize> TableRenderContext<COLS> { impl<const COLS: usize> RenderOnce for Table<COLS> { fn render(mut self, window: &mut Window, cx: &mut App) -> impl IntoElement { - let table_context = TableRenderContext::new(&self); + let table_context = TableRenderContext::new(&self, cx); let interaction_state = self.interaction_state.and_then(|state| state.upgrade()); + let current_widths = self + .col_widths + .as_ref() + .and_then(|widths| Some((widths.current.as_ref()?, widths.resizable))) + .map(|(curr, resize_behavior)| (curr.downgrade(), resize_behavior)); + + let current_widths_with_initial_sizes = self + .col_widths + .as_ref() + .and_then(|widths| Some((widths.current.as_ref()?, widths.resizable, widths.initial))) + .map(|(curr, resize_behavior, initial)| (curr.downgrade(), resize_behavior, initial)); let scroll_track_size = px(16.); let h_scroll_offset = if interaction_state @@ -560,13 +1066,54 @@ impl<const COLS: usize> RenderOnce for Table<COLS> { }; let width = self.width; + let no_rows_rendered = self.rows.is_empty(); let table = div() .when_some(width, |this, width| this.w(width)) .h_full() .v_flex() .when_some(self.headers.take(), |this, headers| { - this.child(render_header(headers, table_context.clone(), cx)) + this.child(render_header( + headers, + table_context.clone(), + current_widths_with_initial_sizes, + interaction_state.as_ref().map(Entity::entity_id), + cx, + )) + }) + .when_some(current_widths, { + |this, (widths, resize_behavior)| { + this.on_drag_move::<DraggedColumn>({ + let widths = widths.clone(); + move |e, window, cx| { + widths + .update(cx, |widths, cx| { + widths.on_drag_move(e, &resize_behavior, window, cx); + }) + .ok(); + } + }) + .on_children_prepainted({ + let widths = widths.clone(); + move |bounds, _, cx| { + widths + .update(cx, |widths, _| { + // This works because all children x axis bounds are the same + widths.cached_bounds_width = + bounds[0].right() - bounds[0].left(); + }) + .ok(); + } + }) + .on_drop::<DraggedColumn>(move |_, _, cx| { + widths + .update(cx, |widths, _| { + widths.widths = widths.visible_widths; + }) + .ok(); + // Finish the resize operation + }) + } }) .child( div() @@ -622,6 +1169,25 @@ impl<const COLS: usize> RenderOnce for Table<COLS> { ), ), }) + .when_some( + self.col_widths.as_ref().zip(interaction_state.as_ref()), + |parent, (table_widths, state)| { + parent.child(state.update(cx, |state, cx| { + let resizable_columns = table_widths.resizable; + let column_widths = table_widths.lengths(cx); + let columns = table_widths.current.clone(); + let initial_sizes = table_widths.initial; + state.render_resize_handles( + &column_widths, + &resizable_columns, + initial_sizes, + columns, + window, + cx, + ) + })) + }, + ) .when_some(interaction_state.as_ref(), |this, interaction_state| { this.map(|this| { TableInteractionState::render_vertical_scrollbar_track( @@ -640,6 +1206,21 @@ impl<const COLS: usize> RenderOnce for Table<COLS> { }) }), ) + .when_some( + no_rows_rendered + .then_some(self.empty_table_callback) + .flatten(), + |this, callback| { + this.child( + h_flex() + .size_full() + .p_3() + .items_start() + .justify_center() + .child(callback(window, cx)), + ) + }, + ) .when_some( width.and(interaction_state.as_ref()), |this, interaction_state| { @@ -862,3 +1443,323 @@ impl Component for Table<3> { ) } } + +#[cfg(test)] +mod test { + use super::*; + + fn is_almost_eq(a: &[f32], b: &[f32]) -> bool { + a.len() == b.len() && a.iter().zip(b).all(|(x, y)| (x - y).abs() < 1e-6) + } + + fn cols_to_str<const COLS: usize>(cols: &[f32; COLS], total_size: f32) -> String { + cols.map(|f| "*".repeat(f32::round(f * total_size) as usize)) + .join("|") + } + + fn parse_resize_behavior<const COLS: usize>( + input: &str, + total_size: f32, + ) -> [ResizeBehavior; COLS] { + let mut resize_behavior = [ResizeBehavior::None; COLS]; + let mut max_index = 0; + for (index, col) in input.split('|').enumerate() { + if col.starts_with('X') || col.is_empty() { + resize_behavior[index] = ResizeBehavior::None; + } else if col.starts_with('*') { + resize_behavior[index] = ResizeBehavior::MinSize(col.len() as f32 / total_size); + } else { + panic!("invalid test input: unrecognized resize behavior: {}", col); + } + max_index = index; + } + + if max_index + 1 != COLS { + panic!("invalid test input: too many columns"); + } + resize_behavior + } + + mod reset_column_size { + use super::*; + + fn parse<const COLS: usize>(input: &str) -> ([f32; COLS], f32, Option<usize>) { + let mut widths = [f32::NAN; COLS]; + let mut column_index = None; + for (index, col) in input.split('|').enumerate() { + widths[index] = col.len() as f32; + if col.starts_with('X') { + column_index = Some(index); + } + } + + for w in widths { + assert!(w.is_finite(), "incorrect number of columns"); + } + let total = widths.iter().sum::<f32>(); + for width in &mut widths { + *width /= total; + } + (widths, total, column_index) + } + + #[track_caller] + fn check_reset_size<const COLS: usize>( + initial_sizes: &str, + widths: &str, + expected: &str, + resize_behavior: &str, + ) { + let (initial_sizes, total_1, None) = parse::<COLS>(initial_sizes) else { + panic!("invalid test input: initial sizes should not be marked"); + }; + let (widths, total_2, Some(column_index)) = parse::<COLS>(widths) else { + panic!("invalid test input: widths should be marked"); + }; + assert_eq!( + total_1, total_2, + "invalid test input: total width not the same {total_1}, {total_2}" + ); + let (expected, total_3, None) = parse::<COLS>(expected) else { + panic!("invalid test input: expected should not be marked: {expected:?}"); + }; + assert_eq!( + total_2, total_3, + "invalid test input: total width not the same" + ); + let resize_behavior = parse_resize_behavior::<COLS>(resize_behavior, total_1); + let result = ColumnWidths::reset_to_initial_size( + column_index, + widths, + initial_sizes, + &resize_behavior, + ); + let is_eq = is_almost_eq(&result, &expected); + if !is_eq { + let result_str = cols_to_str(&result, total_1); + let expected_str = cols_to_str(&expected, total_1); + panic!( + "resize failed\ncomputed: {result_str}\nexpected: {expected_str}\n\ncomputed values: {result:?}\nexpected values: {expected:?}\n:minimum widths: {resize_behavior:?}" + ); + } + } + + macro_rules! check_reset_size { + (columns: $cols:expr, starting: $initial:expr, snapshot: $current:expr, expected: $expected:expr, resizing: $resizing:expr $(,)?) => { + check_reset_size::<$cols>($initial, $current, $expected, $resizing); + }; + ($name:ident, columns: $cols:expr, starting: $initial:expr, snapshot: $current:expr, expected: $expected:expr, minimums: $resizing:expr $(,)?) => { + #[test] + fn $name() { + check_reset_size::<$cols>($initial, $current, $expected, $resizing); + } + }; + } + + check_reset_size!( + basic_right, + columns: 5, + starting: "**|**|**|**|**", + snapshot: "**|**|X|***|**", + expected: "**|**|**|**|**", + minimums: "X|*|*|*|*", + ); + + check_reset_size!( + basic_left, + columns: 5, + starting: "**|**|**|**|**", + snapshot: "**|**|***|X|**", + expected: "**|**|**|**|**", + minimums: "X|*|*|*|**", + ); + + check_reset_size!( + squashed_left_reset_col2, + columns: 6, + starting: "*|***|**|**|****|*", + snapshot: "*|*|X|*|*|********", + expected: "*|*|**|*|*|*******", + minimums: "X|*|*|*|*|*", + ); + + check_reset_size!( + grow_cascading_right, + columns: 6, + starting: "*|***|****|**|***|*", + snapshot: "*|***|X|**|**|*****", + expected: "*|***|****|*|*|****", + minimums: "X|*|*|*|*|*", + ); + + check_reset_size!( + squashed_right_reset_col4, + columns: 6, + starting: "*|***|**|**|****|*", + snapshot: "*|********|*|*|X|*", + expected: "*|*****|*|*|****|*", + minimums: "X|*|*|*|*|*", + ); + + check_reset_size!( + reset_col6_right, + columns: 6, + starting: "*|***|**|***|***|**", + snapshot: "*|***|**|***|**|XXX", + expected: "*|***|**|***|***|**", + minimums: "X|*|*|*|*|*", + ); + + check_reset_size!( + reset_col6_left, + columns: 6, + starting: "*|***|**|***|***|**", + snapshot: "*|***|**|***|****|X", + expected: "*|***|**|***|***|**", + minimums: "X|*|*|*|*|*", + ); + + check_reset_size!( + last_column_grow_cascading, + columns: 6, + starting: "*|***|**|**|**|***", + snapshot: "*|*******|*|**|*|X", + expected: "*|******|*|*|*|***", + minimums: "X|*|*|*|*|*", + ); + + check_reset_size!( + goes_left_when_left_has_extreme_diff, + columns: 6, + starting: "*|***|****|**|**|***", + snapshot: "*|********|X|*|**|**", + expected: "*|*****|****|*|**|**", + minimums: "X|*|*|*|*|*", + ); + + check_reset_size!( + basic_shrink_right, + columns: 6, + starting: "**|**|**|**|**|**", + snapshot: "**|**|XXX|*|**|**", + expected: "**|**|**|**|**|**", + minimums: "X|*|*|*|*|*", + ); + + check_reset_size!( + shrink_should_go_left, + columns: 6, + starting: "*|***|**|*|*|*", + snapshot: "*|*|XXX|**|*|*", + expected: "*|**|**|**|*|*", + minimums: "X|*|*|*|*|*", + ); + + check_reset_size!( + shrink_should_go_right, + columns: 6, + starting: "*|***|**|**|**|*", + snapshot: "*|****|XXX|*|*|*", + expected: "*|****|**|**|*|*", + minimums: "X|*|*|*|*|*", + ); + } + + mod drag_handle { + use super::*; + + fn parse<const COLS: usize>(input: &str) -> ([f32; COLS], f32, Option<usize>) { + let mut widths = [f32::NAN; COLS]; + let column_index = input.replace("*", "").find("I"); + for (index, col) in input.replace("I", "|").split('|').enumerate() { + widths[index] = col.len() as f32; + } + + for w in widths { + assert!(w.is_finite(), "incorrect number of columns"); + } + let total = widths.iter().sum::<f32>(); + for width in &mut widths { + *width /= total; + } + (widths, total, column_index) + } + + #[track_caller] + fn check<const COLS: usize>( + distance: i32, + widths: &str, + expected: &str, + resize_behavior: &str, + ) { + let (mut widths, total_1, Some(column_index)) = parse::<COLS>(widths) else { + panic!("invalid test input: widths should be marked"); + }; + let (expected, total_2, None) = parse::<COLS>(expected) else { + panic!("invalid test input: expected should not be marked: {expected:?}"); + }; + assert_eq!( + total_1, total_2, + "invalid test input: total width not the same" + ); + let resize_behavior = parse_resize_behavior::<COLS>(resize_behavior, total_1); + + let distance = distance as f32 / total_1; + + let result = ColumnWidths::drag_column_handle( + distance, + column_index, + &mut widths, + &resize_behavior, + ); + + let is_eq = is_almost_eq(&widths, &expected); + if !is_eq { + let result_str = cols_to_str(&widths, total_1); + let expected_str = cols_to_str(&expected, total_1); + panic!( + "resize failed\ncomputed: {result_str}\nexpected: {expected_str}\n\ncomputed values: {result:?}\nexpected values: {expected:?}\n:minimum widths: {resize_behavior:?}" + ); + } + } + + macro_rules! check { + (columns: $cols:expr, distance: $dist:expr, snapshot: $current:expr, expected: $expected:expr, resizing: $resizing:expr $(,)?) => { + check!($cols, $dist, $snapshot, $expected, $resizing); + }; + ($name:ident, columns: $cols:expr, distance: $dist:expr, snapshot: $current:expr, expected: $expected:expr, minimums: $resizing:expr $(,)?) => { + #[test] + fn $name() { + check::<$cols>($dist, $current, $expected, $resizing); + } + }; + } + + check!( + basic_right_drag, + columns: 3, + distance: 1, + snapshot: "**|**I**", + expected: "**|***|*", + minimums: "X|*|*", + ); + + check!( + drag_left_against_mins, + columns: 5, + distance: -1, + snapshot: "*|*|*|*I*******", + expected: "*|*|*|*|*******", + minimums: "X|*|*|*|*", + ); + + check!( + drag_left, + columns: 5, + distance: -2, + snapshot: "*|*|*|*****I***", + expected: "*|*|*|***|*****", + minimums: "X|*|*|*|*", + ); + } +} diff --git a/crates/snippets_ui/src/snippets_ui.rs b/crates/snippets_ui/src/snippets_ui.rs index 1cc16c55761508c11470b35715b8085447032114..a8710d1672c16964545a454362fd0eeb431714a2 100644 --- a/crates/snippets_ui/src/snippets_ui.rs +++ b/crates/snippets_ui/src/snippets_ui.rs @@ -149,13 +149,12 @@ impl ScopeSelectorDelegate { scope_selector: WeakEntity<ScopeSelector>, language_registry: Arc<LanguageRegistry>, ) -> Self { - let candidates = Vec::from([GLOBAL_SCOPE_NAME.to_string()]).into_iter(); let languages = language_registry.language_names().into_iter(); - let candidates = candidates + let candidates = std::iter::once(LanguageName::new(GLOBAL_SCOPE_NAME)) .chain(languages) .enumerate() - .map(|(candidate_id, name)| StringMatchCandidate::new(candidate_id, &name)) + .map(|(candidate_id, name)| StringMatchCandidate::new(candidate_id, name.as_ref())) .collect::<Vec<_>>(); let mut existing_scopes = HashSet::new(); diff --git a/crates/sum_tree/src/cursor.rs b/crates/sum_tree/src/cursor.rs index 8edd04afcef12781f1acc43eb2edb5805128c3b6..50a556a6d279d0b7f733d0d80c6c2e7e3d6c61cd 100644 --- a/crates/sum_tree/src/cursor.rs +++ b/crates/sum_tree/src/cursor.rs @@ -25,6 +25,7 @@ pub struct Cursor<'a, T: Item, D> { position: D, did_seek: bool, at_end: bool, + cx: &'a <T::Summary as Summary>::Context, } impl<T: Item + fmt::Debug, D: fmt::Debug> fmt::Debug for Cursor<'_, T, D> @@ -52,21 +53,22 @@ where T: Item, D: Dimension<'a, T::Summary>, { - pub fn new(tree: &'a SumTree<T>, cx: &<T::Summary as Summary>::Context) -> Self { + pub fn new(tree: &'a SumTree<T>, cx: &'a <T::Summary as Summary>::Context) -> Self { Self { tree, stack: ArrayVec::new(), position: D::zero(cx), did_seek: false, at_end: tree.is_empty(), + cx, } } - fn reset(&mut self, cx: &<T::Summary as Summary>::Context) { + fn reset(&mut self) { self.did_seek = false; self.at_end = self.tree.is_empty(); self.stack.truncate(0); - self.position = D::zero(cx); + self.position = D::zero(self.cx); } pub fn start(&self) -> &D { @@ -74,10 +76,10 @@ where } #[track_caller] - pub fn end(&self, cx: &<T::Summary as Summary>::Context) -> D { + pub fn end(&self) -> D { if let Some(item_summary) = self.item_summary() { let mut end = self.start().clone(); - end.add_summary(item_summary, cx); + end.add_summary(item_summary, self.cx); end } else { self.start().clone() @@ -202,12 +204,12 @@ where } #[track_caller] - pub fn prev(&mut self, cx: &<T::Summary as Summary>::Context) { - self.search_backward(|_| true, cx) + pub fn prev(&mut self) { + self.search_backward(|_| true) } #[track_caller] - pub fn search_backward<F>(&mut self, mut filter_node: F, cx: &<T::Summary as Summary>::Context) + pub fn search_backward<F>(&mut self, mut filter_node: F) where F: FnMut(&T::Summary) -> bool, { @@ -217,13 +219,13 @@ where } if self.at_end { - self.position = D::zero(cx); + self.position = D::zero(self.cx); self.at_end = self.tree.is_empty(); if !self.tree.is_empty() { self.stack.push(StackEntry { tree: self.tree, index: self.tree.0.child_summaries().len(), - position: D::from_summary(self.tree.summary(), cx), + position: D::from_summary(self.tree.summary(), self.cx), }); } } @@ -233,7 +235,7 @@ where if let Some(StackEntry { position, .. }) = self.stack.iter().rev().nth(1) { self.position = position.clone(); } else { - self.position = D::zero(cx); + self.position = D::zero(self.cx); } let entry = self.stack.last_mut().unwrap(); @@ -247,7 +249,7 @@ where } for summary in &entry.tree.0.child_summaries()[..entry.index] { - self.position.add_summary(summary, cx); + self.position.add_summary(summary, self.cx); } entry.position = self.position.clone(); @@ -257,7 +259,7 @@ where if descending { let tree = &child_trees[entry.index]; self.stack.push(StackEntry { - position: D::zero(cx), + position: D::zero(self.cx), tree, index: tree.0.child_summaries().len() - 1, }) @@ -273,12 +275,12 @@ where } #[track_caller] - pub fn next(&mut self, cx: &<T::Summary as Summary>::Context) { - self.search_forward(|_| true, cx) + pub fn next(&mut self) { + self.search_forward(|_| true) } #[track_caller] - pub fn search_forward<F>(&mut self, mut filter_node: F, cx: &<T::Summary as Summary>::Context) + pub fn search_forward<F>(&mut self, mut filter_node: F) where F: FnMut(&T::Summary) -> bool, { @@ -289,7 +291,7 @@ where self.stack.push(StackEntry { tree: self.tree, index: 0, - position: D::zero(cx), + position: D::zero(self.cx), }); descend = true; } @@ -316,8 +318,8 @@ where break; } else { entry.index += 1; - entry.position.add_summary(next_summary, cx); - self.position.add_summary(next_summary, cx); + entry.position.add_summary(next_summary, self.cx); + self.position.add_summary(next_summary, self.cx); } } @@ -327,8 +329,8 @@ where if !descend { let item_summary = &item_summaries[entry.index]; entry.index += 1; - entry.position.add_summary(item_summary, cx); - self.position.add_summary(item_summary, cx); + entry.position.add_summary(item_summary, self.cx); + self.position.add_summary(item_summary, self.cx); } loop { @@ -337,8 +339,8 @@ where return; } else { entry.index += 1; - entry.position.add_summary(next_item_summary, cx); - self.position.add_summary(next_item_summary, cx); + entry.position.add_summary(next_item_summary, self.cx); + self.position.add_summary(next_item_summary, self.cx); } } else { break None; @@ -380,71 +382,51 @@ where D: Dimension<'a, T::Summary>, { #[track_caller] - pub fn seek<Target>( - &mut self, - pos: &Target, - bias: Bias, - cx: &<T::Summary as Summary>::Context, - ) -> bool + pub fn seek<Target>(&mut self, pos: &Target, bias: Bias) -> bool where Target: SeekTarget<'a, T::Summary, D>, { - self.reset(cx); - self.seek_internal(pos, bias, &mut (), cx) + self.reset(); + self.seek_internal(pos, bias, &mut ()) } #[track_caller] - pub fn seek_forward<Target>( - &mut self, - pos: &Target, - bias: Bias, - cx: &<T::Summary as Summary>::Context, - ) -> bool + pub fn seek_forward<Target>(&mut self, pos: &Target, bias: Bias) -> bool where Target: SeekTarget<'a, T::Summary, D>, { - self.seek_internal(pos, bias, &mut (), cx) + self.seek_internal(pos, bias, &mut ()) } /// Advances the cursor and returns traversed items as a tree. #[track_caller] - pub fn slice<Target>( - &mut self, - end: &Target, - bias: Bias, - cx: &<T::Summary as Summary>::Context, - ) -> SumTree<T> + pub fn slice<Target>(&mut self, end: &Target, bias: Bias) -> SumTree<T> where Target: SeekTarget<'a, T::Summary, D>, { let mut slice = SliceSeekAggregate { - tree: SumTree::new(cx), + tree: SumTree::new(self.cx), leaf_items: ArrayVec::new(), leaf_item_summaries: ArrayVec::new(), - leaf_summary: <T::Summary as Summary>::zero(cx), + leaf_summary: <T::Summary as Summary>::zero(self.cx), }; - self.seek_internal(end, bias, &mut slice, cx); + self.seek_internal(end, bias, &mut slice); slice.tree } #[track_caller] - pub fn suffix(&mut self, cx: &<T::Summary as Summary>::Context) -> SumTree<T> { - self.slice(&End::new(), Bias::Right, cx) + pub fn suffix(&mut self) -> SumTree<T> { + self.slice(&End::new(), Bias::Right) } #[track_caller] - pub fn summary<Target, Output>( - &mut self, - end: &Target, - bias: Bias, - cx: &<T::Summary as Summary>::Context, - ) -> Output + pub fn summary<Target, Output>(&mut self, end: &Target, bias: Bias) -> Output where Target: SeekTarget<'a, T::Summary, D>, Output: Dimension<'a, T::Summary>, { - let mut summary = SummarySeekAggregate(Output::zero(cx)); - self.seek_internal(end, bias, &mut summary, cx); + let mut summary = SummarySeekAggregate(Output::zero(self.cx)); + self.seek_internal(end, bias, &mut summary); summary.0 } @@ -455,10 +437,9 @@ where target: &dyn SeekTarget<'a, T::Summary, D>, bias: Bias, aggregate: &mut dyn SeekAggregate<'a, T>, - cx: &<T::Summary as Summary>::Context, ) -> bool { assert!( - target.cmp(&self.position, cx) >= Ordering::Equal, + target.cmp(&self.position, self.cx) >= Ordering::Equal, "cannot seek backward", ); @@ -467,7 +448,7 @@ where self.stack.push(StackEntry { tree: self.tree, index: 0, - position: D::zero(cx), + position: D::zero(self.cx), }); } @@ -489,14 +470,14 @@ where .zip(&child_summaries[entry.index..]) { let mut child_end = self.position.clone(); - child_end.add_summary(child_summary, cx); + child_end.add_summary(child_summary, self.cx); - let comparison = target.cmp(&child_end, cx); + let comparison = target.cmp(&child_end, self.cx); if comparison == Ordering::Greater || (comparison == Ordering::Equal && bias == Bias::Right) { self.position = child_end; - aggregate.push_tree(child_tree, child_summary, cx); + aggregate.push_tree(child_tree, child_summary, self.cx); entry.index += 1; entry.position = self.position.clone(); } else { @@ -522,22 +503,22 @@ where .zip(&item_summaries[entry.index..]) { let mut child_end = self.position.clone(); - child_end.add_summary(item_summary, cx); + child_end.add_summary(item_summary, self.cx); - let comparison = target.cmp(&child_end, cx); + let comparison = target.cmp(&child_end, self.cx); if comparison == Ordering::Greater || (comparison == Ordering::Equal && bias == Bias::Right) { self.position = child_end; - aggregate.push_item(item, item_summary, cx); + aggregate.push_item(item, item_summary, self.cx); entry.index += 1; } else { - aggregate.end_leaf(cx); + aggregate.end_leaf(self.cx); break 'outer; } } - aggregate.end_leaf(cx); + aggregate.end_leaf(self.cx); } } @@ -551,11 +532,11 @@ where let mut end = self.position.clone(); if bias == Bias::Left { if let Some(summary) = self.item_summary() { - end.add_summary(summary, cx); + end.add_summary(summary, self.cx); } } - target.cmp(&end, cx) == Ordering::Equal + target.cmp(&end, self.cx) == Ordering::Equal } } @@ -624,21 +605,19 @@ impl<'a, T: Item> Iterator for Iter<'a, T> { } } -impl<'a, T, S, D> Iterator for Cursor<'a, T, D> +impl<'a, T: Item, D> Iterator for Cursor<'a, T, D> where - T: Item<Summary = S>, - S: Summary<Context = ()>, D: Dimension<'a, T::Summary>, { type Item = &'a T; fn next(&mut self) -> Option<Self::Item> { if !self.did_seek { - self.next(&()); + self.next(); } if let Some(item) = self.item() { - self.next(&()); + self.next(); Some(item) } else { None @@ -651,7 +630,7 @@ pub struct FilterCursor<'a, F, T: Item, D> { filter_node: F, } -impl<'a, F, T, D> FilterCursor<'a, F, T, D> +impl<'a, F, T: Item, D> FilterCursor<'a, F, T, D> where F: FnMut(&T::Summary) -> bool, T: Item, @@ -659,7 +638,7 @@ where { pub fn new( tree: &'a SumTree<T>, - cx: &<T::Summary as Summary>::Context, + cx: &'a <T::Summary as Summary>::Context, filter_node: F, ) -> Self { let cursor = tree.cursor::<D>(cx); @@ -673,8 +652,8 @@ where self.cursor.start() } - pub fn end(&self, cx: &<T::Summary as Summary>::Context) -> D { - self.cursor.end(cx) + pub fn end(&self) -> D { + self.cursor.end() } pub fn item(&self) -> Option<&'a T> { @@ -685,31 +664,29 @@ where self.cursor.item_summary() } - pub fn next(&mut self, cx: &<T::Summary as Summary>::Context) { - self.cursor.search_forward(&mut self.filter_node, cx); + pub fn next(&mut self) { + self.cursor.search_forward(&mut self.filter_node); } - pub fn prev(&mut self, cx: &<T::Summary as Summary>::Context) { - self.cursor.search_backward(&mut self.filter_node, cx); + pub fn prev(&mut self) { + self.cursor.search_backward(&mut self.filter_node); } } -impl<'a, F, T, S, U> Iterator for FilterCursor<'a, F, T, U> +impl<'a, F, T: Item, U> Iterator for FilterCursor<'a, F, T, U> where F: FnMut(&T::Summary) -> bool, - T: Item<Summary = S>, - S: Summary<Context = ()>, //Context for the summary must be unit type, as .next() doesn't take arguments U: Dimension<'a, T::Summary>, { type Item = &'a T; fn next(&mut self) -> Option<Self::Item> { if !self.cursor.did_seek { - self.next(&()); + self.next(); } if let Some(item) = self.item() { - self.cursor.search_forward(&mut self.filter_node, &()); + self.cursor.search_forward(&mut self.filter_node); Some(item) } else { None @@ -795,3 +772,23 @@ where self.0.add_summary(summary, cx); } } + +struct End<D>(PhantomData<D>); + +impl<D> End<D> { + fn new() -> Self { + Self(PhantomData) + } +} + +impl<'a, S: Summary, D: Dimension<'a, S>> SeekTarget<'a, S, D> for End<D> { + fn cmp(&self, _: &D, _: &S::Context) -> Ordering { + Ordering::Greater + } +} + +impl<D> fmt::Debug for End<D> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("End").finish() + } +} diff --git a/crates/sum_tree/src/sum_tree.rs b/crates/sum_tree/src/sum_tree.rs index 82022d668554e904fe52f445dfa17dd72b0dd6bf..3a12e3a681f7bd289e4e4a9fa9036d5f307aa1d7 100644 --- a/crates/sum_tree/src/sum_tree.rs +++ b/crates/sum_tree/src/sum_tree.rs @@ -38,20 +38,17 @@ pub trait Summary: Clone { type Context; fn zero(cx: &Self::Context) -> Self; - fn add_summary(&mut self, summary: &Self, cx: &Self::Context); } -/// This type exists because we can't implement Summary for () without causing -/// type resolution errors -#[derive(Copy, Clone, PartialEq, Eq, Debug)] -pub struct Unit; - -impl Summary for Unit { +/// Catch-all implementation for when you need something that implements [`Summary`] without a specific type. +/// We implement it on a &'static, as that avoids blanket impl collisions with `impl<T: Summary> Dimension for T` +/// (as we also need unit type to be a fill-in dimension) +impl Summary for &'static () { type Context = (); fn zero(_: &()) -> Self { - Unit + &() } fn add_summary(&mut self, _: &Self, _: &()) {} @@ -104,57 +101,32 @@ impl<'a, T: Summary> Dimension<'a, T> for () { fn add_summary(&mut self, _: &'a T, _: &T::Context) {} } -impl<'a, T: Summary, D1: Dimension<'a, T>, D2: Dimension<'a, T>> Dimension<'a, T> for (D1, D2) { +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)] +pub struct Dimensions<D1, D2, D3 = ()>(pub D1, pub D2, pub D3); + +impl<'a, T: Summary, D1: Dimension<'a, T>, D2: Dimension<'a, T>, D3: Dimension<'a, T>> + Dimension<'a, T> for Dimensions<D1, D2, D3> +{ fn zero(cx: &T::Context) -> Self { - (D1::zero(cx), D2::zero(cx)) + Dimensions(D1::zero(cx), D2::zero(cx), D3::zero(cx)) } fn add_summary(&mut self, summary: &'a T, cx: &T::Context) { self.0.add_summary(summary, cx); self.1.add_summary(summary, cx); + self.2.add_summary(summary, cx); } } -impl<'a, S, D1, D2> SeekTarget<'a, S, (D1, D2)> for D1 -where - S: Summary, - D1: SeekTarget<'a, S, D1> + Dimension<'a, S>, - D2: Dimension<'a, S>, -{ - fn cmp(&self, cursor_location: &(D1, D2), cx: &S::Context) -> Ordering { - self.cmp(&cursor_location.0, cx) - } -} - -impl<'a, S, D1, D2, D3> SeekTarget<'a, S, ((D1, D2), D3)> for D1 +impl<'a, S, D1, D2, D3> SeekTarget<'a, S, Dimensions<D1, D2, D3>> for D1 where S: Summary, D1: SeekTarget<'a, S, D1> + Dimension<'a, S>, D2: Dimension<'a, S>, D3: Dimension<'a, S>, { - fn cmp(&self, cursor_location: &((D1, D2), D3), cx: &S::Context) -> Ordering { - self.cmp(&cursor_location.0.0, cx) - } -} - -struct End<D>(PhantomData<D>); - -impl<D> End<D> { - fn new() -> Self { - Self(PhantomData) - } -} - -impl<'a, S: Summary, D: Dimension<'a, S>> SeekTarget<'a, S, D> for End<D> { - fn cmp(&self, _: &D, _: &S::Context) -> Ordering { - Ordering::Greater - } -} - -impl<D> fmt::Debug for End<D> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("End").finish() + fn cmp(&self, cursor_location: &Dimensions<D1, D2, D3>, cx: &S::Context) -> Ordering { + self.cmp(&cursor_location.0, cx) } } @@ -372,10 +344,10 @@ impl<T: Item> SumTree<T> { pub fn items(&self, cx: &<T::Summary as Summary>::Context) -> Vec<T> { let mut items = Vec::new(); let mut cursor = self.cursor::<()>(cx); - cursor.next(cx); + cursor.next(); while let Some(item) = cursor.item() { items.push(item.clone()); - cursor.next(cx); + cursor.next(); } items } @@ -384,7 +356,7 @@ impl<T: Item> SumTree<T> { Iter::new(self) } - pub fn cursor<'a, S>(&'a self, cx: &<T::Summary as Summary>::Context) -> Cursor<'a, T, S> + pub fn cursor<'a, S>(&'a self, cx: &'a <T::Summary as Summary>::Context) -> Cursor<'a, T, S> where S: Dimension<'a, T::Summary>, { @@ -395,7 +367,7 @@ impl<T: Item> SumTree<T> { /// that is returned cannot be used with Rust's iterators. pub fn filter<'a, F, U>( &'a self, - cx: &<T::Summary as Summary>::Context, + cx: &'a <T::Summary as Summary>::Context, filter_node: F, ) -> FilterCursor<'a, F, T, U> where @@ -525,10 +497,6 @@ impl<T: Item> SumTree<T> { } } - pub fn ptr_eq(&self, other: &Self) -> bool { - Arc::ptr_eq(&self.0, &other.0) - } - fn push_tree_recursive( &mut self, other: SumTree<T>, @@ -686,11 +654,6 @@ impl<T: Item> SumTree<T> { } => child_trees.last().unwrap().rightmost_leaf(), } } - - #[cfg(debug_assertions)] - pub fn _debug_entries(&self) -> Vec<&T> { - self.iter().collect::<Vec<_>>() - } } impl<T: Item + PartialEq> PartialEq for SumTree<T> { @@ -710,15 +673,15 @@ impl<T: KeyedItem> SumTree<T> { let mut replaced = None; *self = { let mut cursor = self.cursor::<T::Key>(cx); - let mut new_tree = cursor.slice(&item.key(), Bias::Left, cx); + let mut new_tree = cursor.slice(&item.key(), Bias::Left); if let Some(cursor_item) = cursor.item() { if cursor_item.key() == item.key() { replaced = Some(cursor_item.clone()); - cursor.next(cx); + cursor.next(); } } new_tree.push(item, cx); - new_tree.append(cursor.suffix(cx), cx); + new_tree.append(cursor.suffix(), cx); new_tree }; replaced @@ -728,14 +691,14 @@ impl<T: KeyedItem> SumTree<T> { let mut removed = None; *self = { let mut cursor = self.cursor::<T::Key>(cx); - let mut new_tree = cursor.slice(key, Bias::Left, cx); + let mut new_tree = cursor.slice(key, Bias::Left); if let Some(item) = cursor.item() { if item.key() == *key { removed = Some(item.clone()); - cursor.next(cx); + cursor.next(); } } - new_tree.append(cursor.suffix(cx), cx); + new_tree.append(cursor.suffix(), cx); new_tree }; removed @@ -758,7 +721,7 @@ impl<T: KeyedItem> SumTree<T> { let mut new_tree = SumTree::new(cx); let mut buffered_items = Vec::new(); - cursor.seek(&T::Key::zero(cx), Bias::Left, cx); + cursor.seek(&T::Key::zero(cx), Bias::Left); for edit in edits { let new_key = edit.key(); let mut old_item = cursor.item(); @@ -768,7 +731,7 @@ impl<T: KeyedItem> SumTree<T> { .map_or(false, |old_item| old_item.key() < new_key) { new_tree.extend(buffered_items.drain(..), cx); - let slice = cursor.slice(&new_key, Bias::Left, cx); + let slice = cursor.slice(&new_key, Bias::Left); new_tree.append(slice, cx); old_item = cursor.item(); } @@ -776,7 +739,7 @@ impl<T: KeyedItem> SumTree<T> { if let Some(old_item) = old_item { if old_item.key() == new_key { removed.push(old_item.clone()); - cursor.next(cx); + cursor.next(); } } @@ -789,70 +752,25 @@ impl<T: KeyedItem> SumTree<T> { } new_tree.extend(buffered_items, cx); - new_tree.append(cursor.suffix(cx), cx); + new_tree.append(cursor.suffix(), cx); new_tree }; removed } - pub fn get(&self, key: &T::Key, cx: &<T::Summary as Summary>::Context) -> Option<&T> { + pub fn get<'a>( + &'a self, + key: &T::Key, + cx: &'a <T::Summary as Summary>::Context, + ) -> Option<&'a T> { let mut cursor = self.cursor::<T::Key>(cx); - if cursor.seek(key, Bias::Left, cx) { + if cursor.seek(key, Bias::Left) { cursor.item() } else { None } } - - #[inline] - pub fn contains(&self, key: &T::Key, cx: &<T::Summary as Summary>::Context) -> bool { - self.get(key, cx).is_some() - } - - pub fn update<F, R>( - &mut self, - key: &T::Key, - cx: &<T::Summary as Summary>::Context, - f: F, - ) -> Option<R> - where - F: FnOnce(&mut T) -> R, - { - let mut cursor = self.cursor::<T::Key>(cx); - let mut new_tree = cursor.slice(key, Bias::Left, cx); - let mut result = None; - if Ord::cmp(key, &cursor.end(cx)) == Ordering::Equal { - let mut updated = cursor.item().unwrap().clone(); - result = Some(f(&mut updated)); - new_tree.push(updated, cx); - cursor.next(cx); - } - new_tree.append(cursor.suffix(cx), cx); - drop(cursor); - *self = new_tree; - result - } - - pub fn retain<F: FnMut(&T) -> bool>( - &mut self, - cx: &<T::Summary as Summary>::Context, - mut predicate: F, - ) { - let mut new_map = SumTree::new(cx); - - let mut cursor = self.cursor::<T::Key>(cx); - cursor.next(cx); - while let Some(item) = cursor.item() { - if predicate(&item) { - new_map.push(item.clone(), cx); - } - cursor.next(cx); - } - drop(cursor); - - *self = new_map; - } } impl<T, S> Default for SumTree<T> @@ -1061,14 +979,14 @@ mod tests { tree = { let mut cursor = tree.cursor::<Count>(&()); - let mut new_tree = cursor.slice(&Count(splice_start), Bias::Right, &()); + let mut new_tree = cursor.slice(&Count(splice_start), Bias::Right); if rng.r#gen() { new_tree.extend(new_items, &()); } else { new_tree.par_extend(new_items, &()); } - cursor.seek(&Count(splice_end), Bias::Right, &()); - new_tree.append(cursor.slice(&tree_end, Bias::Right, &()), &()); + cursor.seek(&Count(splice_end), Bias::Right); + new_tree.append(cursor.slice(&tree_end, Bias::Right), &()); new_tree }; @@ -1090,10 +1008,10 @@ mod tests { .collect::<Vec<_>>(); let mut item_ix = if rng.r#gen() { - filter_cursor.next(&()); + filter_cursor.next(); 0 } else { - filter_cursor.prev(&()); + filter_cursor.prev(); expected_filtered_items.len().saturating_sub(1) }; while item_ix < expected_filtered_items.len() { @@ -1103,19 +1021,19 @@ mod tests { assert_eq!(actual_item, &reference_item); assert_eq!(filter_cursor.start().0, reference_index); log::info!("next"); - filter_cursor.next(&()); + filter_cursor.next(); item_ix += 1; while item_ix > 0 && rng.gen_bool(0.2) { log::info!("prev"); - filter_cursor.prev(&()); + filter_cursor.prev(); item_ix -= 1; if item_ix == 0 && rng.gen_bool(0.2) { - filter_cursor.prev(&()); + filter_cursor.prev(); assert_eq!(filter_cursor.item(), None); assert_eq!(filter_cursor.start().0, 0); - filter_cursor.next(&()); + filter_cursor.next(); } } } @@ -1124,9 +1042,9 @@ mod tests { let mut before_start = false; let mut cursor = tree.cursor::<Count>(&()); let start_pos = rng.gen_range(0..=reference_items.len()); - cursor.seek(&Count(start_pos), Bias::Right, &()); + cursor.seek(&Count(start_pos), Bias::Right); let mut pos = rng.gen_range(start_pos..=reference_items.len()); - cursor.seek_forward(&Count(pos), Bias::Right, &()); + cursor.seek_forward(&Count(pos), Bias::Right); for i in 0..10 { assert_eq!(cursor.start().0, pos); @@ -1152,13 +1070,13 @@ mod tests { } if i < 5 { - cursor.next(&()); + cursor.next(); if pos < reference_items.len() { pos += 1; before_start = false; } } else { - cursor.prev(&()); + cursor.prev(); if pos == 0 { before_start = true; } @@ -1174,11 +1092,11 @@ mod tests { let end_bias = if rng.r#gen() { Bias::Left } else { Bias::Right }; let mut cursor = tree.cursor::<Count>(&()); - cursor.seek(&Count(start), start_bias, &()); - let slice = cursor.slice(&Count(end), end_bias, &()); + cursor.seek(&Count(start), start_bias); + let slice = cursor.slice(&Count(end), end_bias); - cursor.seek(&Count(start), start_bias, &()); - let summary = cursor.summary::<_, Sum>(&Count(end), end_bias, &()); + cursor.seek(&Count(start), start_bias); + let summary = cursor.summary::<_, Sum>(&Count(end), end_bias); assert_eq!(summary.0, slice.summary().sum); } @@ -1191,19 +1109,19 @@ mod tests { let tree = SumTree::<u8>::default(); let mut cursor = tree.cursor::<IntegersSummary>(&()); assert_eq!( - cursor.slice(&Count(0), Bias::Right, &()).items(&()), + cursor.slice(&Count(0), Bias::Right).items(&()), Vec::<u8>::new() ); assert_eq!(cursor.item(), None); assert_eq!(cursor.prev_item(), None); assert_eq!(cursor.next_item(), None); assert_eq!(cursor.start().sum, 0); - cursor.prev(&()); + cursor.prev(); assert_eq!(cursor.item(), None); assert_eq!(cursor.prev_item(), None); assert_eq!(cursor.next_item(), None); assert_eq!(cursor.start().sum, 0); - cursor.next(&()); + cursor.next(); assert_eq!(cursor.item(), None); assert_eq!(cursor.prev_item(), None); assert_eq!(cursor.next_item(), None); @@ -1214,7 +1132,7 @@ mod tests { tree.extend(vec![1], &()); let mut cursor = tree.cursor::<IntegersSummary>(&()); assert_eq!( - cursor.slice(&Count(0), Bias::Right, &()).items(&()), + cursor.slice(&Count(0), Bias::Right).items(&()), Vec::<u8>::new() ); assert_eq!(cursor.item(), Some(&1)); @@ -1222,29 +1140,29 @@ mod tests { assert_eq!(cursor.next_item(), None); assert_eq!(cursor.start().sum, 0); - cursor.next(&()); + cursor.next(); assert_eq!(cursor.item(), None); assert_eq!(cursor.prev_item(), Some(&1)); assert_eq!(cursor.next_item(), None); assert_eq!(cursor.start().sum, 1); - cursor.prev(&()); + cursor.prev(); assert_eq!(cursor.item(), Some(&1)); assert_eq!(cursor.prev_item(), None); assert_eq!(cursor.next_item(), None); assert_eq!(cursor.start().sum, 0); let mut cursor = tree.cursor::<IntegersSummary>(&()); - assert_eq!(cursor.slice(&Count(1), Bias::Right, &()).items(&()), [1]); + assert_eq!(cursor.slice(&Count(1), Bias::Right).items(&()), [1]); assert_eq!(cursor.item(), None); assert_eq!(cursor.prev_item(), Some(&1)); assert_eq!(cursor.next_item(), None); assert_eq!(cursor.start().sum, 1); - cursor.seek(&Count(0), Bias::Right, &()); + cursor.seek(&Count(0), Bias::Right); assert_eq!( cursor - .slice(&tree.extent::<Count>(&()), Bias::Right, &()) + .slice(&tree.extent::<Count>(&()), Bias::Right) .items(&()), [1] ); @@ -1258,80 +1176,80 @@ mod tests { tree.extend(vec![1, 2, 3, 4, 5, 6], &()); let mut cursor = tree.cursor::<IntegersSummary>(&()); - assert_eq!(cursor.slice(&Count(2), Bias::Right, &()).items(&()), [1, 2]); + assert_eq!(cursor.slice(&Count(2), Bias::Right).items(&()), [1, 2]); assert_eq!(cursor.item(), Some(&3)); assert_eq!(cursor.prev_item(), Some(&2)); assert_eq!(cursor.next_item(), Some(&4)); assert_eq!(cursor.start().sum, 3); - cursor.next(&()); + cursor.next(); assert_eq!(cursor.item(), Some(&4)); assert_eq!(cursor.prev_item(), Some(&3)); assert_eq!(cursor.next_item(), Some(&5)); assert_eq!(cursor.start().sum, 6); - cursor.next(&()); + cursor.next(); assert_eq!(cursor.item(), Some(&5)); assert_eq!(cursor.prev_item(), Some(&4)); assert_eq!(cursor.next_item(), Some(&6)); assert_eq!(cursor.start().sum, 10); - cursor.next(&()); + cursor.next(); assert_eq!(cursor.item(), Some(&6)); assert_eq!(cursor.prev_item(), Some(&5)); assert_eq!(cursor.next_item(), None); assert_eq!(cursor.start().sum, 15); - cursor.next(&()); - cursor.next(&()); + cursor.next(); + cursor.next(); assert_eq!(cursor.item(), None); assert_eq!(cursor.prev_item(), Some(&6)); assert_eq!(cursor.next_item(), None); assert_eq!(cursor.start().sum, 21); - cursor.prev(&()); + cursor.prev(); assert_eq!(cursor.item(), Some(&6)); assert_eq!(cursor.prev_item(), Some(&5)); assert_eq!(cursor.next_item(), None); assert_eq!(cursor.start().sum, 15); - cursor.prev(&()); + cursor.prev(); assert_eq!(cursor.item(), Some(&5)); assert_eq!(cursor.prev_item(), Some(&4)); assert_eq!(cursor.next_item(), Some(&6)); assert_eq!(cursor.start().sum, 10); - cursor.prev(&()); + cursor.prev(); assert_eq!(cursor.item(), Some(&4)); assert_eq!(cursor.prev_item(), Some(&3)); assert_eq!(cursor.next_item(), Some(&5)); assert_eq!(cursor.start().sum, 6); - cursor.prev(&()); + cursor.prev(); assert_eq!(cursor.item(), Some(&3)); assert_eq!(cursor.prev_item(), Some(&2)); assert_eq!(cursor.next_item(), Some(&4)); assert_eq!(cursor.start().sum, 3); - cursor.prev(&()); + cursor.prev(); assert_eq!(cursor.item(), Some(&2)); assert_eq!(cursor.prev_item(), Some(&1)); assert_eq!(cursor.next_item(), Some(&3)); assert_eq!(cursor.start().sum, 1); - cursor.prev(&()); + cursor.prev(); assert_eq!(cursor.item(), Some(&1)); assert_eq!(cursor.prev_item(), None); assert_eq!(cursor.next_item(), Some(&2)); assert_eq!(cursor.start().sum, 0); - cursor.prev(&()); + cursor.prev(); assert_eq!(cursor.item(), None); assert_eq!(cursor.prev_item(), None); assert_eq!(cursor.next_item(), Some(&1)); assert_eq!(cursor.start().sum, 0); - cursor.next(&()); + cursor.next(); assert_eq!(cursor.item(), Some(&1)); assert_eq!(cursor.prev_item(), None); assert_eq!(cursor.next_item(), Some(&2)); @@ -1340,7 +1258,7 @@ mod tests { let mut cursor = tree.cursor::<IntegersSummary>(&()); assert_eq!( cursor - .slice(&tree.extent::<Count>(&()), Bias::Right, &()) + .slice(&tree.extent::<Count>(&()), Bias::Right) .items(&()), tree.items(&()) ); @@ -1349,10 +1267,10 @@ mod tests { assert_eq!(cursor.next_item(), None); assert_eq!(cursor.start().sum, 21); - cursor.seek(&Count(3), Bias::Right, &()); + cursor.seek(&Count(3), Bias::Right); assert_eq!( cursor - .slice(&tree.extent::<Count>(&()), Bias::Right, &()) + .slice(&tree.extent::<Count>(&()), Bias::Right) .items(&()), [4, 5, 6] ); @@ -1362,25 +1280,16 @@ mod tests { assert_eq!(cursor.start().sum, 21); // Seeking can bias left or right - cursor.seek(&Count(1), Bias::Left, &()); + cursor.seek(&Count(1), Bias::Left); assert_eq!(cursor.item(), Some(&1)); - cursor.seek(&Count(1), Bias::Right, &()); + cursor.seek(&Count(1), Bias::Right); assert_eq!(cursor.item(), Some(&2)); // Slicing without resetting starts from where the cursor is parked at. - cursor.seek(&Count(1), Bias::Right, &()); - assert_eq!( - cursor.slice(&Count(3), Bias::Right, &()).items(&()), - vec![2, 3] - ); - assert_eq!( - cursor.slice(&Count(6), Bias::Left, &()).items(&()), - vec![4, 5] - ); - assert_eq!( - cursor.slice(&Count(6), Bias::Right, &()).items(&()), - vec![6] - ); + cursor.seek(&Count(1), Bias::Right); + assert_eq!(cursor.slice(&Count(3), Bias::Right).items(&()), vec![2, 3]); + assert_eq!(cursor.slice(&Count(6), Bias::Left).items(&()), vec![4, 5]); + assert_eq!(cursor.slice(&Count(6), Bias::Right).items(&()), vec![6]); } #[test] diff --git a/crates/sum_tree/src/tree_map.rs b/crates/sum_tree/src/tree_map.rs index 884042b722aef0bb84db180ac02fe795a6b8b45e..54e8ae8343f4778e04a37a7ebd3dbe2b6da587cd 100644 --- a/crates/sum_tree/src/tree_map.rs +++ b/crates/sum_tree/src/tree_map.rs @@ -54,7 +54,7 @@ impl<K: Clone + Ord, V: Clone> TreeMap<K, V> { pub fn get(&self, key: &K) -> Option<&V> { let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&()); - cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &()); + cursor.seek(&MapKeyRef(Some(key)), Bias::Left); if let Some(item) = cursor.item() { if Some(key) == item.key().0.as_ref() { Some(&item.value) @@ -71,10 +71,10 @@ impl<K: Clone + Ord, V: Clone> TreeMap<K, V> { } pub fn extend(&mut self, iter: impl IntoIterator<Item = (K, V)>) { - let mut edits = Vec::new(); - for (key, value) in iter { - edits.push(Edit::Insert(MapEntry { key, value })); - } + let edits: Vec<_> = iter + .into_iter() + .map(|(key, value)| Edit::Insert(MapEntry { key, value })) + .collect(); self.0.edit(edits, &()); } @@ -86,12 +86,12 @@ impl<K: Clone + Ord, V: Clone> TreeMap<K, V> { let mut removed = None; let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&()); let key = MapKeyRef(Some(key)); - let mut new_tree = cursor.slice(&key, Bias::Left, &()); - if key.cmp(&cursor.end(&()), &()) == Ordering::Equal { + let mut new_tree = cursor.slice(&key, Bias::Left); + if key.cmp(&cursor.end(), &()) == Ordering::Equal { removed = Some(cursor.item().unwrap().value.clone()); - cursor.next(&()); + cursor.next(); } - new_tree.append(cursor.suffix(&()), &()); + new_tree.append(cursor.suffix(), &()); drop(cursor); self.0 = new_tree; removed @@ -101,9 +101,9 @@ impl<K: Clone + Ord, V: Clone> TreeMap<K, V> { let start = MapSeekTargetAdaptor(start); let end = MapSeekTargetAdaptor(end); let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&()); - let mut new_tree = cursor.slice(&start, Bias::Left, &()); - cursor.seek(&end, Bias::Left, &()); - new_tree.append(cursor.suffix(&()), &()); + let mut new_tree = cursor.slice(&start, Bias::Left); + cursor.seek(&end, Bias::Left); + new_tree.append(cursor.suffix(), &()); drop(cursor); self.0 = new_tree; } @@ -112,15 +112,15 @@ impl<K: Clone + Ord, V: Clone> TreeMap<K, V> { pub fn closest(&self, key: &K) -> Option<(&K, &V)> { let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&()); let key = MapKeyRef(Some(key)); - cursor.seek(&key, Bias::Right, &()); - cursor.prev(&()); + cursor.seek(&key, Bias::Right); + cursor.prev(); cursor.item().map(|item| (&item.key, &item.value)) } pub fn iter_from<'a>(&'a self, from: &K) -> impl Iterator<Item = (&'a K, &'a V)> + 'a { let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&()); let from_key = MapKeyRef(Some(from)); - cursor.seek(&from_key, Bias::Left, &()); + cursor.seek(&from_key, Bias::Left); cursor.map(|map_entry| (&map_entry.key, &map_entry.value)) } @@ -131,15 +131,15 @@ impl<K: Clone + Ord, V: Clone> TreeMap<K, V> { { let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&()); let key = MapKeyRef(Some(key)); - let mut new_tree = cursor.slice(&key, Bias::Left, &()); + let mut new_tree = cursor.slice(&key, Bias::Left); let mut result = None; - if key.cmp(&cursor.end(&()), &()) == Ordering::Equal { + if key.cmp(&cursor.end(), &()) == Ordering::Equal { let mut updated = cursor.item().unwrap().clone(); result = Some(f(&mut updated.value)); new_tree.push(updated, &()); - cursor.next(&()); + cursor.next(); } - new_tree.append(cursor.suffix(&()), &()); + new_tree.append(cursor.suffix(), &()); drop(cursor); self.0 = new_tree; result @@ -149,12 +149,12 @@ impl<K: Clone + Ord, V: Clone> TreeMap<K, V> { let mut new_map = SumTree::<MapEntry<K, V>>::default(); let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&()); - cursor.next(&()); + cursor.next(); while let Some(item) = cursor.item() { if predicate(&item.key, &item.value) { new_map.push(item.clone(), &()); } - cursor.next(&()); + cursor.next(); } drop(cursor); diff --git a/crates/supermaven/Cargo.toml b/crates/supermaven/Cargo.toml index d0451f34f2275cb81e95bb347219ee8ac79422e4..4fc6a618ff1b585d9365357dc3a33c1b148feb99 100644 --- a/crates/supermaven/Cargo.toml +++ b/crates/supermaven/Cargo.toml @@ -16,9 +16,9 @@ doctest = false anyhow.workspace = true client.workspace = true collections.workspace = true +edit_prediction.workspace = true futures.workspace = true gpui.workspace = true -inline_completion.workspace = true language.workspace = true log.workspace = true postage.workspace = true diff --git a/crates/supermaven/src/supermaven.rs b/crates/supermaven/src/supermaven.rs index ab500fb79d0584f07dd12a9b25b97c0a4393c01b..a31b96d8825334a3aed5fceed0efb86db4fac9f5 100644 --- a/crates/supermaven/src/supermaven.rs +++ b/crates/supermaven/src/supermaven.rs @@ -234,16 +234,14 @@ fn find_relevant_completion<'a>( } let original_cursor_offset = buffer.clip_offset(state.prefix_offset, text::Bias::Left); - let text_inserted_since_completion_request = - buffer.text_for_range(original_cursor_offset..current_cursor_offset); - let mut trimmed_completion = state_completion; - for chunk in text_inserted_since_completion_request { - if let Some(suffix) = trimmed_completion.strip_prefix(chunk) { - trimmed_completion = suffix; - } else { - continue 'completions; - } - } + let text_inserted_since_completion_request: String = buffer + .text_for_range(original_cursor_offset..current_cursor_offset) + .collect(); + let trimmed_completion = + match state_completion.strip_prefix(&text_inserted_since_completion_request) { + Some(suffix) => suffix, + None => continue 'completions, + }; if best_completion.map_or(false, |best| best.len() > trimmed_completion.len()) { continue; @@ -439,3 +437,77 @@ pub struct SupermavenCompletion { pub id: SupermavenCompletionStateId, pub updates: watch::Receiver<()>, } + +#[cfg(test)] +mod tests { + use super::*; + use collections::BTreeMap; + use gpui::TestAppContext; + use language::Buffer; + + #[gpui::test] + async fn test_find_relevant_completion_no_first_letter_skip(cx: &mut TestAppContext) { + let buffer = cx.new(|cx| Buffer::local("hello world", cx)); + let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()); + + let mut states = BTreeMap::new(); + let state_id = SupermavenCompletionStateId(1); + let (updates_tx, _) = watch::channel(); + + states.insert( + state_id, + SupermavenCompletionState { + buffer_id: buffer.entity_id(), + prefix_anchor: buffer_snapshot.anchor_before(0), // Start of buffer + prefix_offset: 0, + text: "hello".to_string(), + dedent: String::new(), + updates_tx, + }, + ); + + let cursor_position = buffer_snapshot.anchor_after(1); + + let result = find_relevant_completion( + &states, + buffer.entity_id(), + &buffer_snapshot, + cursor_position, + ); + + assert_eq!(result, Some("ello")); + } + + #[gpui::test] + async fn test_find_relevant_completion_with_multiple_chars(cx: &mut TestAppContext) { + let buffer = cx.new(|cx| Buffer::local("hello world", cx)); + let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()); + + let mut states = BTreeMap::new(); + let state_id = SupermavenCompletionStateId(1); + let (updates_tx, _) = watch::channel(); + + states.insert( + state_id, + SupermavenCompletionState { + buffer_id: buffer.entity_id(), + prefix_anchor: buffer_snapshot.anchor_before(0), // Start of buffer + prefix_offset: 0, + text: "hello".to_string(), + dedent: String::new(), + updates_tx, + }, + ); + + let cursor_position = buffer_snapshot.anchor_after(3); + + let result = find_relevant_completion( + &states, + buffer.entity_id(), + &buffer_snapshot, + cursor_position, + ); + + assert_eq!(result, Some("lo")); + } +} diff --git a/crates/supermaven/src/supermaven_completion_provider.rs b/crates/supermaven/src/supermaven_completion_provider.rs index c49272e66e4bb762c5aad5204c93cf932ec851bd..1b1fc54a7a335ac436038fe9c254678a6628cb78 100644 --- a/crates/supermaven/src/supermaven_completion_provider.rs +++ b/crates/supermaven/src/supermaven_completion_provider.rs @@ -1,8 +1,8 @@ use crate::{Supermaven, SupermavenCompletionStateId}; use anyhow::Result; +use edit_prediction::{Direction, EditPrediction, EditPredictionProvider}; use futures::StreamExt as _; use gpui::{App, Context, Entity, EntityId, Task}; -use inline_completion::{Direction, EditPredictionProvider, InlineCompletion}; use language::{Anchor, Buffer, BufferSnapshot}; use project::Project; use std::{ @@ -44,7 +44,7 @@ fn completion_from_diff( completion_text: &str, position: Anchor, delete_range: Range<Anchor>, -) -> InlineCompletion { +) -> EditPrediction { let buffer_text = snapshot .text_for_range(delete_range.clone()) .collect::<String>(); @@ -91,7 +91,7 @@ fn completion_from_diff( edits.push((edit_range, edit_text)); } - InlineCompletion { + EditPrediction { id: None, edits, edit_preview: None, @@ -108,6 +108,14 @@ impl EditPredictionProvider for SupermavenCompletionProvider { } fn show_completions_in_menu() -> bool { + true + } + + fn show_tab_accept_marker() -> bool { + true + } + + fn supports_jump_to_edit() -> bool { false } @@ -116,7 +124,7 @@ impl EditPredictionProvider for SupermavenCompletionProvider { } fn is_refreshing(&self) -> bool { - self.pending_refresh.is_some() + self.pending_refresh.is_some() && self.completion_id.is_none() } fn refresh( @@ -182,7 +190,7 @@ impl EditPredictionProvider for SupermavenCompletionProvider { buffer: &Entity<Buffer>, cursor_position: Anchor, cx: &mut Context<Self>, - ) -> Option<InlineCompletion> { + ) -> Option<EditPrediction> { let completion_text = self .supermaven .read(cx) @@ -197,6 +205,7 @@ impl EditPredictionProvider for SupermavenCompletionProvider { let mut point = cursor_position.to_point(&snapshot); point.column = snapshot.line_len(point.row); let range = cursor_position..snapshot.anchor_after(point); + Some(completion_from_diff( snapshot, completion_text, diff --git a/crates/tasks_ui/src/modal.rs b/crates/tasks_ui/src/modal.rs index 1510f613e34ef7bfc78bbfad23b7843787432491..c4b0931c353a5651906dbc26c2eba77f55a080b2 100644 --- a/crates/tasks_ui/src/modal.rs +++ b/crates/tasks_ui/src/modal.rs @@ -500,7 +500,7 @@ impl PickerDelegate for TasksModalDelegate { .map(|icon| icon.color(Color::Muted).size(IconSize::Small)); let indicator = if matches!(source_kind, TaskSourceKind::Lsp { .. }) { Some(Indicator::icon( - Icon::new(IconName::Bolt).size(IconSize::Small), + Icon::new(IconName::BoltOutlined).size(IconSize::Small), )) } else { None diff --git a/crates/telemetry_events/src/telemetry_events.rs b/crates/telemetry_events/src/telemetry_events.rs index dfe167fcd44c2fe3c163af260090f10eb94d4a71..735a1310ae063befb056563fe8050e8fda153941 100644 --- a/crates/telemetry_events/src/telemetry_events.rs +++ b/crates/telemetry_events/src/telemetry_events.rs @@ -94,8 +94,8 @@ impl Display for AssistantPhase { pub enum Event { Flexible(FlexibleEvent), Editor(EditorEvent), - InlineCompletion(InlineCompletionEvent), - InlineCompletionRating(InlineCompletionRatingEvent), + EditPrediction(EditPredictionEvent), + EditPredictionRating(EditPredictionRatingEvent), Call(CallEvent), Assistant(AssistantEventData), Cpu(CpuEvent), @@ -132,7 +132,7 @@ pub struct EditorEvent { } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct InlineCompletionEvent { +pub struct EditPredictionEvent { /// Provider of the completion suggestion (e.g. copilot, supermaven) pub provider: String, pub suggestion_accepted: bool, @@ -140,14 +140,14 @@ pub struct InlineCompletionEvent { } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub enum InlineCompletionRating { +pub enum EditPredictionRating { Positive, Negative, } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct InlineCompletionRatingEvent { - pub rating: InlineCompletionRating, +pub struct EditPredictionRatingEvent { + pub rating: EditPredictionRating, pub input_events: Arc<str>, pub input_excerpt: Arc<str>, pub output_excerpt: Arc<str>, diff --git a/crates/terminal/src/pty_info.rs b/crates/terminal/src/pty_info.rs index d9515afbf751a0027bc8e911334b2db13bf03935..802470493cc6c883008640af2b8f374254e3299d 100644 --- a/crates/terminal/src/pty_info.rs +++ b/crates/terminal/src/pty_info.rs @@ -121,6 +121,10 @@ impl PtyProcessInfo { } } + pub(crate) fn kill_current_process(&mut self) -> bool { + self.refresh().map_or(false, |process| process.kill()) + } + fn load(&mut self) -> Option<ProcessInfo> { let process = self.refresh()?; let cwd = process.cwd().map_or(PathBuf::new(), |p| p.to_owned()); diff --git a/crates/terminal/src/terminal.rs b/crates/terminal/src/terminal.rs index 032a750d1a58cce592ac64e8e8aa35cd0b07df99..6e359414d76b0a10a1cbfdef8bfe868fff3eb1fa 100644 --- a/crates/terminal/src/terminal.rs +++ b/crates/terminal/src/terminal.rs @@ -162,7 +162,8 @@ enum InternalEvent { UpdateSelection(Point<Pixels>), // Adjusted mouse position, should open FindHyperlink(Point<Pixels>, bool), - Copy, + // Whether keep selection when copy + Copy(Option<bool>), // Vi mode events ToggleViMode, ViMotion(ViMotion), @@ -931,13 +932,13 @@ impl Terminal { } } - InternalEvent::Copy => { + InternalEvent::Copy(keep_selection) => { if let Some(txt) = term.selection_to_string() { cx.write_to_clipboard(ClipboardItem::new_string(txt)); - - let settings = TerminalSettings::get_global(cx); - - if !settings.keep_selection_on_copy { + if !keep_selection.unwrap_or_else(|| { + let settings = TerminalSettings::get_global(cx); + settings.keep_selection_on_copy + }) { self.events.push_back(InternalEvent::SetSelection(None)); } } @@ -1108,8 +1109,8 @@ impl Terminal { .push_back(InternalEvent::SetSelection(selection)); } - pub fn copy(&mut self) { - self.events.push_back(InternalEvent::Copy); + pub fn copy(&mut self, keep_selection: Option<bool>) { + self.events.push_back(InternalEvent::Copy(keep_selection)); } pub fn clear(&mut self) { @@ -1267,8 +1268,7 @@ impl Terminal { } "y" => { - self.events.push_back(InternalEvent::Copy); - self.events.push_back(InternalEvent::SetSelection(None)); + self.copy(Some(false)); return; } @@ -1653,7 +1653,7 @@ impl Terminal { } } else { if e.button == MouseButton::Left && setting.copy_on_select { - self.copy(); + self.copy(Some(true)); } //Hyperlinks @@ -1824,6 +1824,14 @@ impl Terminal { } } + pub fn kill_active_task(&mut self) { + if let Some(task) = self.task() { + if task.status == TaskStatus::Running { + self.pty_info.kill_current_process(); + } + } + } + pub fn task(&self) -> Option<&TaskState> { self.task.as_ref() } diff --git a/crates/terminal/src/terminal_settings.rs b/crates/terminal/src/terminal_settings.rs index 31c32dbdca22a73fddda7cd9334c6cde76a99c8b..3f89afffab766126d5f1ef33f7d12b109d0198ca 100644 --- a/crates/terminal/src/terminal_settings.rs +++ b/crates/terminal/src/terminal_settings.rs @@ -95,12 +95,14 @@ pub enum VenvSettings { /// to the current working directory. We recommend overriding this /// in your project's settings, rather than globally. activate_script: Option<ActivateScript>, + venv_name: Option<String>, directories: Option<Vec<PathBuf>>, }, } pub struct VenvSettingsContent<'a> { pub activate_script: ActivateScript, + pub venv_name: &'a str, pub directories: &'a [PathBuf], } @@ -110,16 +112,18 @@ impl VenvSettings { VenvSettings::Off => None, VenvSettings::On { activate_script, + venv_name, directories, } => Some(VenvSettingsContent { activate_script: activate_script.unwrap_or(ActivateScript::Default), + venv_name: venv_name.as_deref().unwrap_or(""), directories: directories.as_deref().unwrap_or(&[]), }), } } } -#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Copy, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "snake_case")] pub enum ActivateScript { #[default] @@ -128,6 +132,7 @@ pub enum ActivateScript { Fish, Nushell, PowerShell, + Pyenv, } #[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema)] @@ -243,7 +248,7 @@ pub struct TerminalSettingsContent { /// - 75: Minimum for body text /// - 90: Preferred for body text /// - /// Default: 0 (no adjustment) + /// Default: 45 pub minimum_contrast: Option<f32>, } diff --git a/crates/terminal_view/src/terminal_element.rs b/crates/terminal_view/src/terminal_element.rs index c34d8926440287ca684d7c93527516c41e2869df..6c1be9d5e790e9e3b83aa843c1bd9b6ed9153b77 100644 --- a/crates/terminal_view/src/terminal_element.rs +++ b/crates/terminal_view/src/terminal_element.rs @@ -127,7 +127,7 @@ impl BatchedTextRun { cx: &mut App, ) { let pos = Point::new( - (origin.x + self.start_point.column as f32 * dimensions.cell_width).floor(), + origin.x + self.start_point.column as f32 * dimensions.cell_width, origin.y + self.start_point.line as f32 * dimensions.line_height, ); @@ -136,7 +136,7 @@ impl BatchedTextRun { .shape_line( self.text.clone().into(), self.font_size.to_pixels(window.rem_size()), - &[self.style.clone()], + std::slice::from_ref(&self.style), Some(dimensions.cell_width), ) .paint(pos, dimensions.line_height, window, cx); @@ -494,6 +494,30 @@ impl TerminalElement { } } + /// Checks if a character is a decorative block/box-like character that should + /// preserve its exact colors without contrast adjustment. + /// + /// This specifically targets characters used as visual connectors, separators, + /// and borders where color matching with adjacent backgrounds is critical. + /// Regular icons (git, folders, etc.) are excluded as they need to remain readable. + /// + /// Fixes https://github.com/zed-industries/zed/issues/34234 + fn is_decorative_character(ch: char) -> bool { + matches!( + ch as u32, + // Unicode Box Drawing and Block Elements + 0x2500..=0x257F // Box Drawing (└ ┐ ─ │ etc.) + | 0x2580..=0x259F // Block Elements (▀ ▄ █ ░ ▒ ▓ etc.) + | 0x25A0..=0x25FF // Geometric Shapes (■ ▶ ● etc. - includes triangular/circular separators) + + // Private Use Area - Powerline separator symbols only + | 0xE0B0..=0xE0B7 // Powerline separators: triangles (E0B0-E0B3) and half circles (E0B4-E0B7) + | 0xE0B8..=0xE0BF // Additional Powerline separators: angles, flames, etc. + | 0xE0C0..=0xE0C8 // Powerline separators: pixelated triangles, curves + | 0xE0CC..=0xE0D4 // Powerline separators: rounded triangles, ice/lego style + ) + } + /// Converts the Alacritty cell styles to GPUI text styles and background color. fn cell_style( indexed: &IndexedCell, @@ -508,7 +532,10 @@ impl TerminalElement { let mut fg = convert_color(&fg, colors); let bg = convert_color(&bg, colors); - fg = color_contrast::ensure_minimum_contrast(fg, bg, minimum_contrast); + // Only apply contrast adjustment to non-decorative characters + if !Self::is_decorative_character(indexed.c) { + fg = color_contrast::ensure_minimum_contrast(fg, bg, minimum_contrast); + } // Ghostty uses (175/255) as the multiplier (~0.69), Alacritty uses 0.66, Kitty // uses 0.75. We're using 0.7 because it's pretty well in the middle of that. @@ -1575,6 +1602,101 @@ mod tests { use super::*; use gpui::{AbsoluteLength, Hsla, font}; + #[test] + fn test_is_decorative_character() { + // Box Drawing characters (U+2500 to U+257F) + assert!(TerminalElement::is_decorative_character('─')); // U+2500 + assert!(TerminalElement::is_decorative_character('│')); // U+2502 + assert!(TerminalElement::is_decorative_character('┌')); // U+250C + assert!(TerminalElement::is_decorative_character('┐')); // U+2510 + assert!(TerminalElement::is_decorative_character('└')); // U+2514 + assert!(TerminalElement::is_decorative_character('┘')); // U+2518 + assert!(TerminalElement::is_decorative_character('┼')); // U+253C + + // Block Elements (U+2580 to U+259F) + assert!(TerminalElement::is_decorative_character('▀')); // U+2580 + assert!(TerminalElement::is_decorative_character('▄')); // U+2584 + assert!(TerminalElement::is_decorative_character('█')); // U+2588 + assert!(TerminalElement::is_decorative_character('░')); // U+2591 + assert!(TerminalElement::is_decorative_character('▒')); // U+2592 + assert!(TerminalElement::is_decorative_character('▓')); // U+2593 + + // Geometric Shapes - block/box-like subset (U+25A0 to U+25D7) + assert!(TerminalElement::is_decorative_character('■')); // U+25A0 + assert!(TerminalElement::is_decorative_character('□')); // U+25A1 + assert!(TerminalElement::is_decorative_character('▲')); // U+25B2 + assert!(TerminalElement::is_decorative_character('▼')); // U+25BC + assert!(TerminalElement::is_decorative_character('◆')); // U+25C6 + assert!(TerminalElement::is_decorative_character('●')); // U+25CF + + // The specific character from the issue + assert!(TerminalElement::is_decorative_character('◗')); // U+25D7 + assert!(TerminalElement::is_decorative_character('◘')); // U+25D8 (now included in Geometric Shapes) + assert!(TerminalElement::is_decorative_character('◙')); // U+25D9 (now included in Geometric Shapes) + + // Powerline symbols (Private Use Area) + assert!(TerminalElement::is_decorative_character('\u{E0B0}')); // Powerline right triangle + assert!(TerminalElement::is_decorative_character('\u{E0B2}')); // Powerline left triangle + assert!(TerminalElement::is_decorative_character('\u{E0B4}')); // Powerline right half circle (the actual issue!) + assert!(TerminalElement::is_decorative_character('\u{E0B6}')); // Powerline left half circle + + // Characters that should NOT be considered decorative + assert!(!TerminalElement::is_decorative_character('A')); // Regular letter + assert!(!TerminalElement::is_decorative_character('$')); // Symbol + assert!(!TerminalElement::is_decorative_character(' ')); // Space + assert!(!TerminalElement::is_decorative_character('←')); // U+2190 (Arrow, not in our ranges) + assert!(!TerminalElement::is_decorative_character('→')); // U+2192 (Arrow, not in our ranges) + assert!(!TerminalElement::is_decorative_character('\u{F00C}')); // Font Awesome check (icon, needs contrast) + assert!(!TerminalElement::is_decorative_character('\u{E711}')); // Devicons (icon, needs contrast) + assert!(!TerminalElement::is_decorative_character('\u{EA71}')); // Codicons folder (icon, needs contrast) + assert!(!TerminalElement::is_decorative_character('\u{F401}')); // Octicons (icon, needs contrast) + assert!(!TerminalElement::is_decorative_character('\u{1F600}')); // Emoji (not in our ranges) + } + + #[test] + fn test_decorative_character_boundary_cases() { + // Test exact boundaries of our ranges + // Box Drawing range boundaries + assert!(TerminalElement::is_decorative_character('\u{2500}')); // First char + assert!(TerminalElement::is_decorative_character('\u{257F}')); // Last char + assert!(!TerminalElement::is_decorative_character('\u{24FF}')); // Just before + + // Block Elements range boundaries + assert!(TerminalElement::is_decorative_character('\u{2580}')); // First char + assert!(TerminalElement::is_decorative_character('\u{259F}')); // Last char + + // Geometric Shapes subset boundaries + assert!(TerminalElement::is_decorative_character('\u{25A0}')); // First char + assert!(TerminalElement::is_decorative_character('\u{25FF}')); // Last char + assert!(!TerminalElement::is_decorative_character('\u{2600}')); // Just after + } + + #[test] + fn test_decorative_characters_bypass_contrast_adjustment() { + // Decorative characters should not be affected by contrast adjustment + + // The specific character from issue #34234 + let problematic_char = '◗'; // U+25D7 + assert!( + TerminalElement::is_decorative_character(problematic_char), + "Character ◗ (U+25D7) should be recognized as decorative" + ); + + // Verify some other commonly used decorative characters + assert!(TerminalElement::is_decorative_character('│')); // Vertical line + assert!(TerminalElement::is_decorative_character('─')); // Horizontal line + assert!(TerminalElement::is_decorative_character('█')); // Full block + assert!(TerminalElement::is_decorative_character('▓')); // Dark shade + assert!(TerminalElement::is_decorative_character('■')); // Black square + assert!(TerminalElement::is_decorative_character('●')); // Black circle + + // Verify normal text characters are NOT decorative + assert!(!TerminalElement::is_decorative_character('A')); + assert!(!TerminalElement::is_decorative_character('1')); + assert!(!TerminalElement::is_decorative_character('$')); + assert!(!TerminalElement::is_decorative_character(' ')); + } + #[test] fn test_contrast_adjustment_logic() { // Test the core contrast adjustment logic without needing full app context diff --git a/crates/terminal_view/src/terminal_panel.rs b/crates/terminal_view/src/terminal_panel.rs index f6eee3065ca974449315ab2ac519de1acb5da11e..cb1e3628848e9e850fa0c13f8b659259a1e6fd48 100644 --- a/crates/terminal_view/src/terminal_panel.rs +++ b/crates/terminal_view/src/terminal_panel.rs @@ -1437,7 +1437,7 @@ impl Panel for TerminalPanel { if (self.is_enabled(cx) || !self.has_no_terminals(cx)) && TerminalSettings::get_global(cx).button { - Some(IconName::Terminal) + Some(IconName::TerminalAlt) } else { None } diff --git a/crates/terminal_view/src/terminal_scrollbar.rs b/crates/terminal_view/src/terminal_scrollbar.rs index 18e135be2eef3b8e7ec71c070f2a60a46792a271..c8565a42bee0858e0928e557b9fae590dba319fb 100644 --- a/crates/terminal_view/src/terminal_scrollbar.rs +++ b/crates/terminal_view/src/terminal_scrollbar.rs @@ -46,9 +46,16 @@ impl TerminalScrollHandle { } impl ScrollableHandle for TerminalScrollHandle { - fn content_size(&self) -> Size<Pixels> { + fn max_offset(&self) -> Size<Pixels> { let state = self.state.borrow(); - size(Pixels::ZERO, state.total_lines as f32 * state.line_height) + size( + Pixels::ZERO, + state + .total_lines + .checked_sub(state.viewport_lines) + .unwrap_or(0) as f32 + * state.line_height, + ) } fn offset(&self) -> Point<Pixels> { diff --git a/crates/terminal_view/src/terminal_view.rs b/crates/terminal_view/src/terminal_view.rs index 76ec9dcb2591a3d0f5483507f25a7c036464e93d..2e6be5aaf46ed4dbdf4f0b0f29f1c6501f8f134c 100644 --- a/crates/terminal_view/src/terminal_view.rs +++ b/crates/terminal_view/src/terminal_view.rs @@ -25,11 +25,11 @@ use terminal::{ TaskStatus, Terminal, TerminalBounds, ToggleViMode, alacritty_terminal::{ index::Point, - term::{TermMode, search::RegexSearch}, + term::{TermMode, point_to_viewport, search::RegexSearch}, }, terminal_settings::{self, CursorShape, TerminalBlink, TerminalSettings, WorkingDirectory}, }; -use terminal_element::{TerminalElement, is_blank}; +use terminal_element::TerminalElement; use terminal_panel::TerminalPanel; use terminal_scrollbar::TerminalScrollHandle; use terminal_slash_command::TerminalSlashCommand; @@ -430,6 +430,7 @@ impl TerminalView { fn settings_changed(&mut self, cx: &mut Context<Self>) { let settings = TerminalSettings::get_global(cx); + let breadcrumb_visibility_changed = self.show_breadcrumbs != settings.toolbar.breadcrumbs; self.show_breadcrumbs = settings.toolbar.breadcrumbs; let new_cursor_shape = settings.cursor_shape.unwrap_or_default(); @@ -441,6 +442,9 @@ impl TerminalView { }); } + if breadcrumb_visibility_changed { + cx.emit(ItemEvent::UpdateBreadcrumbs); + } cx.notify(); } @@ -497,25 +501,14 @@ impl TerminalView { }; let line_height = terminal.last_content().terminal_bounds.line_height; - let mut terminal_lines = terminal.total_lines(); let viewport_lines = terminal.viewport_lines(); - if terminal.total_lines() == terminal.viewport_lines() { - let mut last_line = None; - for cell in terminal.last_content.cells.iter().rev() { - if !is_blank(cell) { - break; - } - - let last_line = last_line.get_or_insert(cell.point.line); - if *last_line != cell.point.line { - terminal_lines -= 1; - } - *last_line = cell.point.line; - } - } - + let cursor = point_to_viewport( + terminal.last_content.display_offset, + terminal.last_content.cursor.point, + ) + .unwrap_or_default(); let max_scroll_top_in_lines = - (block.height as usize).saturating_sub(viewport_lines.saturating_sub(terminal_lines)); + (block.height as usize).saturating_sub(viewport_lines.saturating_sub(cursor.line + 1)); max_scroll_top_in_lines as f32 * line_height } @@ -715,7 +708,7 @@ impl TerminalView { ///Attempt to paste the clipboard into the terminal fn copy(&mut self, _: &Copy, _: &mut Window, cx: &mut Context<Self>) { - self.terminal.update(cx, |term, _| term.copy()); + self.terminal.update(cx, |term, _| term.copy(None)); cx.notify(); } @@ -1598,7 +1591,7 @@ impl Item for TerminalView { let (icon, icon_color, rerun_button) = match terminal.task() { Some(terminal_task) => match &terminal_task.status { TaskStatus::Running => ( - IconName::Play, + IconName::PlayOutlined, Color::Disabled, TerminalView::rerun_button(&terminal_task), ), diff --git a/crates/text/src/anchor.rs b/crates/text/src/anchor.rs index 83a4fc84298f855fc4185a0b8bdc428cfe67856b..c4778216e05ee811e46fabea7eed22171c2498ff 100644 --- a/crates/text/src/anchor.rs +++ b/crates/text/src/anchor.rs @@ -3,7 +3,7 @@ use crate::{ locator::Locator, }; use std::{cmp::Ordering, fmt::Debug, ops::Range}; -use sum_tree::Bias; +use sum_tree::{Bias, Dimensions}; /// A timestamped position in a buffer #[derive(Copy, Clone, Eq, PartialEq, Debug, Hash, Default)] @@ -99,9 +99,13 @@ impl Anchor { } else if self.buffer_id != Some(buffer.remote_id) { false } else { - let fragment_id = buffer.fragment_id_for_anchor(self); - let mut fragment_cursor = buffer.fragments.cursor::<(Option<&Locator>, usize)>(&None); - fragment_cursor.seek(&Some(fragment_id), Bias::Left, &None); + let Some(fragment_id) = buffer.try_fragment_id_for_anchor(self) else { + return false; + }; + let mut fragment_cursor = buffer + .fragments + .cursor::<Dimensions<Option<&Locator>, usize>>(&None); + fragment_cursor.seek(&Some(fragment_id), Bias::Left); fragment_cursor .item() .map_or(false, |fragment| fragment.visible) diff --git a/crates/text/src/text.rs b/crates/text/src/text.rs index a2742081f4b79eeff92cd2fb8a02890d1523fa5a..9f7e49d24dfe8b14ccae5cffb18bb1c65382e3cb 100644 --- a/crates/text/src/text.rs +++ b/crates/text/src/text.rs @@ -37,7 +37,7 @@ use std::{ }; pub use subscription::*; pub use sum_tree::Bias; -use sum_tree::{FilterCursor, SumTree, TreeMap, TreeSet}; +use sum_tree::{Dimensions, FilterCursor, SumTree, TreeMap, TreeSet}; use undo_map::UndoMap; #[cfg(any(test, feature = "test-support"))] @@ -320,7 +320,39 @@ impl History { last_edit_at: now, suppress_grouping: false, }); - self.redo_stack.clear(); + } + + /// Differs from `push_transaction` in that it does not clear the redo + /// stack. Intended to be used to create a parent transaction to merge + /// potential child transactions into. + /// + /// The caller is responsible for removing it from the undo history using + /// `forget_transaction` if no edits are merged into it. Otherwise, if edits + /// are merged into this transaction, the caller is responsible for ensuring + /// the redo stack is cleared. The easiest way to ensure the redo stack is + /// cleared is to create transactions with the usual `start_transaction` and + /// `end_transaction` methods and merging the resulting transactions into + /// the transaction created by this method + fn push_empty_transaction( + &mut self, + start: clock::Global, + now: Instant, + clock: &mut clock::Lamport, + ) -> TransactionId { + assert_eq!(self.transaction_depth, 0); + let id = clock.tick(); + let transaction = Transaction { + id, + start, + edit_ids: Vec::new(), + }; + self.undo_stack.push(HistoryEntry { + transaction, + first_edit_at: now, + last_edit_at: now, + suppress_grouping: false, + }); + id } fn push_undo(&mut self, op_id: clock::Lamport) { @@ -681,7 +713,7 @@ impl Buffer { let mut base_text = base_text.into(); let line_ending = LineEnding::detect(&base_text); LineEnding::normalize(&mut base_text); - Self::new_normalized(replica_id, remote_id, line_ending, Rope::from(base_text)) + Self::new_normalized(replica_id, remote_id, line_ending, Rope::from(&*base_text)) } pub fn new_normalized( @@ -824,14 +856,13 @@ impl Buffer { let mut new_ropes = RopeBuilder::new(self.visible_text.cursor(0), self.deleted_text.cursor(0)); let mut old_fragments = self.fragments.cursor::<FragmentTextSummary>(&None); - let mut new_fragments = - old_fragments.slice(&edits.peek().unwrap().0.start, Bias::Right, &None); + let mut new_fragments = old_fragments.slice(&edits.peek().unwrap().0.start, Bias::Right); new_ropes.append(new_fragments.summary().text); let mut fragment_start = old_fragments.start().visible; for (range, new_text) in edits { let new_text = LineEnding::normalize_arc(new_text.into()); - let fragment_end = old_fragments.end(&None).visible; + let fragment_end = old_fragments.end().visible; // If the current fragment ends before this range, then jump ahead to the first fragment // that extends past the start of this range, reusing any intervening fragments. @@ -847,10 +878,10 @@ impl Buffer { new_ropes.push_fragment(&suffix, suffix.visible); new_fragments.push(suffix, &None); } - old_fragments.next(&None); + old_fragments.next(); } - let slice = old_fragments.slice(&range.start, Bias::Right, &None); + let slice = old_fragments.slice(&range.start, Bias::Right); new_ropes.append(slice.summary().text); new_fragments.append(slice, &None); fragment_start = old_fragments.start().visible; @@ -903,7 +934,7 @@ impl Buffer { // portions as deleted. while fragment_start < range.end { let fragment = old_fragments.item().unwrap(); - let fragment_end = old_fragments.end(&None).visible; + let fragment_end = old_fragments.end().visible; let mut intersection = fragment.clone(); let intersection_end = cmp::min(range.end, fragment_end); if fragment.visible { @@ -930,7 +961,7 @@ impl Buffer { fragment_start = intersection_end; } if fragment_end <= range.end { - old_fragments.next(&None); + old_fragments.next(); } } @@ -942,7 +973,7 @@ impl Buffer { // If the current fragment has been partially consumed, then consume the rest of it // and advance to the next fragment before slicing. if fragment_start > old_fragments.start().visible { - let fragment_end = old_fragments.end(&None).visible; + let fragment_end = old_fragments.end().visible; if fragment_end > fragment_start { let mut suffix = old_fragments.item().unwrap().clone(); suffix.len = fragment_end - fragment_start; @@ -951,10 +982,10 @@ impl Buffer { new_ropes.push_fragment(&suffix, suffix.visible); new_fragments.push(suffix, &None); } - old_fragments.next(&None); + old_fragments.next(); } - let suffix = old_fragments.suffix(&None); + let suffix = old_fragments.suffix(); new_ropes.append(suffix.summary().text); new_fragments.append(suffix, &None); let (visible_text, deleted_text) = new_ropes.finish(); @@ -1040,17 +1071,16 @@ impl Buffer { let mut insertion_offset = 0; let mut new_ropes = RopeBuilder::new(self.visible_text.cursor(0), self.deleted_text.cursor(0)); - let mut old_fragments = self.fragments.cursor::<(VersionedFullOffset, usize)>(&cx); - let mut new_fragments = old_fragments.slice( - &VersionedFullOffset::Offset(ranges[0].start), - Bias::Left, - &cx, - ); + let mut old_fragments = self + .fragments + .cursor::<Dimensions<VersionedFullOffset, usize>>(&cx); + let mut new_fragments = + old_fragments.slice(&VersionedFullOffset::Offset(ranges[0].start), Bias::Left); new_ropes.append(new_fragments.summary().text); let mut fragment_start = old_fragments.start().0.full_offset(); for (range, new_text) in edits { - let fragment_end = old_fragments.end(&cx).0.full_offset(); + let fragment_end = old_fragments.end().0.full_offset(); // If the current fragment ends before this range, then jump ahead to the first fragment // that extends past the start of this range, reusing any intervening fragments. @@ -1067,18 +1097,18 @@ impl Buffer { new_ropes.push_fragment(&suffix, suffix.visible); new_fragments.push(suffix, &None); } - old_fragments.next(&cx); + old_fragments.next(); } let slice = - old_fragments.slice(&VersionedFullOffset::Offset(range.start), Bias::Left, &cx); + old_fragments.slice(&VersionedFullOffset::Offset(range.start), Bias::Left); new_ropes.append(slice.summary().text); new_fragments.append(slice, &None); fragment_start = old_fragments.start().0.full_offset(); } // If we are at the end of a non-concurrent fragment, advance to the next one. - let fragment_end = old_fragments.end(&cx).0.full_offset(); + let fragment_end = old_fragments.end().0.full_offset(); if fragment_end == range.start && fragment_end > fragment_start { let mut fragment = old_fragments.item().unwrap().clone(); fragment.len = fragment_end.0 - fragment_start.0; @@ -1086,7 +1116,7 @@ impl Buffer { new_insertions.push(InsertionFragment::insert_new(&fragment)); new_ropes.push_fragment(&fragment, fragment.visible); new_fragments.push(fragment, &None); - old_fragments.next(&cx); + old_fragments.next(); fragment_start = old_fragments.start().0.full_offset(); } @@ -1096,7 +1126,7 @@ impl Buffer { if fragment_start == range.start && fragment.timestamp > timestamp { new_ropes.push_fragment(fragment, fragment.visible); new_fragments.push(fragment.clone(), &None); - old_fragments.next(&cx); + old_fragments.next(); debug_assert_eq!(fragment_start, range.start); } else { break; @@ -1152,7 +1182,7 @@ impl Buffer { // portions as deleted. while fragment_start < range.end { let fragment = old_fragments.item().unwrap(); - let fragment_end = old_fragments.end(&cx).0.full_offset(); + let fragment_end = old_fragments.end().0.full_offset(); let mut intersection = fragment.clone(); let intersection_end = cmp::min(range.end, fragment_end); if fragment.was_visible(version, &self.undo_map) { @@ -1181,7 +1211,7 @@ impl Buffer { fragment_start = intersection_end; } if fragment_end <= range.end { - old_fragments.next(&cx); + old_fragments.next(); } } } @@ -1189,7 +1219,7 @@ impl Buffer { // If the current fragment has been partially consumed, then consume the rest of it // and advance to the next fragment before slicing. if fragment_start > old_fragments.start().0.full_offset() { - let fragment_end = old_fragments.end(&cx).0.full_offset(); + let fragment_end = old_fragments.end().0.full_offset(); if fragment_end > fragment_start { let mut suffix = old_fragments.item().unwrap().clone(); suffix.len = fragment_end.0 - fragment_start.0; @@ -1198,10 +1228,10 @@ impl Buffer { new_ropes.push_fragment(&suffix, suffix.visible); new_fragments.push(suffix, &None); } - old_fragments.next(&cx); + old_fragments.next(); } - let suffix = old_fragments.suffix(&cx); + let suffix = old_fragments.suffix(); new_ropes.append(suffix.summary().text); new_fragments.append(suffix, &None); let (visible_text, deleted_text) = new_ropes.finish(); @@ -1250,7 +1280,6 @@ impl Buffer { split_offset: insertion_slice.range.start, }, Bias::Left, - &(), ); } while let Some(item) = insertions_cursor.item() { @@ -1260,7 +1289,7 @@ impl Buffer { break; } fragment_ids.push(&item.fragment_id); - insertions_cursor.next(&()); + insertions_cursor.next(); } } fragment_ids.sort_unstable(); @@ -1271,13 +1300,15 @@ impl Buffer { self.snapshot.undo_map.insert(undo); let mut edits = Patch::default(); - let mut old_fragments = self.fragments.cursor::<(Option<&Locator>, usize)>(&None); + let mut old_fragments = self + .fragments + .cursor::<Dimensions<Option<&Locator>, usize>>(&None); let mut new_fragments = SumTree::new(&None); let mut new_ropes = RopeBuilder::new(self.visible_text.cursor(0), self.deleted_text.cursor(0)); for fragment_id in self.fragment_ids_for_edits(undo.counts.keys()) { - let preceding_fragments = old_fragments.slice(&Some(fragment_id), Bias::Left, &None); + let preceding_fragments = old_fragments.slice(&Some(fragment_id), Bias::Left); new_ropes.append(preceding_fragments.summary().text); new_fragments.append(preceding_fragments, &None); @@ -1304,11 +1335,11 @@ impl Buffer { new_ropes.push_fragment(&fragment, fragment_was_visible); new_fragments.push(fragment, &None); - old_fragments.next(&None); + old_fragments.next(); } } - let suffix = old_fragments.suffix(&None); + let suffix = old_fragments.suffix(); new_ropes.append(suffix.summary().text); new_fragments.append(suffix, &None); @@ -1495,6 +1526,24 @@ impl Buffer { self.history.push_transaction(transaction, now); } + /// Differs from `push_transaction` in that it does not clear the redo stack. + /// The caller responsible for + /// Differs from `push_transaction` in that it does not clear the redo + /// stack. Intended to be used to create a parent transaction to merge + /// potential child transactions into. + /// + /// The caller is responsible for removing it from the undo history using + /// `forget_transaction` if no edits are merged into it. Otherwise, if edits + /// are merged into this transaction, the caller is responsible for ensuring + /// the redo stack is cleared. The easiest way to ensure the redo stack is + /// cleared is to create transactions with the usual `start_transaction` and + /// `end_transaction` methods and merging the resulting transactions into + /// the transaction created by this method + pub fn push_empty_transaction(&mut self, now: Instant) -> TransactionId { + self.history + .push_empty_transaction(self.version.clone(), now, &mut self.lamport_clock) + } + pub fn edited_ranges_for_transaction_id<D>( &self, transaction_id: TransactionId, @@ -1516,12 +1565,14 @@ impl Buffer { D: TextDimension, { // get fragment ranges - let mut cursor = self.fragments.cursor::<(Option<&Locator>, usize)>(&None); + let mut cursor = self + .fragments + .cursor::<Dimensions<Option<&Locator>, usize>>(&None); let offset_ranges = self .fragment_ids_for_edits(edit_ids.into_iter()) .into_iter() .filter_map(move |fragment_id| { - cursor.seek_forward(&Some(fragment_id), Bias::Left, &None); + cursor.seek_forward(&Some(fragment_id), Bias::Left); let fragment = cursor.item()?; let start_offset = cursor.start().1; let end_offset = start_offset + if fragment.visible { fragment.len } else { 0 }; @@ -1743,7 +1794,7 @@ impl Buffer { let mut cursor = self.snapshot.fragments.cursor::<Option<&Locator>>(&None); for insertion_fragment in self.snapshot.insertions.cursor::<()>(&()) { - cursor.seek(&Some(&insertion_fragment.fragment_id), Bias::Left, &None); + cursor.seek(&Some(&insertion_fragment.fragment_id), Bias::Left); let fragment = cursor.item().unwrap(); assert_eq!(insertion_fragment.fragment_id, fragment.id); assert_eq!(insertion_fragment.split_offset, fragment.insertion_offset); @@ -1862,7 +1913,7 @@ impl BufferSnapshot { .filter::<_, FragmentTextSummary>(&None, move |summary| { !version.observed_all(&summary.max_version) }); - cursor.next(&None); + cursor.next(); let mut visible_cursor = self.visible_text.cursor(0); let mut deleted_cursor = self.deleted_text.cursor(0); @@ -1875,18 +1926,18 @@ impl BufferSnapshot { if fragment.was_visible(version, &self.undo_map) { if fragment.visible { - let text = visible_cursor.slice(cursor.end(&None).visible); + let text = visible_cursor.slice(cursor.end().visible); rope.append(text); } else { deleted_cursor.seek_forward(cursor.start().deleted); - let text = deleted_cursor.slice(cursor.end(&None).deleted); + let text = deleted_cursor.slice(cursor.end().deleted); rope.append(text); } } else if fragment.visible { - visible_cursor.seek_forward(cursor.end(&None).visible); + visible_cursor.seek_forward(cursor.end().visible); } - cursor.next(&None); + cursor.next(); } if cursor.start().visible > visible_cursor.offset() { @@ -2187,7 +2238,9 @@ impl BufferSnapshot { { let anchors = anchors.into_iter(); let mut insertion_cursor = self.insertions.cursor::<InsertionFragmentKey>(&()); - let mut fragment_cursor = self.fragments.cursor::<(Option<&Locator>, usize)>(&None); + let mut fragment_cursor = self + .fragments + .cursor::<Dimensions<Option<&Locator>, usize>>(&None); let mut text_cursor = self.visible_text.cursor(0); let mut position = D::zero(&()); @@ -2202,7 +2255,7 @@ impl BufferSnapshot { timestamp: anchor.timestamp, split_offset: anchor.offset, }; - insertion_cursor.seek(&anchor_key, anchor.bias, &()); + insertion_cursor.seek(&anchor_key, anchor.bias); if let Some(insertion) = insertion_cursor.item() { let comparison = sum_tree::KeyedItem::key(insertion).cmp(&anchor_key); if comparison == Ordering::Greater @@ -2210,15 +2263,15 @@ impl BufferSnapshot { && comparison == Ordering::Equal && anchor.offset > 0) { - insertion_cursor.prev(&()); + insertion_cursor.prev(); } } else { - insertion_cursor.prev(&()); + insertion_cursor.prev(); } let insertion = insertion_cursor.item().expect("invalid insertion"); assert_eq!(insertion.timestamp, anchor.timestamp, "invalid insertion"); - fragment_cursor.seek_forward(&Some(&insertion.fragment_id), Bias::Left, &None); + fragment_cursor.seek_forward(&Some(&insertion.fragment_id), Bias::Left); let fragment = fragment_cursor.item().unwrap(); let mut fragment_offset = fragment_cursor.start().1; if fragment.visible { @@ -2249,7 +2302,7 @@ impl BufferSnapshot { split_offset: anchor.offset, }; let mut insertion_cursor = self.insertions.cursor::<InsertionFragmentKey>(&()); - insertion_cursor.seek(&anchor_key, anchor.bias, &()); + insertion_cursor.seek(&anchor_key, anchor.bias); if let Some(insertion) = insertion_cursor.item() { let comparison = sum_tree::KeyedItem::key(insertion).cmp(&anchor_key); if comparison == Ordering::Greater @@ -2257,10 +2310,10 @@ impl BufferSnapshot { && comparison == Ordering::Equal && anchor.offset > 0) { - insertion_cursor.prev(&()); + insertion_cursor.prev(); } } else { - insertion_cursor.prev(&()); + insertion_cursor.prev(); } let Some(insertion) = insertion_cursor @@ -2273,8 +2326,10 @@ impl BufferSnapshot { ); }; - let mut fragment_cursor = self.fragments.cursor::<(Option<&Locator>, usize)>(&None); - fragment_cursor.seek(&Some(&insertion.fragment_id), Bias::Left, &None); + let mut fragment_cursor = self + .fragments + .cursor::<Dimensions<Option<&Locator>, usize>>(&None); + fragment_cursor.seek(&Some(&insertion.fragment_id), Bias::Left); let fragment = fragment_cursor.item().unwrap(); let mut fragment_offset = fragment_cursor.start().1; if fragment.visible { @@ -2285,17 +2340,26 @@ impl BufferSnapshot { } fn fragment_id_for_anchor(&self, anchor: &Anchor) -> &Locator { + self.try_fragment_id_for_anchor(anchor).unwrap_or_else(|| { + panic!( + "invalid anchor {:?}. buffer id: {}, version: {:?}", + anchor, self.remote_id, self.version, + ) + }) + } + + fn try_fragment_id_for_anchor(&self, anchor: &Anchor) -> Option<&Locator> { if *anchor == Anchor::MIN { - Locator::min_ref() + Some(Locator::min_ref()) } else if *anchor == Anchor::MAX { - Locator::max_ref() + Some(Locator::max_ref()) } else { let anchor_key = InsertionFragmentKey { timestamp: anchor.timestamp, split_offset: anchor.offset, }; let mut insertion_cursor = self.insertions.cursor::<InsertionFragmentKey>(&()); - insertion_cursor.seek(&anchor_key, anchor.bias, &()); + insertion_cursor.seek(&anchor_key, anchor.bias); if let Some(insertion) = insertion_cursor.item() { let comparison = sum_tree::KeyedItem::key(insertion).cmp(&anchor_key); if comparison == Ordering::Greater @@ -2303,26 +2367,18 @@ impl BufferSnapshot { && comparison == Ordering::Equal && anchor.offset > 0) { - insertion_cursor.prev(&()); + insertion_cursor.prev(); } } else { - insertion_cursor.prev(&()); + insertion_cursor.prev(); } - let Some(insertion) = insertion_cursor.item().filter(|insertion| { - if cfg!(debug_assertions) { - insertion.timestamp == anchor.timestamp - } else { - true - } - }) else { - panic!( - "invalid anchor {:?}. buffer id: {}, version: {:?}", - anchor, self.remote_id, self.version - ); - }; - - &insertion.fragment_id + insertion_cursor + .item() + .filter(|insertion| { + !cfg!(debug_assertions) || insertion.timestamp == anchor.timestamp + }) + .map(|insertion| &insertion.fragment_id) } } @@ -2345,7 +2401,7 @@ impl BufferSnapshot { Anchor::MAX } else { let mut fragment_cursor = self.fragments.cursor::<usize>(&None); - fragment_cursor.seek(&offset, bias, &None); + fragment_cursor.seek(&offset, bias); let fragment = fragment_cursor.item().unwrap(); let overshoot = offset - *fragment_cursor.start(); Anchor { @@ -2425,15 +2481,15 @@ impl BufferSnapshot { let mut cursor = self.fragments.filter(&None, move |summary| { !since.observed_all(&summary.max_version) }); - cursor.next(&None); + cursor.next(); Some(cursor) }; let mut cursor = self .fragments - .cursor::<(Option<&Locator>, FragmentTextSummary)>(&None); + .cursor::<Dimensions<Option<&Locator>, FragmentTextSummary>>(&None); let start_fragment_id = self.fragment_id_for_anchor(&range.start); - cursor.seek(&Some(start_fragment_id), Bias::Left, &None); + cursor.seek(&Some(start_fragment_id), Bias::Left); let mut visible_start = cursor.start().1.visible; let mut deleted_start = cursor.start().1.deleted; if let Some(fragment) = cursor.item() { @@ -2466,7 +2522,7 @@ impl BufferSnapshot { let mut cursor = self.fragments.filter::<_, usize>(&None, move |summary| { !since.observed_all(&summary.max_version) }); - cursor.next(&None); + cursor.next(); while let Some(fragment) = cursor.item() { if fragment.id > *end_fragment_id { break; @@ -2478,7 +2534,7 @@ impl BufferSnapshot { return true; } } - cursor.next(&None); + cursor.next(); } } false @@ -2489,14 +2545,14 @@ impl BufferSnapshot { let mut cursor = self.fragments.filter::<_, usize>(&None, move |summary| { !since.observed_all(&summary.max_version) }); - cursor.next(&None); + cursor.next(); while let Some(fragment) = cursor.item() { let was_visible = fragment.was_visible(since, &self.undo_map); let is_visible = fragment.visible; if was_visible != is_visible { return true; } - cursor.next(&None); + cursor.next(); } } false @@ -2601,7 +2657,7 @@ impl<D: TextDimension + Ord, F: FnMut(&FragmentSummary) -> bool> Iterator for Ed while let Some(fragment) = cursor.item() { if fragment.id < *self.range.start.0 { - cursor.next(&None); + cursor.next(); continue; } else if fragment.id > *self.range.end.0 { break; @@ -2634,7 +2690,7 @@ impl<D: TextDimension + Ord, F: FnMut(&FragmentSummary) -> bool> Iterator for Ed }; if !fragment.was_visible(self.since, self.undos) && fragment.visible { - let mut visible_end = cursor.end(&None).visible; + let mut visible_end = cursor.end().visible; if fragment.id == *self.range.end.0 { visible_end = cmp::min( visible_end, @@ -2660,7 +2716,7 @@ impl<D: TextDimension + Ord, F: FnMut(&FragmentSummary) -> bool> Iterator for Ed self.new_end = new_end; } else if fragment.was_visible(self.since, self.undos) && !fragment.visible { - let mut deleted_end = cursor.end(&None).deleted; + let mut deleted_end = cursor.end().deleted; if fragment.id == *self.range.end.0 { deleted_end = cmp::min( deleted_end, @@ -2690,7 +2746,7 @@ impl<D: TextDimension + Ord, F: FnMut(&FragmentSummary) -> bool> Iterator for Ed self.old_end = old_end; } - cursor.next(&None); + cursor.next(); } pending_edit diff --git a/crates/text/src/undo_map.rs b/crates/text/src/undo_map.rs index ed363cfc6b6d77aa6a7e091acac0b0a76824b61c..6a409189fa8d2a9bd3bc821e37b9923b5ed884dd 100644 --- a/crates/text/src/undo_map.rs +++ b/crates/text/src/undo_map.rs @@ -74,7 +74,6 @@ impl UndoMap { undo_id: Default::default(), }, Bias::Left, - &(), ); let mut undo_count = 0; @@ -99,7 +98,6 @@ impl UndoMap { undo_id: Default::default(), }, Bias::Left, - &(), ); let mut undo_count = 0; diff --git a/crates/theme/src/default_colors.rs b/crates/theme/src/default_colors.rs index 3424e0fe04cdbc11544fa81018edba4ff2b357c1..1c3f48b548d3fdd4a2a554b476afaa08dcbae150 100644 --- a/crates/theme/src/default_colors.rs +++ b/crates/theme/src/default_colors.rs @@ -83,6 +83,8 @@ impl ThemeColors { panel_indent_guide: neutral().light_alpha().step_5(), panel_indent_guide_hover: neutral().light_alpha().step_6(), panel_indent_guide_active: neutral().light_alpha().step_6(), + panel_overlay_background: neutral().light().step_2(), + panel_overlay_hover: neutral().light_alpha().step_4(), pane_focused_border: blue().light().step_5(), pane_group_border: neutral().light().step_6(), scrollbar_thumb_background: neutral().light_alpha().step_3(), @@ -206,6 +208,8 @@ impl ThemeColors { panel_indent_guide: neutral().dark_alpha().step_4(), panel_indent_guide_hover: neutral().dark_alpha().step_6(), panel_indent_guide_active: neutral().dark_alpha().step_6(), + panel_overlay_background: neutral().dark().step_2(), + panel_overlay_hover: neutral().dark_alpha().step_4(), pane_focused_border: blue().dark().step_5(), pane_group_border: neutral().dark().step_6(), scrollbar_thumb_background: neutral().dark_alpha().step_3(), diff --git a/crates/theme/src/fallback_themes.rs b/crates/theme/src/fallback_themes.rs index 5e9967d4603a5bac8c9f1a7e461c7319f52f82d7..4d77dd5d81dfc45427bda4034ff7a2085dbcb489 100644 --- a/crates/theme/src/fallback_themes.rs +++ b/crates/theme/src/fallback_themes.rs @@ -59,6 +59,7 @@ pub(crate) fn zed_default_dark() -> Theme { let bg = hsla(215. / 360., 12. / 100., 15. / 100., 1.); let editor = hsla(220. / 360., 12. / 100., 18. / 100., 1.); let elevated_surface = hsla(225. / 360., 12. / 100., 17. / 100., 1.); + let hover = hsla(225.0 / 360., 11.8 / 100., 26.7 / 100., 1.0); let blue = hsla(207.8 / 360., 81. / 100., 66. / 100., 1.0); let gray = hsla(218.8 / 360., 10. / 100., 40. / 100., 1.0); @@ -108,14 +109,14 @@ pub(crate) fn zed_default_dark() -> Theme { surface_background: bg, background: bg, element_background: hsla(223.0 / 360., 13. / 100., 21. / 100., 1.0), - element_hover: hsla(225.0 / 360., 11.8 / 100., 26.7 / 100., 1.0), + element_hover: hover, element_active: hsla(220.0 / 360., 11.8 / 100., 20.0 / 100., 1.0), element_selected: hsla(224.0 / 360., 11.3 / 100., 26.1 / 100., 1.0), element_disabled: SystemColors::default().transparent, element_selection_background: player.local().selection.alpha(0.25), drop_target_background: hsla(220.0 / 360., 8.3 / 100., 21.4 / 100., 1.0), ghost_element_background: SystemColors::default().transparent, - ghost_element_hover: hsla(225.0 / 360., 11.8 / 100., 26.7 / 100., 1.0), + ghost_element_hover: hover, ghost_element_active: hsla(220.0 / 360., 11.8 / 100., 20.0 / 100., 1.0), ghost_element_selected: hsla(224.0 / 360., 11.3 / 100., 26.1 / 100., 1.0), ghost_element_disabled: SystemColors::default().transparent, @@ -202,10 +203,12 @@ pub(crate) fn zed_default_dark() -> Theme { panel_indent_guide: hsla(228. / 360., 8. / 100., 25. / 100., 1.), panel_indent_guide_hover: hsla(225. / 360., 13. / 100., 12. / 100., 1.), panel_indent_guide_active: hsla(225. / 360., 13. / 100., 12. / 100., 1.), + panel_overlay_background: bg, + panel_overlay_hover: hover, pane_focused_border: blue, pane_group_border: hsla(225. / 360., 13. / 100., 12. / 100., 1.), scrollbar_thumb_background: gpui::transparent_black(), - scrollbar_thumb_hover_background: hsla(225.0 / 360., 11.8 / 100., 26.7 / 100., 1.0), + scrollbar_thumb_hover_background: hover, scrollbar_thumb_active_background: hsla( 225.0 / 360., 11.8 / 100., diff --git a/crates/theme/src/icon_theme.rs b/crates/theme/src/icon_theme.rs index 09f5df06b05bfa47b3a9d0e7b32a54f10d999d76..5bd69c173340fa0cb31aa334072084f33a7c7281 100644 --- a/crates/theme/src/icon_theme.rs +++ b/crates/theme/src/icon_theme.rs @@ -152,6 +152,7 @@ const FILE_SUFFIXES_BY_ICON_KEY: &[(&str, &[&str])] = &[ ("javascript", &["cjs", "js", "mjs"]), ("json", &["json"]), ("julia", &["jl"]), + ("kdl", &["kdl"]), ("kotlin", &["kt"]), ("lock", &["lock"]), ("log", &["log"]), @@ -182,6 +183,7 @@ const FILE_SUFFIXES_BY_ICON_KEY: &[(&str, &[&str])] = &[ ], ), ("prisma", &["prisma"]), + ("puppet", &["pp"]), ("python", &["py"]), ("r", &["r", "R"]), ("react", &["cjsx", "ctsx", "jsx", "mjsx", "mtsx", "tsx"]), @@ -216,6 +218,7 @@ const FILE_SUFFIXES_BY_ICON_KEY: &[(&str, &[&str])] = &[ "stylelintrc.yml", ], ), + ("surrealql", &["surql"]), ("svelte", &["svelte"]), ("swift", &["swift"]), ("tcl", &["tcl"]), @@ -314,6 +317,7 @@ const FILE_ICONS: &[(&str, &str)] = &[ ("javascript", "icons/file_icons/javascript.svg"), ("json", "icons/file_icons/code.svg"), ("julia", "icons/file_icons/julia.svg"), + ("kdl", "icons/file_icons/kdl.svg"), ("kotlin", "icons/file_icons/kotlin.svg"), ("lock", "icons/file_icons/lock.svg"), ("log", "icons/file_icons/info.svg"), @@ -328,6 +332,7 @@ const FILE_ICONS: &[(&str, &str)] = &[ ("php", "icons/file_icons/php.svg"), ("prettier", "icons/file_icons/prettier.svg"), ("prisma", "icons/file_icons/prisma.svg"), + ("puppet", "icons/file_icons/puppet.svg"), ("python", "icons/file_icons/python.svg"), ("r", "icons/file_icons/r.svg"), ("react", "icons/file_icons/react.svg"), @@ -340,6 +345,7 @@ const FILE_ICONS: &[(&str, &str)] = &[ ("solidity", "icons/file_icons/file.svg"), ("storage", "icons/file_icons/database.svg"), ("stylelint", "icons/file_icons/javascript.svg"), + ("surrealql", "icons/file_icons/surrealql.svg"), ("svelte", "icons/file_icons/html.svg"), ("swift", "icons/file_icons/swift.svg"), ("tcl", "icons/file_icons/tcl.svg"), diff --git a/crates/theme/src/schema.rs b/crates/theme/src/schema.rs index b2a13b54b662f106018667de9635a4c896e1993c..bfa2adcedf73ec9d51c25d30785b1e81cd83173e 100644 --- a/crates/theme/src/schema.rs +++ b/crates/theme/src/schema.rs @@ -4,11 +4,10 @@ use anyhow::Result; use gpui::{FontStyle, FontWeight, HighlightStyle, Hsla, WindowBackgroundAppearance}; use indexmap::IndexMap; use palette::FromColor; -use schemars::{JsonSchema, json_schema}; +use schemars::{JsonSchema, JsonSchema_repr}; use serde::{Deserialize, Deserializer, Serialize}; use serde_json::Value; use serde_repr::{Deserialize_repr, Serialize_repr}; -use std::borrow::Cow; use crate::{StatusColorsRefinement, ThemeColorsRefinement}; @@ -352,6 +351,12 @@ pub struct ThemeColorsContent { #[serde(rename = "panel.indent_guide_active")] pub panel_indent_guide_active: Option<String>, + #[serde(rename = "panel.overlay_background")] + pub panel_overlay_background: Option<String>, + + #[serde(rename = "panel.overlay_hover")] + pub panel_overlay_hover: Option<String>, + #[serde(rename = "pane.focused_border")] pub pane_focused_border: Option<String>, @@ -675,6 +680,14 @@ impl ThemeColorsContent { .scrollbar_thumb_border .as_ref() .and_then(|color| try_parse_color(color).ok()); + let element_hover = self + .element_hover + .as_ref() + .and_then(|color| try_parse_color(color).ok()); + let panel_background = self + .panel_background + .as_ref() + .and_then(|color| try_parse_color(color).ok()); ThemeColorsRefinement { border, border_variant: self @@ -713,10 +726,7 @@ impl ThemeColorsContent { .element_background .as_ref() .and_then(|color| try_parse_color(color).ok()), - element_hover: self - .element_hover - .as_ref() - .and_then(|color| try_parse_color(color).ok()), + element_hover, element_active: self .element_active .as_ref() @@ -833,10 +843,7 @@ impl ThemeColorsContent { .search_match_background .as_ref() .and_then(|color| try_parse_color(color).ok()), - panel_background: self - .panel_background - .as_ref() - .and_then(|color| try_parse_color(color).ok()), + panel_background, panel_focused_border: self .panel_focused_border .as_ref() @@ -853,6 +860,16 @@ impl ThemeColorsContent { .panel_indent_guide_active .as_ref() .and_then(|color| try_parse_color(color).ok()), + panel_overlay_background: self + .panel_overlay_background + .as_ref() + .and_then(|color| try_parse_color(color).ok()) + .or(panel_background), + panel_overlay_hover: self + .panel_overlay_hover + .as_ref() + .and_then(|color| try_parse_color(color).ok()) + .or(element_hover), pane_focused_border: self .pane_focused_border .as_ref() @@ -1486,7 +1503,7 @@ impl From<FontStyleContent> for FontStyle { } } -#[derive(Debug, Clone, Copy, Serialize_repr, Deserialize_repr, PartialEq)] +#[derive(Debug, Clone, Copy, Serialize_repr, Deserialize_repr, JsonSchema_repr, PartialEq)] #[repr(u16)] pub enum FontWeightContent { Thin = 100, @@ -1500,19 +1517,6 @@ pub enum FontWeightContent { Black = 900, } -impl JsonSchema for FontWeightContent { - fn schema_name() -> Cow<'static, str> { - "FontWeightContent".into() - } - - fn json_schema(_: &mut schemars::SchemaGenerator) -> schemars::Schema { - json_schema!({ - "type": "integer", - "enum": [100, 200, 300, 400, 500, 600, 700, 800, 900] - }) - } -} - impl From<FontWeightContent> for FontWeight { fn from(value: FontWeightContent) -> Self { match value { diff --git a/crates/theme/src/settings.rs b/crates/theme/src/settings.rs index 1c4c90a475ca3fa155d4d7169f3f72d37193a747..6d19494f4009e92768e91dd2b7db0ba7a3880ecb 100644 --- a/crates/theme/src/settings.rs +++ b/crates/theme/src/settings.rs @@ -438,7 +438,7 @@ fn default_font_fallbacks() -> Option<FontFallbacks> { impl ThemeSettingsContent { /// Sets the theme for the given appearance to the theme with the specified name. - pub fn set_theme(&mut self, theme_name: String, appearance: Appearance) { + pub fn set_theme(&mut self, theme_name: impl Into<Arc<str>>, appearance: Appearance) { if let Some(selection) = self.theme.as_mut() { let theme_to_update = match selection { ThemeSelection::Static(theme) => theme, @@ -867,6 +867,8 @@ impl settings::Settings for ThemeSettings { .user .into_iter() .chain(sources.release_channel) + .chain(sources.operating_system) + .chain(sources.profile) .chain(sources.server) { if let Some(value) = value.ui_density { diff --git a/crates/theme/src/styles/colors.rs b/crates/theme/src/styles/colors.rs index 7c5270e3612dfbe1fb6b1ec45dc4787dac0e9463..aab11803f4d810453f5bfc286624ea8e4efb4a61 100644 --- a/crates/theme/src/styles/colors.rs +++ b/crates/theme/src/styles/colors.rs @@ -131,6 +131,12 @@ pub struct ThemeColors { pub panel_indent_guide: Hsla, pub panel_indent_guide_hover: Hsla, pub panel_indent_guide_active: Hsla, + + /// The color of the overlay surface on top of panel. + pub panel_overlay_background: Hsla, + /// The color of the overlay surface on top of panel when hovered over. + pub panel_overlay_hover: Hsla, + pub pane_focused_border: Hsla, pub pane_group_border: Hsla, /// The color of the scrollbar thumb. @@ -326,6 +332,8 @@ pub enum ThemeColorField { PanelIndentGuide, PanelIndentGuideHover, PanelIndentGuideActive, + PanelOverlayBackground, + PanelOverlayHover, PaneFocusedBorder, PaneGroupBorder, ScrollbarThumbBackground, @@ -438,6 +446,8 @@ impl ThemeColors { ThemeColorField::PanelIndentGuide => self.panel_indent_guide, ThemeColorField::PanelIndentGuideHover => self.panel_indent_guide_hover, ThemeColorField::PanelIndentGuideActive => self.panel_indent_guide_active, + ThemeColorField::PanelOverlayBackground => self.panel_overlay_background, + ThemeColorField::PanelOverlayHover => self.panel_overlay_hover, ThemeColorField::PaneFocusedBorder => self.pane_focused_border, ThemeColorField::PaneGroupBorder => self.pane_group_border, ThemeColorField::ScrollbarThumbBackground => self.scrollbar_thumb_background, diff --git a/crates/theme_importer/src/vscode/converter.rs b/crates/theme_importer/src/vscode/converter.rs index 9a17a4cdd2b13e116b81c86c753ccab83a965c79..0249bdc7c94a5008240bde25153203c10d247a82 100644 --- a/crates/theme_importer/src/vscode/converter.rs +++ b/crates/theme_importer/src/vscode/converter.rs @@ -175,6 +175,8 @@ impl VsCodeThemeConverter { scrollbar_track_background: vscode_editor_background.clone(), scrollbar_track_border: vscode_colors.editor_overview_ruler.border.clone(), minimap_thumb_background: vscode_colors.minimap_slider.background.clone(), + minimap_thumb_hover_background: vscode_colors.minimap_slider.hover_background.clone(), + minimap_thumb_active_background: vscode_colors.minimap_slider.active_background.clone(), editor_foreground: vscode_editor_foreground .clone() .or(vscode_token_colors_foreground.clone()), diff --git a/crates/theme_selector/src/icon_theme_selector.rs b/crates/theme_selector/src/icon_theme_selector.rs index 40ba7bd5a6e8381f3c11331d73aa9215f555ec8f..2d0b9480d58ee5e163674cb0da9cea083ddf54fd 100644 --- a/crates/theme_selector/src/icon_theme_selector.rs +++ b/crates/theme_selector/src/icon_theme_selector.rs @@ -40,7 +40,10 @@ impl IconThemeSelector { impl Render for IconThemeSelector { fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement { - v_flex().w(rems(34.)).child(self.picker.clone()) + v_flex() + .key_context("IconThemeSelector") + .w(rems(34.)) + .child(self.picker.clone()) } } @@ -327,6 +330,7 @@ impl PickerDelegate for IconThemeSelectorDelegate { window.dispatch_action( Box::new(Extensions { category_filter: Some(ExtensionCategoryFilter::IconThemes), + id: None, }), cx, ); diff --git a/crates/theme_selector/src/theme_selector.rs b/crates/theme_selector/src/theme_selector.rs index 09d9877df874f192365a7bd595a62ee3cb108846..ba8bde243ba7b1b02b2ae2a67af5709cc5c94b7e 100644 --- a/crates/theme_selector/src/theme_selector.rs +++ b/crates/theme_selector/src/theme_selector.rs @@ -92,7 +92,10 @@ impl Focusable for ThemeSelector { impl Render for ThemeSelector { fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement { - v_flex().w(rems(34.)).child(self.picker.clone()) + v_flex() + .key_context("ThemeSelector") + .w(rems(34.)) + .child(self.picker.clone()) } } @@ -385,6 +388,7 @@ impl PickerDelegate for ThemeSelectorDelegate { window.dispatch_action( Box::new(Extensions { category_filter: Some(ExtensionCategoryFilter::Themes), + id: None, }), cx, ); diff --git a/crates/title_bar/Cargo.toml b/crates/title_bar/Cargo.toml index 123d0468ac86d6d37d428e73b4fd8de37dce429c..cf178e2850397c5a2398033a02addb73ab615ec9 100644 --- a/crates/title_bar/Cargo.toml +++ b/crates/title_bar/Cargo.toml @@ -27,10 +27,12 @@ test-support = [ ] [dependencies] +anyhow.workspace = true auto_update.workspace = true call.workspace = true chrono.workspace = true client.workspace = true +cloud_llm_client.workspace = true db.workspace = true gpui = { workspace = true, features = ["screen-capture"] } notifications.workspace = true @@ -40,6 +42,7 @@ rpc.workspace = true schemars.workspace = true serde.workspace = true settings.workspace = true +settings_ui.workspace = true smallvec.workspace = true story = { workspace = true, optional = true } telemetry.workspace = true diff --git a/crates/title_bar/src/collab.rs b/crates/title_bar/src/collab.rs index b2a37a4f1c11c00139abe5c555f7ef254cc69f4c..d026b4de14263f442c1ede308da0fe467fe69bba 100644 --- a/crates/title_bar/src/collab.rs +++ b/crates/title_bar/src/collab.rs @@ -1,12 +1,20 @@ +use std::rc::Rc; use std::sync::Arc; use call::{ActiveCall, ParticipantLocation, Room}; use client::{User, proto::PeerId}; -use gpui::{AnyElement, Hsla, IntoElement, MouseButton, Path, Styled, canvas, point}; +use gpui::{ + AnyElement, Hsla, IntoElement, MouseButton, Path, ScreenCaptureSource, Styled, WeakEntity, + canvas, point, +}; use gpui::{App, Task, Window, actions}; use rpc::proto::{self}; use theme::ActiveTheme; -use ui::{Avatar, AvatarAudioStatusIndicator, Facepile, TintColor, Tooltip, prelude::*}; +use ui::{ + Avatar, AvatarAudioStatusIndicator, ContextMenu, ContextMenuItem, Divider, DividerColor, + Facepile, PopoverMenu, SplitButton, SplitButtonStyle, TintColor, Tooltip, prelude::*, +}; +use util::maybe; use workspace::notifications::DetachAndPromptErr; use crate::TitleBar; @@ -23,24 +31,49 @@ actions!( ] ); -fn toggle_screen_sharing(_: &ToggleScreenSharing, window: &mut Window, cx: &mut App) { +fn toggle_screen_sharing( + screen: Option<Rc<dyn ScreenCaptureSource>>, + window: &mut Window, + cx: &mut App, +) { let call = ActiveCall::global(cx).read(cx); if let Some(room) = call.room().cloned() { let toggle_screen_sharing = room.update(cx, |room, cx| { - if room.is_screen_sharing() { + let clicked_on_currently_shared_screen = + room.shared_screen_id().is_some_and(|screen_id| { + Some(screen_id) + == screen + .as_deref() + .and_then(|s| s.metadata().ok().map(|meta| meta.id)) + }); + let should_unshare_current_screen = room.is_sharing_screen(); + let unshared_current_screen = should_unshare_current_screen.then(|| { telemetry::event!( "Screen Share Disabled", room_id = room.id(), channel_id = room.channel_id(), ); - Task::ready(room.unshare_screen(cx)) + room.unshare_screen(clicked_on_currently_shared_screen || screen.is_none(), cx) + }); + if let Some(screen) = screen { + if !should_unshare_current_screen { + telemetry::event!( + "Screen Share Enabled", + room_id = room.id(), + channel_id = room.channel_id(), + ); + } + cx.spawn(async move |room, cx| { + unshared_current_screen.transpose()?; + if !clicked_on_currently_shared_screen { + room.update(cx, |room, cx| room.share_screen(screen, cx))? + .await + } else { + Ok(()) + } + }) } else { - telemetry::event!( - "Screen Share Enabled", - room_id = room.id(), - channel_id = room.channel_id(), - ); - room.share_screen(cx) + Task::ready(Ok(())) } }); toggle_screen_sharing.detach_and_prompt_err("Sharing Screen Failed", window, cx, |e, _, _| Some(format!("{:?}\n\nPlease check that you have given Zed permissions to record your screen in Settings.", e))); @@ -303,13 +336,31 @@ impl TitleBar { let is_muted = room.is_muted(); let muted_by_user = room.muted_by_user(); let is_deafened = room.is_deafened().unwrap_or(false); - let is_screen_sharing = room.is_screen_sharing(); + let is_screen_sharing = room.is_sharing_screen(); let can_use_microphone = room.can_use_microphone(); let can_share_projects = room.can_share_projects(); let screen_sharing_supported = cx.is_screen_capture_supported(); let mut children = Vec::new(); + children.push( + h_flex() + .gap_1() + .child( + IconButton::new("leave-call", IconName::Exit) + .style(ButtonStyle::Subtle) + .tooltip(Tooltip::text("Leave Call")) + .icon_size(IconSize::Small) + .on_click(move |_, _window, cx| { + ActiveCall::global(cx) + .update(cx, |call, cx| call.hang_up(cx)) + .detach_and_log_err(cx); + }), + ) + .child(Divider::vertical().color(DividerColor::Border)) + .into_any_element(), + ); + if is_local && can_share_projects && !is_connecting_to_project { children.push( Button::new( @@ -336,31 +387,14 @@ impl TitleBar { ); } - children.push( - div() - .pr_2() - .child( - IconButton::new("leave-call", ui::IconName::Exit) - .style(ButtonStyle::Subtle) - .tooltip(Tooltip::text("Leave call")) - .icon_size(IconSize::Small) - .on_click(move |_, _window, cx| { - ActiveCall::global(cx) - .update(cx, |call, cx| call.hang_up(cx)) - .detach_and_log_err(cx); - }), - ) - .into_any_element(), - ); - if can_use_microphone { children.push( IconButton::new( "mute-microphone", if is_muted { - ui::IconName::MicMute + IconName::MicMute } else { - ui::IconName::Mic + IconName::Mic }, ) .tooltip(move |window, cx| { @@ -395,9 +429,9 @@ impl TitleBar { IconButton::new( "mute-sound", if is_deafened { - ui::IconName::AudioOff + IconName::AudioOff } else { - ui::IconName::AudioOn + IconName::AudioOn }, ) .style(ButtonStyle::Subtle) @@ -428,21 +462,44 @@ impl TitleBar { ); if can_use_microphone && screen_sharing_supported { + let trigger = IconButton::new("screen-share", IconName::Screen) + .style(ButtonStyle::Subtle) + .icon_size(IconSize::Small) + .toggle_state(is_screen_sharing) + .selected_style(ButtonStyle::Tinted(TintColor::Accent)) + .tooltip(Tooltip::text(if is_screen_sharing { + "Stop Sharing Screen" + } else { + "Share Screen" + })) + .on_click(move |_, window, cx| { + let should_share = ActiveCall::global(cx) + .read(cx) + .room() + .is_some_and(|room| !room.read(cx).is_sharing_screen()); + + window + .spawn(cx, async move |cx| { + let screen = if should_share { + cx.update(|_, cx| pick_default_screen(cx))?.await + } else { + None + }; + + cx.update(|window, cx| toggle_screen_sharing(screen, window, cx))?; + + Result::<_, anyhow::Error>::Ok(()) + }) + .detach(); + }); + children.push( - IconButton::new("screen-share", ui::IconName::Screen) - .style(ButtonStyle::Subtle) - .icon_size(IconSize::Small) - .toggle_state(is_screen_sharing) - .selected_style(ButtonStyle::Tinted(TintColor::Accent)) - .tooltip(Tooltip::text(if is_screen_sharing { - "Stop Sharing Screen" - } else { - "Share Screen" - })) - .on_click(move |_, window, cx| { - toggle_screen_sharing(&Default::default(), window, cx) - }) - .into_any_element(), + SplitButton::new( + trigger.render(window, cx), + self.render_screen_list().into_any_element(), + ) + .style(SplitButtonStyle::Transparent) + .into_any_element(), ); } @@ -450,4 +507,96 @@ impl TitleBar { children } + + fn render_screen_list(&self) -> impl IntoElement { + PopoverMenu::new("screen-share-screen-list") + .with_handle(self.screen_share_popover_handle.clone()) + .trigger( + ui::ButtonLike::new_rounded_right("screen-share-screen-list-trigger") + .child( + h_flex() + .mx_neg_0p5() + .h_full() + .justify_center() + .child(Icon::new(IconName::ChevronDownSmall).size(IconSize::XSmall)), + ) + .toggle_state(self.screen_share_popover_handle.is_deployed()), + ) + .menu(|window, cx| { + let screens = cx.screen_capture_sources(); + Some(ContextMenu::build(window, cx, |context_menu, _, cx| { + cx.spawn(async move |this: WeakEntity<ContextMenu>, cx| { + let screens = screens.await??; + this.update(cx, |this, cx| { + let active_screenshare_id = ActiveCall::global(cx) + .read(cx) + .room() + .and_then(|room| room.read(cx).shared_screen_id()); + for screen in screens { + let Ok(meta) = screen.metadata() else { + continue; + }; + + let label = meta + .label + .clone() + .unwrap_or_else(|| SharedString::from("Unknown screen")); + let resolution = SharedString::from(format!( + "{} × {}", + meta.resolution.width.0, meta.resolution.height.0 + )); + this.push_item(ContextMenuItem::CustomEntry { + entry_render: Box::new(move |_, _| { + h_flex() + .gap_2() + .child( + Icon::new(IconName::Screen) + .size(IconSize::XSmall) + .map(|this| { + if active_screenshare_id == Some(meta.id) { + this.color(Color::Accent) + } else { + this.color(Color::Muted) + } + }), + ) + .child(Label::new(label.clone())) + .child( + Label::new(resolution.clone()) + .color(Color::Muted) + .size(LabelSize::Small), + ) + .into_any() + }), + selectable: true, + documentation_aside: None, + handler: Rc::new(move |_, window, cx| { + toggle_screen_sharing(Some(screen.clone()), window, cx); + }), + }); + } + }) + }) + .detach_and_log_err(cx); + context_menu + })) + }) + } +} + +/// Picks the screen to share when clicking on the main screen sharing button. +fn pick_default_screen(cx: &App) -> Task<Option<Rc<dyn ScreenCaptureSource>>> { + let source = cx.screen_capture_sources(); + cx.spawn(async move |_| { + let available_sources = maybe!(async move { source.await? }).await.ok()?; + available_sources + .iter() + .find(|it| { + it.as_ref() + .metadata() + .is_ok_and(|meta| meta.is_main.unwrap_or_default()) + }) + .or_else(|| available_sources.iter().next()) + .cloned() + }) } diff --git a/crates/title_bar/src/onboarding_banner.rs b/crates/title_bar/src/onboarding_banner.rs index 8ed6e956af4a5789708d1f1995f6fe82aee5dc96..e7cf0cd2d9326b68973935a7815ef281a01b03c3 100644 --- a/crates/title_bar/src/onboarding_banner.rs +++ b/crates/title_bar/src/onboarding_banner.rs @@ -51,7 +51,6 @@ impl OnboardingBanner { } fn dismiss(&mut self, cx: &mut Context<Self>) { - telemetry::event!("Banner Dismissed", source = self.source); persist_dismissed(&self.source, cx); self.dismissed = true; cx.notify(); @@ -144,7 +143,10 @@ impl Render for OnboardingBanner { div().border_l_1().border_color(border_color).child( IconButton::new("close", IconName::Close) .icon_size(IconSize::Indicator) - .on_click(cx.listener(|this, _, _window, cx| this.dismiss(cx))) + .on_click(cx.listener(|this, _, _window, cx| { + telemetry::event!("Banner Dismissed", source = this.source); + this.dismiss(cx) + })) .tooltip(|window, cx| { Tooltip::with_meta( "Close Announcement Banner", diff --git a/crates/title_bar/src/platform_title_bar.rs b/crates/title_bar/src/platform_title_bar.rs index 30b1b4c3f8d1b31e99d38032b7d9ed3d348a5c68..ef6ef93eed9ecd648bd5689eb14cb5cd5481463e 100644 --- a/crates/title_bar/src/platform_title_bar.rs +++ b/crates/title_bar/src/platform_title_bar.rs @@ -106,14 +106,14 @@ impl Render for PlatformTitleBar { // Note: On Windows the title bar behavior is handled by the platform implementation. .when(self.platform_style == PlatformStyle::Mac, |this| { this.on_click(|event, window, _| { - if event.up.click_count == 2 { + if event.click_count() == 2 { window.titlebar_double_click(); } }) }) .when(self.platform_style == PlatformStyle::Linux, |this| { this.on_click(|event, window, _| { - if event.up.click_count == 2 { + if event.click_count() == 2 { window.zoom_window(); } }) diff --git a/crates/title_bar/src/title_bar.rs b/crates/title_bar/src/title_bar.rs index 5c916254125672850b2d9a403554fcb8ff140567..a8b16d881f42e6d3185987b1d036193f505f1d5c 100644 --- a/crates/title_bar/src/title_bar.rs +++ b/crates/title_bar/src/title_bar.rs @@ -20,22 +20,23 @@ use crate::application_menu::{ use auto_update::AutoUpdateStatus; use call::ActiveCall; -use client::{Client, UserStore}; +use client::{Client, UserStore, zed_urls}; +use cloud_llm_client::Plan; use gpui::{ - Action, AnyElement, App, Context, Corner, Element, Entity, InteractiveElement, IntoElement, - MouseButton, ParentElement, Render, StatefulInteractiveElement, Styled, Subscription, - WeakEntity, Window, actions, div, + Action, AnyElement, App, Context, Corner, Element, Entity, Focusable, InteractiveElement, + IntoElement, MouseButton, ParentElement, Render, StatefulInteractiveElement, Styled, + Subscription, WeakEntity, Window, actions, div, }; use onboarding_banner::OnboardingBanner; use project::Project; -use rpc::proto; use settings::Settings as _; +use settings_ui::keybindings; use std::sync::Arc; use theme::ActiveTheme; use title_bar_settings::TitleBarSettings; use ui::{ - Avatar, Button, ButtonLike, ButtonStyle, ContextMenu, Icon, IconName, IconSize, - IconWithIndicator, Indicator, PopoverMenu, Tooltip, h_flex, prelude::*, + Avatar, Button, ButtonLike, ButtonStyle, Chip, ContextMenu, Icon, IconName, IconSize, + IconWithIndicator, Indicator, PopoverMenu, PopoverMenuHandle, Tooltip, h_flex, prelude::*, }; use util::ResultExt; use workspace::{Workspace, notifications::NotifyResultExt}; @@ -130,6 +131,7 @@ pub struct TitleBar { application_menu: Option<Entity<ApplicationMenu>>, _subscriptions: Vec<Subscription>, banner: Entity<OnboardingBanner>, + screen_share_popover_handle: PopoverMenuHandle<ContextMenu>, } impl Render for TitleBar { @@ -177,24 +179,23 @@ impl Render for TitleBar { children.push(self.banner.clone().into_any_element()) } + let status = self.client.status(); + let status = &*status.borrow(); + let user = self.user_store.read(cx).current_user(); + children.push( h_flex() .gap_1() .pr_1() .on_mouse_down(MouseButton::Left, |_, _, cx| cx.stop_propagation()) .children(self.render_call_controls(window, cx)) - .map(|el| { - let status = self.client.status(); - let status = &*status.borrow(); - if matches!(status, client::Status::Connected { .. }) { - el.child(self.render_user_menu_button(cx)) - } else { - el.children(self.render_connection_status(status, cx)) - .when(TitleBarSettings::get_global(cx).show_sign_in, |el| { - el.child(self.render_sign_in_button(cx)) - }) - .child(self.render_user_menu_button(cx)) - } + .children(self.render_connection_status(status, cx)) + .when( + user.is_none() && TitleBarSettings::get_global(cx).show_sign_in, + |el| el.child(self.render_sign_in_button(cx)), + ) + .when(user.is_some(), |parent| { + parent.child(self.render_user_menu_button(cx)) }) .into_any_element(), ); @@ -294,6 +295,7 @@ impl TitleBar { client, _subscriptions: subscriptions, banner, + screen_share_popover_handle: Default::default(), } } @@ -503,7 +505,8 @@ impl TitleBar { ) }) .on_click(move |_, window, cx| { - let _ = workspace.update(cx, |_this, cx| { + let _ = workspace.update(cx, |this, cx| { + window.focus(&this.active_pane().focus_handle(cx)); window.dispatch_action(zed_actions::git::Branch.boxed_clone(), cx); }); }) @@ -614,9 +617,8 @@ impl TitleBar { window .spawn(cx, async move |cx| { client - .authenticate_and_connect(true, &cx) + .sign_in_with_optional_connect(true, &cx) .await - .into_response() .notify_async_err(cx); }) .detach(); @@ -626,30 +628,65 @@ impl TitleBar { pub fn render_user_menu_button(&mut self, cx: &mut Context<Self>) -> impl Element { let user_store = self.user_store.read(cx); if let Some(user) = user_store.current_user() { - let has_subscription_period = self.user_store.read(cx).subscription_period().is_some(); - let plan = self.user_store.read(cx).current_plan().filter(|_| { + let has_subscription_period = user_store.subscription_period().is_some(); + let plan = user_store.plan().filter(|_| { // Since the user might be on the legacy free plan we filter based on whether we have a subscription period. has_subscription_period }); + + let user_avatar = user.avatar_uri.clone(); + let free_chip_bg = cx + .theme() + .colors() + .editor_background + .opacity(0.5) + .blend(cx.theme().colors().text_accent.opacity(0.05)); + + let pro_chip_bg = cx + .theme() + .colors() + .editor_background + .opacity(0.5) + .blend(cx.theme().colors().text_accent.opacity(0.2)); + PopoverMenu::new("user-menu") .anchor(Corner::TopRight) .menu(move |window, cx| { ContextMenu::build(window, cx, |menu, _, _cx| { - menu.link( - format!( - "Current Plan: {}", - match plan { - None => "None", - Some(proto::Plan::Free) => "Zed Free", - Some(proto::Plan::ZedPro) => "Zed Pro", - Some(proto::Plan::ZedProTrial) => "Zed Pro (Trial)", - } - ), - zed_actions::OpenAccountSettings.boxed_clone(), + let user_login = user.github_login.clone(); + + let (plan_name, label_color, bg_color) = match plan { + None | Some(Plan::ZedFree) => ("Free", Color::Default, free_chip_bg), + Some(Plan::ZedProTrial) => ("Pro Trial", Color::Accent, pro_chip_bg), + Some(Plan::ZedPro) => ("Pro", Color::Accent, pro_chip_bg), + }; + + menu.custom_entry( + move |_window, _cx| { + let user_login = user_login.clone(); + + h_flex() + .w_full() + .justify_between() + .child(Label::new(user_login)) + .child( + Chip::new(plan_name.to_string()) + .bg_color(bg_color) + .label_color(label_color), + ) + .into_any_element() + }, + move |_, cx| { + cx.open_url(&zed_urls::account_url(cx)); + }, ) .separator() .action("Settings", zed_actions::OpenSettings.boxed_clone()) - .action("Key Bindings", Box::new(zed_actions::OpenKeymap)) + .action( + "Settings Profiles", + zed_actions::settings_profile_selector::Toggle.boxed_clone(), + ) + .action("Key Bindings", Box::new(keybindings::OpenKeymapEditor)) .action( "Themes…", zed_actions::theme_selector::Toggle::default().boxed_clone(), @@ -675,7 +712,7 @@ impl TitleBar { .children( TitleBarSettings::get_global(cx) .show_user_picture - .then(|| Avatar::new(user.avatar_uri.clone())), + .then(|| Avatar::new(user_avatar)), ) .child( Icon::new(IconName::ChevronDown) @@ -693,7 +730,11 @@ impl TitleBar { .menu(|window, cx| { ContextMenu::build(window, cx, |menu, _, _| { menu.action("Settings", zed_actions::OpenSettings.boxed_clone()) - .action("Key Bindings", Box::new(zed_actions::OpenKeymap)) + .action( + "Settings Profiles", + zed_actions::settings_profile_selector::Toggle.boxed_clone(), + ) + .action("Key Bindings", Box::new(keybindings::OpenKeymapEditor)) .action( "Themes…", zed_actions::theme_selector::Toggle::default().boxed_clone(), diff --git a/crates/ui/src/components.rs b/crates/ui/src/components.rs index 88676e8a2bbe383538e91499a71ca908b2057203..486673e73354b488753cead1f187a5f7ce7687cc 100644 --- a/crates/ui/src/components.rs +++ b/crates/ui/src/components.rs @@ -1,7 +1,9 @@ mod avatar; +mod badge; mod banner; mod button; mod callout; +mod chip; mod content_group; mod context_menu; mod disclosure; @@ -40,9 +42,11 @@ mod tooltip; mod stories; pub use avatar::*; +pub use badge::*; pub use banner::*; pub use button::*; pub use callout::*; +pub use chip::*; pub use content_group::*; pub use context_menu::*; pub use disclosure::*; diff --git a/crates/ui/src/components/badge.rs b/crates/ui/src/components/badge.rs new file mode 100644 index 0000000000000000000000000000000000000000..f36e03291c5915f70e8370c6cc1e037d097622b0 --- /dev/null +++ b/crates/ui/src/components/badge.rs @@ -0,0 +1,94 @@ +use std::rc::Rc; + +use crate::Divider; +use crate::DividerColor; +use crate::Tooltip; +use crate::component_prelude::*; +use crate::prelude::*; +use gpui::AnyView; +use gpui::{AnyElement, IntoElement, SharedString, Window}; + +#[derive(IntoElement, RegisterComponent)] +pub struct Badge { + label: SharedString, + icon: IconName, + tooltip: Option<Rc<dyn Fn(&mut Window, &mut App) -> AnyView>>, +} + +impl Badge { + pub fn new(label: impl Into<SharedString>) -> Self { + Self { + label: label.into(), + icon: IconName::Check, + tooltip: None, + } + } + + pub fn icon(mut self, icon: IconName) -> Self { + self.icon = icon; + self + } + + pub fn tooltip(mut self, tooltip: impl Fn(&mut Window, &mut App) -> AnyView + 'static) -> Self { + self.tooltip = Some(Rc::new(tooltip)); + self + } +} + +impl RenderOnce for Badge { + fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { + let tooltip = self.tooltip; + + h_flex() + .id(self.label.clone()) + .h_full() + .gap_1() + .pl_1() + .pr_2() + .border_1() + .border_color(cx.theme().colors().border.opacity(0.6)) + .bg(cx.theme().colors().element_background) + .rounded_sm() + .overflow_hidden() + .child( + Icon::new(self.icon) + .size(IconSize::XSmall) + .color(Color::Muted), + ) + .child(Divider::vertical().color(DividerColor::Border)) + .child(Label::new(self.label.clone()).size(LabelSize::Small).ml_1()) + .when_some(tooltip, |this, tooltip| { + this.tooltip(move |window, cx| tooltip(window, cx)) + }) + } +} + +impl Component for Badge { + fn scope() -> ComponentScope { + ComponentScope::DataDisplay + } + + fn description() -> Option<&'static str> { + Some( + "A compact, labeled component with optional icon for displaying status, categories, or metadata.", + ) + } + + fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> { + Some( + v_flex() + .gap_6() + .child(single_example( + "Basic Badge", + Badge::new("Default").into_any_element(), + )) + .child(single_example( + "With Tooltip", + Badge::new("Tooltip") + .tooltip(Tooltip::text("This is a tooltip.")) + .into_any_element(), + )) + .into_any_element(), + ) + } +} diff --git a/crates/ui/src/components/banner.rs b/crates/ui/src/components/banner.rs index 043791cdd86ccf6a94fb469356bd2aca7abaddf4..d88905d4664f83ff985cb6b4226ae9c6b43ebe91 100644 --- a/crates/ui/src/components/banner.rs +++ b/crates/ui/src/components/banner.rs @@ -19,8 +19,8 @@ pub enum Severity { /// use ui::{Banner}; /// /// Banner::new() -/// .severity(Severity::Info) -/// .children(Label::new("This is an informational message")) +/// .severity(Severity::Success) +/// .children(Label::new("This is a success message")) /// .action_slot( /// Button::new("learn-more", "Learn More") /// .icon(IconName::ArrowUpRight) @@ -32,7 +32,6 @@ pub enum Severity { pub struct Banner { severity: Severity, children: Vec<AnyElement>, - icon: Option<(IconName, Option<Color>)>, action_slot: Option<AnyElement>, } @@ -42,7 +41,6 @@ impl Banner { Self { severity: Severity::Info, children: Vec::new(), - icon: None, action_slot: None, } } @@ -53,12 +51,6 @@ impl Banner { self } - /// Sets an icon to display in the banner with an optional color. - pub fn icon(mut self, icon: IconName, color: Option<impl Into<Color>>) -> Self { - self.icon = Some((icon, color.map(|c| c.into()))); - self - } - /// A slot for actions, such as CTA or dismissal buttons. pub fn action_slot(mut self, element: impl IntoElement) -> Self { self.action_slot = Some(element.into_any_element()); @@ -73,12 +65,13 @@ impl ParentElement for Banner { } impl RenderOnce for Banner { - fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { - let base = h_flex() + fn render(self, window: &mut Window, cx: &mut App) -> impl IntoElement { + let banner = h_flex() .py_0p5() - .rounded_sm() + .gap_1p5() .flex_wrap() .justify_between() + .rounded_sm() .border_1(); let (icon, icon_color, bg_color, border_color) = match self.severity { @@ -108,35 +101,37 @@ impl RenderOnce for Banner { ), }; - let mut container = base.bg(bg_color).border_color(border_color); - - let mut content_area = h_flex().id("content_area").gap_1p5().overflow_x_scroll(); - - if self.icon.is_none() { - content_area = - content_area.child(Icon::new(icon).size(IconSize::XSmall).color(icon_color)); - } + let mut banner = banner.bg(bg_color).border_color(border_color); - content_area = content_area.children(self.children); + let icon_and_child = h_flex() + .items_start() + .min_w_0() + .gap_1p5() + .child( + h_flex() + .h(window.line_height()) + .flex_shrink_0() + .child(Icon::new(icon).size(IconSize::XSmall).color(icon_color)), + ) + .child(div().min_w_0().children(self.children)); if let Some(action_slot) = self.action_slot { - container = container + banner = banner .pl_2() - .pr_0p5() - .gap_2() - .child(content_area) + .pr_1() + .child(icon_and_child) .child(action_slot); } else { - container = container.px_2().child(div().w_full().child(content_area)); + banner = banner.px_2().child(icon_and_child); } - container + banner } } impl Component for Banner { fn scope() -> ComponentScope { - ComponentScope::Notification + ComponentScope::DataDisplay } fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> { diff --git a/crates/ui/src/components/button/button.rs b/crates/ui/src/components/button/button.rs index cae5d0e2ca4c39a69f56be5fb3ac2342c0919fb9..19f782fb98e7e78cca8f78202fd0b6d95448f2cd 100644 --- a/crates/ui/src/components/button/button.rs +++ b/crates/ui/src/components/button/button.rs @@ -393,6 +393,11 @@ impl ButtonCommon for Button { self } + fn tab_index(mut self, tab_index: impl Into<isize>) -> Self { + self.base = self.base.tab_index(tab_index); + self + } + fn layer(mut self, elevation: ElevationIndex) -> Self { self.base = self.base.layer(elevation); self diff --git a/crates/ui/src/components/button/button_like.rs b/crates/ui/src/components/button/button_like.rs index a0158b2fe745f383be179594c49ce1874b181176..35c78fbb5dff5987dac5cad9b1765ca0b05c0b54 100644 --- a/crates/ui/src/components/button/button_like.rs +++ b/crates/ui/src/components/button/button_like.rs @@ -1,7 +1,8 @@ use documented::Documented; use gpui::{ AnyElement, AnyView, ClickEvent, CursorStyle, DefiniteLength, Hsla, MouseButton, - MouseDownEvent, MouseUpEvent, Rems, relative, transparent_black, + MouseClickEvent, MouseDownEvent, MouseUpEvent, Rems, StyleRefinement, relative, + transparent_black, }; use smallvec::SmallVec; @@ -37,6 +38,8 @@ pub trait ButtonCommon: Clickable + Disableable { /// exceptions might a scroll bar, or a slider. fn tooltip(self, tooltip: impl Fn(&mut Window, &mut App) -> AnyView + 'static) -> Self; + fn tab_index(self, tab_index: impl Into<isize>) -> Self; + fn layer(self, elevation: ElevationIndex) -> Self; } @@ -126,6 +129,10 @@ pub enum ButtonStyle { /// coloring like an error or success button. Tinted(TintColor), + /// Usually used as a secondary action that should have more emphasis than + /// a fully transparent button. + Outlined, + /// The default button style, used for most buttons. Has a transparent background, /// but has a background color to indicate states like hover and active. #[default] @@ -180,6 +187,12 @@ impl ButtonStyle { icon_color: Color::Default.color(cx), }, ButtonStyle::Tinted(tint) => tint.button_like_style(cx), + ButtonStyle::Outlined => ButtonLikeStyles { + background: element_bg_from_elevation(elevation, cx), + border_color: cx.theme().colors().border_variant, + label_color: Color::Default.color(cx), + icon_color: Color::Default.color(cx), + }, ButtonStyle::Subtle => ButtonLikeStyles { background: cx.theme().colors().ghost_element_background, border_color: transparent_black(), @@ -219,6 +232,12 @@ impl ButtonStyle { styles.background = theme.darken(styles.background, 0.05, 0.2); styles } + ButtonStyle::Outlined => ButtonLikeStyles { + background: cx.theme().colors().ghost_element_hover, + border_color: cx.theme().colors().border, + label_color: Color::Default.color(cx), + icon_color: Color::Default.color(cx), + }, ButtonStyle::Subtle => ButtonLikeStyles { background: cx.theme().colors().ghost_element_hover, border_color: transparent_black(), @@ -251,6 +270,12 @@ impl ButtonStyle { label_color: Color::Default.color(cx), icon_color: Color::Default.color(cx), }, + ButtonStyle::Outlined => ButtonLikeStyles { + background: cx.theme().colors().element_active, + border_color: cx.theme().colors().border_variant, + label_color: Color::Default.color(cx), + icon_color: Color::Default.color(cx), + }, ButtonStyle::Transparent => ButtonLikeStyles { background: transparent_black(), border_color: transparent_black(), @@ -278,6 +303,12 @@ impl ButtonStyle { label_color: Color::Default.color(cx), icon_color: Color::Default.color(cx), }, + ButtonStyle::Outlined => ButtonLikeStyles { + background: cx.theme().colors().ghost_element_background, + border_color: cx.theme().colors().border, + label_color: Color::Default.color(cx), + icon_color: Color::Default.color(cx), + }, ButtonStyle::Transparent => ButtonLikeStyles { background: transparent_black(), border_color: cx.theme().colors().border_focused, @@ -308,6 +339,12 @@ impl ButtonStyle { label_color: Color::Disabled.color(cx), icon_color: Color::Disabled.color(cx), }, + ButtonStyle::Outlined => ButtonLikeStyles { + background: cx.theme().colors().element_disabled, + border_color: cx.theme().colors().border_disabled, + label_color: Color::Default.color(cx), + icon_color: Color::Default.color(cx), + }, ButtonStyle::Transparent => ButtonLikeStyles { background: transparent_black(), border_color: transparent_black(), @@ -324,6 +361,7 @@ impl ButtonStyle { #[derive(Default, PartialEq, Clone, Copy)] pub enum ButtonSize { Large, + Medium, #[default] Default, Compact, @@ -334,6 +372,7 @@ impl ButtonSize { pub fn rems(self) -> Rems { match self { ButtonSize::Large => rems_from_px(32.), + ButtonSize::Medium => rems_from_px(28.), ButtonSize::Default => rems_from_px(22.), ButtonSize::Compact => rems_from_px(18.), ButtonSize::None => rems_from_px(16.), @@ -357,6 +396,7 @@ pub struct ButtonLike { pub(super) width: Option<DefiniteLength>, pub(super) height: Option<DefiniteLength>, pub(super) layer: Option<ElevationIndex>, + tab_index: Option<isize>, size: ButtonSize, rounding: Option<ButtonLikeRounding>, tooltip: Option<Box<dyn Fn(&mut Window, &mut App) -> AnyView>>, @@ -385,6 +425,7 @@ impl ButtonLike { on_click: None, on_right_click: None, layer: None, + tab_index: None, } } @@ -489,6 +530,11 @@ impl ButtonCommon for ButtonLike { self } + fn tab_index(mut self, tab_index: impl Into<isize>) -> Self { + self.tab_index = Some(tab_index.into()); + self + } + fn layer(mut self, elevation: ElevationIndex) -> Self { self.layer = Some(elevation); self @@ -518,6 +564,7 @@ impl RenderOnce for ButtonLike { self.base .h_flex() .id(self.id.clone()) + .when_some(self.tab_index, |this, tab_index| this.tab_index(tab_index)) .font_ui(cx) .group("") .flex_none() @@ -525,6 +572,13 @@ impl RenderOnce for ButtonLike { .when_some(self.width, |this, width| { this.w(width).justify_center().text_center() }) + .when( + match self.style { + ButtonStyle::Outlined => true, + _ => false, + }, + |this| this.border_1(), + ) .when_some(self.rounding, |this, rounding| match rounding { ButtonLikeRounding::All => this.rounded_sm(), ButtonLikeRounding::Left => this.rounded_l_sm(), @@ -532,12 +586,13 @@ impl RenderOnce for ButtonLike { }) .gap(DynamicSpacing::Base04.rems(cx)) .map(|this| match self.size { - ButtonSize::Large => this.px(DynamicSpacing::Base06.rems(cx)), + ButtonSize::Large | ButtonSize::Medium => this.px(DynamicSpacing::Base06.rems(cx)), ButtonSize::Default | ButtonSize::Compact => { this.px(DynamicSpacing::Base04.rems(cx)) } ButtonSize::None => this, }) + .border_color(style.enabled(self.layer, cx).border_color) .bg(style.enabled(self.layer, cx).background) .when(self.disabled, |this| { if self.cursor_style == CursorStyle::PointingHand { @@ -547,8 +602,12 @@ impl RenderOnce for ButtonLike { } }) .when(!self.disabled, |this| { + let hovered_style = style.hovered(self.layer, cx); + let focus_color = + |refinement: StyleRefinement| refinement.bg(hovered_style.background); this.cursor(self.cursor_style) - .hover(|hover| hover.bg(style.hovered(self.layer, cx).background)) + .hover(focus_color) + .focus(focus_color) .active(|active| active.bg(style.active(cx).background)) }) .when_some( @@ -562,7 +621,7 @@ impl RenderOnce for ButtonLike { MouseButton::Right, move |event, window, cx| { cx.stop_propagation(); - let click_event = ClickEvent { + let click_event = ClickEvent::Mouse(MouseClickEvent { down: MouseDownEvent { button: MouseButton::Right, position: event.position, @@ -576,7 +635,7 @@ impl RenderOnce for ButtonLike { modifiers: event.modifiers, click_count: 1, }, - }; + }); (on_right_click)(&click_event, window, cx) }, ) diff --git a/crates/ui/src/components/button/icon_button.rs b/crates/ui/src/components/button/icon_button.rs index 050db6addd2ba32535edffdd6fde066ac57ec644..8d8718a6346eccf38ae6df2fa6e56c15c7cae3b9 100644 --- a/crates/ui/src/components/button/icon_button.rs +++ b/crates/ui/src/components/button/icon_button.rs @@ -164,6 +164,11 @@ impl ButtonCommon for IconButton { self } + fn tab_index(mut self, tab_index: impl Into<isize>) -> Self { + self.base = self.base.tab_index(tab_index); + self + } + fn layer(mut self, elevation: ElevationIndex) -> Self { self.base = self.base.layer(elevation); self @@ -178,7 +183,8 @@ impl VisibleOnHover for IconButton { } impl RenderOnce for IconButton { - fn render(self, window: &mut Window, cx: &mut App) -> impl IntoElement { + #[allow(refining_impl_trait)] + fn render(self, window: &mut Window, cx: &mut App) -> ButtonLike { let is_disabled = self.base.disabled; let is_selected = self.base.selected; let selected_style = self.base.selected_style; diff --git a/crates/ui/src/components/button/split_button.rs b/crates/ui/src/components/button/split_button.rs index c0811ecbab9f3897328edd25c8fdd6bd85ffabbc..14b9fd153cd5ad662467c75ff81700587667cee3 100644 --- a/crates/ui/src/components/button/split_button.rs +++ b/crates/ui/src/components/button/split_button.rs @@ -1,6 +1,6 @@ use gpui::{ AnyElement, App, BoxShadow, IntoElement, ParentElement, RenderOnce, Styled, Window, div, hsla, - point, px, + point, prelude::FluentBuilder, px, }; use theme::ActiveTheme; @@ -8,6 +8,13 @@ use crate::{ElevationIndex, h_flex}; use super::ButtonLike; +#[derive(Clone, Copy, PartialEq)] +pub enum SplitButtonStyle { + Filled, + Outlined, + Transparent, +} + /// /// A button with two parts: a primary action on the left and a secondary action on the right. /// /// The left side is a [`ButtonLike`] with the main action, while the right side can contain @@ -18,34 +25,53 @@ use super::ButtonLike; pub struct SplitButton { pub left: ButtonLike, pub right: AnyElement, + style: SplitButtonStyle, } impl SplitButton { pub fn new(left: ButtonLike, right: AnyElement) -> Self { - Self { left, right } + Self { + left, + right, + style: SplitButtonStyle::Filled, + } + } + + pub fn style(mut self, style: SplitButtonStyle) -> Self { + self.style = style; + self } } impl RenderOnce for SplitButton { fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { + let is_filled_or_outlined = matches!( + self.style, + SplitButtonStyle::Filled | SplitButtonStyle::Outlined + ); + h_flex() .rounded_sm() - .border_1() - .border_color(cx.theme().colors().text_muted.alpha(0.12)) + .when(is_filled_or_outlined, |this| { + this.border_1() + .border_color(cx.theme().colors().border.opacity(0.8)) + }) .child(div().flex_grow().child(self.left)) .child( div() .h_full() .w_px() - .bg(cx.theme().colors().text_muted.alpha(0.16)), + .bg(cx.theme().colors().border.opacity(0.5)), ) .child(self.right) - .bg(ElevationIndex::Surface.on_elevation_bg(cx)) - .shadow(vec![BoxShadow { - color: hsla(0.0, 0.0, 0.0, 0.16), - offset: point(px(0.), px(1.)), - blur_radius: px(0.), - spread_radius: px(0.), - }]) + .when(self.style == SplitButtonStyle::Filled, |this| { + this.bg(ElevationIndex::Surface.on_elevation_bg(cx)) + .shadow(vec![BoxShadow { + color: hsla(0.0, 0.0, 0.0, 0.16), + offset: point(px(0.), px(1.)), + blur_radius: px(0.), + spread_radius: px(0.), + }]) + }) } } diff --git a/crates/ui/src/components/button/toggle_button.rs b/crates/ui/src/components/button/toggle_button.rs index eca23fe6f7584ce30fdabd5e9c1445a53ec65a5f..91defa730b3e9d5be2cb4adb2fdf764169f2d55d 100644 --- a/crates/ui/src/components/button/toggle_button.rs +++ b/crates/ui/src/components/button/toggle_button.rs @@ -1,6 +1,8 @@ +use std::rc::Rc; + use gpui::{AnyView, ClickEvent}; -use crate::{ButtonLike, ButtonLikeRounding, ElevationIndex, prelude::*}; +use crate::{ButtonLike, ButtonLikeRounding, ElevationIndex, TintColor, Tooltip, prelude::*}; /// The position of a [`ToggleButton`] within a group of buttons. #[derive(Debug, PartialEq, Eq, Clone, Copy)] @@ -121,6 +123,11 @@ impl ButtonCommon for ToggleButton { self } + fn tab_index(mut self, tab_index: impl Into<isize>) -> Self { + self.base = self.base.tab_index(tab_index); + self + } + fn layer(mut self, elevation: ElevationIndex) -> Self { self.base = self.base.layer(elevation); self @@ -290,3 +297,670 @@ impl Component for ToggleButton { ) } } + +pub struct ButtonConfiguration { + label: SharedString, + icon: Option<IconName>, + on_click: Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>, + selected: bool, + tooltip: Option<Rc<dyn Fn(&mut Window, &mut App) -> AnyView>>, +} + +mod private { + pub trait ToggleButtonStyle {} +} + +pub trait ButtonBuilder: 'static + private::ToggleButtonStyle { + fn into_configuration(self) -> ButtonConfiguration; +} + +pub struct ToggleButtonSimple { + label: SharedString, + on_click: Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>, + selected: bool, + tooltip: Option<Rc<dyn Fn(&mut Window, &mut App) -> AnyView>>, +} + +impl ToggleButtonSimple { + pub fn new( + label: impl Into<SharedString>, + on_click: impl Fn(&ClickEvent, &mut Window, &mut App) + 'static, + ) -> Self { + Self { + label: label.into(), + on_click: Box::new(on_click), + selected: false, + tooltip: None, + } + } + + pub fn selected(mut self, selected: bool) -> Self { + self.selected = selected; + self + } + + pub fn tooltip(mut self, tooltip: impl Fn(&mut Window, &mut App) -> AnyView + 'static) -> Self { + self.tooltip = Some(Rc::new(tooltip)); + self + } +} + +impl private::ToggleButtonStyle for ToggleButtonSimple {} + +impl ButtonBuilder for ToggleButtonSimple { + fn into_configuration(self) -> ButtonConfiguration { + ButtonConfiguration { + label: self.label, + icon: None, + on_click: self.on_click, + selected: self.selected, + tooltip: self.tooltip, + } + } +} + +pub struct ToggleButtonWithIcon { + label: SharedString, + icon: IconName, + on_click: Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>, + selected: bool, + tooltip: Option<Rc<dyn Fn(&mut Window, &mut App) -> AnyView>>, +} + +impl ToggleButtonWithIcon { + pub fn new( + label: impl Into<SharedString>, + icon: IconName, + on_click: impl Fn(&ClickEvent, &mut Window, &mut App) + 'static, + ) -> Self { + Self { + label: label.into(), + icon, + on_click: Box::new(on_click), + selected: false, + tooltip: None, + } + } + + pub fn selected(mut self, selected: bool) -> Self { + self.selected = selected; + self + } + + pub fn tooltip(mut self, tooltip: impl Fn(&mut Window, &mut App) -> AnyView + 'static) -> Self { + self.tooltip = Some(Rc::new(tooltip)); + self + } +} + +impl private::ToggleButtonStyle for ToggleButtonWithIcon {} + +impl ButtonBuilder for ToggleButtonWithIcon { + fn into_configuration(self) -> ButtonConfiguration { + ButtonConfiguration { + label: self.label, + icon: Some(self.icon), + on_click: self.on_click, + selected: self.selected, + tooltip: self.tooltip, + } + } +} + +#[derive(Clone, Copy, PartialEq)] +pub enum ToggleButtonGroupStyle { + Transparent, + Filled, + Outlined, +} + +#[derive(Clone, Copy, PartialEq)] +pub enum ToggleButtonGroupSize { + Default, + Medium, +} + +#[derive(IntoElement)] +pub struct ToggleButtonGroup<T, const COLS: usize = 3, const ROWS: usize = 1> +where + T: ButtonBuilder, +{ + group_name: &'static str, + rows: [[T; COLS]; ROWS], + style: ToggleButtonGroupStyle, + size: ToggleButtonGroupSize, + button_width: Rems, + selected_index: usize, + tab_index: Option<isize>, +} + +impl<T: ButtonBuilder, const COLS: usize> ToggleButtonGroup<T, COLS> { + pub fn single_row(group_name: &'static str, buttons: [T; COLS]) -> Self { + Self { + group_name, + rows: [buttons], + style: ToggleButtonGroupStyle::Transparent, + size: ToggleButtonGroupSize::Default, + button_width: rems_from_px(100.), + selected_index: 0, + tab_index: None, + } + } +} + +impl<T: ButtonBuilder, const COLS: usize> ToggleButtonGroup<T, COLS, 2> { + pub fn two_rows(group_name: &'static str, first_row: [T; COLS], second_row: [T; COLS]) -> Self { + Self { + group_name, + rows: [first_row, second_row], + style: ToggleButtonGroupStyle::Transparent, + size: ToggleButtonGroupSize::Default, + button_width: rems_from_px(100.), + selected_index: 0, + tab_index: None, + } + } +} + +impl<T: ButtonBuilder, const COLS: usize, const ROWS: usize> ToggleButtonGroup<T, COLS, ROWS> { + pub fn style(mut self, style: ToggleButtonGroupStyle) -> Self { + self.style = style; + self + } + + pub fn size(mut self, size: ToggleButtonGroupSize) -> Self { + self.size = size; + self + } + + pub fn button_width(mut self, button_width: Rems) -> Self { + self.button_width = button_width; + self + } + + pub fn selected_index(mut self, index: usize) -> Self { + self.selected_index = index; + self + } + + /// Sets the tab index for the toggle button group. + /// The tab index is set to the initial value provided, then the + /// value is incremented by the number of buttons in the group. + pub fn tab_index(mut self, tab_index: &mut isize) -> Self { + self.tab_index = Some(*tab_index); + *tab_index += (COLS * ROWS) as isize; + self + } +} + +impl<T: ButtonBuilder, const COLS: usize, const ROWS: usize> RenderOnce + for ToggleButtonGroup<T, COLS, ROWS> +{ + fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { + let entries = + self.rows.into_iter().enumerate().map(|(row_index, row)| { + row.into_iter().enumerate().map(move |(col_index, button)| { + let ButtonConfiguration { + label, + icon, + on_click, + selected, + tooltip, + } = button.into_configuration(); + + let entry_index = row_index * COLS + col_index; + + ButtonLike::new((self.group_name, entry_index)) + .rounding(None) + .when_some(self.tab_index, |this, tab_index| { + this.tab_index(tab_index + entry_index as isize) + }) + .when(entry_index == self.selected_index || selected, |this| { + this.toggle_state(true) + .selected_style(ButtonStyle::Tinted(TintColor::Accent)) + }) + .when(self.style == ToggleButtonGroupStyle::Filled, |button| { + button.style(ButtonStyle::Filled) + }) + .when(self.size == ToggleButtonGroupSize::Medium, |button| { + button.size(ButtonSize::Medium) + }) + .child( + h_flex() + .min_w(self.button_width) + .gap_1p5() + .px_3() + .py_1() + .justify_center() + .when_some(icon, |this, icon| { + this.py_2() + .child(Icon::new(icon).size(IconSize::XSmall).map(|this| { + if entry_index == self.selected_index || selected { + this.color(Color::Accent) + } else { + this.color(Color::Muted) + } + })) + }) + .child(Label::new(label).size(LabelSize::Small).when( + entry_index == self.selected_index || selected, + |this| this.color(Color::Accent), + )), + ) + .when_some(tooltip, |this, tooltip| { + this.tooltip(move |window, cx| tooltip(window, cx)) + }) + .on_click(on_click) + .into_any_element() + }) + }); + + let border_color = cx.theme().colors().border.opacity(0.6); + let is_outlined_or_filled = self.style == ToggleButtonGroupStyle::Outlined + || self.style == ToggleButtonGroupStyle::Filled; + let is_transparent = self.style == ToggleButtonGroupStyle::Transparent; + + v_flex() + .rounded_md() + .overflow_hidden() + .map(|this| { + if is_transparent { + this.gap_px() + } else { + this.border_1().border_color(border_color) + } + }) + .children(entries.enumerate().map(|(row_index, row)| { + let last_row = row_index == ROWS - 1; + h_flex() + .when(!is_outlined_or_filled, |this| this.gap_px()) + .when(is_outlined_or_filled && !last_row, |this| { + this.border_b_1().border_color(border_color) + }) + .children(row.enumerate().map(|(item_index, item)| { + let last_item = item_index == COLS - 1; + div() + .when(is_outlined_or_filled && !last_item, |this| { + this.border_r_1().border_color(border_color) + }) + .child(item) + })) + })) + } +} + +fn register_toggle_button_group() { + component::register_component::<ToggleButtonGroup<ToggleButtonSimple>>(); +} + +component::__private::inventory::submit! { + component::ComponentFn::new(register_toggle_button_group) +} + +impl<T: ButtonBuilder, const COLS: usize, const ROWS: usize> Component + for ToggleButtonGroup<T, COLS, ROWS> +{ + fn name() -> &'static str { + "ToggleButtonGroup" + } + + fn scope() -> ComponentScope { + ComponentScope::Input + } + + fn sort_name() -> &'static str { + "ButtonG" + } + + fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> { + Some( + v_flex() + .gap_6() + .children(vec![example_group_with_title( + "Transparent Variant", + vec![ + single_example( + "Single Row Group", + ToggleButtonGroup::single_row( + "single_row_test", + [ + ToggleButtonSimple::new("First", |_, _, _| {}), + ToggleButtonSimple::new("Second", |_, _, _| {}), + ToggleButtonSimple::new("Third", |_, _, _| {}), + ], + ) + .selected_index(1) + .button_width(rems_from_px(100.)) + .into_any_element(), + ), + single_example( + "Single Row Group with icons", + ToggleButtonGroup::single_row( + "single_row_test_icon", + [ + ToggleButtonWithIcon::new( + "First", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Second", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Third", + IconName::AiZed, + |_, _, _| {}, + ), + ], + ) + .selected_index(1) + .button_width(rems_from_px(100.)) + .into_any_element(), + ), + single_example( + "Multiple Row Group", + ToggleButtonGroup::two_rows( + "multiple_row_test", + [ + ToggleButtonSimple::new("First", |_, _, _| {}), + ToggleButtonSimple::new("Second", |_, _, _| {}), + ToggleButtonSimple::new("Third", |_, _, _| {}), + ], + [ + ToggleButtonSimple::new("Fourth", |_, _, _| {}), + ToggleButtonSimple::new("Fifth", |_, _, _| {}), + ToggleButtonSimple::new("Sixth", |_, _, _| {}), + ], + ) + .selected_index(3) + .button_width(rems_from_px(100.)) + .into_any_element(), + ), + single_example( + "Multiple Row Group with Icons", + ToggleButtonGroup::two_rows( + "multiple_row_test_icons", + [ + ToggleButtonWithIcon::new( + "First", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Second", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Third", + IconName::AiZed, + |_, _, _| {}, + ), + ], + [ + ToggleButtonWithIcon::new( + "Fourth", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Fifth", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Sixth", + IconName::AiZed, + |_, _, _| {}, + ), + ], + ) + .selected_index(3) + .button_width(rems_from_px(100.)) + .into_any_element(), + ), + ], + )]) + .children(vec![example_group_with_title( + "Outlined Variant", + vec![ + single_example( + "Single Row Group", + ToggleButtonGroup::single_row( + "single_row_test_outline", + [ + ToggleButtonSimple::new("First", |_, _, _| {}), + ToggleButtonSimple::new("Second", |_, _, _| {}), + ToggleButtonSimple::new("Third", |_, _, _| {}), + ], + ) + .selected_index(1) + .style(ToggleButtonGroupStyle::Outlined) + .into_any_element(), + ), + single_example( + "Single Row Group with icons", + ToggleButtonGroup::single_row( + "single_row_test_icon_outlined", + [ + ToggleButtonWithIcon::new( + "First", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Second", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Third", + IconName::AiZed, + |_, _, _| {}, + ), + ], + ) + .selected_index(1) + .button_width(rems_from_px(100.)) + .style(ToggleButtonGroupStyle::Outlined) + .into_any_element(), + ), + single_example( + "Multiple Row Group", + ToggleButtonGroup::two_rows( + "multiple_row_test", + [ + ToggleButtonSimple::new("First", |_, _, _| {}), + ToggleButtonSimple::new("Second", |_, _, _| {}), + ToggleButtonSimple::new("Third", |_, _, _| {}), + ], + [ + ToggleButtonSimple::new("Fourth", |_, _, _| {}), + ToggleButtonSimple::new("Fifth", |_, _, _| {}), + ToggleButtonSimple::new("Sixth", |_, _, _| {}), + ], + ) + .selected_index(3) + .button_width(rems_from_px(100.)) + .style(ToggleButtonGroupStyle::Outlined) + .into_any_element(), + ), + single_example( + "Multiple Row Group with Icons", + ToggleButtonGroup::two_rows( + "multiple_row_test", + [ + ToggleButtonWithIcon::new( + "First", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Second", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Third", + IconName::AiZed, + |_, _, _| {}, + ), + ], + [ + ToggleButtonWithIcon::new( + "Fourth", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Fifth", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Sixth", + IconName::AiZed, + |_, _, _| {}, + ), + ], + ) + .selected_index(3) + .button_width(rems_from_px(100.)) + .style(ToggleButtonGroupStyle::Outlined) + .into_any_element(), + ), + ], + )]) + .children(vec![example_group_with_title( + "Filled Variant", + vec![ + single_example( + "Single Row Group", + ToggleButtonGroup::single_row( + "single_row_test_outline", + [ + ToggleButtonSimple::new("First", |_, _, _| {}), + ToggleButtonSimple::new("Second", |_, _, _| {}), + ToggleButtonSimple::new("Third", |_, _, _| {}), + ], + ) + .selected_index(2) + .style(ToggleButtonGroupStyle::Filled) + .into_any_element(), + ), + single_example( + "Single Row Group with icons", + ToggleButtonGroup::single_row( + "single_row_test_icon_outlined", + [ + ToggleButtonWithIcon::new( + "First", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Second", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Third", + IconName::AiZed, + |_, _, _| {}, + ), + ], + ) + .selected_index(1) + .button_width(rems_from_px(100.)) + .style(ToggleButtonGroupStyle::Filled) + .into_any_element(), + ), + single_example( + "Multiple Row Group", + ToggleButtonGroup::two_rows( + "multiple_row_test", + [ + ToggleButtonSimple::new("First", |_, _, _| {}), + ToggleButtonSimple::new("Second", |_, _, _| {}), + ToggleButtonSimple::new("Third", |_, _, _| {}), + ], + [ + ToggleButtonSimple::new("Fourth", |_, _, _| {}), + ToggleButtonSimple::new("Fifth", |_, _, _| {}), + ToggleButtonSimple::new("Sixth", |_, _, _| {}), + ], + ) + .selected_index(3) + .button_width(rems_from_px(100.)) + .style(ToggleButtonGroupStyle::Filled) + .into_any_element(), + ), + single_example( + "Multiple Row Group with Icons", + ToggleButtonGroup::two_rows( + "multiple_row_test", + [ + ToggleButtonWithIcon::new( + "First", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Second", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Third", + IconName::AiZed, + |_, _, _| {}, + ), + ], + [ + ToggleButtonWithIcon::new( + "Fourth", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Fifth", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Sixth", + IconName::AiZed, + |_, _, _| {}, + ), + ], + ) + .selected_index(3) + .button_width(rems_from_px(100.)) + .style(ToggleButtonGroupStyle::Filled) + .into_any_element(), + ), + ], + )]) + .children(vec![single_example( + "With Tooltips", + ToggleButtonGroup::single_row( + "with_tooltips", + [ + ToggleButtonSimple::new("First", |_, _, _| {}) + .tooltip(Tooltip::text("This is a tooltip. Hello!")), + ToggleButtonSimple::new("Second", |_, _, _| {}) + .tooltip(Tooltip::text("This is a tooltip. Hey?")), + ToggleButtonSimple::new("Third", |_, _, _| {}) + .tooltip(Tooltip::text("This is a tooltip. Get out of here now!")), + ], + ) + .selected_index(1) + .button_width(rems_from_px(100.)) + .into_any_element(), + )]) + .into_any_element(), + ) + } +} diff --git a/crates/ui/src/components/callout.rs b/crates/ui/src/components/callout.rs index d15fa122ed95e5e9a922c8bc694d1c35d975f9a4..9c1c9fb1a9d7b5b603bd3c55b64b19375b6b521e 100644 --- a/crates/ui/src/components/callout.rs +++ b/crates/ui/src/components/callout.rs @@ -158,7 +158,7 @@ impl RenderOnce for Callout { impl Component for Callout { fn scope() -> ComponentScope { - ComponentScope::Notification + ComponentScope::DataDisplay } fn description() -> Option<&'static str> { diff --git a/crates/ui/src/components/chip.rs b/crates/ui/src/components/chip.rs new file mode 100644 index 0000000000000000000000000000000000000000..e1262875feae77b69e660c0e9da17e1e669137b7 --- /dev/null +++ b/crates/ui/src/components/chip.rs @@ -0,0 +1,106 @@ +use crate::prelude::*; +use gpui::{AnyElement, Hsla, IntoElement, ParentElement, Styled}; + +/// Chips provide a container for an informative label. +/// +/// # Usage Example +/// +/// ``` +/// use ui::{Chip}; +/// +/// Chip::new("This Chip") +/// ``` +#[derive(IntoElement, RegisterComponent)] +pub struct Chip { + label: SharedString, + label_color: Color, + label_size: LabelSize, + bg_color: Option<Hsla>, +} + +impl Chip { + /// Creates a new `Chip` component with the specified label. + pub fn new(label: impl Into<SharedString>) -> Self { + Self { + label: label.into(), + label_color: Color::Default, + label_size: LabelSize::XSmall, + bg_color: None, + } + } + + /// Sets the color of the label. + pub fn label_color(mut self, color: Color) -> Self { + self.label_color = color; + self + } + + /// Sets the size of the label. + pub fn label_size(mut self, size: LabelSize) -> Self { + self.label_size = size; + self + } + + /// Sets a custom background color for the callout content. + pub fn bg_color(mut self, color: Hsla) -> Self { + self.bg_color = Some(color); + self + } +} + +impl RenderOnce for Chip { + fn render(self, _: &mut Window, cx: &mut App) -> impl IntoElement { + let bg_color = self + .bg_color + .unwrap_or(cx.theme().colors().element_background); + + h_flex() + .min_w_0() + .flex_initial() + .px_1() + .border_1() + .rounded_sm() + .border_color(cx.theme().colors().border) + .bg(bg_color) + .overflow_hidden() + .child( + Label::new(self.label) + .size(self.label_size) + .color(self.label_color) + .buffer_font(cx), + ) + } +} + +impl Component for Chip { + fn scope() -> ComponentScope { + ComponentScope::DataDisplay + } + + fn preview(_window: &mut Window, cx: &mut App) -> Option<AnyElement> { + let chip_examples = vec![ + single_example("Default", Chip::new("Chip Example").into_any_element()), + single_example( + "Customized Label Color", + Chip::new("Chip Example") + .label_color(Color::Accent) + .into_any_element(), + ), + single_example( + "Customized Label Size", + Chip::new("Chip Example") + .label_size(LabelSize::Large) + .label_color(Color::Accent) + .into_any_element(), + ), + single_example( + "Customized Background Color", + Chip::new("Chip Example") + .bg_color(cx.theme().colors().text_accent.opacity(0.1)) + .into_any_element(), + ), + ]; + + Some(example_group(chip_examples).vertical().into_any_element()) + } +} diff --git a/crates/ui/src/components/context_menu.rs b/crates/ui/src/components/context_menu.rs index b5bdfdd8bbd6afda4ee1c2e08d50c4b9691a6892..77468fd29596a2aae015e73ad2d618c82031128c 100644 --- a/crates/ui/src/components/context_menu.rs +++ b/crates/ui/src/components/context_menu.rs @@ -139,6 +139,8 @@ impl ContextMenuEntry { } } +impl FluentBuilder for ContextMenuEntry {} + impl From<ContextMenuEntry> for ContextMenuItem { fn from(entry: ContextMenuEntry) -> Self { ContextMenuItem::Entry(entry) @@ -353,6 +355,10 @@ impl ContextMenu { self } + pub fn push_item(&mut self, item: impl Into<ContextMenuItem>) { + self.items.push(item.into()); + } + pub fn entry( mut self, label: impl Into<SharedString>, @@ -668,7 +674,7 @@ impl ContextMenu { } } - fn select_next(&mut self, _: &SelectNext, window: &mut Window, cx: &mut Context<Self>) { + pub fn select_next(&mut self, _: &SelectNext, window: &mut Window, cx: &mut Context<Self>) { if let Some(ix) = self.selected_index { let next_index = ix + 1; if self.items.len() <= next_index { @@ -972,12 +978,10 @@ impl ContextMenu { .children(action.as_ref().and_then(|action| { self.action_context .as_ref() - .map(|focus| { + .and_then(|focus| { KeyBinding::for_action_in(&**action, focus, window, cx) }) - .unwrap_or_else(|| { - KeyBinding::for_action(&**action, window, cx) - }) + .or_else(|| KeyBinding::for_action(&**action, window, cx)) .map(|binding| { div().ml_4().child(binding.disabled(*disabled)).when( *disabled && documentation_aside.is_some(), diff --git a/crates/ui/src/components/disclosure.rs b/crates/ui/src/components/disclosure.rs index a1fab02e542b2caefce67855b43a6e1d9ea978b8..98406cd1e278b1028587535dc47105ff5d634cf7 100644 --- a/crates/ui/src/components/disclosure.rs +++ b/crates/ui/src/components/disclosure.rs @@ -95,7 +95,7 @@ impl RenderOnce for Disclosure { impl Component for Disclosure { fn scope() -> ComponentScope { - ComponentScope::Navigation + ComponentScope::Input } fn description() -> Option<&'static str> { diff --git a/crates/ui/src/components/dropdown_menu.rs b/crates/ui/src/components/dropdown_menu.rs index 189fac930fc78df0f53829d024acdf0ec1e5b784..7ad9400f0d0944a02292813f75c28fb1fcf9d78b 100644 --- a/crates/ui/src/components/dropdown_menu.rs +++ b/crates/ui/src/components/dropdown_menu.rs @@ -8,6 +8,7 @@ use super::PopoverMenuHandle; pub enum DropdownStyle { #[default] Solid, + Outlined, Ghost, } @@ -147,6 +148,23 @@ impl Component for DropdownMenu { ), ], ), + example_group_with_title( + "Styles", + vec![ + single_example( + "Outlined", + DropdownMenu::new("outlined", "Outlined Dropdown", menu.clone()) + .style(DropdownStyle::Outlined) + .into_any_element(), + ), + single_example( + "Ghost", + DropdownMenu::new("ghost", "Ghost Dropdown", menu.clone()) + .style(DropdownStyle::Ghost) + .into_any_element(), + ), + ], + ), example_group_with_title( "States", vec![single_example( @@ -170,10 +188,13 @@ pub struct DropdownTriggerStyle { impl DropdownTriggerStyle { pub fn for_style(style: DropdownStyle, cx: &App) -> Self { let colors = cx.theme().colors(); + let bg = match style { DropdownStyle::Solid => colors.editor_background, + DropdownStyle::Outlined => colors.surface_background, DropdownStyle::Ghost => colors.ghost_element_background, }; + Self { bg } } } @@ -244,29 +265,36 @@ impl RenderOnce for DropdownMenuTrigger { let disabled = self.disabled; let style = DropdownTriggerStyle::for_style(self.style, cx); + let is_outlined = matches!(self.style, DropdownStyle::Outlined); h_flex() .id("dropdown-menu-trigger") - .justify_between() - .rounded_sm() - .bg(style.bg) + .min_w_20() .pl_2() .pr_1p5() .py_0p5() .gap_2() - .min_w_20() - .map(|el| { + .justify_between() + .rounded_sm() + .map(|this| { if self.full_width { - el.w_full() + this.w_full() } else { - el.flex_none().w_auto() + this.flex_none().w_auto() } }) - .map(|el| { + .when(is_outlined, |this| { + this.border_1() + .border_color(cx.theme().colors().border) + .overflow_hidden() + }) + .map(|this| { if disabled { - el.cursor_not_allowed() + this.cursor_not_allowed() + .bg(cx.theme().colors().element_disabled) } else { - el.cursor_pointer() + this.bg(style.bg) + .hover(|s| s.bg(cx.theme().colors().element_hover)) } }) .child(match self.label { diff --git a/crates/ui/src/components/image.rs b/crates/ui/src/components/image.rs index 2deba68d88f52d19571146cd6c71da87565bc3eb..09c3bbeb943ca11a00d42621f0bdd73613efaee3 100644 --- a/crates/ui/src/components/image.rs +++ b/crates/ui/src/components/image.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use gpui::Transformation; use gpui::{App, IntoElement, Rems, RenderOnce, Size, Styled, Window, svg}; use serde::{Deserialize, Serialize}; use strum::{EnumIter, EnumString, IntoStaticStr}; @@ -12,11 +13,13 @@ use crate::prelude::*; )] #[strum(serialize_all = "snake_case")] pub enum VectorName { - ZedLogo, - ZedXCopilot, - Grid, AiGrid, DebuggerGrid, + Grid, + ProTrialStamp, + ProUserStamp, + ZedLogo, + ZedXCopilot, } impl VectorName { @@ -37,6 +40,7 @@ pub struct Vector { path: Arc<str>, color: Color, size: Size<Rems>, + transformation: Transformation, } impl Vector { @@ -46,6 +50,7 @@ impl Vector { path: vector.path(), color: Color::default(), size: Size { width, height }, + transformation: Transformation::default(), } } @@ -66,6 +71,11 @@ impl Vector { self.size = size; self } + + pub fn transform(mut self, transformation: Transformation) -> Self { + self.transformation = transformation; + self + } } impl RenderOnce for Vector { @@ -81,6 +91,7 @@ impl RenderOnce for Vector { .h(height) .path(self.path) .text_color(self.color.color(cx)) + .with_transformation(self.transformation) } } diff --git a/crates/ui/src/components/keybinding.rs b/crates/ui/src/components/keybinding.rs index 1d91492f26c7e9e93a761a1d9d46b06300ba3614..5779093ccc63845a8f8b5c151680b584f2b201bd 100644 --- a/crates/ui/src/components/keybinding.rs +++ b/crates/ui/src/components/keybinding.rs @@ -44,7 +44,7 @@ impl KeyBinding { pub fn for_action_in( action: &dyn Action, focus: &FocusHandle, - window: &mut Window, + window: &Window, cx: &App, ) -> Option<Self> { let key_binding = window.highest_precedence_binding_for_action_in(action, focus)?; diff --git a/crates/ui/src/components/keybinding_hint.rs b/crates/ui/src/components/keybinding_hint.rs index d6dc094d415bec9991b83dfc50a865a838c1bdf4..a34ca40ed8c413d2edd6278dd035b93329dc5339 100644 --- a/crates/ui/src/components/keybinding_hint.rs +++ b/crates/ui/src/components/keybinding_hint.rs @@ -206,7 +206,7 @@ impl RenderOnce for KeybindingHint { impl Component for KeybindingHint { fn scope() -> ComponentScope { - ComponentScope::None + ComponentScope::DataDisplay } fn description() -> Option<&'static str> { diff --git a/crates/ui/src/components/list.rs b/crates/ui/src/components/list.rs index 88650b6ae8d2aa8d60cf0dbe682474f417919a9d..6876f290ced9b33fbe368de47870e023ab0389a9 100644 --- a/crates/ui/src/components/list.rs +++ b/crates/ui/src/components/list.rs @@ -1,10 +1,12 @@ mod list; +mod list_bullet_item; mod list_header; mod list_item; mod list_separator; mod list_sub_header; pub use list::*; +pub use list_bullet_item::*; pub use list_header::*; pub use list_item::*; pub use list_separator::*; diff --git a/crates/ui/src/components/list/list.rs b/crates/ui/src/components/list/list.rs index 1402b5d3d3328f0b5344bb4d81c8bc8413a0dac0..b6950f06a4449265cccd48f9f13590650619a01c 100644 --- a/crates/ui/src/components/list/list.rs +++ b/crates/ui/src/components/list/list.rs @@ -84,7 +84,9 @@ impl RenderOnce for List { (false, _) => this.children(self.children), (true, Some(false)) => this, (true, _) => match self.empty_message { - EmptyMessage::Text(text) => this.child(Label::new(text).color(Color::Muted)), + EmptyMessage::Text(text) => { + this.px_2().child(Label::new(text).color(Color::Muted)) + } EmptyMessage::Element(element) => this.child(element), }, }) diff --git a/crates/ui/src/components/list/list_bullet_item.rs b/crates/ui/src/components/list/list_bullet_item.rs new file mode 100644 index 0000000000000000000000000000000000000000..6e079d9f112f74e5198ad174d823643ea923f822 --- /dev/null +++ b/crates/ui/src/components/list/list_bullet_item.rs @@ -0,0 +1,40 @@ +use crate::{ListItem, prelude::*}; +use gpui::{IntoElement, ParentElement, SharedString}; + +#[derive(IntoElement)] +pub struct ListBulletItem { + label: SharedString, +} + +impl ListBulletItem { + pub fn new(label: impl Into<SharedString>) -> Self { + Self { + label: label.into(), + } + } +} + +impl RenderOnce for ListBulletItem { + fn render(self, window: &mut Window, _cx: &mut App) -> impl IntoElement { + let line_height = 0.85 * window.line_height(); + + ListItem::new("list-item") + .selectable(false) + .child( + h_flex() + .w_full() + .min_w_0() + .gap_1() + .items_start() + .child( + h_flex().h(line_height).justify_center().child( + Icon::new(IconName::Dash) + .size(IconSize::XSmall) + .color(Color::Hidden), + ), + ) + .child(div().w_full().min_w_0().child(Label::new(self.label))), + ) + .into_any_element() + } +} diff --git a/crates/ui/src/components/modal.rs b/crates/ui/src/components/modal.rs index 2e926b7593808070ab65be36902b01483945e2ac..a70f5e1ea5a53a043086f3e102878f3614990d6e 100644 --- a/crates/ui/src/components/modal.rs +++ b/crates/ui/src/components/modal.rs @@ -1,5 +1,5 @@ use crate::{ - Clickable, Color, DynamicSpacing, Headline, HeadlineSize, IconButton, IconButtonShape, + Clickable, Color, DynamicSpacing, Headline, HeadlineSize, Icon, IconButton, IconButtonShape, IconName, Label, LabelCommon, LabelSize, h_flex, v_flex, }; use gpui::{prelude::FluentBuilder, *}; @@ -92,7 +92,9 @@ impl RenderOnce for Modal { #[derive(IntoElement)] pub struct ModalHeader { + icon: Option<Icon>, headline: Option<SharedString>, + description: Option<SharedString>, children: SmallVec<[AnyElement; 2]>, show_dismiss_button: bool, show_back_button: bool, @@ -107,13 +109,20 @@ impl Default for ModalHeader { impl ModalHeader { pub fn new() -> Self { Self { + icon: None, headline: None, + description: None, children: SmallVec::new(), show_dismiss_button: false, show_back_button: false, } } + pub fn icon(mut self, icon: Icon) -> Self { + self.icon = Some(icon); + self + } + /// Set the headline of the modal. /// /// This will insert the headline as the first item @@ -123,6 +132,11 @@ impl ModalHeader { self } + pub fn description(mut self, description: impl Into<SharedString>) -> Self { + self.description = Some(description.into()); + self + } + pub fn show_dismiss_button(mut self, show: bool) -> Self { self.show_dismiss_button = show; self @@ -171,7 +185,19 @@ impl RenderOnce for ModalHeader { }), ) }) - .child(div().flex_1().children(children)) + .child( + v_flex() + .flex_1() + .child( + h_flex() + .gap_1() + .when_some(self.icon, |this, icon| this.child(icon)) + .children(children), + ) + .when_some(self.description, |this, description| { + this.child(Label::new(description).color(Color::Muted).mb_2()) + }), + ) .when(self.show_dismiss_button, |this| { this.child( IconButton::new("dismiss", IconName::Close) diff --git a/crates/ui/src/components/numeric_stepper.rs b/crates/ui/src/components/numeric_stepper.rs index f9e6e88f01f64fb7c78e98410178b101383f9be4..2ddb86d9a0d595edffc76319b415f9f68f9c6b9c 100644 --- a/crates/ui/src/components/numeric_stepper.rs +++ b/crates/ui/src/components/numeric_stepper.rs @@ -2,15 +2,24 @@ use gpui::ClickEvent; use crate::{IconButtonShape, prelude::*}; -#[derive(IntoElement)] +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)] +pub enum NumericStepperStyle { + Outlined, + #[default] + Ghost, +} + +#[derive(IntoElement, RegisterComponent)] pub struct NumericStepper { id: ElementId, value: SharedString, + style: NumericStepperStyle, on_decrement: Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>, on_increment: Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>, /// Whether to reserve space for the reset button. reserve_space_for_reset: bool, on_reset: Option<Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>>, + tab_index: Option<isize>, } impl NumericStepper { @@ -23,13 +32,20 @@ impl NumericStepper { Self { id: id.into(), value: value.into(), + style: NumericStepperStyle::default(), on_decrement: Box::new(on_decrement), on_increment: Box::new(on_increment), reserve_space_for_reset: false, on_reset: None, + tab_index: None, } } + pub fn style(mut self, style: NumericStepperStyle) -> Self { + self.style = style; + self + } + pub fn reserve_space_for_reset(mut self, reserve_space_for_reset: bool) -> Self { self.reserve_space_for_reset = reserve_space_for_reset; self @@ -42,6 +58,11 @@ impl NumericStepper { self.on_reset = Some(Box::new(on_reset)); self } + + pub fn tab_index(mut self, tab_index: isize) -> Self { + self.tab_index = Some(tab_index); + self + } } impl RenderOnce for NumericStepper { @@ -49,6 +70,9 @@ impl RenderOnce for NumericStepper { let shape = IconButtonShape::Square; let icon_size = IconSize::Small; + let is_outlined = matches!(self.style, NumericStepperStyle::Outlined); + let mut tab_index = self.tab_index; + h_flex() .id(self.id) .gap_1() @@ -58,6 +82,10 @@ impl RenderOnce for NumericStepper { IconButton::new("reset", IconName::RotateCcw) .shape(shape) .icon_size(icon_size) + .when_some(tab_index.as_mut(), |this, tab_index| { + *tab_index += 1; + this.tab_index(*tab_index - 1) + }) .on_click(on_reset), ) } else if self.reserve_space_for_reset { @@ -74,22 +102,136 @@ impl RenderOnce for NumericStepper { .child( h_flex() .gap_1() - .px_1() - .rounded_xs() - .bg(cx.theme().colors().editor_background) - .child( - IconButton::new("decrement", IconName::Dash) - .shape(shape) - .icon_size(icon_size) - .on_click(self.on_decrement), - ) - .child(Label::new(self.value)) - .child( - IconButton::new("increment", IconName::Plus) - .shape(shape) - .icon_size(icon_size) - .on_click(self.on_increment), - ), + .rounded_sm() + .map(|this| { + if is_outlined { + this.overflow_hidden() + .bg(cx.theme().colors().surface_background) + .border_1() + .border_color(cx.theme().colors().border_variant) + } else { + this.px_1().bg(cx.theme().colors().editor_background) + } + }) + .map(|decrement| { + if is_outlined { + decrement.child( + h_flex() + .id("decrement_button") + .p_1p5() + .size_full() + .justify_center() + .hover(|s| s.bg(cx.theme().colors().element_hover)) + .border_r_1() + .border_color(cx.theme().colors().border_variant) + .child(Icon::new(IconName::Dash).size(IconSize::Small)) + .when_some(tab_index.as_mut(), |this, tab_index| { + *tab_index += 1; + this.tab_index(*tab_index - 1).focus(|style| { + style.bg(cx.theme().colors().element_hover) + }) + }) + .on_click(self.on_decrement), + ) + } else { + decrement.child( + IconButton::new("decrement", IconName::Dash) + .shape(shape) + .icon_size(icon_size) + .when_some(tab_index.as_mut(), |this, tab_index| { + *tab_index += 1; + this.tab_index(*tab_index - 1) + }) + .on_click(self.on_decrement), + ) + } + }) + .child(Label::new(self.value).mx_3()) + .map(|increment| { + if is_outlined { + increment.child( + h_flex() + .id("increment_button") + .p_1p5() + .size_full() + .justify_center() + .hover(|s| s.bg(cx.theme().colors().element_hover)) + .border_l_1() + .border_color(cx.theme().colors().border_variant) + .child(Icon::new(IconName::Plus).size(IconSize::Small)) + .when_some(tab_index.as_mut(), |this, tab_index| { + *tab_index += 1; + this.tab_index(*tab_index - 1).focus(|style| { + style.bg(cx.theme().colors().element_hover) + }) + }) + .on_click(self.on_increment), + ) + } else { + increment.child( + IconButton::new("increment", IconName::Dash) + .shape(shape) + .icon_size(icon_size) + .when_some(tab_index.as_mut(), |this, tab_index| { + *tab_index += 1; + this.tab_index(*tab_index - 1) + }) + .on_click(self.on_increment), + ) + } + }), ) } } + +impl Component for NumericStepper { + fn scope() -> ComponentScope { + ComponentScope::Input + } + + fn name() -> &'static str { + "Numeric Stepper" + } + + fn sort_name() -> &'static str { + Self::name() + } + + fn description() -> Option<&'static str> { + Some("A button used to increment or decrement a numeric value.") + } + + fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> { + Some( + v_flex() + .gap_6() + .children(vec![example_group_with_title( + "Styles", + vec![ + single_example( + "Default", + NumericStepper::new( + "numeric-stepper-component-preview", + "10", + move |_, _, _| {}, + move |_, _, _| {}, + ) + .into_any_element(), + ), + single_example( + "Outlined", + NumericStepper::new( + "numeric-stepper-with-border-component-preview", + "10", + move |_, _, _| {}, + move |_, _, _| {}, + ) + .style(NumericStepperStyle::Outlined) + .into_any_element(), + ), + ], + )]) + .into_any_element(), + ) + } +} diff --git a/crates/ui/src/components/popover.rs b/crates/ui/src/components/popover.rs index 24460f6d9ce8e28b80c1e345007a88e7ee21a7a9..7143514c5269baf6dba2802f96e59ec0f8634317 100644 --- a/crates/ui/src/components/popover.rs +++ b/crates/ui/src/components/popover.rs @@ -50,7 +50,7 @@ impl RenderOnce for Popover { v_flex() .elevation_2(cx) .py(POPOVER_Y_PADDING / 2.) - .children(self.children), + .child(div().children(self.children)), ) .when_some(self.aside, |this, aside| { this.child( diff --git a/crates/ui/src/components/progress/progress_bar.rs b/crates/ui/src/components/progress/progress_bar.rs index 67b6be6723fc9441a96321003f7194121467ea14..5cc5abd36d041bc03676410983020b94ac8d8809 100644 --- a/crates/ui/src/components/progress/progress_bar.rs +++ b/crates/ui/src/components/progress/progress_bar.rs @@ -69,8 +69,7 @@ impl RenderOnce for ProgressBar { .w_full() .h(px(8.0)) .rounded_full() - .py(px(2.0)) - .px(px(4.0)) + .p(px(2.0)) .bg(self.bg_color) .shadow(vec![gpui::BoxShadow { color: gpui::black().opacity(0.08), diff --git a/crates/ui/src/components/scrollbar.rs b/crates/ui/src/components/scrollbar.rs index 2a8c4885acff5f3b5e75c7e2f6ae62335f9b8ebe..605028202fffa37d67bbdb4a9f33a97459390dfa 100644 --- a/crates/ui/src/components/scrollbar.rs +++ b/crates/ui/src/components/scrollbar.rs @@ -1,11 +1,20 @@ -use std::{any::Any, cell::Cell, fmt::Debug, ops::Range, rc::Rc, sync::Arc}; +use std::{ + any::Any, + cell::{Cell, RefCell}, + fmt::Debug, + ops::Range, + rc::Rc, + sync::Arc, + time::Duration, +}; use crate::{IntoElement, prelude::*, px, relative}; use gpui::{ Along, App, Axis as ScrollbarAxis, BorderStyle, Bounds, ContentMask, Corners, CursorStyle, Edges, Element, ElementId, Entity, EntityId, GlobalElementId, Hitbox, HitboxBehavior, Hsla, - IsZero, LayoutId, ListState, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Pixels, Point, - ScrollHandle, ScrollWheelEvent, Size, Style, UniformListScrollHandle, Window, quad, + IsZero, LayoutId, ListState, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Pixels, + Point, ScrollHandle, ScrollWheelEvent, Size, Style, Task, UniformListScrollHandle, Window, + quad, }; pub struct Scrollbar { @@ -29,8 +38,8 @@ impl ThumbState { } impl ScrollableHandle for UniformListScrollHandle { - fn content_size(&self) -> Size<Pixels> { - self.0.borrow().base_handle.content_size() + fn max_offset(&self) -> Size<Pixels> { + self.0.borrow().base_handle.max_offset() } fn set_offset(&self, point: Point<Pixels>) { @@ -47,8 +56,8 @@ impl ScrollableHandle for UniformListScrollHandle { } impl ScrollableHandle for ListState { - fn content_size(&self) -> Size<Pixels> { - self.content_size_for_scrollbar() + fn max_offset(&self) -> Size<Pixels> { + self.max_offset_for_scrollbar() } fn set_offset(&self, point: Point<Pixels>) { @@ -73,8 +82,8 @@ impl ScrollableHandle for ListState { } impl ScrollableHandle for ScrollHandle { - fn content_size(&self) -> Size<Pixels> { - self.padded_content_size() + fn max_offset(&self) -> Size<Pixels> { + self.max_offset() } fn set_offset(&self, point: Point<Pixels>) { @@ -91,7 +100,10 @@ impl ScrollableHandle for ScrollHandle { } pub trait ScrollableHandle: Any + Debug { - fn content_size(&self) -> Size<Pixels>; + fn content_size(&self) -> Size<Pixels> { + self.viewport().size + self.max_offset() + } + fn max_offset(&self) -> Size<Pixels>; fn set_offset(&self, point: Point<Pixels>); fn offset(&self) -> Point<Pixels>; fn viewport(&self) -> Bounds<Pixels>; @@ -105,6 +117,25 @@ pub struct ScrollbarState { thumb_state: Rc<Cell<ThumbState>>, parent_id: Option<EntityId>, scroll_handle: Arc<dyn ScrollableHandle>, + auto_hide: Rc<RefCell<AutoHide>>, +} + +#[derive(Debug)] +enum AutoHide { + Disabled, + Hidden { + parent_id: EntityId, + }, + Visible { + parent_id: EntityId, + _task: Task<()>, + }, +} + +impl AutoHide { + fn is_hidden(&self) -> bool { + matches!(self, AutoHide::Hidden { .. }) + } } impl ScrollbarState { @@ -113,6 +144,7 @@ impl ScrollbarState { thumb_state: Default::default(), parent_id: None, scroll_handle: Arc::new(scroll), + auto_hide: Rc::new(RefCell::new(AutoHide::Disabled)), } } @@ -149,17 +181,17 @@ impl ScrollbarState { fn thumb_range(&self, axis: ScrollbarAxis) -> Option<Range<f32>> { const MINIMUM_THUMB_SIZE: Pixels = px(25.); - let content_size = self.scroll_handle.content_size().along(axis); + let max_offset = self.scroll_handle.max_offset().along(axis); let viewport_size = self.scroll_handle.viewport().size.along(axis); - if content_size.is_zero() || viewport_size.is_zero() || content_size <= viewport_size { + if max_offset.is_zero() || viewport_size.is_zero() { return None; } + let content_size = viewport_size + max_offset; let visible_percentage = viewport_size / content_size; let thumb_size = MINIMUM_THUMB_SIZE.max(viewport_size * visible_percentage); if thumb_size > viewport_size { return None; } - let max_offset = content_size - viewport_size; let current_offset = self .scroll_handle .offset() @@ -171,6 +203,38 @@ impl ScrollbarState { let thumb_percentage_end = (start_offset + thumb_size) / viewport_size; Some(thumb_percentage_start..thumb_percentage_end) } + + fn show_temporarily(&self, parent_id: EntityId, cx: &mut App) { + const SHOW_INTERVAL: Duration = Duration::from_secs(1); + + let auto_hide = self.auto_hide.clone(); + auto_hide.replace(AutoHide::Visible { + parent_id, + _task: cx.spawn({ + let this = auto_hide.clone(); + async move |cx| { + cx.background_executor().timer(SHOW_INTERVAL).await; + this.replace(AutoHide::Hidden { parent_id }); + cx.update(|cx| { + cx.notify(parent_id); + }) + .ok(); + } + }), + }); + } + + fn unhide(&self, position: &Point<Pixels>, cx: &mut App) { + let parent_id = match &*self.auto_hide.borrow() { + AutoHide::Disabled => return, + AutoHide::Hidden { parent_id } => *parent_id, + AutoHide::Visible { parent_id, _task } => *parent_id, + }; + + if self.scroll_handle().viewport().contains(position) { + self.show_temporarily(parent_id, cx); + } + } } impl Scrollbar { @@ -186,6 +250,14 @@ impl Scrollbar { let thumb = state.thumb_range(kind)?; Some(Self { thumb, state, kind }) } + + /// Automatically hide the scrollbar when idle + pub fn auto_hide<V: 'static>(self, cx: &mut Context<V>) -> Self { + if matches!(*self.state.auto_hide.borrow(), AutoHide::Disabled) { + self.state.show_temporarily(cx.entity_id(), cx); + } + self + } } impl Element for Scrollbar { @@ -281,16 +353,18 @@ impl Element for Scrollbar { .apply_along(axis.invert(), |width| width / 1.5), ); - let corners = Corners::all(thumb_bounds.size.along(axis.invert()) / 2.0); - - window.paint_quad(quad( - thumb_bounds, - corners, - thumb_background, - Edges::default(), - Hsla::transparent_black(), - BorderStyle::default(), - )); + if thumb_state.is_dragging() || !self.state.auto_hide.borrow().is_hidden() { + let corners = Corners::all(thumb_bounds.size.along(axis.invert()) / 2.0); + + window.paint_quad(quad( + thumb_bounds, + corners, + thumb_background, + Edges::default(), + Hsla::transparent_black(), + BorderStyle::default(), + )); + } if thumb_state.is_dragging() { window.set_window_cursor_style(CursorStyle::Arrow); @@ -298,8 +372,6 @@ impl Element for Scrollbar { window.set_cursor_style(CursorStyle::Arrow, hitbox); } - let scroll = self.state.scroll_handle.clone(); - enum ScrollbarMouseEvent { GutterClick, ThumbDrag(Pixels), @@ -307,7 +379,7 @@ impl Element for Scrollbar { let compute_click_offset = move |event_position: Point<Pixels>, - item_size: Size<Pixels>, + max_offset: Size<Pixels>, event_type: ScrollbarMouseEvent| { let viewport_size = padded_bounds.size.along(axis); @@ -323,7 +395,7 @@ impl Element for Scrollbar { - thumb_offset) .clamp(px(0.), viewport_size - thumb_size); - let max_offset = (item_size.along(axis) - viewport_size).max(px(0.)); + let max_offset = max_offset.along(axis); let percentage = if viewport_size > thumb_size { thumb_start / (viewport_size - thumb_size) } else { @@ -334,10 +406,12 @@ impl Element for Scrollbar { }; window.on_mouse_event({ - let scroll = scroll.clone(); let state = self.state.clone(); move |event: &MouseDownEvent, phase, _, _| { - if !(phase.bubble() && bounds.contains(&event.position)) { + if !phase.bubble() + || event.button != MouseButton::Left + || !bounds.contains(&event.position) + { return; } @@ -345,57 +419,78 @@ impl Element for Scrollbar { let offset = event.position.along(axis) - thumb_bounds.origin.along(axis); state.set_dragging(offset); } else { + let scroll_handle = state.scroll_handle(); let click_offset = compute_click_offset( event.position, - scroll.content_size(), + scroll_handle.max_offset(), ScrollbarMouseEvent::GutterClick, ); - scroll.set_offset(scroll.offset().apply_along(axis, |_| click_offset)); + scroll_handle + .set_offset(scroll_handle.offset().apply_along(axis, |_| click_offset)); } } }); window.on_mouse_event({ - let scroll = scroll.clone(); - move |event: &ScrollWheelEvent, phase, window, _| { - if phase.bubble() && bounds.contains(&event.position) { - let current_offset = scroll.offset(); - scroll.set_offset( - current_offset + event.delta.pixel_delta(window.line_height()), - ); + let state = self.state.clone(); + let scroll_handle = self.state.scroll_handle().clone(); + move |event: &ScrollWheelEvent, phase, window, cx| { + if phase.bubble() { + state.unhide(&event.position, cx); + + if bounds.contains(&event.position) { + let current_offset = scroll_handle.offset(); + scroll_handle.set_offset( + current_offset + event.delta.pixel_delta(window.line_height()), + ); + } } } }); - let state = self.state.clone(); - window.on_mouse_event(move |event: &MouseMoveEvent, _, window, cx| { - match state.thumb_state.get() { - ThumbState::Dragging(drag_state) if event.dragging() => { - let drag_offset = compute_click_offset( - event.position, - scroll.content_size(), - ScrollbarMouseEvent::ThumbDrag(drag_state), - ); - scroll.set_offset(scroll.offset().apply_along(axis, |_| drag_offset)); - window.refresh(); - if let Some(id) = state.parent_id { - cx.notify(id); + window.on_mouse_event({ + let state = self.state.clone(); + move |event: &MouseMoveEvent, phase, window, cx| { + if phase.bubble() { + state.unhide(&event.position, cx); + + match state.thumb_state.get() { + ThumbState::Dragging(drag_state) if event.dragging() => { + let scroll_handle = state.scroll_handle(); + let drag_offset = compute_click_offset( + event.position, + scroll_handle.max_offset(), + ScrollbarMouseEvent::ThumbDrag(drag_state), + ); + scroll_handle.set_offset( + scroll_handle.offset().apply_along(axis, |_| drag_offset), + ); + window.refresh(); + if let Some(id) = state.parent_id { + cx.notify(id); + } + } + _ if event.pressed_button.is_none() => { + state.set_thumb_hovered(thumb_bounds.contains(&event.position)) + } + _ => {} } } - _ => state.set_thumb_hovered(thumb_bounds.contains(&event.position)), } }); - let state = self.state.clone(); - let scroll = self.state.scroll_handle.clone(); - window.on_mouse_event(move |event: &MouseUpEvent, phase, _, cx| { - if phase.bubble() { - if state.is_dragging() { + + window.on_mouse_event({ + let state = self.state.clone(); + move |event: &MouseUpEvent, phase, _, cx| { + if phase.bubble() { + if state.is_dragging() { + state.scroll_handle().drag_ended(); + if let Some(id) = state.parent_id { + cx.notify(id); + } + } state.set_thumb_hovered(thumb_bounds.contains(&event.position)); } - scroll.drag_ended(); - if let Some(id) = state.parent_id { - cx.notify(id); - } } }); }) diff --git a/crates/ui/src/components/sticky_items.rs b/crates/ui/src/components/sticky_items.rs index 218f7aae3510213afeed9d80a28428ce9c0df28a..ca8b336a5aa97101f29d394399f36bda2fcc44b9 100644 --- a/crates/ui/src/components/sticky_items.rs +++ b/crates/ui/src/components/sticky_items.rs @@ -149,47 +149,7 @@ where ) -> AnyElement { let entries = (self.compute_fn)(visible_range.clone(), window, cx); - struct StickyAnchor<T> { - entry: T, - index: usize, - } - - let mut sticky_anchor = None; - let mut last_item_is_drifting = false; - - let mut iter = entries.iter().enumerate().peekable(); - while let Some((ix, current_entry)) = iter.next() { - let depth = current_entry.depth(); - - if depth < ix { - sticky_anchor = Some(StickyAnchor { - entry: current_entry.clone(), - index: visible_range.start + ix, - }); - break; - } - - if let Some(&(_next_ix, next_entry)) = iter.peek() { - let next_depth = next_entry.depth(); - let next_item_outdented = next_depth + 1 == depth; - - let depth_same_as_index = depth == ix; - let depth_greater_than_index = depth == ix + 1; - - if next_item_outdented && (depth_same_as_index || depth_greater_than_index) { - if depth_greater_than_index { - last_item_is_drifting = true; - } - sticky_anchor = Some(StickyAnchor { - entry: current_entry.clone(), - index: visible_range.start + ix, - }); - break; - } - } - } - - let Some(sticky_anchor) = sticky_anchor else { + let Some(sticky_anchor) = find_sticky_anchor(&entries, visible_range.start) else { return StickyItemsElement { drifting_element: None, drifting_decoration: None, @@ -203,23 +163,21 @@ where let mut elements = (self.render_fn)(sticky_anchor.entry, window, cx); let items_count = elements.len(); - let indents: SmallVec<[usize; 8]> = { - elements - .iter() - .enumerate() - .map(|(ix, _)| anchor_depth.saturating_sub(items_count.saturating_sub(ix))) - .collect() - }; + let indents: SmallVec<[usize; 8]> = (0..items_count) + .map(|ix| anchor_depth.saturating_sub(items_count.saturating_sub(ix))) + .collect(); let mut last_decoration_element = None; let mut rest_decoration_elements = SmallVec::new(); - let available_space = size( - AvailableSpace::Definite(bounds.size.width), + let expanded_width = bounds.size.width + scroll_offset.x.abs(); + + let decor_available_space = size( + AvailableSpace::Definite(expanded_width), AvailableSpace::Definite(bounds.size.height), ); - let drifting_y_offset = if last_item_is_drifting { + let drifting_y_offset = if sticky_anchor.drifting { let scroll_top = -scroll_offset.y; let anchor_top = item_height * (sticky_anchor.index + 1); let sticky_area_height = item_height * items_count; @@ -228,7 +186,7 @@ where Pixels::ZERO }; - let (drifting_indent, rest_indents) = if last_item_is_drifting && !indents.is_empty() { + let (drifting_indent, rest_indents) = if sticky_anchor.drifting && !indents.is_empty() { let last = indents[indents.len() - 1]; let rest: SmallVec<[usize; 8]> = indents[..indents.len() - 1].iter().copied().collect(); (Some(last), rest) @@ -236,11 +194,14 @@ where (None, indents) }; + let base_origin = bounds.origin - point(px(0.), scroll_offset.y); + for decoration in &self.decorations { if let Some(drifting_indent) = drifting_indent { let drifting_indent_vec: SmallVec<[usize; 8]> = [drifting_indent].into_iter().collect(); - let sticky_origin = bounds.origin - scroll_offset + + let sticky_origin = base_origin + point(px(0.), item_height * rest_indents.len() + drifting_y_offset); let decoration_bounds = Bounds::new(sticky_origin, bounds.size); @@ -252,13 +213,13 @@ where window, cx, ); - drifting_dec.layout_as_root(available_space, window, cx); + drifting_dec.layout_as_root(decor_available_space, window, cx); drifting_dec.prepaint_at(sticky_origin, window, cx); last_decoration_element = Some(drifting_dec); } if !rest_indents.is_empty() { - let decoration_bounds = Bounds::new(bounds.origin - scroll_offset, bounds.size); + let decoration_bounds = Bounds::new(base_origin, bounds.size); let mut rest_dec = decoration.as_ref().compute( &rest_indents, decoration_bounds, @@ -267,46 +228,45 @@ where window, cx, ); - rest_dec.layout_as_root(available_space, window, cx); + rest_dec.layout_as_root(decor_available_space, window, cx); rest_dec.prepaint_at(bounds.origin, window, cx); rest_decoration_elements.push(rest_dec); } } let (mut drifting_element, mut rest_elements) = - if last_item_is_drifting && !elements.is_empty() { + if sticky_anchor.drifting && !elements.is_empty() { let last = elements.pop().unwrap(); (Some(last), elements) } else { (None, elements) }; - for (ix, element) in rest_elements.iter_mut().enumerate() { - let sticky_origin = bounds.origin - scroll_offset + point(px(0.), item_height * ix); - let element_available_space = size( - AvailableSpace::Definite(bounds.size.width), - AvailableSpace::Definite(item_height), - ); - - element.layout_as_root(element_available_space, window, cx); - element.prepaint_at(sticky_origin, window, cx); - } + let element_available_space = size( + AvailableSpace::Definite(expanded_width), + AvailableSpace::Definite(item_height), + ); + // order of prepaint is important here + // mouse events checks hitboxes in reverse insertion order if let Some(ref mut drifting_element) = drifting_element { - let sticky_origin = bounds.origin - scroll_offset + let sticky_origin = base_origin + point( px(0.), item_height * rest_elements.len() + drifting_y_offset, ); - let element_available_space = size( - AvailableSpace::Definite(bounds.size.width), - AvailableSpace::Definite(item_height), - ); drifting_element.layout_as_root(element_available_space, window, cx); drifting_element.prepaint_at(sticky_origin, window, cx); } + for (ix, element) in rest_elements.iter_mut().enumerate() { + let sticky_origin = base_origin + point(px(0.), item_height * ix); + + element.layout_as_root(element_available_space, window, cx); + element.prepaint_at(sticky_origin, window, cx); + } + StickyItemsElement { drifting_element, drifting_decoration: last_decoration_element, @@ -317,6 +277,48 @@ where } } +struct StickyAnchor<T> { + entry: T, + index: usize, + drifting: bool, +} + +fn find_sticky_anchor<T: StickyCandidate + Clone>( + entries: &SmallVec<[T; 8]>, + visible_range_start: usize, +) -> Option<StickyAnchor<T>> { + let mut iter = entries.iter().enumerate().peekable(); + while let Some((ix, current_entry)) = iter.next() { + let depth = current_entry.depth(); + + if depth < ix { + return Some(StickyAnchor { + entry: current_entry.clone(), + index: visible_range_start + ix, + drifting: false, + }); + } + + if let Some(&(_next_ix, next_entry)) = iter.peek() { + let next_depth = next_entry.depth(); + let next_item_outdented = next_depth + 1 == depth; + + let depth_same_as_index = depth == ix; + let depth_greater_than_index = depth == ix + 1; + + if next_item_outdented && (depth_same_as_index || depth_greater_than_index) { + return Some(StickyAnchor { + entry: current_entry.clone(), + index: visible_range_start + ix, + drifting: depth_greater_than_index, + }); + } + } + } + + None +} + /// A decoration for a [`StickyItems`]. This can be used for various things, /// such as rendering indent guides, or other visual effects. pub trait StickyItemsDecoration { diff --git a/crates/ui/src/components/stories/icon_button.rs b/crates/ui/src/components/stories/icon_button.rs index e787e81b5599756086f9552b6c1e719a6819e7ea..ad6886252d9beeabb64696dbb12292bfd841eb19 100644 --- a/crates/ui/src/components/stories/icon_button.rs +++ b/crates/ui/src/components/stories/icon_button.rs @@ -77,7 +77,7 @@ impl Render for IconButtonStory { let with_tooltip_button = StoryItem::new( "With `tooltip`", - IconButton::new("with_tooltip_button", IconName::MessageBubbles) + IconButton::new("with_tooltip_button", IconName::Chat) .tooltip(Tooltip::text("Open messages")), ) .description("Displays an icon button that has a tooltip when hovered.") diff --git a/crates/ui/src/components/tab.rs b/crates/ui/src/components/tab.rs index a205c33358eb7ac46d81572948a3165967b734d6..d704846a6834e094a6a6aeb5fcf1fda6ea66c8b2 100644 --- a/crates/ui/src/components/tab.rs +++ b/crates/ui/src/components/tab.rs @@ -179,7 +179,7 @@ impl RenderOnce for Tab { impl Component for Tab { fn scope() -> ComponentScope { - ComponentScope::None + ComponentScope::Navigation } fn description() -> Option<&'static str> { diff --git a/crates/ui/src/components/toggle.rs b/crates/ui/src/components/toggle.rs index 7a12e1f44509ba9f116305c1b4710da3c7e47001..4b985fd2c2552f731a7b16c6e850b522c26460db 100644 --- a/crates/ui/src/components/toggle.rs +++ b/crates/ui/src/components/toggle.rs @@ -1,10 +1,11 @@ use gpui::{ - AnyElement, AnyView, ElementId, Hsla, IntoElement, Styled, Window, div, hsla, prelude::*, + AnyElement, AnyView, ClickEvent, ElementId, Hsla, IntoElement, Styled, Window, div, hsla, + prelude::*, }; -use std::sync::Arc; +use std::{rc::Rc, sync::Arc}; use crate::utils::is_light; -use crate::{Color, Icon, IconName, ToggleState}; +use crate::{Color, Icon, IconName, ToggleState, Tooltip}; use crate::{ElevationIndex, KeyBinding, prelude::*}; // TODO: Checkbox, CheckboxWithLabel, and Switch could all be @@ -44,7 +45,7 @@ pub struct Checkbox { toggle_state: ToggleState, disabled: bool, placeholder: bool, - on_click: Option<Box<dyn Fn(&ToggleState, &mut Window, &mut App) + 'static>>, + on_click: Option<Box<dyn Fn(&ToggleState, &ClickEvent, &mut Window, &mut App) + 'static>>, filled: bool, style: ToggleStyle, tooltip: Option<Box<dyn Fn(&mut Window, &mut App) -> AnyView>>, @@ -83,6 +84,16 @@ impl Checkbox { pub fn on_click( mut self, handler: impl Fn(&ToggleState, &mut Window, &mut App) + 'static, + ) -> Self { + self.on_click = Some(Box::new(move |state, _, window, cx| { + handler(state, window, cx) + })); + self + } + + pub fn on_click_ext( + mut self, + handler: impl Fn(&ToggleState, &ClickEvent, &mut Window, &mut App) + 'static, ) -> Self { self.on_click = Some(Box::new(handler)); self @@ -226,8 +237,8 @@ impl RenderOnce for Checkbox { .when_some( self.on_click.filter(|_| !self.disabled), |this, on_click| { - this.on_click(move |_, window, cx| { - on_click(&self.toggle_state.inverse(), window, cx) + this.on_click(move |click, window, cx| { + on_click(&self.toggle_state.inverse(), click, window, cx) }) }, ) @@ -413,6 +424,7 @@ pub struct Switch { label: Option<SharedString>, key_binding: Option<KeyBinding>, color: SwitchColor, + tab_index: Option<isize>, } impl Switch { @@ -426,6 +438,7 @@ impl Switch { label: None, key_binding: None, color: SwitchColor::default(), + tab_index: None, } } @@ -461,6 +474,11 @@ impl Switch { self.key_binding = key_binding.into(); self } + + pub fn tab_index(mut self, tab_index: impl Into<isize>) -> Self { + self.tab_index = Some(tab_index.into()); + self + } } impl RenderOnce for Switch { @@ -486,29 +504,46 @@ impl RenderOnce for Switch { let group_id = format!("switch_group_{:?}", self.id); - let switch = h_flex() - .w(DynamicSpacing::Base32.rems(cx)) - .h(DynamicSpacing::Base20.rems(cx)) - .group(group_id.clone()) + let switch = div() + .id((self.id.clone(), "switch")) + .p(px(1.0)) + .border_2() + .border_color(cx.theme().colors().border_transparent) + .rounded_full() + .when_some( + self.tab_index.filter(|_| !self.disabled), + |this, tab_index| { + this.tab_index(tab_index).focus(|mut style| { + style.border_color = Some(cx.theme().colors().border_focused); + style + }) + }, + ) .child( h_flex() - .when(is_on, |on| on.justify_end()) - .when(!is_on, |off| off.justify_start()) - .size_full() - .rounded_full() - .px(DynamicSpacing::Base02.px(cx)) - .bg(bg_color) - .when(!self.disabled, |this| { - this.group_hover(group_id.clone(), |el| el.bg(bg_hover_color)) - }) - .border_1() - .border_color(border_color) + .w(DynamicSpacing::Base32.rems(cx)) + .h(DynamicSpacing::Base20.rems(cx)) + .group(group_id.clone()) .child( - div() - .size(DynamicSpacing::Base12.rems(cx)) + h_flex() + .when(is_on, |on| on.justify_end()) + .when(!is_on, |off| off.justify_start()) + .size_full() .rounded_full() - .bg(thumb_color) - .opacity(thumb_opacity), + .px(DynamicSpacing::Base02.px(cx)) + .bg(bg_color) + .when(!self.disabled, |this| { + this.group_hover(group_id.clone(), |el| el.bg(bg_hover_color)) + }) + .border_1() + .border_color(border_color) + .child( + div() + .size(DynamicSpacing::Base12.rems(cx)) + .rounded_full() + .bg(thumb_color) + .opacity(thumb_opacity), + ), ), ); @@ -532,69 +567,287 @@ impl RenderOnce for Switch { } } -/// A [`Switch`] that has a [`Label`]. -#[derive(IntoElement)] -pub struct SwitchWithLabel { +/// # SwitchField +/// +/// A field component that combines a label, description, and switch into one reusable component. +/// +/// # Examples +/// +/// ``` +/// use ui::prelude::*; +/// +/// SwitchField::new( +/// "feature-toggle", +/// "Enable feature", +/// "This feature adds new functionality to the app.", +/// ToggleState::Unselected, +/// |state, window, cx| { +/// // Logic here +/// } +/// ); +/// ``` +#[derive(IntoElement, RegisterComponent)] +pub struct SwitchField { id: ElementId, - label: Label, + label: SharedString, + description: Option<SharedString>, toggle_state: ToggleState, on_click: Arc<dyn Fn(&ToggleState, &mut Window, &mut App) + 'static>, disabled: bool, color: SwitchColor, + tooltip: Option<Rc<dyn Fn(&mut Window, &mut App) -> AnyView>>, + tab_index: Option<isize>, } -impl SwitchWithLabel { - /// Creates a switch with an attached label. +impl SwitchField { pub fn new( id: impl Into<ElementId>, - label: Label, + label: impl Into<SharedString>, + description: Option<SharedString>, toggle_state: impl Into<ToggleState>, on_click: impl Fn(&ToggleState, &mut Window, &mut App) + 'static, ) -> Self { Self { id: id.into(), - label, + label: label.into(), + description: description, toggle_state: toggle_state.into(), on_click: Arc::new(on_click), disabled: false, - color: SwitchColor::default(), + color: SwitchColor::Accent, + tooltip: None, + tab_index: None, } } - /// Sets the disabled state of the [`SwitchWithLabel`]. + pub fn description(mut self, description: impl Into<SharedString>) -> Self { + self.description = Some(description.into()); + self + } + pub fn disabled(mut self, disabled: bool) -> Self { self.disabled = disabled; self } /// Sets the color of the switch using the specified [`SwitchColor`]. + /// This changes the color scheme of the switch when it's in the "on" state. pub fn color(mut self, color: SwitchColor) -> Self { self.color = color; self } + + pub fn tooltip(mut self, tooltip: impl Fn(&mut Window, &mut App) -> AnyView + 'static) -> Self { + self.tooltip = Some(Rc::new(tooltip)); + self + } + + pub fn tab_index(mut self, tab_index: isize) -> Self { + self.tab_index = Some(tab_index); + self + } } -impl RenderOnce for SwitchWithLabel { - fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { +impl RenderOnce for SwitchField { + fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement { + let tooltip = self.tooltip.map(|tooltip_fn| { + h_flex() + .gap_0p5() + .child(Label::new(self.label.clone())) + .child( + IconButton::new("tooltip_button", IconName::Info) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .shape(crate::IconButtonShape::Square) + .tooltip({ + let tooltip = tooltip_fn.clone(); + move |window, cx| tooltip(window, cx) + }), + ) + }); + h_flex() - .id(SharedString::from(format!("{}-container", self.id))) - .gap(DynamicSpacing::Base08.rems(cx)) + .id((self.id.clone(), "container")) + .when(!self.disabled, |this| { + this.hover(|this| this.cursor_pointer()) + }) + .w_full() + .gap_4() + .justify_between() + .flex_wrap() + .child(match (&self.description, tooltip) { + (Some(description), Some(tooltip)) => v_flex() + .gap_0p5() + .max_w_5_6() + .child(tooltip) + .child(Label::new(description.clone()).color(Color::Muted)) + .into_any_element(), + (Some(description), None) => v_flex() + .gap_0p5() + .max_w_5_6() + .child(Label::new(self.label.clone())) + .child(Label::new(description.clone()).color(Color::Muted)) + .into_any_element(), + (None, Some(tooltip)) => tooltip.into_any_element(), + (None, None) => Label::new(self.label.clone()).into_any_element(), + }) .child( - Switch::new(self.id.clone(), self.toggle_state) - .disabled(self.disabled) + Switch::new((self.id.clone(), "switch"), self.toggle_state) .color(self.color) + .disabled(self.disabled) + .when_some( + self.tab_index.filter(|_| !self.disabled), + |this, tab_index| this.tab_index(tab_index), + ) .on_click({ let on_click = self.on_click.clone(); - move |checked, window, cx| { - (on_click)(checked, window, cx); + move |state, window, cx| { + (on_click)(state, window, cx); } }), ) - .child( - div() - .id(SharedString::from(format!("{}-label", self.id))) - .child(self.label), - ) + .when(!self.disabled, |this| { + this.on_click({ + let on_click = self.on_click.clone(); + let toggle_state = self.toggle_state; + move |_click, window, cx| { + (on_click)(&toggle_state.inverse(), window, cx); + } + }) + }) + } +} + +impl Component for SwitchField { + fn scope() -> ComponentScope { + ComponentScope::Input + } + + fn description() -> Option<&'static str> { + Some("A field component that combines a label, description, and switch") + } + + fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> { + Some( + v_flex() + .gap_6() + .children(vec![ + example_group_with_title( + "States", + vec![ + single_example( + "Unselected", + SwitchField::new( + "switch_field_unselected", + "Enable notifications", + Some("Receive notifications when new messages arrive.".into()), + ToggleState::Unselected, + |_, _, _| {}, + ) + .into_any_element(), + ), + single_example( + "Selected", + SwitchField::new( + "switch_field_selected", + "Enable notifications", + Some("Receive notifications when new messages arrive.".into()), + ToggleState::Selected, + |_, _, _| {}, + ) + .into_any_element(), + ), + ], + ), + example_group_with_title( + "Colors", + vec![ + single_example( + "Default", + SwitchField::new( + "switch_field_default", + "Default color", + Some("This uses the default switch color.".into()), + ToggleState::Selected, + |_, _, _| {}, + ) + .into_any_element(), + ), + single_example( + "Accent", + SwitchField::new( + "switch_field_accent", + "Accent color", + Some("This uses the accent color scheme.".into()), + ToggleState::Selected, + |_, _, _| {}, + ) + .color(SwitchColor::Accent) + .into_any_element(), + ), + ], + ), + example_group_with_title( + "Disabled", + vec![single_example( + "Disabled", + SwitchField::new( + "switch_field_disabled", + "Disabled field", + Some("This field is disabled and cannot be toggled.".into()), + ToggleState::Selected, + |_, _, _| {}, + ) + .disabled(true) + .into_any_element(), + )], + ), + example_group_with_title( + "No Description", + vec![single_example( + "No Description", + SwitchField::new( + "switch_field_disabled", + "Disabled field", + None, + ToggleState::Selected, + |_, _, _| {}, + ) + .into_any_element(), + )], + ), + example_group_with_title( + "With Tooltip", + vec![ + single_example( + "Tooltip with Description", + SwitchField::new( + "switch_field_tooltip_with_desc", + "Nice Feature", + Some("Enable advanced configuration options.".into()), + ToggleState::Unselected, + |_, _, _| {}, + ) + .tooltip(Tooltip::text("This is content for this tooltip!")) + .into_any_element(), + ), + single_example( + "Tooltip without Description", + SwitchField::new( + "switch_field_tooltip_no_desc", + "Nice Feature", + None, + ToggleState::Selected, + |_, _, _| {}, + ) + .tooltip(Tooltip::text("This is content for this tooltip!")) + .into_any_element(), + ), + ], + ), + ]) + .into_any_element(), + ) } } diff --git a/crates/ui/src/components/tooltip.rs b/crates/ui/src/components/tooltip.rs index 647b700c377b4ca6816924e592f698937646b6b8..ed0fdd0114137256273f420acd647228bf605218 100644 --- a/crates/ui/src/components/tooltip.rs +++ b/crates/ui/src/components/tooltip.rs @@ -1,3 +1,5 @@ +use std::rc::Rc; + use gpui::{Action, AnyElement, AnyView, AppContext as _, FocusHandle, IntoElement, Render}; use settings::Settings; use theme::ThemeSettings; @@ -7,15 +9,36 @@ use crate::{Color, KeyBinding, Label, LabelSize, StyledExt, h_flex, v_flex}; #[derive(RegisterComponent)] pub struct Tooltip { - title: SharedString, + title: Title, meta: Option<SharedString>, key_binding: Option<KeyBinding>, } +#[derive(Clone, IntoElement)] +enum Title { + Str(SharedString), + Callback(Rc<dyn Fn(&mut Window, &mut App) -> AnyElement>), +} + +impl From<SharedString> for Title { + fn from(value: SharedString) -> Self { + Title::Str(value) + } +} + +impl RenderOnce for Title { + fn render(self, window: &mut Window, cx: &mut App) -> impl gpui::IntoElement { + match self { + Title::Str(title) => title.into_any_element(), + Title::Callback(element) => element(window, cx), + } + } +} + impl Tooltip { pub fn simple(title: impl Into<SharedString>, cx: &mut App) -> AnyView { cx.new(|_| Self { - title: title.into(), + title: Title::Str(title.into()), meta: None, key_binding: None, }) @@ -26,7 +49,7 @@ impl Tooltip { let title = title.into(); move |_, cx| { cx.new(|_| Self { - title: title.clone(), + title: title.clone().into(), meta: None, key_binding: None, }) @@ -34,15 +57,15 @@ impl Tooltip { } } - pub fn for_action_title<Title: Into<SharedString>>( - title: Title, + pub fn for_action_title<T: Into<SharedString>>( + title: T, action: &dyn Action, - ) -> impl Fn(&mut Window, &mut App) -> AnyView + use<Title> { + ) -> impl Fn(&mut Window, &mut App) -> AnyView + use<T> { let title = title.into(); let action = action.boxed_clone(); move |window, cx| { cx.new(|cx| Self { - title: title.clone(), + title: Title::Str(title.clone()), meta: None, key_binding: KeyBinding::for_action(action.as_ref(), window, cx), }) @@ -60,7 +83,7 @@ impl Tooltip { let focus_handle = focus_handle.clone(); move |window, cx| { cx.new(|cx| Self { - title: title.clone(), + title: Title::Str(title.clone()), meta: None, key_binding: KeyBinding::for_action_in(action.as_ref(), &focus_handle, window, cx), }) @@ -75,7 +98,7 @@ impl Tooltip { cx: &mut App, ) -> AnyView { cx.new(|cx| Self { - title: title.into(), + title: Title::Str(title.into()), meta: None, key_binding: KeyBinding::for_action(action, window, cx), }) @@ -90,7 +113,7 @@ impl Tooltip { cx: &mut App, ) -> AnyView { cx.new(|cx| Self { - title: title.into(), + title: title.into().into(), meta: None, key_binding: KeyBinding::for_action_in(action, focus_handle, window, cx), }) @@ -105,7 +128,7 @@ impl Tooltip { cx: &mut App, ) -> AnyView { cx.new(|cx| Self { - title: title.into(), + title: title.into().into(), meta: Some(meta.into()), key_binding: action.and_then(|action| KeyBinding::for_action(action, window, cx)), }) @@ -121,7 +144,7 @@ impl Tooltip { cx: &mut App, ) -> AnyView { cx.new(|cx| Self { - title: title.into(), + title: title.into().into(), meta: Some(meta.into()), key_binding: action .and_then(|action| KeyBinding::for_action_in(action, focus_handle, window, cx)), @@ -131,12 +154,35 @@ impl Tooltip { pub fn new(title: impl Into<SharedString>) -> Self { Self { - title: title.into(), + title: title.into().into(), meta: None, key_binding: None, } } + pub fn new_element(title: impl Fn(&mut Window, &mut App) -> AnyElement + 'static) -> Self { + Self { + title: Title::Callback(Rc::new(title)), + meta: None, + key_binding: None, + } + } + + pub fn element( + title: impl Fn(&mut Window, &mut App) -> AnyElement + 'static, + ) -> impl Fn(&mut Window, &mut App) -> AnyView { + let title = Title::Callback(Rc::new(title)); + move |_, cx| { + let title = title.clone(); + cx.new(|_| Self { + title: title, + meta: None, + key_binding: None, + }) + .into() + } + } + pub fn meta(mut self, meta: impl Into<SharedString>) -> Self { self.meta = Some(meta.into()); self @@ -228,7 +274,7 @@ impl Render for LinkPreview { impl Component for Tooltip { fn scope() -> ComponentScope { - ComponentScope::None + ComponentScope::DataDisplay } fn description() -> Option<&'static str> { diff --git a/crates/ui/src/styles/animation.rs b/crates/ui/src/styles/animation.rs index 50c4e0eb0daf6d0868c5ab76db5374d695863f99..ee5352d45403183555fe8d6c72806a5b90f88ca8 100644 --- a/crates/ui/src/styles/animation.rs +++ b/crates/ui/src/styles/animation.rs @@ -99,7 +99,7 @@ struct Animation {} impl Component for Animation { fn scope() -> ComponentScope { - ComponentScope::None + ComponentScope::Utilities } fn description() -> Option<&'static str> { @@ -109,7 +109,7 @@ impl Component for Animation { fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> { let container_size = 128.0; let element_size = 32.0; - let left_offset = element_size - container_size / 2.0; + let offset = container_size / 2.0 - element_size / 2.0; Some( v_flex() .gap_6() @@ -129,7 +129,7 @@ impl Component for Animation { .id("animate-in-from-bottom") .absolute() .size(px(element_size)) - .left(px(left_offset)) + .left(px(offset)) .rounded_md() .bg(gpui::red()) .animate_in(AnimationDirection::FromBottom, false), @@ -148,7 +148,7 @@ impl Component for Animation { .id("animate-in-from-top") .absolute() .size(px(element_size)) - .left(px(left_offset)) + .left(px(offset)) .rounded_md() .bg(gpui::blue()) .animate_in(AnimationDirection::FromTop, false), @@ -167,7 +167,7 @@ impl Component for Animation { .id("animate-in-from-left") .absolute() .size(px(element_size)) - .left(px(left_offset)) + .top(px(offset)) .rounded_md() .bg(gpui::green()) .animate_in(AnimationDirection::FromLeft, false), @@ -186,7 +186,7 @@ impl Component for Animation { .id("animate-in-from-right") .absolute() .size(px(element_size)) - .left(px(left_offset)) + .top(px(offset)) .rounded_md() .bg(gpui::yellow()) .animate_in(AnimationDirection::FromRight, false), @@ -211,7 +211,7 @@ impl Component for Animation { .id("fade-animate-in-from-bottom") .absolute() .size(px(element_size)) - .left(px(left_offset)) + .left(px(offset)) .rounded_md() .bg(gpui::red()) .animate_in(AnimationDirection::FromBottom, true), @@ -230,7 +230,7 @@ impl Component for Animation { .id("fade-animate-in-from-top") .absolute() .size(px(element_size)) - .left(px(left_offset)) + .left(px(offset)) .rounded_md() .bg(gpui::blue()) .animate_in(AnimationDirection::FromTop, true), @@ -249,7 +249,7 @@ impl Component for Animation { .id("fade-animate-in-from-left") .absolute() .size(px(element_size)) - .left(px(left_offset)) + .top(px(offset)) .rounded_md() .bg(gpui::green()) .animate_in(AnimationDirection::FromLeft, true), @@ -268,7 +268,7 @@ impl Component for Animation { .id("fade-animate-in-from-right") .absolute() .size(px(element_size)) - .left(px(left_offset)) + .top(px(offset)) .rounded_md() .bg(gpui::yellow()) .animate_in(AnimationDirection::FromRight, true), diff --git a/crates/ui/src/styles/color.rs b/crates/ui/src/styles/color.rs index c7b995d39afc67d2441eeeae36832451ae071af7..586b2ccc576fc6321e95cd530d4703fa95b3c36f 100644 --- a/crates/ui/src/styles/color.rs +++ b/crates/ui/src/styles/color.rs @@ -126,7 +126,7 @@ impl From<Hsla> for Color { impl Component for Color { fn scope() -> ComponentScope { - ComponentScope::None + ComponentScope::Utilities } fn description() -> Option<&'static str> { diff --git a/crates/ui_input/src/ui_input.rs b/crates/ui_input/src/ui_input.rs index bd99814cb30534165ad2bfba3911233e2946271b..1a5bebaf1e952a02106bb05a2ec54055d361cb38 100644 --- a/crates/ui_input/src/ui_input.rs +++ b/crates/ui_input/src/ui_input.rs @@ -27,6 +27,8 @@ pub struct SingleLineInput { /// /// Its position is determined by the [`FieldLabelLayout`]. label: Option<SharedString>, + /// The size of the label text. + label_size: LabelSize, /// The placeholder text for the text field. placeholder: SharedString, /// Exposes the underlying [`Entity<Editor>`] to allow for customizing the editor beyond the provided API. @@ -59,6 +61,7 @@ impl SingleLineInput { Self { label: None, + label_size: LabelSize::Small, placeholder: placeholder_text, editor, start_icon: None, @@ -76,6 +79,11 @@ impl SingleLineInput { self } + pub fn label_size(mut self, size: LabelSize) -> Self { + self.label_size = size; + self + } + pub fn set_disabled(&mut self, disabled: bool, cx: &mut Context<Self>) { self.disabled = disabled; self.editor @@ -89,6 +97,10 @@ impl SingleLineInput { pub fn editor(&self) -> &Entity<Editor> { &self.editor } + + pub fn text(&self, cx: &App) -> String { + self.editor().read(cx).text(cx) + } } impl Render for SingleLineInput { @@ -127,6 +139,7 @@ impl Render for SingleLineInput { let editor_style = EditorStyle { background: theme_color.ghost_element_background, local_player: cx.theme().players().local(), + syntax: cx.theme().syntax().clone(), text: text_style, ..Default::default() }; @@ -138,7 +151,7 @@ impl Render for SingleLineInput { .when_some(self.label.clone(), |this, label| { this.child( Label::new(label) - .size(LabelSize::Small) + .size(self.label_size) .color(if self.disabled { Color::Disabled } else { @@ -148,16 +161,17 @@ impl Render for SingleLineInput { }) .child( h_flex() + .min_w_48() + .min_h_8() + .w_full() .px_2() .py_1p5() - .bg(style.background_color) + .flex_grow() .text_color(style.text_color) - .rounded_md() + .rounded_sm() + .bg(style.background_color) .border_1() .border_color(style.border_color) - .min_w_48() - .w_full() - .flex_grow() .when_some(self.start_icon, |this, icon| { this.gap_1() .child(Icon::new(icon).size(IconSize::Small).color(Color::Muted)) @@ -173,16 +187,28 @@ impl Component for SingleLineInput { } fn preview(window: &mut Window, cx: &mut App) -> Option<AnyElement> { - let input_1 = - cx.new(|cx| SingleLineInput::new(window, cx, "placeholder").label("Some Label")); + let input_small = + cx.new(|cx| SingleLineInput::new(window, cx, "placeholder").label("Small Label")); + + let input_regular = cx.new(|cx| { + SingleLineInput::new(window, cx, "placeholder") + .label("Regular Label") + .label_size(LabelSize::Default) + }); Some( v_flex() .gap_6() - .children(vec![example_group(vec![single_example( - "Default", - div().child(input_1.clone()).into_any_element(), - )])]) + .children(vec![example_group(vec![ + single_example( + "Small Label (Default)", + div().child(input_small.clone()).into_any_element(), + ), + single_example( + "Regular Label", + div().child(input_regular.clone()).into_any_element(), + ), + ])]) .into_any_element(), ) } diff --git a/crates/ui_prompt/src/ui_prompt.rs b/crates/ui_prompt/src/ui_prompt.rs index 2b6a030f26e752401a56a61a3f6a0a881bb89557..fe6dc5b3f4afc6d2d0097292555baaea4f077642 100644 --- a/crates/ui_prompt/src/ui_prompt.rs +++ b/crates/ui_prompt/src/ui_prompt.rs @@ -43,7 +43,7 @@ fn zed_prompt_renderer( let renderer = cx.new({ |cx| ZedPromptRenderer { _level: level, - message: message.to_string(), + message: cx.new(|cx| Markdown::new(SharedString::new(message), None, None, cx)), actions: actions.iter().map(|a| a.label().to_string()).collect(), focus: cx.focus_handle(), active_action_id: 0, @@ -58,7 +58,7 @@ fn zed_prompt_renderer( pub struct ZedPromptRenderer { _level: PromptLevel, - message: String, + message: Entity<Markdown>, actions: Vec<String>, focus: FocusHandle, active_action_id: usize, @@ -114,7 +114,7 @@ impl ZedPromptRenderer { impl Render for ZedPromptRenderer { fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { let settings = ThemeSettings::get_global(cx); - let font_family = settings.ui_font.family.clone(); + let font_size = settings.ui_font_size(cx).into(); let prompt = v_flex() .key_context("Prompt") .cursor_default() @@ -130,24 +130,38 @@ impl Render for ZedPromptRenderer { .overflow_hidden() .p_4() .gap_4() - .font_family(font_family) + .font_family(settings.ui_font.family.clone()) .child( div() .w_full() - .font_weight(FontWeight::BOLD) - .child(self.message.clone()) - .text_color(ui::Color::Default.color(cx)), + .child(MarkdownElement::new(self.message.clone(), { + let mut base_text_style = window.text_style(); + base_text_style.refine(&TextStyleRefinement { + font_family: Some(settings.ui_font.family.clone()), + font_size: Some(font_size), + font_weight: Some(FontWeight::BOLD), + color: Some(ui::Color::Default.color(cx)), + ..Default::default() + }); + MarkdownStyle { + base_text_style, + selection_background_color: cx + .theme() + .colors() + .element_selection_background, + ..Default::default() + } + })), ) .children(self.detail.clone().map(|detail| { div() .w_full() .text_xs() .child(MarkdownElement::new(detail, { - let settings = ThemeSettings::get_global(cx); let mut base_text_style = window.text_style(); base_text_style.refine(&TextStyleRefinement { font_family: Some(settings.ui_font.family.clone()), - font_size: Some(settings.ui_font_size(cx).into()), + font_size: Some(font_size), color: Some(ui::Color::Muted.color(cx)), ..Default::default() }); @@ -176,24 +190,28 @@ impl Render for ZedPromptRenderer { }), )); - div().size_full().occlude().child( - div() - .size_full() - .absolute() - .top_0() - .left_0() - .flex() - .flex_col() - .justify_around() - .child( - div() - .w_full() - .flex() - .flex_row() - .justify_around() - .child(prompt), - ), - ) + div() + .size_full() + .occlude() + .bg(gpui::black().opacity(0.2)) + .child( + div() + .size_full() + .absolute() + .top_0() + .left_0() + .flex() + .flex_col() + .justify_around() + .child( + div() + .w_full() + .flex() + .flex_row() + .justify_around() + .child(prompt), + ), + ) } } diff --git a/crates/util/src/archive.rs b/crates/util/src/archive.rs index d10b9967163978edd2b5cb7cc0a6d8c06d61429a..3e4d281c29902b682d14886431bc9387baf9cee3 100644 --- a/crates/util/src/archive.rs +++ b/crates/util/src/archive.rs @@ -2,6 +2,8 @@ use std::path::Path; use anyhow::{Context as _, Result}; use async_zip::base::read; +#[cfg(not(windows))] +use futures::AsyncSeek; use futures::{AsyncRead, io::BufReader}; #[cfg(windows)] @@ -62,7 +64,15 @@ pub async fn extract_zip<R: AsyncRead + Unpin>(destination: &Path, reader: R) -> futures::io::copy(&mut BufReader::new(reader), &mut file) .await .context("saving archive contents into the temporary file")?; - let mut reader = read::seek::ZipFileReader::new(BufReader::new(file)) + extract_seekable_zip(destination, file).await +} + +#[cfg(not(windows))] +pub async fn extract_seekable_zip<R: AsyncRead + AsyncSeek + Unpin>( + destination: &Path, + reader: R, +) -> Result<()> { + let mut reader = read::seek::ZipFileReader::new(BufReader::new(reader)) .await .context("reading the zip archive")?; let destination = &destination diff --git a/crates/util/src/fs.rs b/crates/util/src/fs.rs index 2738b6e213c10c03bd2271fd0138074d0eaea07a..3e96594f85caf801aa80d18ee423f3d94e23b426 100644 --- a/crates/util/src/fs.rs +++ b/crates/util/src/fs.rs @@ -95,9 +95,9 @@ pub async fn move_folder_files_to_folder<P: AsRef<Path>>( #[cfg(unix)] /// Set the permissions for the given path so that the file becomes executable. /// This is a noop for non-unix platforms. -pub async fn make_file_executable(path: &PathBuf) -> std::io::Result<()> { +pub async fn make_file_executable(path: &Path) -> std::io::Result<()> { fs::set_permissions( - &path, + path, <fs::Permissions as fs::unix::PermissionsExt>::from_mode(0o755), ) .await @@ -107,6 +107,6 @@ pub async fn make_file_executable(path: &PathBuf) -> std::io::Result<()> { #[allow(clippy::unused_async)] /// Set the permissions for the given path so that the file becomes executable. /// This is a noop for non-unix platforms. -pub async fn make_file_executable(_path: &PathBuf) -> std::io::Result<()> { +pub async fn make_file_executable(_path: &Path) -> std::io::Result<()> { Ok(()) } diff --git a/crates/util/src/redact.rs b/crates/util/src/redact.rs index 0b979fb4132f3e39381b75ddcaa85505c62de530..6b297dfb58bb0b4537d4032d8f9cf4db845f9d78 100644 --- a/crates/util/src/redact.rs +++ b/crates/util/src/redact.rs @@ -1,7 +1,14 @@ /// Whether a given environment variable name should have its value redacted pub fn should_redact(env_var_name: &str) -> bool { - const REDACTED_SUFFIXES: &[&str] = - &["KEY", "TOKEN", "PASSWORD", "SECRET", "PASS", "CREDENTIALS"]; + const REDACTED_SUFFIXES: &[&str] = &[ + "KEY", + "TOKEN", + "PASSWORD", + "SECRET", + "PASS", + "CREDENTIALS", + "LICENSE", + ]; REDACTED_SUFFIXES .iter() .any(|suffix| env_var_name.ends_with(suffix)) diff --git a/crates/util/src/schemars.rs b/crates/util/src/schemars.rs index 4d8ab530dd6beb3cf3c448256ff4bde89f9de8f7..e162b41933117eb603d36601aaf4b87b0e3d1d85 100644 --- a/crates/util/src/schemars.rs +++ b/crates/util/src/schemars.rs @@ -15,7 +15,6 @@ pub fn replace_subschema<T: JsonSchema>( generator: &mut schemars::SchemaGenerator, schema: impl Fn() -> schemars::Schema, ) -> schemars::Schema { - // fallback on just using the schema name, which could collide. let schema_name = T::schema_name(); let definitions = generator.definitions_mut(); assert!(!definitions.contains_key(&format!("{schema_name}2"))); diff --git a/crates/util/src/shell_env.rs b/crates/util/src/shell_env.rs index 21f6096f19fa0c89bf4516b122878be04361ddcd..2b1063316fa1d08ba3fa6e4c945b30175ff2cfdc 100644 --- a/crates/util/src/shell_env.rs +++ b/crates/util/src/shell_env.rs @@ -18,15 +18,19 @@ pub fn capture(directory: &std::path::Path) -> Result<collections::HashMap<Strin // In some shells, file descriptors greater than 2 cannot be used in interactive mode, // so file descriptor 0 (stdin) is used instead. This impacts zsh, old bash; perhaps others. // See: https://github.com/zed-industries/zed/pull/32136#issuecomment-2999645482 - const ENV_OUTPUT_FD: std::os::fd::RawFd = 0; - let redir = match shell_name { - Some("rc") => format!(">[1={}]", ENV_OUTPUT_FD), // `[1=0]` - _ => format!(">&{}", ENV_OUTPUT_FD), // `>&0` + const FD_STDIN: std::os::fd::RawFd = 0; + const FD_STDOUT: std::os::fd::RawFd = 1; + + let (fd_num, redir) = match shell_name { + Some("rc") => (FD_STDIN, format!(">[1={}]", FD_STDIN)), // `[1=0]` + Some("nu") | Some("tcsh") => (FD_STDOUT, "".to_string()), + _ => (FD_STDIN, format!(">&{}", FD_STDIN)), // `>&0` }; command.stdin(Stdio::null()); command.stdout(Stdio::piped()); command.stderr(Stdio::piped()); + let mut command_prefix = String::new(); match shell_name { Some("tcsh" | "csh") => { // For csh/tcsh, login shell requires passing `-` as 0th argument (instead of `-l`) @@ -37,18 +41,25 @@ pub fn capture(directory: &std::path::Path) -> Result<collections::HashMap<Strin command_string.push_str("emit fish_prompt;"); command.arg("-l"); } + Some("nu") => { + // nu needs special handling for -- options. + command_prefix = String::from("^"); + } _ => { command.arg("-l"); } } // cd into the directory, triggering directory specific side-effects (asdf, direnv, etc) command_string.push_str(&format!("cd '{}';", directory.display())); - command_string.push_str(&format!("{} --printenv {}", zed_path, redir)); + command_string.push_str(&format!( + "{}{} --printenv {}", + command_prefix, zed_path, redir + )); command.args(["-i", "-c", &command_string]); super::set_pre_exec_to_start_new_session(&mut command); - let (env_output, process_output) = spawn_and_read_fd(command, ENV_OUTPUT_FD)?; + let (env_output, process_output) = spawn_and_read_fd(command, fd_num)?; let env_output = String::from_utf8_lossy(&env_output); anyhow::ensure!( diff --git a/crates/vim/src/command.rs b/crates/vim/src/command.rs index c001f55a41c9d488240cb59fcc70ba111cca988b..7963db35712a22395a49e0a2767b7d24edba7654 100644 --- a/crates/vim/src/command.rs +++ b/crates/vim/src/command.rs @@ -6,7 +6,7 @@ use editor::{ actions::{SortLinesCaseInsensitive, SortLinesCaseSensitive}, display_map::ToDisplayPoint, }; -use gpui::{Action, App, AppContext as _, Context, Global, Window, actions}; +use gpui::{Action, App, AppContext as _, Context, Global, Keystroke, Window, actions}; use itertools::Itertools; use language::Point; use multi_buffer::MultiBufferRow; @@ -202,6 +202,7 @@ actions!( ArgumentRequired ] ); + /// Opens the specified file for editing. #[derive(Clone, PartialEq, Action)] #[action(namespace = vim, no_json, no_register)] @@ -209,6 +210,13 @@ struct VimEdit { pub filename: String, } +#[derive(Clone, PartialEq, Action)] +#[action(namespace = vim, no_json, no_register)] +struct VimNorm { + pub range: Option<CommandRange>, + pub command: String, +} + #[derive(Debug)] struct WrappedAction(Box<dyn Action>); @@ -447,6 +455,81 @@ pub fn register(editor: &mut Editor, cx: &mut Context<Vim>) { }); }); + Vim::action(editor, cx, |vim, action: &VimNorm, window, cx| { + let keystrokes = action + .command + .chars() + .map(|c| Keystroke::parse(&c.to_string()).unwrap()) + .collect(); + vim.switch_mode(Mode::Normal, true, window, cx); + let initial_selections = vim.update_editor(window, cx, |_, editor, _, _| { + editor.selections.disjoint_anchors() + }); + if let Some(range) = &action.range { + let result = vim.update_editor(window, cx, |vim, editor, window, cx| { + let range = range.buffer_range(vim, editor, window, cx)?; + editor.change_selections( + SelectionEffects::no_scroll().nav_history(false), + window, + cx, + |s| { + s.select_ranges( + (range.start.0..=range.end.0) + .map(|line| Point::new(line, 0)..Point::new(line, 0)), + ); + }, + ); + anyhow::Ok(()) + }); + if let Some(Err(err)) = result { + log::error!("Error selecting range: {}", err); + return; + } + }; + + let Some(workspace) = vim.workspace(window) else { + return; + }; + let task = workspace.update(cx, |workspace, cx| { + workspace.send_keystrokes_impl(keystrokes, window, cx) + }); + let had_range = action.range.is_some(); + + cx.spawn_in(window, async move |vim, cx| { + task.await; + vim.update_in(cx, |vim, window, cx| { + vim.update_editor(window, cx, |_, editor, window, cx| { + if had_range { + editor.change_selections(SelectionEffects::default(), window, cx, |s| { + s.select_anchor_ranges([s.newest_anchor().range()]); + }) + } + }); + if matches!(vim.mode, Mode::Insert | Mode::Replace) { + vim.normal_before(&Default::default(), window, cx); + } else { + vim.switch_mode(Mode::Normal, true, window, cx); + } + vim.update_editor(window, cx, |_, editor, _, cx| { + if let Some(first_sel) = initial_selections { + if let Some(tx_id) = editor + .buffer() + .update(cx, |multi, cx| multi.last_transaction_id(cx)) + { + let last_sel = editor.selections.disjoint_anchors(); + editor.modify_transaction_selection_history(tx_id, |old| { + old.0 = first_sel; + old.1 = Some(last_sel); + }); + } + } + }); + }) + .ok(); + }) + .detach(); + }); + Vim::action(editor, cx, |vim, _: &CountCommand, window, cx| { let Some(workspace) = vim.workspace(window) else { return; @@ -675,14 +758,15 @@ impl VimCommand { } else { return None; }; - if !args.is_empty() { + + let action = if args.is_empty() { + action + } else { // if command does not accept args and we have args then we should do no action - if let Some(args_fn) = &self.args { - args_fn.deref()(action, args) - } else { - None - } - } else if let Some(range) = range { + self.args.as_ref()?(action, args)? + }; + + if let Some(range) = range { self.range.as_ref().and_then(|f| f(action, range)) } else { Some(action) @@ -1061,6 +1145,27 @@ fn generate_commands(_: &App) -> Vec<VimCommand> { save_intent: Some(SaveIntent::Skip), close_pinned: true, }), + VimCommand::new( + ("norm", "al"), + VimNorm { + command: "".into(), + range: None, + }, + ) + .args(|_, args| { + Some( + VimNorm { + command: args, + range: None, + } + .boxed_clone(), + ) + }) + .range(|action, range| { + let mut action: VimNorm = action.as_any().downcast_ref::<VimNorm>().unwrap().clone(); + action.range.replace(range.clone()); + Some(Box::new(action)) + }), VimCommand::new(("bn", "ext"), workspace::ActivateNextItem).count(), VimCommand::new(("bN", "ext"), workspace::ActivatePreviousItem).count(), VimCommand::new(("bp", "revious"), workspace::ActivatePreviousItem).count(), @@ -1085,12 +1190,12 @@ fn generate_commands(_: &App) -> Vec<VimCommand> { ), VimCommand::new( ("tabo", "nly"), - workspace::CloseInactiveItems { + workspace::CloseOtherItems { save_intent: Some(SaveIntent::Close), close_pinned: false, }, ) - .bang(workspace::CloseInactiveItems { + .bang(workspace::CloseOtherItems { save_intent: Some(SaveIntent::Skip), close_pinned: false, }), @@ -1106,13 +1211,28 @@ fn generate_commands(_: &App) -> Vec<VimCommand> { VimCommand::str(("cl", "ist"), "diagnostics::Deploy"), VimCommand::new(("cc", ""), editor::actions::Hover), VimCommand::new(("ll", ""), editor::actions::Hover), - VimCommand::new(("cn", "ext"), editor::actions::GoToDiagnostic).range(wrap_count), - VimCommand::new(("cp", "revious"), editor::actions::GoToPreviousDiagnostic) + VimCommand::new(("cn", "ext"), editor::actions::GoToDiagnostic::default()) .range(wrap_count), - VimCommand::new(("cN", "ext"), editor::actions::GoToPreviousDiagnostic).range(wrap_count), - VimCommand::new(("lp", "revious"), editor::actions::GoToPreviousDiagnostic) - .range(wrap_count), - VimCommand::new(("lN", "ext"), editor::actions::GoToPreviousDiagnostic).range(wrap_count), + VimCommand::new( + ("cp", "revious"), + editor::actions::GoToPreviousDiagnostic::default(), + ) + .range(wrap_count), + VimCommand::new( + ("cN", "ext"), + editor::actions::GoToPreviousDiagnostic::default(), + ) + .range(wrap_count), + VimCommand::new( + ("lp", "revious"), + editor::actions::GoToPreviousDiagnostic::default(), + ) + .range(wrap_count), + VimCommand::new( + ("lN", "ext"), + editor::actions::GoToPreviousDiagnostic::default(), + ) + .range(wrap_count), VimCommand::new(("j", "oin"), JoinLines).range(select_range), VimCommand::new(("fo", "ld"), editor::actions::FoldSelectedRanges).range(act_on_range), VimCommand::new(("foldo", "pen"), editor::actions::UnfoldLines) @@ -2283,4 +2403,78 @@ mod test { }); assert!(mark.is_none()) } + + #[gpui::test] + async fn test_normal_command(cx: &mut TestAppContext) { + let mut cx = NeovimBackedTestContext::new(cx).await; + + cx.set_shared_state(indoc! {" + The quick + brown« fox + jumpsˇ» over + the lazy dog + "}) + .await; + + cx.simulate_shared_keystrokes(": n o r m space w C w o r d") + .await; + cx.simulate_shared_keystrokes("enter").await; + + cx.shared_state().await.assert_eq(indoc! {" + The quick + brown word + jumps worˇd + the lazy dog + "}); + + cx.simulate_shared_keystrokes(": n o r m space _ w c i w t e s t") + .await; + cx.simulate_shared_keystrokes("enter").await; + + cx.shared_state().await.assert_eq(indoc! {" + The quick + brown word + jumps tesˇt + the lazy dog + "}); + + cx.simulate_shared_keystrokes("_ l v l : n o r m space s l a") + .await; + cx.simulate_shared_keystrokes("enter").await; + + cx.shared_state().await.assert_eq(indoc! {" + The quick + brown word + lˇaumps test + the lazy dog + "}); + + cx.set_shared_state(indoc! {" + ˇThe quick + brown fox + jumps over + the lazy dog + "}) + .await; + + cx.simulate_shared_keystrokes("c i w M y escape").await; + + cx.shared_state().await.assert_eq(indoc! {" + Mˇy quick + brown fox + jumps over + the lazy dog + "}); + + cx.simulate_shared_keystrokes(": n o r m space u").await; + cx.simulate_shared_keystrokes("enter").await; + + cx.shared_state().await.assert_eq(indoc! {" + ˇThe quick + brown fox + jumps over + the lazy dog + "}); + // Once ctrl-v to input character literals is added there should be a test for redo + } } diff --git a/crates/vim/src/helix.rs b/crates/vim/src/helix.rs index ec9b959b1220939394956e22e8936141c74fae1b..ca93c9c1de0993f4627073d645b347fa9a875ca0 100644 --- a/crates/vim/src/helix.rs +++ b/crates/vim/src/helix.rs @@ -1,21 +1,31 @@ -use editor::{DisplayPoint, Editor, movement}; +use editor::{DisplayPoint, Editor, SelectionEffects, ToOffset, ToPoint, movement}; use gpui::{Action, actions}; use gpui::{Context, Window}; use language::{CharClassifier, CharKind}; -use text::SelectionGoal; +use text::{Bias, SelectionGoal}; -use crate::{Vim, motion::Motion, state::Mode}; +use crate::{ + Vim, + motion::{Motion, right}, + state::Mode, +}; actions!( vim, [ /// Switches to normal mode after the cursor (Helix-style). - HelixNormalAfter + HelixNormalAfter, + /// Inserts at the beginning of the selection. + HelixInsert, + /// Appends at the end of the selection. + HelixAppend, ] ); pub fn register(editor: &mut Editor, cx: &mut Context<Vim>) { Vim::action(editor, cx, Vim::helix_normal_after); + Vim::action(editor, cx, Vim::helix_insert); + Vim::action(editor, cx, Vim::helix_append); } impl Vim { @@ -299,6 +309,112 @@ impl Vim { _ => self.helix_move_and_collapse(motion, times, window, cx), } } + + fn helix_insert(&mut self, _: &HelixInsert, window: &mut Window, cx: &mut Context<Self>) { + self.start_recording(cx); + self.update_editor(window, cx, |_, editor, window, cx| { + editor.change_selections(Default::default(), window, cx, |s| { + s.move_with(|_map, selection| { + // In helix normal mode, move cursor to start of selection and collapse + if !selection.is_empty() { + selection.collapse_to(selection.start, SelectionGoal::None); + } + }); + }); + }); + self.switch_mode(Mode::Insert, false, window, cx); + } + + fn helix_append(&mut self, _: &HelixAppend, window: &mut Window, cx: &mut Context<Self>) { + self.start_recording(cx); + self.switch_mode(Mode::Insert, false, window, cx); + self.update_editor(window, cx, |_, editor, window, cx| { + editor.change_selections(Default::default(), window, cx, |s| { + s.move_with(|map, selection| { + let point = if selection.is_empty() { + right(map, selection.head(), 1) + } else { + selection.end + }; + selection.collapse_to(point, SelectionGoal::None); + }); + }); + }); + } + + pub fn helix_replace(&mut self, text: &str, window: &mut Window, cx: &mut Context<Self>) { + self.update_editor(window, cx, |_, editor, window, cx| { + editor.transact(window, cx, |editor, window, cx| { + let (map, selections) = editor.selections.all_display(cx); + + // Store selection info for positioning after edit + let selection_info: Vec<_> = selections + .iter() + .map(|selection| { + let range = selection.range(); + let start_offset = range.start.to_offset(&map, Bias::Left); + let end_offset = range.end.to_offset(&map, Bias::Left); + let was_empty = range.is_empty(); + let was_reversed = selection.reversed; + ( + map.buffer_snapshot.anchor_at(start_offset, Bias::Left), + end_offset - start_offset, + was_empty, + was_reversed, + ) + }) + .collect(); + + let mut edits = Vec::new(); + for selection in &selections { + let mut range = selection.range(); + + // For empty selections, extend to replace one character + if range.is_empty() { + range.end = movement::saturating_right(&map, range.start); + } + + let byte_range = range.start.to_offset(&map, Bias::Left) + ..range.end.to_offset(&map, Bias::Left); + + if !byte_range.is_empty() { + let replacement_text = text.repeat(byte_range.len()); + edits.push((byte_range, replacement_text)); + } + } + + editor.edit(edits, cx); + + // Restore selections based on original info + let snapshot = editor.buffer().read(cx).snapshot(cx); + let ranges: Vec<_> = selection_info + .into_iter() + .map(|(start_anchor, original_len, was_empty, was_reversed)| { + let start_point = start_anchor.to_point(&snapshot); + if was_empty { + // For cursor-only, collapse to start + start_point..start_point + } else { + // For selections, span the replaced text + let replacement_len = text.len() * original_len; + let end_offset = start_anchor.to_offset(&snapshot) + replacement_len; + let end_point = snapshot.offset_to_point(end_offset); + if was_reversed { + end_point..start_point + } else { + start_point..end_point + } + } + }) + .collect(); + + editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { + s.select_ranges(ranges); + }); + }); + }); + self.switch_mode(Mode::HelixNormal, true, window, cx); + } } #[cfg(test)] @@ -497,4 +613,94 @@ mod test { cx.assert_state("«ˇaa»\n", Mode::HelixNormal); } + + #[gpui::test] + async fn test_insert_selected(cx: &mut gpui::TestAppContext) { + let mut cx = VimTestContext::new(cx, true).await; + cx.set_state( + indoc! {" + «The ˇ»quick brown + fox jumps over + the lazy dog."}, + Mode::HelixNormal, + ); + + cx.simulate_keystrokes("i"); + + cx.assert_state( + indoc! {" + ˇThe quick brown + fox jumps over + the lazy dog."}, + Mode::Insert, + ); + } + + #[gpui::test] + async fn test_append(cx: &mut gpui::TestAppContext) { + let mut cx = VimTestContext::new(cx, true).await; + // test from the end of the selection + cx.set_state( + indoc! {" + «Theˇ» quick brown + fox jumps over + the lazy dog."}, + Mode::HelixNormal, + ); + + cx.simulate_keystrokes("a"); + + cx.assert_state( + indoc! {" + Theˇ quick brown + fox jumps over + the lazy dog."}, + Mode::Insert, + ); + + // test from the beginning of the selection + cx.set_state( + indoc! {" + «ˇThe» quick brown + fox jumps over + the lazy dog."}, + Mode::HelixNormal, + ); + + cx.simulate_keystrokes("a"); + + cx.assert_state( + indoc! {" + Theˇ quick brown + fox jumps over + the lazy dog."}, + Mode::Insert, + ); + } + + #[gpui::test] + async fn test_replace(cx: &mut gpui::TestAppContext) { + let mut cx = VimTestContext::new(cx, true).await; + + // No selection (single character) + cx.set_state("ˇaa", Mode::HelixNormal); + + cx.simulate_keystrokes("r x"); + + cx.assert_state("ˇxa", Mode::HelixNormal); + + // Cursor at the beginning + cx.set_state("«ˇaa»", Mode::HelixNormal); + + cx.simulate_keystrokes("r x"); + + cx.assert_state("«ˇxx»", Mode::HelixNormal); + + // Cursor at the end + cx.set_state("«aaˇ»", Mode::HelixNormal); + + cx.simulate_keystrokes("r x"); + + cx.assert_state("«xxˇ»", Mode::HelixNormal); + } } diff --git a/crates/vim/src/insert.rs b/crates/vim/src/insert.rs index 89c60adee7f7c2a92b9f5c7d671cbcfac7045843..0a370e16ba418ae04cdfe47e1ccbdb3904b6af45 100644 --- a/crates/vim/src/insert.rs +++ b/crates/vim/src/insert.rs @@ -21,7 +21,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context<Vim>) { } impl Vim { - fn normal_before( + pub(crate) fn normal_before( &mut self, action: &NormalBefore, window: &mut Window, diff --git a/crates/vim/src/motion.rs b/crates/vim/src/motion.rs index a50b238cc5c6591f163f2fb89ef0a2cdf145d23c..0e487f44104fc7307e5a606792e59d9c49e91963 100644 --- a/crates/vim/src/motion.rs +++ b/crates/vim/src/motion.rs @@ -987,7 +987,7 @@ impl Motion { SelectionGoal::None, ), NextWordEnd { ignore_punctuation } => ( - next_word_end(map, point, *ignore_punctuation, times, true), + next_word_end(map, point, *ignore_punctuation, times, true, true), SelectionGoal::None, ), PreviousWordStart { ignore_punctuation } => ( @@ -1723,14 +1723,19 @@ pub(crate) fn next_word_end( ignore_punctuation: bool, times: usize, allow_cross_newline: bool, + always_advance: bool, ) -> DisplayPoint { let classifier = map .buffer_snapshot .char_classifier_at(point.to_point(map)) .ignore_punctuation(ignore_punctuation); for _ in 0..times { - let new_point = next_char(map, point, allow_cross_newline); let mut need_next_char = false; + let new_point = if always_advance { + next_char(map, point, allow_cross_newline) + } else { + point + }; let new_point = movement::find_boundary_exclusive( map, new_point, @@ -3803,7 +3808,7 @@ mod test { cx.update_editor(|editor, _window, cx| { let range = editor.selections.newest_anchor().range(); let inlay_text = " field: int,\n field2: string\n field3: float"; - let inlay = Inlay::inline_completion(1, range.start, inlay_text); + let inlay = Inlay::edit_prediction(1, range.start, inlay_text); editor.splice_inlays(&[], vec![inlay], cx); }); @@ -3835,7 +3840,7 @@ mod test { let end_of_line = snapshot.anchor_after(Point::new(0, snapshot.line_len(MultiBufferRow(0)))); let inlay_text = " hint"; - let inlay = Inlay::inline_completion(1, end_of_line, inlay_text); + let inlay = Inlay::edit_prediction(1, end_of_line, inlay_text); editor.splice_inlays(&[], vec![inlay], cx); }); cx.simulate_keystrokes("$"); diff --git a/crates/vim/src/normal.rs b/crates/vim/src/normal.rs index 6131032f4fab7b6ac7f3d2965413464317e55490..13128e7b403ab0e921e7f323d03d83e187d1556d 100644 --- a/crates/vim/src/normal.rs +++ b/crates/vim/src/normal.rs @@ -64,6 +64,8 @@ actions!( DeleteRight, /// Deletes using Helix-style behavior. HelixDelete, + /// Collapse the current selection + HelixCollapseSelection, /// Changes from cursor to end of line. ChangeToEndOfLine, /// Deletes from cursor to end of line. @@ -143,6 +145,20 @@ pub(crate) fn register(editor: &mut Editor, cx: &mut Context<Vim>) { vim.switch_mode(Mode::HelixNormal, true, window, cx); }); + Vim::action(editor, cx, |vim, _: &HelixCollapseSelection, window, cx| { + vim.update_editor(window, cx, |_, editor, window, cx| { + editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { + s.move_with(|map, selection| { + let mut point = selection.head(); + if !selection.reversed && !selection.is_empty() { + point = movement::left(map, selection.head()); + } + selection.collapse_to(point, selection.goal) + }); + }); + }); + }); + Vim::action(editor, cx, |vim, _: &ChangeToEndOfLine, window, cx| { vim.start_recording(cx); let times = Vim::take_count(cx); diff --git a/crates/vim/src/normal/change.rs b/crates/vim/src/normal/change.rs index 9485f174771cd1f21f1513e9609008dce8479b14..c1bc7a70ae1830c2181f68f70d266c1998c8e1ef 100644 --- a/crates/vim/src/normal/change.rs +++ b/crates/vim/src/normal/change.rs @@ -51,6 +51,7 @@ impl Vim { ignore_punctuation, &text_layout_details, motion == Motion::NextSubwordStart { ignore_punctuation }, + !matches!(motion, Motion::NextWordStart { .. }), ) } _ => { @@ -89,7 +90,7 @@ impl Vim { if let Some(kind) = motion_kind { vim.copy_selections_content(editor, kind, window, cx); editor.insert("", window, cx); - editor.refresh_inline_completion(true, false, window, cx); + editor.refresh_edit_prediction(true, false, window, cx); } }); }); @@ -122,7 +123,7 @@ impl Vim { if objects_found { vim.copy_selections_content(editor, MotionKind::Exclusive, window, cx); editor.insert("", window, cx); - editor.refresh_inline_completion(true, false, window, cx); + editor.refresh_edit_prediction(true, false, window, cx); } }); }); @@ -148,6 +149,7 @@ fn expand_changed_word_selection( ignore_punctuation: bool, text_layout_details: &TextLayoutDetails, use_subword: bool, + always_advance: bool, ) -> Option<MotionKind> { let is_in_word = || { let classifier = map @@ -173,8 +175,14 @@ fn expand_changed_word_selection( selection.end = motion::next_subword_end(map, selection.end, ignore_punctuation, 1, false); } else { - selection.end = - motion::next_word_end(map, selection.end, ignore_punctuation, 1, false); + selection.end = motion::next_word_end( + map, + selection.end, + ignore_punctuation, + 1, + false, + always_advance, + ); } selection.end = motion::next_char(map, selection.end, false); } @@ -271,6 +279,10 @@ mod test { cx.simulate("c shift-w", "Test teˇst-test test") .await .assert_matches(); + + // on last character of word, `cw` doesn't eat subsequent punctuation + // see https://github.com/zed-industries/zed/issues/35269 + cx.simulate("c w", "tesˇt-test").await.assert_matches(); } #[gpui::test] diff --git a/crates/vim/src/normal/delete.rs b/crates/vim/src/normal/delete.rs index ccbb3dd0fd901b515258a34bb9377063e2a84cbd..2cf40292cf5ccf765f830302f3cfb2617728c769 100644 --- a/crates/vim/src/normal/delete.rs +++ b/crates/vim/src/normal/delete.rs @@ -82,7 +82,7 @@ impl Vim { selection.collapse_to(cursor, selection.goal) }); }); - editor.refresh_inline_completion(true, false, window, cx); + editor.refresh_edit_prediction(true, false, window, cx); }); }); } @@ -169,7 +169,7 @@ impl Vim { selection.collapse_to(cursor, selection.goal) }); }); - editor.refresh_inline_completion(true, false, window, cx); + editor.refresh_edit_prediction(true, false, window, cx); }); }); } diff --git a/crates/vim/src/test.rs b/crates/vim/src/test.rs index 2db1d4a20cb7c4162ca2e795f880ece500d88e0f..ce04b621cb91c7b6b7da57bd1e1b74e9c0e00bbc 100644 --- a/crates/vim/src/test.rs +++ b/crates/vim/src/test.rs @@ -1006,8 +1006,6 @@ async fn test_rename(cx: &mut gpui::TestAppContext) { cx.assert_state("const afterˇ = 2; console.log(after)", Mode::Normal) } -// TODO: this test is flaky on our linux CI machines -#[cfg(target_os = "macos")] #[gpui::test] async fn test_remap(cx: &mut gpui::TestAppContext) { let mut cx = VimTestContext::new(cx, true).await; @@ -1048,8 +1046,6 @@ async fn test_remap(cx: &mut gpui::TestAppContext) { cx.simulate_keystrokes("g x"); cx.assert_state("1234fooˇ56789", Mode::Normal); - cx.executor().allow_parking(); - // test command cx.update(|_, cx| { cx.bind_keys([KeyBinding::new( diff --git a/crates/vim/src/vim.rs b/crates/vim/src/vim.rs index 95a08d7c66a49b0ca3f0f1d40ecc378276fbf131..72edbe77ed1e8f572866c6f5229477ca4f594c36 100644 --- a/crates/vim/src/vim.rs +++ b/crates/vim/src/vim.rs @@ -747,7 +747,7 @@ impl Vim { Vim::action( editor, cx, - |vim, action: &editor::AcceptEditPrediction, window, cx| { + |vim, action: &editor::actions::AcceptEditPrediction, window, cx| { vim.update_editor(window, cx, |_, editor, window, cx| { editor.accept_edit_prediction(action, window, cx); }); @@ -1639,6 +1639,7 @@ impl Vim { Mode::Visual | Mode::VisualLine | Mode::VisualBlock => { self.visual_replace(text, window, cx) } + Mode::HelixNormal => self.helix_replace(&text, window, cx), _ => self.clear_operator(window, cx), }, Some(Operator::Digraph { first_char }) => { @@ -1740,11 +1741,11 @@ impl Vim { editor.set_autoindent(vim.should_autoindent()); editor.selections.line_mode = matches!(vim.mode, Mode::VisualLine); - let hide_inline_completions = match vim.mode { + let hide_edit_predictions = match vim.mode { Mode::Insert | Mode::Replace => false, _ => true, }; - editor.set_inline_completions_hidden_for_vim_mode(hide_inline_completions, window, cx); + editor.set_edit_predictions_hidden_for_vim_mode(hide_edit_predictions, window, cx); }); cx.notify() } diff --git a/crates/vim/test_data/test_change_w.json b/crates/vim/test_data/test_change_w.json index 27be5435327013a08d1a7d61b496ff4d2b0864eb..149dac842093aa943382c20c3ab3d51c1e3d748e 100644 --- a/crates/vim/test_data/test_change_w.json +++ b/crates/vim/test_data/test_change_w.json @@ -30,3 +30,7 @@ {"Key":"c"} {"Key":"shift-w"} {"Get":{"state":"Test teˇ test","mode":"Insert"}} +{"Put":{"state":"tesˇt-test"}} +{"Key":"c"} +{"Key":"w"} +{"Get":{"state":"tesˇ-test","mode":"Insert"}} diff --git a/crates/vim/test_data/test_normal_command.json b/crates/vim/test_data/test_normal_command.json new file mode 100644 index 0000000000000000000000000000000000000000..efd1d532c4261976a5e1ef00e85fdac9b2b90fab --- /dev/null +++ b/crates/vim/test_data/test_normal_command.json @@ -0,0 +1,64 @@ +{"Put":{"state":"The quick\nbrown« fox\njumpsˇ» over\nthe lazy dog\n"}} +{"Key":":"} +{"Key":"n"} +{"Key":"o"} +{"Key":"r"} +{"Key":"m"} +{"Key":"space"} +{"Key":"w"} +{"Key":"C"} +{"Key":"w"} +{"Key":"o"} +{"Key":"r"} +{"Key":"d"} +{"Key":"enter"} +{"Get":{"state":"The quick\nbrown word\njumps worˇd\nthe lazy dog\n","mode":"Normal"}} +{"Key":":"} +{"Key":"n"} +{"Key":"o"} +{"Key":"r"} +{"Key":"m"} +{"Key":"space"} +{"Key":"_"} +{"Key":"w"} +{"Key":"c"} +{"Key":"i"} +{"Key":"w"} +{"Key":"t"} +{"Key":"e"} +{"Key":"s"} +{"Key":"t"} +{"Key":"enter"} +{"Get":{"state":"The quick\nbrown word\njumps tesˇt\nthe lazy dog\n","mode":"Normal"}} +{"Key":"_"} +{"Key":"l"} +{"Key":"v"} +{"Key":"l"} +{"Key":":"} +{"Key":"n"} +{"Key":"o"} +{"Key":"r"} +{"Key":"m"} +{"Key":"space"} +{"Key":"s"} +{"Key":"l"} +{"Key":"a"} +{"Key":"enter"} +{"Get":{"state":"The quick\nbrown word\nlˇaumps test\nthe lazy dog\n","mode":"Normal"}} +{"Put":{"state":"ˇThe quick\nbrown fox\njumps over\nthe lazy dog\n"}} +{"Key":"c"} +{"Key":"i"} +{"Key":"w"} +{"Key":"M"} +{"Key":"y"} +{"Key":"escape"} +{"Get":{"state":"Mˇy quick\nbrown fox\njumps over\nthe lazy dog\n","mode":"Normal"}} +{"Key":":"} +{"Key":"n"} +{"Key":"o"} +{"Key":"r"} +{"Key":"m"} +{"Key":"space"} +{"Key":"u"} +{"Key":"enter"} +{"Get":{"state":"ˇThe quick\nbrown fox\njumps over\nthe lazy dog\n","mode":"Normal"}} diff --git a/crates/web_search/Cargo.toml b/crates/web_search/Cargo.toml index e5b8ca63b25ef5a0eb030a58198492b12ff68470..4ba46faec4362ac98fffaffb6c606608c02373e8 100644 --- a/crates/web_search/Cargo.toml +++ b/crates/web_search/Cargo.toml @@ -13,8 +13,8 @@ path = "src/web_search.rs" [dependencies] anyhow.workspace = true +cloud_llm_client.workspace = true collections.workspace = true gpui.workspace = true serde.workspace = true workspace-hack.workspace = true -zed_llm_client.workspace = true diff --git a/crates/web_search/src/web_search.rs b/crates/web_search/src/web_search.rs index a131b0de7166709679a8ebc8e7d64af35118eb22..8578cfe4aaab77fdc731a8dd49c62c5afd514600 100644 --- a/crates/web_search/src/web_search.rs +++ b/crates/web_search/src/web_search.rs @@ -1,8 +1,9 @@ +use std::sync::Arc; + use anyhow::Result; +use cloud_llm_client::WebSearchResponse; use collections::HashMap; use gpui::{App, AppContext as _, Context, Entity, Global, SharedString, Task}; -use std::sync::Arc; -use zed_llm_client::WebSearchResponse; pub fn init(cx: &mut App) { let registry = cx.new(|_cx| WebSearchRegistry::default()); diff --git a/crates/web_search_providers/Cargo.toml b/crates/web_search_providers/Cargo.toml index 208cb63593f0647970c1b576f1733740ee99196c..f7a248d10649dc83d7d76b454e8db2d37b55cbef 100644 --- a/crates/web_search_providers/Cargo.toml +++ b/crates/web_search_providers/Cargo.toml @@ -14,7 +14,7 @@ path = "src/web_search_providers.rs" [dependencies] anyhow.workspace = true client.workspace = true -feature_flags.workspace = true +cloud_llm_client.workspace = true futures.workspace = true gpui.workspace = true http_client.workspace = true @@ -23,4 +23,3 @@ serde.workspace = true serde_json.workspace = true web_search.workspace = true workspace-hack.workspace = true -zed_llm_client.workspace = true diff --git a/crates/web_search_providers/src/cloud.rs b/crates/web_search_providers/src/cloud.rs index 79ccf97e47aacdeaa1da0cc2f063b3937e4f955f..52ee0da0d46287c78164d4ff6cc3eb31e46167b1 100644 --- a/crates/web_search_providers/src/cloud.rs +++ b/crates/web_search_providers/src/cloud.rs @@ -2,13 +2,12 @@ use std::sync::Arc; use anyhow::{Context as _, Result}; use client::Client; -use feature_flags::{FeatureFlagAppExt as _, ZedCloudFeatureFlag}; +use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, WebSearchBody, WebSearchResponse}; use futures::AsyncReadExt as _; use gpui::{App, AppContext, Context, Entity, Subscription, Task}; use http_client::{HttpClient, Method}; use language_model::{LlmApiToken, RefreshLlmTokenListener}; use web_search::{WebSearchProvider, WebSearchProviderId}; -use zed_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, WebSearchBody, WebSearchResponse}; pub struct CloudWebSearchProvider { state: Entity<State>, @@ -63,10 +62,7 @@ impl WebSearchProvider for CloudWebSearchProvider { let client = state.client.clone(); let llm_api_token = state.llm_api_token.clone(); let body = WebSearchBody { query }; - let use_cloud = cx.has_flag::<ZedCloudFeatureFlag>(); - cx.background_spawn(async move { - perform_web_search(client, llm_api_token, body, use_cloud).await - }) + cx.background_spawn(async move { perform_web_search(client, llm_api_token, body).await }) } } @@ -74,7 +70,6 @@ async fn perform_web_search( client: Arc<Client>, llm_api_token: LlmApiToken, body: WebSearchBody, - use_cloud: bool, ) -> Result<WebSearchResponse> { const MAX_RETRIES: usize = 3; @@ -91,11 +86,7 @@ async fn perform_web_search( let request = http_client::Request::builder() .method(Method::POST) - .uri( - http_client - .build_zed_llm_url("/web_search", &[], use_cloud)? - .as_ref(), - ) + .uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref()) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {token}")) .body(serde_json::to_string(&body)?.into())?; diff --git a/crates/welcome/Cargo.toml b/crates/welcome/Cargo.toml index 769dd8d6aa1591b2bd9c62dcb8c9ad9e48ad457b..acb3fe0f84daab8fc0cf1dd3c200c93d7a44c36f 100644 --- a/crates/welcome/Cargo.toml +++ b/crates/welcome/Cargo.toml @@ -29,7 +29,6 @@ project.workspace = true serde.workspace = true settings.workspace = true telemetry.workspace = true -theme.workspace = true ui.workspace = true util.workspace = true vim_mode_setting.workspace = true diff --git a/crates/welcome/src/welcome.rs b/crates/welcome/src/welcome.rs index ea4ac13de7f41f4b46da6b465f1e22076a5bae7b..b0a1c316f4228492c56c1a234d08d69101e47456 100644 --- a/crates/welcome/src/welcome.rs +++ b/crates/welcome/src/welcome.rs @@ -5,6 +5,7 @@ use gpui::{ ParentElement, Render, Styled, Subscription, Task, WeakEntity, Window, actions, svg, }; use language::language_settings::{EditPredictionProvider, all_language_settings}; +use project::DisableAiSettings; use settings::{Settings, SettingsStore}; use std::sync::Arc; use ui::{CheckboxWithLabel, ElevationIndex, Tooltip, prelude::*}; @@ -21,7 +22,6 @@ pub use multibuffer_hint::*; mod base_keymap_picker; mod multibuffer_hint; -mod welcome_ui; actions!( welcome, @@ -174,23 +174,25 @@ impl Render for WelcomePage { .ok(); })), ) - .child( - Button::new( - "try-zed-edit-prediction", - edit_prediction_label, + .when(!DisableAiSettings::get_global(cx).disable_ai, |parent| { + parent.child( + Button::new( + "edit_prediction_onboarding", + edit_prediction_label, + ) + .disabled(edit_prediction_provider_is_zed) + .icon(IconName::ZedPredict) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .icon_position(IconPosition::Start) + .on_click( + cx.listener(|_, _, window, cx| { + telemetry::event!("Welcome Screen Try Edit Prediction clicked"); + window.dispatch_action(zed_actions::OpenZedPredictOnboarding.boxed_clone(), cx); + }), + ), ) - .disabled(edit_prediction_provider_is_zed) - .icon(IconName::ZedPredict) - .icon_size(IconSize::XSmall) - .icon_color(Color::Muted) - .icon_position(IconPosition::Start) - .on_click( - cx.listener(|_, _, window, cx| { - telemetry::event!("Welcome Screen Try Edit Prediction clicked"); - window.dispatch_action(zed_actions::OpenZedPredictOnboarding.boxed_clone(), cx); - }), - ), - ) + }) .child( Button::new("edit settings", "Edit Settings") .icon(IconName::Settings) diff --git a/crates/welcome/src/welcome_ui.rs b/crates/welcome/src/welcome_ui.rs deleted file mode 100644 index 622b6f448d01b99fd82763c963b683f83a738df9..0000000000000000000000000000000000000000 --- a/crates/welcome/src/welcome_ui.rs +++ /dev/null @@ -1 +0,0 @@ -mod theme_preview; diff --git a/crates/welcome/src/welcome_ui/theme_preview.rs b/crates/welcome/src/welcome_ui/theme_preview.rs deleted file mode 100644 index b3a80c74c3c9dfc1ff13e4c4bfa8b319a56eba1d..0000000000000000000000000000000000000000 --- a/crates/welcome/src/welcome_ui/theme_preview.rs +++ /dev/null @@ -1,280 +0,0 @@ -#![allow(unused, dead_code)] -use gpui::{Hsla, Length}; -use std::sync::Arc; -use theme::{Theme, ThemeRegistry}; -use ui::{ - IntoElement, RenderOnce, component_prelude::Documented, prelude::*, utils::inner_corner_radius, -}; - -/// Shows a preview of a theme as an abstract illustration -/// of a thumbnail-sized editor. -#[derive(IntoElement, RegisterComponent, Documented)] -pub struct ThemePreviewTile { - theme: Arc<Theme>, - selected: bool, - seed: f32, -} - -impl ThemePreviewTile { - pub fn new(theme: Arc<Theme>, selected: bool, seed: f32) -> Self { - Self { - theme, - selected, - seed, - } - } - - pub fn selected(mut self, selected: bool) -> Self { - self.selected = selected; - self - } -} - -impl RenderOnce for ThemePreviewTile { - fn render(self, _window: &mut ui::Window, _cx: &mut ui::App) -> impl IntoElement { - let color = self.theme.colors(); - - let root_radius = px(8.0); - let root_border = px(2.0); - let root_padding = px(2.0); - let child_border = px(1.0); - let inner_radius = - inner_corner_radius(root_radius, root_border, root_padding, child_border); - - let item_skeleton = |w: Length, h: Pixels, bg: Hsla| div().w(w).h(h).rounded_full().bg(bg); - - let skeleton_height = px(4.); - - let sidebar_seeded_width = |seed: f32, index: usize| { - let value = (seed * 1000.0 + index as f32 * 10.0).sin() * 0.5 + 0.5; - 0.5 + value * 0.45 - }; - - let sidebar_skeleton_items = 8; - - let sidebar_skeleton = (0..sidebar_skeleton_items) - .map(|i| { - let width = sidebar_seeded_width(self.seed, i); - item_skeleton( - relative(width).into(), - skeleton_height, - color.text.alpha(0.45), - ) - }) - .collect::<Vec<_>>(); - - let sidebar = div() - .h_full() - .w(relative(0.25)) - .border_r(px(1.)) - .border_color(color.border_transparent) - .bg(color.panel_background) - .child( - div() - .p_2() - .flex() - .flex_col() - .size_full() - .gap(px(4.)) - .children(sidebar_skeleton), - ); - - let pseudo_code_skeleton = |theme: Arc<Theme>, seed: f32| -> AnyElement { - let colors = theme.colors(); - let syntax = theme.syntax(); - - let keyword_color = syntax.get("keyword").color; - let function_color = syntax.get("function").color; - let string_color = syntax.get("string").color; - let comment_color = syntax.get("comment").color; - let variable_color = syntax.get("variable").color; - let type_color = syntax.get("type").color; - let punctuation_color = syntax.get("punctuation").color; - - let syntax_colors = [ - keyword_color, - function_color, - string_color, - variable_color, - type_color, - punctuation_color, - comment_color, - ]; - - let line_width = |line_idx: usize, block_idx: usize| -> f32 { - let val = (seed * 100.0 + line_idx as f32 * 20.0 + block_idx as f32 * 5.0).sin() - * 0.5 - + 0.5; - 0.05 + val * 0.2 - }; - - let indentation = |line_idx: usize| -> f32 { - let step = line_idx % 6; - if step < 3 { - step as f32 * 0.1 - } else { - (5 - step) as f32 * 0.1 - } - }; - - let pick_color = |line_idx: usize, block_idx: usize| -> Hsla { - let idx = ((seed * 10.0 + line_idx as f32 * 7.0 + block_idx as f32 * 3.0).sin() - * 3.5) - .abs() as usize - % syntax_colors.len(); - syntax_colors[idx].unwrap_or(colors.text) - }; - - let line_count = 13; - - let lines = (0..line_count) - .map(|line_idx| { - let block_count = (((seed * 30.0 + line_idx as f32 * 12.0).sin() * 0.5 + 0.5) - * 3.0) - .round() as usize - + 2; - - let indent = indentation(line_idx); - - let blocks = (0..block_count) - .map(|block_idx| { - let width = line_width(line_idx, block_idx); - let color = pick_color(line_idx, block_idx); - item_skeleton(relative(width).into(), skeleton_height, color) - }) - .collect::<Vec<_>>(); - - h_flex().gap(px(2.)).ml(relative(indent)).children(blocks) - }) - .collect::<Vec<_>>(); - - v_flex() - .size_full() - .p_1() - .gap(px(6.)) - .children(lines) - .into_any_element() - }; - - let pane = div() - .h_full() - .flex_grow() - .flex() - .flex_col() - // .child( - // div() - // .w_full() - // .border_color(color.border) - // .border_b(px(1.)) - // .h(relative(0.1)) - // .bg(color.tab_bar_background), - // ) - .child( - div() - .size_full() - .overflow_hidden() - .bg(color.editor_background) - .p_2() - .child(pseudo_code_skeleton(self.theme.clone(), self.seed)), - ); - - let content = div().size_full().flex().child(sidebar).child(pane); - - div() - .size_full() - .rounded(root_radius) - .p(root_padding) - .border(root_border) - .border_color(color.border_transparent) - .when(self.selected, |this| { - this.border_color(color.border_selected) - }) - .child( - div() - .size_full() - .rounded(inner_radius) - .border(child_border) - .border_color(color.border) - .bg(color.background) - .child(content), - ) - } -} - -impl Component for ThemePreviewTile { - fn description() -> Option<&'static str> { - Some(Self::DOCS) - } - - fn preview(_window: &mut Window, cx: &mut App) -> Option<AnyElement> { - let theme_registry = ThemeRegistry::global(cx); - - let one_dark = theme_registry.get("One Dark"); - let one_light = theme_registry.get("One Light"); - let gruvbox_dark = theme_registry.get("Gruvbox Dark"); - let gruvbox_light = theme_registry.get("Gruvbox Light"); - - let themes_to_preview = vec![ - one_dark.clone().ok(), - one_light.clone().ok(), - gruvbox_dark.clone().ok(), - gruvbox_light.clone().ok(), - ] - .into_iter() - .flatten() - .collect::<Vec<_>>(); - - Some( - v_flex() - .gap_6() - .p_4() - .children({ - if let Some(one_dark) = one_dark.ok() { - vec![example_group(vec![ - single_example( - "Default", - div() - .w(px(240.)) - .h(px(180.)) - .child(ThemePreviewTile::new(one_dark.clone(), false, 0.42)) - .into_any_element(), - ), - single_example( - "Selected", - div() - .w(px(240.)) - .h(px(180.)) - .child(ThemePreviewTile::new(one_dark, true, 0.42)) - .into_any_element(), - ), - ])] - } else { - vec![] - } - }) - .child( - example_group(vec![single_example( - "Default Themes", - h_flex() - .gap_4() - .children( - themes_to_preview - .iter() - .enumerate() - .map(|(i, theme)| { - div().w(px(200.)).h(px(140.)).child(ThemePreviewTile::new( - theme.clone(), - false, - 0.42, - )) - }) - .collect::<Vec<_>>(), - ) - .into_any_element(), - )]) - .grow(), - ) - .into_any_element(), - ) - } -} diff --git a/crates/workspace/src/dock.rs b/crates/workspace/src/dock.rs index 8fcd55b784fc4202a34a1a34f72590933bb0f3d1..ca63d3e5532a393436046e04a3b50a448a0e94f0 100644 --- a/crates/workspace/src/dock.rs +++ b/crates/workspace/src/dock.rs @@ -221,9 +221,9 @@ pub enum DockPosition { impl DockPosition { fn label(&self) -> &'static str { match self { - Self::Left => "left", - Self::Bottom => "bottom", - Self::Right => "right", + Self::Left => "Left", + Self::Bottom => "Bottom", + Self::Right => "Right", } } @@ -242,6 +242,7 @@ struct PanelEntry { pub struct PanelButtons { dock: Entity<Dock>, + _settings_subscription: Subscription, } impl Dock { @@ -833,7 +834,11 @@ impl Render for Dock { impl PanelButtons { pub fn new(dock: Entity<Dock>, cx: &mut Context<Self>) -> Self { cx.observe(&dock, |_, _, cx| cx.notify()).detach(); - Self { dock } + let settings_subscription = cx.observe_global::<SettingsStore>(|_, cx| cx.notify()); + Self { + dock, + _settings_subscription: settings_subscription, + } } } @@ -864,7 +869,7 @@ impl Render for PanelButtons { let action = dock.toggle_action(); let tooltip: SharedString = - format!("Close {} dock", dock.position.label()).into(); + format!("Close {} Dock", dock.position.label()).into(); (action, tooltip) } else { @@ -873,6 +878,8 @@ impl Render for PanelButtons { (action, icon_tooltip.into()) }; + let focus_handle = dock.focus_handle(cx); + Some( right_click_menu(name) .menu(move |window, cx| { @@ -909,6 +916,7 @@ impl Render for PanelButtons { .on_click({ let action = action.boxed_clone(); move |_, window, cx| { + window.focus(&focus_handle); window.dispatch_action(action.boxed_clone(), cx) } }) @@ -923,8 +931,13 @@ impl Render for PanelButtons { .collect(); let has_buttons = !buttons.is_empty(); + h_flex() .gap_1() + .when( + has_buttons && dock.position == DockPosition::Bottom, + |this| this.child(Divider::vertical().color(DividerColor::Border)), + ) .children(buttons) .when(has_buttons && dock.position == DockPosition::Left, |this| { this.child(Divider::vertical().color(DividerColor::Border)) diff --git a/crates/workspace/src/pane.rs b/crates/workspace/src/pane.rs index 56db7fa57009b739909bcfa40c4b0a28967f776b..a9e7304e47d8a8ee4d7bea2b6571aa4686964a48 100644 --- a/crates/workspace/src/pane.rs +++ b/crates/workspace/src/pane.rs @@ -18,7 +18,7 @@ use futures::{StreamExt, stream::FuturesUnordered}; use gpui::{ Action, AnyElement, App, AsyncWindowContext, ClickEvent, ClipboardItem, Context, Corner, Div, DragMoveEvent, Entity, EntityId, EventEmitter, ExternalPaths, FocusHandle, FocusOutEvent, - Focusable, KeyContext, MouseButton, MouseDownEvent, NavigationDirection, Pixels, Point, + Focusable, IsZero, KeyContext, MouseButton, MouseDownEvent, NavigationDirection, Pixels, Point, PromptLevel, Render, ScrollHandle, Subscription, Task, WeakEntity, WeakFocusHandle, Window, actions, anchored, deferred, prelude::*, }; @@ -40,13 +40,14 @@ use std::{ Arc, atomic::{AtomicUsize, Ordering}, }, + time::Duration, }; use theme::ThemeSettings; use ui::{ ButtonSize, Color, ContextMenu, ContextMenuEntry, ContextMenuItem, DecoratedIcon, IconButton, IconButtonShape, IconDecoration, IconDecorationKind, IconName, IconSize, Indicator, Label, - PopoverMenu, PopoverMenuHandle, ScrollableHandle, Tab, TabBar, TabPosition, Tooltip, - prelude::*, right_click_menu, + PopoverMenu, PopoverMenuHandle, Tab, TabBar, TabPosition, Tooltip, prelude::*, + right_click_menu, }; use util::{ResultExt, debug_panic, maybe, truncate_and_remove_front}; @@ -61,7 +62,7 @@ pub struct SelectedEntry { #[derive(Debug)] pub struct DraggedSelection { pub active_selection: SelectedEntry, - pub marked_selections: Arc<BTreeSet<SelectedEntry>>, + pub marked_selections: Arc<[SelectedEntry]>, } impl DraggedSelection { @@ -115,7 +116,8 @@ pub struct CloseActiveItem { #[derive(Clone, PartialEq, Debug, Deserialize, JsonSchema, Default, Action)] #[action(namespace = pane)] #[serde(deny_unknown_fields)] -pub struct CloseInactiveItems { +#[action(deprecated_aliases = ["pane::CloseInactiveItems"])] +pub struct CloseOtherItems { #[serde(default)] pub save_intent: Option<SaveIntent>, #[serde(default)] @@ -364,6 +366,7 @@ pub struct Pane { pinned_tab_count: usize, diagnostics: HashMap<ProjectPath, DiagnosticSeverity>, zoom_out_on_close: bool, + diagnostic_summary_update: Task<()>, /// If a certain project item wants to get recreated with specific data, it can persist its data before the recreation here. pub project_item_restoration_data: HashMap<ProjectItemKind, Box<dyn Any + Send>>, } @@ -505,6 +508,7 @@ impl Pane { pinned_tab_count: 0, diagnostics: Default::default(), zoom_out_on_close: true, + diagnostic_summary_update: Task::ready(()), project_item_restoration_data: HashMap::default(), } } @@ -616,8 +620,16 @@ impl Pane { project::Event::DiskBasedDiagnosticsFinished { .. } | project::Event::DiagnosticsUpdated { .. } => { if ItemSettings::get_global(cx).show_diagnostics != ShowDiagnostics::Off { - self.update_diagnostics(cx); - cx.notify(); + self.diagnostic_summary_update = cx.spawn(async move |this, cx| { + cx.background_executor() + .timer(Duration::from_millis(30)) + .await; + this.update(cx, |this, cx| { + this.update_diagnostics(cx); + cx.notify(); + }) + .log_err(); + }); } } _ => {} @@ -1343,9 +1355,10 @@ impl Pane { }) } - pub fn close_inactive_items( + pub fn close_other_items( &mut self, - action: &CloseInactiveItems, + action: &CloseOtherItems, + target_item_id: Option<EntityId>, window: &mut Window, cx: &mut Context<Self>, ) -> Task<Result<()>> { @@ -1353,7 +1366,11 @@ impl Pane { return Task::ready(Ok(())); } - let active_item_id = self.active_item_id(); + let active_item_id = match target_item_id { + Some(result) => result, + None => self.active_item_id(), + }; + let pinned_item_ids = self.pinned_item_ids(); self.close_items( @@ -1647,10 +1664,33 @@ impl Pane { } if should_save { - if !Self::save_item(project.clone(), &pane, &*item_to_close, save_intent, cx) - .await? + match Self::save_item(project.clone(), &pane, &*item_to_close, save_intent, cx) + .await { - break; + Ok(success) => { + if !success { + break; + } + } + Err(err) => { + let answer = pane.update_in(cx, |_, window, cx| { + let detail = Self::file_names_for_prompt( + &mut [&item_to_close].into_iter(), + cx, + ); + window.prompt( + PromptLevel::Warning, + &format!("Unable to save file: {}", &err), + Some(&detail), + &["Close Without Saving", "Cancel"], + cx, + ) + })?; + match answer.await { + Ok(0) => {} + Ok(1..) | Err(_) => break, + } + } } } @@ -2562,7 +2602,7 @@ impl Pane { save_intent: None, close_pinned: true, }; - let close_inactive_items_action = CloseInactiveItems { + let close_inactive_items_action = CloseOtherItems { save_intent: None, close_pinned: false, }; @@ -2594,8 +2634,9 @@ impl Pane { .action(Box::new(close_inactive_items_action.clone())) .disabled(total_items == 1) .handler(window.handler_for(&pane, move |pane, window, cx| { - pane.close_inactive_items( + pane.close_other_items( &close_inactive_items_action, + Some(item_id), window, cx, ) @@ -2814,7 +2855,7 @@ impl Pane { }) .collect::<Vec<_>>(); let tab_count = tab_items.len(); - if self.pinned_tab_count > tab_count { + if self.is_tab_pinned(tab_count) { log::warn!( "Pinned tab count ({}) exceeds actual tab count ({}). \ This should not happen. If possible, add reproduction steps, \ @@ -2847,10 +2888,9 @@ impl Pane { } }) .children(pinned_tabs.len().ne(&0).then(|| { - let content_width = self.tab_bar_scroll_handle.content_size().width; - let viewport_width = self.tab_bar_scroll_handle.viewport().size.width; + let max_scroll = self.tab_bar_scroll_handle.max_offset().width; // We need to check both because offset returns delta values even when the scroll handle is not scrollable - let is_scrollable = content_width > viewport_width; + let is_scrollable = !max_scroll.is_zero(); let is_scrolled = self.tab_bar_scroll_handle.offset().x < px(0.); let has_active_unpinned_tab = self.active_item_index >= self.pinned_tab_count; h_flex() @@ -2905,7 +2945,7 @@ impl Pane { this.handle_external_paths_drop(paths, window, cx) })) .on_click(cx.listener(move |this, event: &ClickEvent, window, cx| { - if event.up.click_count == 2 { + if event.click_count() == 2 { window.dispatch_action( this.double_click_dispatch_action.boxed_clone(), cx, @@ -3013,7 +3053,7 @@ impl Pane { || cfg!(not(target_os = "macos")) && window.modifiers().control; let from_pane = dragged_tab.pane.clone(); - let from_ix = dragged_tab.ix; + self.workspace .update(cx, |_, cx| { cx.defer_in(window, move |workspace, window, cx| { @@ -3045,9 +3085,13 @@ impl Pane { } to_pane.update(cx, |this, _| { if to_pane == from_pane { - let moved_right = ix > from_ix; - let ix = if moved_right { ix - 1 } else { ix }; - let is_pinned_in_to_pane = this.is_tab_pinned(ix); + let actual_ix = this + .items + .iter() + .position(|item| item.item_id() == item_id) + .unwrap_or(0); + + let is_pinned_in_to_pane = this.is_tab_pinned(actual_ix); if !was_pinned_in_from_pane && is_pinned_in_to_pane { this.pinned_tab_count += 1; @@ -3222,28 +3266,37 @@ impl Pane { split_direction = None; } - if let Ok(open_task) = workspace.update_in(cx, |workspace, window, cx| { - if let Some(split_direction) = split_direction { - to_pane = workspace.split_pane(to_pane, split_direction, window, cx); - } - workspace.open_paths( - paths, - OpenOptions { - visible: Some(OpenVisible::OnlyDirectories), - ..Default::default() - }, - Some(to_pane.downgrade()), - window, - cx, - ) - }) { + if let Ok((open_task, to_pane)) = + workspace.update_in(cx, |workspace, window, cx| { + if let Some(split_direction) = split_direction { + to_pane = + workspace.split_pane(to_pane, split_direction, window, cx); + } + ( + workspace.open_paths( + paths, + OpenOptions { + visible: Some(OpenVisible::OnlyDirectories), + ..Default::default() + }, + Some(to_pane.downgrade()), + window, + cx, + ), + to_pane, + ) + }) + { let opened_items: Vec<_> = open_task.await; - _ = workspace.update(cx, |workspace, cx| { + _ = workspace.update_in(cx, |workspace, window, cx| { for item in opened_items.into_iter().flatten() { if let Err(e) = item { workspace.show_error(&e, cx); } } + if to_pane.read(cx).items_len() == 0 { + workspace.remove_pane(to_pane, None, window, cx); + } }); } }) @@ -3504,8 +3557,8 @@ impl Render for Pane { }), ) .on_action( - cx.listener(|pane: &mut Self, action: &CloseInactiveItems, window, cx| { - pane.close_inactive_items(action, window, cx) + cx.listener(|pane: &mut Self, action: &CloseOtherItems, window, cx| { + pane.close_other_items(action, None, window, cx) .detach_and_log_err(cx); }), ) @@ -3587,7 +3640,7 @@ impl Render for Pane { .justify_center() .on_click(cx.listener( move |this, event: &ClickEvent, window, cx| { - if event.up.click_count == 2 { + if event.click_count() == 2 { window.dispatch_action( this.double_click_dispatch_action.boxed_clone(), cx, @@ -4924,6 +4977,43 @@ mod tests { assert_item_labels(&pane_a, ["B!", "A*!"], cx); } + #[gpui::test] + async fn test_dragging_pinned_tab_onto_unpinned_tab_reduces_unpinned_tab_count( + cx: &mut TestAppContext, + ) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + + let project = Project::test(fs, None, cx).await; + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let pane_a = workspace.read_with(cx, |workspace, _| workspace.active_pane().clone()); + + // Add A, B to pane A and pin A + let item_a = add_labeled_item(&pane_a, "A", false, cx); + add_labeled_item(&pane_a, "B", false, cx); + pane_a.update_in(cx, |pane, window, cx| { + let ix = pane.index_for_item_id(item_a.item_id()).unwrap(); + pane.pin_tab_at(ix, window, cx); + }); + assert_item_labels(&pane_a, ["A!", "B*"], cx); + + // Drag pinned A on top of B in the same pane, which changes tab order to B, A + pane_a.update_in(cx, |pane, window, cx| { + let dragged_tab = DraggedTab { + pane: pane_a.clone(), + item: item_a.boxed_clone(), + ix: 0, + detail: 0, + is_active: true, + }; + pane.handle_tab_drop(&dragged_tab, 1, window, cx); + }); + + // Neither are pinned + assert_item_labels(&pane_a, ["B", "A*"], cx); + } + #[gpui::test] async fn test_drag_pinned_tab_beyond_unpinned_tab_in_same_pane_becomes_unpinned( cx: &mut TestAppContext, @@ -5836,11 +5926,12 @@ mod tests { assert_item_labels(&pane, ["A!", "B!", "C", "D", "E*"], cx); pane.update_in(cx, |pane, window, cx| { - pane.close_inactive_items( - &CloseInactiveItems { + pane.close_other_items( + &CloseOtherItems { save_intent: None, close_pinned: false, }, + None, window, cx, ) @@ -5850,6 +5941,43 @@ mod tests { assert_item_labels(&pane, ["A!", "B!", "E*"], cx); } + #[gpui::test] + async fn test_running_close_inactive_items_via_an_inactive_item(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + + let project = Project::test(fs, None, cx).await; + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let pane = workspace.read_with(cx, |workspace, _| workspace.active_pane().clone()); + + add_labeled_item(&pane, "A", false, cx); + assert_item_labels(&pane, ["A*"], cx); + + let item_b = add_labeled_item(&pane, "B", false, cx); + assert_item_labels(&pane, ["A", "B*"], cx); + + add_labeled_item(&pane, "C", false, cx); + add_labeled_item(&pane, "D", false, cx); + add_labeled_item(&pane, "E", false, cx); + assert_item_labels(&pane, ["A", "B", "C", "D", "E*"], cx); + + pane.update_in(cx, |pane, window, cx| { + pane.close_other_items( + &CloseOtherItems { + save_intent: None, + close_pinned: false, + }, + Some(item_b.item_id()), + window, + cx, + ) + }) + .await + .unwrap(); + assert_item_labels(&pane, ["B*"], cx); + } + #[gpui::test] async fn test_close_clean_items(cx: &mut TestAppContext) { init_test(cx); @@ -6201,11 +6329,12 @@ mod tests { .unwrap(); pane.update_in(cx, |pane, window, cx| { - pane.close_inactive_items( - &CloseInactiveItems { + pane.close_other_items( + &CloseOtherItems { save_intent: None, close_pinned: false, }, + None, window, cx, ) diff --git a/crates/workspace/src/pane_group.rs b/crates/workspace/src/pane_group.rs index 4565cef34719cdf3d4c506e7ba73dedb8cc6e3de..5c87206e9e96cf3866b183684d981b02692d039f 100644 --- a/crates/workspace/src/pane_group.rs +++ b/crates/workspace/src/pane_group.rs @@ -943,6 +943,8 @@ mod element { pub struct PaneAxisElement { axis: Axis, basis: usize, + /// Equivalent to ColumnWidths (but in terms of flexes instead of percentages) + /// For example, flexes "1.33, 1, 1", instead of "40%, 30%, 30%" flexes: Arc<Mutex<Vec<f32>>>, bounding_boxes: Arc<Mutex<Vec<Option<Bounds<Pixels>>>>>, children: SmallVec<[AnyElement; 2]>, @@ -998,6 +1000,7 @@ mod element { let mut flexes = flexes.lock(); debug_assert!(flex_values_in_bounds(flexes.as_slice())); + // Math to convert a flex value to a pixel value let size = move |ix, flexes: &[f32]| { container_size.along(axis) * (flexes[ix] / flexes.len() as f32) }; @@ -1007,9 +1010,13 @@ mod element { return; } + // This is basically a "bucket" of pixel changes that need to be applied in response to this + // mouse event. Probably a small, fractional number like 0.5 or 1.5 pixels let mut proposed_current_pixel_change = (e.position - child_start).along(axis) - size(ix, flexes.as_slice()); + // This takes a pixel change, and computes the flex changes that correspond to this pixel change + // as well as the next one, for some reason let flex_changes = |pixel_dx, target_ix, next: isize, flexes: &[f32]| { let flex_change = pixel_dx / container_size.along(axis); let current_target_flex = flexes[target_ix] + flex_change; @@ -1017,6 +1024,9 @@ mod element { (current_target_flex, next_target_flex) }; + // Generate the list of flex successors, from the current index. + // If you're dragging column 3 forward, out of 6 columns, then this code will produce [4, 5, 6] + // If you're dragging column 3 backward, out of 6 columns, then this code will produce [2, 1, 0] let mut successors = iter::from_fn({ let forward = proposed_current_pixel_change > px(0.); let mut ix_offset = 0; @@ -1034,6 +1044,7 @@ mod element { } }); + // Now actually loop over these, and empty our bucket of pixel changes while proposed_current_pixel_change.abs() > px(0.) { let Some(current_ix) = successors.next() else { break; diff --git a/crates/workspace/src/persistence.rs b/crates/workspace/src/persistence.rs index 406f37419d1b02d14817ad165d4fa5cdd4c6d452..6fa5c969e7ff9dfdf66d0303a61c42d070d91560 100644 --- a/crates/workspace/src/persistence.rs +++ b/crates/workspace/src/persistence.rs @@ -939,6 +939,26 @@ impl WorkspaceDb { } } + query! { + pub async fn update_ssh_project_paths_query(ssh_project_id: u64, paths: String) -> Result<Option<SerializedSshProject>> { + UPDATE ssh_projects + SET paths = ?2 + WHERE id = ?1 + RETURNING id, host, port, paths, user + } + } + + pub(crate) async fn update_ssh_project_paths( + &self, + ssh_project_id: SshProjectId, + new_paths: Vec<String>, + ) -> Result<SerializedSshProject> { + let paths = serde_json::to_string(&new_paths)?; + self.update_ssh_project_paths_query(ssh_project_id.0, paths) + .await? + .context("failed to update ssh project paths") + } + query! { pub async fn next_id() -> Result<WorkspaceId> { INSERT INTO workspaces DEFAULT VALUES RETURNING workspace_id @@ -1336,6 +1356,14 @@ impl WorkspaceDb { } } + query! { + pub(crate) async fn set_session_id(workspace_id: WorkspaceId, session_id: Option<String>) -> Result<()> { + UPDATE workspaces + SET session_id = ?2 + WHERE workspace_id = ?1 + } + } + pub async fn toolchain( &self, workspace_id: WorkspaceId, @@ -2616,4 +2644,56 @@ mod tests { assert_eq!(workspace.center_group, new_workspace.center_group); } + + #[gpui::test] + async fn test_update_ssh_project_paths() { + zlog::init_test(); + + let db = WorkspaceDb::open_test_db("test_update_ssh_project_paths").await; + + let (host, port, initial_paths, user) = ( + "example.com".to_string(), + Some(22_u16), + vec!["/home/user".to_string(), "/etc/nginx".to_string()], + Some("user".to_string()), + ); + + let project = db + .get_or_create_ssh_project(host.clone(), port, initial_paths.clone(), user.clone()) + .await + .unwrap(); + + assert_eq!(project.host, host); + assert_eq!(project.paths, initial_paths); + assert_eq!(project.user, user); + + let new_paths = vec![ + "/home/user".to_string(), + "/etc/nginx".to_string(), + "/var/log".to_string(), + "/opt/app".to_string(), + ]; + + let updated_project = db + .update_ssh_project_paths(project.id, new_paths.clone()) + .await + .unwrap(); + + assert_eq!(updated_project.id, project.id); + assert_eq!(updated_project.paths, new_paths); + + let retrieved_project = db + .get_ssh_project( + host.clone(), + port, + serde_json::to_string(&new_paths).unwrap(), + user.clone(), + ) + .await + .unwrap() + .unwrap(); + + assert_eq!(retrieved_project.id, project.id); + assert_eq!(retrieved_project.paths, new_paths); + } } diff --git a/crates/workspace/src/status_bar.rs b/crates/workspace/src/status_bar.rs index 798d49eec5c7e0ea8d8c4bc7427389af8de2c658..edeb382de7d386b37d81b2649af85cf97f9e8b31 100644 --- a/crates/workspace/src/status_bar.rs +++ b/crates/workspace/src/status_bar.rs @@ -42,7 +42,7 @@ impl Render for StatusBar { .justify_between() .gap(DynamicSpacing::Base08.rems(cx)) .py(DynamicSpacing::Base04.rems(cx)) - .px(DynamicSpacing::Base08.rems(cx)) + .px(DynamicSpacing::Base06.rems(cx)) .bg(cx.theme().colors().status_bar_background) .map(|el| match window.window_decorations() { Decorations::Server => el, @@ -58,22 +58,23 @@ impl Render for StatusBar { .border_b(px(1.0)) .border_color(cx.theme().colors().status_bar_background), }) - .child(self.render_left_tools(cx)) - .child(self.render_right_tools(cx)) + .child(self.render_left_tools()) + .child(self.render_right_tools()) } } impl StatusBar { - fn render_left_tools(&self, cx: &mut Context<Self>) -> impl IntoElement { + fn render_left_tools(&self) -> impl IntoElement { h_flex() - .gap(DynamicSpacing::Base04.rems(cx)) + .gap_1() .overflow_x_hidden() .children(self.left_items.iter().map(|item| item.to_any())) } - fn render_right_tools(&self, cx: &mut Context<Self>) -> impl IntoElement { + fn render_right_tools(&self) -> impl IntoElement { h_flex() - .gap(DynamicSpacing::Base04.rems(cx)) + .gap_1() + .overflow_x_hidden() .children(self.right_items.iter().rev().map(|item| item.to_any())) } } diff --git a/crates/workspace/src/tasks.rs b/crates/workspace/src/tasks.rs index 26edbd8d03ed37d4bddca65f0a94cc9413760dd9..32d066c7eb74f9019348d3bcac9402ebb7216a4e 100644 --- a/crates/workspace/src/tasks.rs +++ b/crates/workspace/src/tasks.rs @@ -73,7 +73,7 @@ impl Workspace { if let Some(terminal_provider) = self.terminal_provider.as_ref() { let task_status = terminal_provider.spawn(spawn_in_terminal, window, cx); - cx.background_spawn(async move { + let task = cx.background_spawn(async move { match task_status.await { Some(Ok(status)) => { if status.success() { @@ -82,11 +82,11 @@ impl Workspace { log::debug!("Task spawn failed, code: {:?}", status.code()); } } - Some(Err(e)) => log::error!("Task spawn failed: {e}"), + Some(Err(e)) => log::error!("Task spawn failed: {e:#}"), None => log::debug!("Task spawn got cancelled"), } - }) - .detach(); + }); + self.scheduled_tasks.push(task); } } diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index 125b5bb98a7dd056880305e631c44bea815a6009..aab8a36f45b94a0f34bc7800da37d8af77c7fb08 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -32,7 +32,7 @@ use futures::{ mpsc::{self, UnboundedReceiver, UnboundedSender}, oneshot, }, - future::try_join_all, + future::{Shared, try_join_all}, }; use gpui::{ Action, AnyEntity, AnyView, AnyWeakView, App, AsyncApp, AsyncWindowContext, Bounds, Context, @@ -48,7 +48,10 @@ pub use item::{ ProjectItem, SerializableItem, SerializableItemHandle, WeakItemHandle, }; use itertools::Itertools; -use language::{Buffer, LanguageRegistry, Rope}; +use language::{ + Buffer, LanguageRegistry, Rope, + language_settings::{AllLanguageSettings, all_language_settings}, +}; pub use modal_layer::*; use node_runtime::NodeRuntime; use notifications::{ @@ -74,7 +77,7 @@ use remote::{SshClientDelegate, SshConnectionOptions, ssh_session::ConnectionIde use schemars::JsonSchema; use serde::Deserialize; use session::AppSession; -use settings::Settings; +use settings::{Settings, update_settings_file}; use shared_screen::SharedScreen; use sqlez::{ bindable::{Bind, Column, StaticColumnCount}, @@ -87,7 +90,7 @@ use std::{ borrow::Cow, cell::RefCell, cmp, - collections::hash_map::DefaultHasher, + collections::{VecDeque, hash_map::DefaultHasher}, env, hash::{Hash, Hasher}, path::{Path, PathBuf}, @@ -233,6 +236,8 @@ actions!( ToggleBottomDock, /// Toggles centered layout mode. ToggleCenteredLayout, + /// Toggles edit prediction feature globally for all files. + ToggleEditPrediction, /// Toggles the left dock. ToggleLeftDock, /// Toggles the right dock. @@ -1016,6 +1021,15 @@ pub enum OpenVisible { OnlyDirectories, } +enum WorkspaceLocation { + // Valid local paths or SSH project to serialize + Location(SerializedWorkspaceLocation), + // No valid location found hence clear session id + DetachFromSession, + // No valid location found to serialize + None, +} + type PromptForNewPath = Box< dyn Fn( &mut Workspace, @@ -1034,6 +1048,13 @@ type PromptForOpenPath = Box< ) -> oneshot::Receiver<Option<Vec<PathBuf>>>, >; +#[derive(Default)] +struct DispatchingKeystrokes { + dispatched: HashSet<Vec<Keystroke>>, + queue: VecDeque<Keystroke>, + task: Option<Shared<Task<()>>>, +} + /// Collects everything project-related for a certain window opened. /// In some way, is a counterpart of a window, as the [`WindowHandle`] could be downcast into `Workspace`. /// @@ -1049,7 +1070,6 @@ pub struct Workspace { center: PaneGroup, left_dock: Entity<Dock>, bottom_dock: Entity<Dock>, - bottom_dock_layout: BottomDockLayout, right_dock: Entity<Dock>, panes: Vec<Entity<Pane>>, panes_by_item: HashMap<EntityId, WeakEntity<Pane>>, @@ -1066,16 +1086,18 @@ pub struct Workspace { follower_states: HashMap<CollaboratorId, FollowerState>, last_leaders_by_pane: HashMap<WeakEntity<Pane>, CollaboratorId>, window_edited: bool, + last_window_title: Option<String>, dirty_items: HashMap<EntityId, Subscription>, active_call: Option<(Entity<ActiveCall>, Vec<Subscription>)>, leader_updates_tx: mpsc::UnboundedSender<(PeerId, proto::UpdateFollowers)>, database_id: Option<WorkspaceId>, app_state: Arc<AppState>, - dispatching_keystrokes: Rc<RefCell<(HashSet<String>, Vec<Keystroke>)>>, + dispatching_keystrokes: Rc<RefCell<DispatchingKeystrokes>>, _subscriptions: Vec<Subscription>, _apply_leader_updates: Task<Result<()>>, _observe_current_user: Task<Result<()>>, - _schedule_serialize: Option<Task<()>>, + _schedule_serialize_workspace: Option<Task<()>>, + _schedule_serialize_ssh_paths: Option<Task<()>>, pane_history_timestamp: Arc<AtomicUsize>, bounds: Bounds<Pixels>, pub centered_layout: bool, @@ -1088,6 +1110,7 @@ pub struct Workspace { serialized_ssh_project: Option<SerializedSshProject>, _items_serializer: Task<Result<()>>, session_id: Option<String>, + scheduled_tasks: Vec<Task<()>>, } impl EventEmitter<Event> for Workspace {} @@ -1133,9 +1156,10 @@ impl Workspace { project::Event::WorktreeRemoved(_) | project::Event::WorktreeAdded(_) => { this.update_window_title(window, cx); + this.update_ssh_paths(cx); + this.serialize_ssh_paths(window, cx); this.serialize_workspace(window, cx); // This event could be triggered by `AddFolderToProject` or `RemoveFromProject`. - // So we need to update the history. this.update_history(cx); } @@ -1291,7 +1315,6 @@ impl Workspace { ) .detach(); - let bottom_dock_layout = WorkspaceSettings::get_global(cx).bottom_dock_layout; let left_dock = Dock::new(DockPosition::Left, modal_layer.clone(), window, cx); let bottom_dock = Dock::new(DockPosition::Bottom, modal_layer.clone(), window, cx); let right_dock = Dock::new(DockPosition::Right, modal_layer.clone(), window, cx); @@ -1390,20 +1413,21 @@ impl Workspace { suppressed_notifications: HashSet::default(), left_dock, bottom_dock, - bottom_dock_layout, right_dock, project: project.clone(), follower_states: Default::default(), last_leaders_by_pane: Default::default(), dispatching_keystrokes: Default::default(), window_edited: false, + last_window_title: None, dirty_items: Default::default(), active_call, database_id: workspace_id, app_state, _observe_current_user, _apply_leader_updates, - _schedule_serialize: None, + _schedule_serialize_workspace: None, + _schedule_serialize_ssh_paths: None, leader_updates_tx, _subscriptions: subscriptions, pane_history_timestamp, @@ -1420,6 +1444,7 @@ impl Workspace { _items_serializer, session_id: Some(session_id), serialized_ssh_project: None, + scheduled_tasks: Vec::new(), } } @@ -1616,10 +1641,6 @@ impl Workspace { &self.bottom_dock } - pub fn bottom_dock_layout(&self) -> BottomDockLayout { - self.bottom_dock_layout - } - pub fn set_bottom_dock_layout( &mut self, layout: BottomDockLayout, @@ -1631,7 +1652,6 @@ impl Workspace { content.bottom_dock_layout = Some(layout); }); - self.bottom_dock_layout = layout; cx.notify(); self.serialize_workspace(window, cx); } @@ -1711,6 +1731,27 @@ impl Workspace { history } + pub fn recent_active_item_by_type<T: 'static>(&self, cx: &App) -> Option<Entity<T>> { + let mut recent_item: Option<Entity<T>> = None; + let mut recent_timestamp = 0; + for pane_handle in &self.panes { + let pane = pane_handle.read(cx); + let item_map: HashMap<EntityId, &Box<dyn ItemHandle>> = + pane.items().map(|item| (item.item_id(), item)).collect(); + for entry in pane.activation_history() { + if entry.timestamp > recent_timestamp { + if let Some(&item) = item_map.get(&entry.entity_id) { + if let Some(typed_item) = item.act_as::<T>(cx) { + recent_timestamp = entry.timestamp; + recent_item = Some(typed_item); + } + } + } + } + } + recent_item + } + pub fn recent_navigation_history_iter( &self, cx: &App, @@ -1774,10 +1815,7 @@ impl Workspace { .max_by(|b1, b2| b1.worktree_id.cmp(&b2.worktree_id)) }); - match latest_project_path_opened { - Some(latest_project_path_opened) => latest_project_path_opened == history_path, - None => true, - } + latest_project_path_opened.map_or(true, |path| path == history_path) }) } @@ -2282,49 +2320,65 @@ impl Workspace { window: &mut Window, cx: &mut Context<Self>, ) { - let mut state = self.dispatching_keystrokes.borrow_mut(); - if !state.0.insert(action.0.clone()) { - cx.propagate(); - return; - } - let mut keystrokes: Vec<Keystroke> = action + let keystrokes: Vec<Keystroke> = action .0 .split(' ') .flat_map(|k| Keystroke::parse(k).log_err()) .collect(); - keystrokes.reverse(); + let _ = self.send_keystrokes_impl(keystrokes, window, cx); + } + + pub fn send_keystrokes_impl( + &mut self, + keystrokes: Vec<Keystroke>, + window: &mut Window, + cx: &mut Context<Self>, + ) -> Shared<Task<()>> { + let mut state = self.dispatching_keystrokes.borrow_mut(); + if !state.dispatched.insert(keystrokes.clone()) { + cx.propagate(); + return state.task.clone().unwrap(); + } - state.1.append(&mut keystrokes); - drop(state); + state.queue.extend(keystrokes); let keystrokes = self.dispatching_keystrokes.clone(); - window - .spawn(cx, async move |cx| { - // limit to 100 keystrokes to avoid infinite recursion. - for _ in 0..100 { - let Some(keystroke) = keystrokes.borrow_mut().1.pop() else { - keystrokes.borrow_mut().0.clear(); - return Ok(()); - }; - cx.update(|window, cx| { - let focused = window.focused(cx); - window.dispatch_keystroke(keystroke.clone(), cx); - if window.focused(cx) != focused { - // dispatch_keystroke may cause the focus to change. - // draw's side effect is to schedule the FocusChanged events in the current flush effect cycle - // And we need that to happen before the next keystroke to keep vim mode happy... - // (Note that the tests always do this implicitly, so you must manually test with something like: - // "bindings": { "g z": ["workspace::SendKeystrokes", ": j <enter> u"]} - // ) - window.draw(cx).clear(); + if state.task.is_none() { + state.task = Some( + window + .spawn(cx, async move |cx| { + // limit to 100 keystrokes to avoid infinite recursion. + for _ in 0..100 { + let mut state = keystrokes.borrow_mut(); + let Some(keystroke) = state.queue.pop_front() else { + state.dispatched.clear(); + state.task.take(); + return; + }; + drop(state); + cx.update(|window, cx| { + let focused = window.focused(cx); + window.dispatch_keystroke(keystroke.clone(), cx); + if window.focused(cx) != focused { + // dispatch_keystroke may cause the focus to change. + // draw's side effect is to schedule the FocusChanged events in the current flush effect cycle + // And we need that to happen before the next keystroke to keep vim mode happy... + // (Note that the tests always do this implicitly, so you must manually test with something like: + // "bindings": { "g z": ["workspace::SendKeystrokes", ": j <enter> u"]} + // ) + window.draw(cx).clear(); + } + }) + .ok(); } - })?; - } - *keystrokes.borrow_mut() = Default::default(); - anyhow::bail!("over 100 keystrokes passed to send_keystrokes"); - }) - .detach_and_log_err(cx); + *keystrokes.borrow_mut() = Default::default(); + log::error!("over 100 keystrokes passed to send_keystrokes"); + }) + .shared(), + ); + } + state.task.clone().unwrap() } fn save_all_internal( @@ -2772,11 +2826,12 @@ impl Workspace { if retain_active_pane { let current_pane_close = current_pane.update(cx, |pane, cx| { - pane.close_inactive_items( - &CloseInactiveItems { + pane.close_other_items( + &CloseOtherItems { save_intent: None, close_pinned: false, }, + None, window, cx, ) @@ -3841,11 +3896,13 @@ impl Workspace { if *local { self.unfollow_in_pane(&pane, window, cx); } + serialize_workspace = *focus_changed || pane != self.active_pane(); if pane == self.active_pane() { self.active_item_path_changed(window, cx); self.update_active_view_for_followers(window, cx); + } else if *local { + self.set_active_pane(&pane, window, cx); } - serialize_workspace = *focus_changed || pane != self.active_pane(); } pane::Event::UserSavedItem { item, save_intent } => { cx.emit(Event::UserSavedItem { @@ -4348,7 +4405,13 @@ impl Workspace { title.push_str(" ↗"); } + if let Some(last_title) = self.last_window_title.as_ref() { + if &title == last_title { + return; + } + } window.set_window_title(&title); + self.last_window_title = Some(title); } fn update_window_edited(&mut self, window: &mut Window, cx: &mut App) { @@ -4738,7 +4801,7 @@ impl Workspace { .remote_id(&self.app_state.client, window, cx) .map(|id| id.to_proto()); - if let Some(id) = id.clone() { + if let Some(id) = id { if let Some(variant) = item.to_state_proto(window, cx) { let view = Some(proto::View { id: id.clone(), @@ -4751,7 +4814,7 @@ impl Workspace { update = proto::UpdateActiveView { view, // TODO: Remove after version 0.145.x stabilizes. - id: id.clone(), + id, leader_id: leader_peer_id, }; } @@ -5024,6 +5087,46 @@ impl Workspace { } } + fn update_ssh_paths(&mut self, cx: &App) { + let project = self.project().read(cx); + if !project.is_local() { + let paths: Vec<String> = project + .visible_worktrees(cx) + .map(|worktree| worktree.read(cx).abs_path().to_string_lossy().to_string()) + .collect(); + if let Some(ssh_project) = &mut self.serialized_ssh_project { + ssh_project.paths = paths; + } + } + } + + fn serialize_ssh_paths(&mut self, window: &mut Window, cx: &mut Context<Workspace>) { + if self._schedule_serialize_ssh_paths.is_none() { + self._schedule_serialize_ssh_paths = + Some(cx.spawn_in(window, async move |this, cx| { + cx.background_executor() + .timer(SERIALIZATION_THROTTLE_TIME) + .await; + this.update_in(cx, |this, window, cx| { + let task = if let Some(ssh_project) = &this.serialized_ssh_project { + let ssh_project_id = ssh_project.id; + let ssh_project_paths = ssh_project.paths.clone(); + window.spawn(cx, async move |_| { + persistence::DB + .update_ssh_project_paths(ssh_project_id, ssh_project_paths) + .await + }) + } else { + Task::ready(Err(anyhow::anyhow!("No SSH project to serialize"))) + }; + task.detach(); + this._schedule_serialize_ssh_paths.take(); + }) + .log_err(); + })); + } + } + fn remove_panes(&mut self, member: Member, window: &mut Window, cx: &mut Context<Workspace>) { match member { Member::Axis(PaneAxis { members, .. }) => { @@ -5067,17 +5170,18 @@ impl Workspace { } fn serialize_workspace(&mut self, window: &mut Window, cx: &mut Context<Self>) { - if self._schedule_serialize.is_none() { - self._schedule_serialize = Some(cx.spawn_in(window, async move |this, cx| { - cx.background_executor() - .timer(Duration::from_millis(100)) - .await; - this.update_in(cx, |this, window, cx| { - this.serialize_workspace_internal(window, cx).detach(); - this._schedule_serialize.take(); - }) - .log_err(); - })); + if self._schedule_serialize_workspace.is_none() { + self._schedule_serialize_workspace = + Some(cx.spawn_in(window, async move |this, cx| { + cx.background_executor() + .timer(SERIALIZATION_THROTTLE_TIME) + .await; + this.update_in(cx, |this, window, cx| { + this.serialize_workspace_internal(window, cx).detach(); + this._schedule_serialize_workspace.take(); + }) + .log_err(); + })); } } @@ -5194,48 +5298,58 @@ impl Workspace { } } - if let Some(location) = self.serialize_workspace_location(cx) { - let breakpoints = self.project.update(cx, |project, cx| { - project - .breakpoint_store() - .read(cx) - .all_source_breakpoints(cx) - }); + match self.serialize_workspace_location(cx) { + WorkspaceLocation::Location(location) => { + let breakpoints = self.project.update(cx, |project, cx| { + project + .breakpoint_store() + .read(cx) + .all_source_breakpoints(cx) + }); - let center_group = build_serialized_pane_group(&self.center.root, window, cx); - let docks = build_serialized_docks(self, window, cx); - let window_bounds = Some(SerializedWindowBounds(window.window_bounds())); - let serialized_workspace = SerializedWorkspace { - id: database_id, - location, - center_group, - window_bounds, - display: Default::default(), - docks, - centered_layout: self.centered_layout, - session_id: self.session_id.clone(), - breakpoints, - window_id: Some(window.window_handle().window_id().as_u64()), - }; + let center_group = build_serialized_pane_group(&self.center.root, window, cx); + let docks = build_serialized_docks(self, window, cx); + let window_bounds = Some(SerializedWindowBounds(window.window_bounds())); + let serialized_workspace = SerializedWorkspace { + id: database_id, + location, + center_group, + window_bounds, + display: Default::default(), + docks, + centered_layout: self.centered_layout, + session_id: self.session_id.clone(), + breakpoints, + window_id: Some(window.window_handle().window_id().as_u64()), + }; - return window.spawn(cx, async move |_| { - persistence::DB.save_workspace(serialized_workspace).await; - }); + window.spawn(cx, async move |_| { + persistence::DB.save_workspace(serialized_workspace).await; + }) + } + WorkspaceLocation::DetachFromSession => window.spawn(cx, async move |_| { + persistence::DB + .set_session_id(database_id, None) + .await + .log_err(); + }), + WorkspaceLocation::None => Task::ready(()), } - Task::ready(()) } - fn serialize_workspace_location(&self, cx: &App) -> Option<SerializedWorkspaceLocation> { + fn serialize_workspace_location(&self, cx: &App) -> WorkspaceLocation { if let Some(ssh_project) = &self.serialized_ssh_project { - Some(SerializedWorkspaceLocation::Ssh(ssh_project.clone())) + WorkspaceLocation::Location(SerializedWorkspaceLocation::Ssh(ssh_project.clone())) } else if let Some(local_paths) = self.local_paths(cx) { if !local_paths.is_empty() { - Some(SerializedWorkspaceLocation::from_local_paths(local_paths)) + WorkspaceLocation::Location(SerializedWorkspaceLocation::from_local_paths( + local_paths, + )) } else { - None + WorkspaceLocation::DetachFromSession } } else { - None + WorkspaceLocation::None } } @@ -5243,8 +5357,9 @@ impl Workspace { let Some(id) = self.database_id() else { return; }; - let Some(location) = self.serialize_workspace_location(cx) else { - return; + let location = match self.serialize_workspace_location(cx) { + WorkspaceLocation::Location(location) => location, + _ => return, }; if let Some(manager) = HistoryManager::global(cx) { manager.update(cx, |this, cx| { @@ -5441,6 +5556,7 @@ impl Workspace { .on_action(cx.listener(Self::activate_pane_at_index)) .on_action(cx.listener(Self::move_item_to_pane_at_index)) .on_action(cx.listener(Self::move_focused_panel_to_next_position)) + .on_action(cx.listener(Self::toggle_edit_predictions_all_files)) .on_action(cx.listener(|workspace, _: &Unfollow, window, cx| { let pane = workspace.active_pane().clone(); workspace.unfollow_in_pane(&pane, window, cx); @@ -5629,7 +5745,6 @@ impl Workspace { let client = project.read(cx).client(); let user_store = project.read(cx).user_store(); - let workspace_store = cx.new(|cx| WorkspaceStore::new(client.clone(), cx)); let session = cx.new(|cx| AppSession::new(Session::test(), cx)); window.activate_window(); @@ -5873,6 +5988,19 @@ impl Workspace { } }); } + + fn toggle_edit_predictions_all_files( + &mut self, + _: &ToggleEditPrediction, + _window: &mut Window, + cx: &mut Context<Self>, + ) { + let fs = self.project().read(cx).fs().clone(); + let show_edit_predictions = all_language_settings(None, cx).show_edit_predictions(None, cx); + update_settings_file::<AllLanguageSettings>(fs, cx, move |file, _| { + file.defaults.show_edit_predictions = Some(!show_edit_predictions) + }); + } } fn leader_border_for_pane( @@ -6178,6 +6306,7 @@ impl Render for Workspace { .iter() .map(|(_, notification)| notification.entity_id()) .collect::<Vec<_>>(); + let bottom_dock_layout = WorkspaceSettings::get_global(cx).bottom_dock_layout; client_side_decorations( self.actions(div(), window, cx) @@ -6301,7 +6430,7 @@ impl Render for Workspace { )) }) .child({ - match self.bottom_dock_layout { + match bottom_dock_layout { BottomDockLayout::Full => div() .flex() .flex_col() @@ -6833,10 +6962,13 @@ async fn join_channel_internal( match status { Status::Connecting | Status::Authenticating + | Status::Authenticated | Status::Reconnecting | Status::Reauthenticating => continue, Status::Connected { .. } => break 'outer, - Status::SignedOut => return Err(ErrorCode::SignedOut.into()), + Status::SignedOut | Status::AuthenticationError => { + return Err(ErrorCode::SignedOut.into()); + } Status::UpgradeRequired => return Err(ErrorCode::UpgradeRequired.into()), Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => { return Err(ErrorCode::Disconnected.into()); @@ -7335,6 +7467,17 @@ async fn open_ssh_project_inner( return Err(project_path_errors.pop().context("no paths given")?); } + if let Some(detach_session_task) = window + .update(cx, |_workspace, window, cx| { + cx.spawn_in(window, async move |this, cx| { + this.update_in(cx, |this, window, cx| this.remove_from_session(window, cx)) + }) + }) + .ok() + { + detach_session_task.await.ok(); + } + cx.update_window(window.into(), |_, window, cx| { window.replace_root(cx, |window, cx| { telemetry::event!("SSH Project Opened"); @@ -9447,11 +9590,12 @@ mod tests { ); }); let close_all_but_multi_buffer_task = pane.update_in(cx, |pane, window, cx| { - pane.close_inactive_items( - &CloseInactiveItems { + pane.close_other_items( + &CloseOtherItems { save_intent: Some(SaveIntent::Save), close_pinned: true, }, + None, window, cx, ) diff --git a/crates/worktree/src/worktree.rs b/crates/worktree/src/worktree.rs index 8c407fdd3eab5a6b7189f67ff46b8ce76d1a428d..b5a0f71e81171be061c834e15d2f50e04688bbb5 100644 --- a/crates/worktree/src/worktree.rs +++ b/crates/worktree/src/worktree.rs @@ -62,7 +62,7 @@ use std::{ }, time::{Duration, Instant}, }; -use sum_tree::{Bias, Edit, KeyedItem, SeekTarget, SumTree, Summary, TreeMap, TreeSet, Unit}; +use sum_tree::{Bias, Dimensions, Edit, KeyedItem, SeekTarget, SumTree, Summary, TreeMap, TreeSet}; use text::{LineEnding, Rope}; use util::{ ResultExt, @@ -407,12 +407,12 @@ struct LocalRepositoryEntry { } impl sum_tree::Item for LocalRepositoryEntry { - type Summary = PathSummary<Unit>; + type Summary = PathSummary<&'static ()>; fn summary(&self, _: &<Self::Summary as Summary>::Context) -> Self::Summary { PathSummary { max_path: self.work_directory.path_key().0, - item_summary: Unit, + item_summary: &(), } } } @@ -425,12 +425,6 @@ impl KeyedItem for LocalRepositoryEntry { } } -//impl LocalRepositoryEntry { -// pub fn repo(&self) -> &Arc<dyn GitRepository> { -// &self.repo_ptr -// } -//} - impl Deref for LocalRepositoryEntry { type Target = WorkDirectory; @@ -2454,16 +2448,16 @@ impl Snapshot { self.entries_by_path = { let mut cursor = self.entries_by_path.cursor::<TraversalProgress>(&()); let mut new_entries_by_path = - cursor.slice(&TraversalTarget::path(&removed_entry.path), Bias::Left, &()); + cursor.slice(&TraversalTarget::path(&removed_entry.path), Bias::Left); while let Some(entry) = cursor.item() { if entry.path.starts_with(&removed_entry.path) { self.entries_by_id.remove(&entry.id, &()); - cursor.next(&()); + cursor.next(); } else { break; } } - new_entries_by_path.append(cursor.suffix(&()), &()); + new_entries_by_path.append(cursor.suffix(), &()); new_entries_by_path }; @@ -2576,7 +2570,6 @@ impl Snapshot { include_ignored, }, Bias::Right, - &(), ); Traversal { snapshot: self, @@ -2632,7 +2625,7 @@ impl Snapshot { options: ChildEntriesOptions, ) -> ChildEntriesIter<'a> { let mut cursor = self.entries_by_path.cursor(&()); - cursor.seek(&TraversalTarget::path(parent_path), Bias::Right, &()); + cursor.seek(&TraversalTarget::path(parent_path), Bias::Right); let traversal = Traversal { snapshot: self, cursor, @@ -3056,9 +3049,9 @@ impl BackgroundScannerState { .snapshot .entries_by_path .cursor::<TraversalProgress>(&()); - new_entries = cursor.slice(&TraversalTarget::path(path), Bias::Left, &()); - removed_entries = cursor.slice(&TraversalTarget::successor(path), Bias::Left, &()); - new_entries.append(cursor.suffix(&()), &()); + new_entries = cursor.slice(&TraversalTarget::path(path), Bias::Left); + removed_entries = cursor.slice(&TraversalTarget::successor(path), Bias::Left); + new_entries.append(cursor.suffix(), &()); } self.snapshot.entries_by_path = new_entries; @@ -3573,10 +3566,15 @@ impl<'a> sum_tree::Dimension<'a, PathSummary<GitSummary>> for GitSummary { } } -impl<'a> sum_tree::SeekTarget<'a, PathSummary<GitSummary>, (TraversalProgress<'a>, GitSummary)> +impl<'a> + sum_tree::SeekTarget<'a, PathSummary<GitSummary>, Dimensions<TraversalProgress<'a>, GitSummary>> for PathTarget<'_> { - fn cmp(&self, cursor_location: &(TraversalProgress<'a>, GitSummary), _: &()) -> Ordering { + fn cmp( + &self, + cursor_location: &Dimensions<TraversalProgress<'a>, GitSummary>, + _: &(), + ) -> Ordering { self.cmp_path(&cursor_location.0.max_path) } } @@ -4925,15 +4923,15 @@ fn build_diff( let mut old_paths = old_snapshot.entries_by_path.cursor::<PathKey>(&()); let mut new_paths = new_snapshot.entries_by_path.cursor::<PathKey>(&()); let mut last_newly_loaded_dir_path = None; - old_paths.next(&()); - new_paths.next(&()); + old_paths.next(); + new_paths.next(); for path in event_paths { let path = PathKey(path.clone()); if old_paths.item().map_or(false, |e| e.path < path.0) { - old_paths.seek_forward(&path, Bias::Left, &()); + old_paths.seek_forward(&path, Bias::Left); } if new_paths.item().map_or(false, |e| e.path < path.0) { - new_paths.seek_forward(&path, Bias::Left, &()); + new_paths.seek_forward(&path, Bias::Left); } loop { match (old_paths.item(), new_paths.item()) { @@ -4949,7 +4947,7 @@ fn build_diff( match Ord::cmp(&old_entry.path, &new_entry.path) { Ordering::Less => { changes.push((old_entry.path.clone(), old_entry.id, Removed)); - old_paths.next(&()); + old_paths.next(); } Ordering::Equal => { if phase == EventsReceivedDuringInitialScan { @@ -4975,8 +4973,8 @@ fn build_diff( changes.push((new_entry.path.clone(), new_entry.id, Updated)); } } - old_paths.next(&()); - new_paths.next(&()); + old_paths.next(); + new_paths.next(); } Ordering::Greater => { let is_newly_loaded = phase == InitialScan @@ -4988,13 +4986,13 @@ fn build_diff( new_entry.id, if is_newly_loaded { Loaded } else { Added }, )); - new_paths.next(&()); + new_paths.next(); } } } (Some(old_entry), None) => { changes.push((old_entry.path.clone(), old_entry.id, Removed)); - old_paths.next(&()); + old_paths.next(); } (None, Some(new_entry)) => { let is_newly_loaded = phase == InitialScan @@ -5006,7 +5004,7 @@ fn build_diff( new_entry.id, if is_newly_loaded { Loaded } else { Added }, )); - new_paths.next(&()); + new_paths.next(); } (None, None) => break, } @@ -5255,7 +5253,7 @@ impl<'a> Traversal<'a> { start_path: &Path, ) -> Self { let mut cursor = snapshot.entries_by_path.cursor(&()); - cursor.seek(&TraversalTarget::path(start_path), Bias::Left, &()); + cursor.seek(&TraversalTarget::path(start_path), Bias::Left); let mut traversal = Self { snapshot, cursor, @@ -5282,14 +5280,13 @@ impl<'a> Traversal<'a> { include_ignored: self.include_ignored, }, Bias::Left, - &(), ) } pub fn advance_to_sibling(&mut self) -> bool { while let Some(entry) = self.cursor.item() { self.cursor - .seek_forward(&TraversalTarget::successor(&entry.path), Bias::Left, &()); + .seek_forward(&TraversalTarget::successor(&entry.path), Bias::Left); if let Some(entry) = self.cursor.item() { if (self.include_files || !entry.is_file()) && (self.include_dirs || !entry.is_dir()) @@ -5307,7 +5304,7 @@ impl<'a> Traversal<'a> { return false; }; self.cursor - .seek(&TraversalTarget::path(parent_path), Bias::Left, &()) + .seek(&TraversalTarget::path(parent_path), Bias::Left) } pub fn entry(&self) -> Option<&'a Entry> { @@ -5326,7 +5323,7 @@ impl<'a> Traversal<'a> { pub fn end_offset(&self) -> usize { self.cursor - .end(&()) + .end() .count(self.include_files, self.include_dirs, self.include_ignored) } } @@ -5419,7 +5416,7 @@ impl<'a> SeekTarget<'a, EntrySummary, TraversalProgress<'a>> for TraversalTarget } } -impl<'a> SeekTarget<'a, PathSummary<Unit>, TraversalProgress<'a>> for TraversalTarget<'_> { +impl<'a> SeekTarget<'a, PathSummary<&'static ()>, TraversalProgress<'a>> for TraversalTarget<'_> { fn cmp(&self, cursor_location: &TraversalProgress<'a>, _: &()) -> Ordering { self.cmp_progress(cursor_location) } diff --git a/crates/x_ai/Cargo.toml b/crates/x_ai/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..7ca0ca09397111404a59dff85d1ccf0659c0ea45 --- /dev/null +++ b/crates/x_ai/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "x_ai" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/x_ai.rs" + +[features] +default = [] +schemars = ["dep:schemars"] + +[dependencies] +anyhow.workspace = true +schemars = { workspace = true, optional = true } +serde.workspace = true +strum.workspace = true +workspace-hack.workspace = true diff --git a/crates/x_ai/LICENSE-GPL b/crates/x_ai/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/x_ai/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/x_ai/src/x_ai.rs b/crates/x_ai/src/x_ai.rs new file mode 100644 index 0000000000000000000000000000000000000000..ac116b2f8f610614b4d1efd380169739bbdbc9f2 --- /dev/null +++ b/crates/x_ai/src/x_ai.rs @@ -0,0 +1,126 @@ +use anyhow::Result; +use serde::{Deserialize, Serialize}; +use strum::EnumIter; + +pub const XAI_API_URL: &str = "https://api.x.ai/v1"; + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)] +pub enum Model { + #[serde(rename = "grok-2-vision-latest")] + Grok2Vision, + #[default] + #[serde(rename = "grok-3-latest")] + Grok3, + #[serde(rename = "grok-3-mini-latest")] + Grok3Mini, + #[serde(rename = "grok-3-fast-latest")] + Grok3Fast, + #[serde(rename = "grok-3-mini-fast-latest")] + Grok3MiniFast, + #[serde(rename = "grok-4-latest")] + Grok4, + #[serde(rename = "custom")] + Custom { + name: String, + /// The name displayed in the UI, such as in the assistant panel model dropdown menu. + display_name: Option<String>, + max_tokens: u64, + max_output_tokens: Option<u64>, + max_completion_tokens: Option<u64>, + }, +} + +impl Model { + pub fn default_fast() -> Self { + Self::Grok3Fast + } + + pub fn from_id(id: &str) -> Result<Self> { + match id { + "grok-2-vision" => Ok(Self::Grok2Vision), + "grok-3" => Ok(Self::Grok3), + "grok-3-mini" => Ok(Self::Grok3Mini), + "grok-3-fast" => Ok(Self::Grok3Fast), + "grok-3-mini-fast" => Ok(Self::Grok3MiniFast), + _ => anyhow::bail!("invalid model id '{id}'"), + } + } + + pub fn id(&self) -> &str { + match self { + Self::Grok2Vision => "grok-2-vision", + Self::Grok3 => "grok-3", + Self::Grok3Mini => "grok-3-mini", + Self::Grok3Fast => "grok-3-fast", + Self::Grok3MiniFast => "grok-3-mini-fast", + Self::Grok4 => "grok-4", + Self::Custom { name, .. } => name, + } + } + + pub fn display_name(&self) -> &str { + match self { + Self::Grok2Vision => "Grok 2 Vision", + Self::Grok3 => "Grok 3", + Self::Grok3Mini => "Grok 3 Mini", + Self::Grok3Fast => "Grok 3 Fast", + Self::Grok3MiniFast => "Grok 3 Mini Fast", + Self::Grok4 => "Grok 4", + Self::Custom { + name, display_name, .. + } => display_name.as_ref().unwrap_or(name), + } + } + + pub fn max_token_count(&self) -> u64 { + match self { + Self::Grok3 | Self::Grok3Mini | Self::Grok3Fast | Self::Grok3MiniFast => 131_072, + Self::Grok4 => 256_000, + Self::Grok2Vision => 8_192, + Self::Custom { max_tokens, .. } => *max_tokens, + } + } + + pub fn max_output_tokens(&self) -> Option<u64> { + match self { + Self::Grok3 | Self::Grok3Mini | Self::Grok3Fast | Self::Grok3MiniFast => Some(8_192), + Self::Grok4 => Some(64_000), + Self::Grok2Vision => Some(4_096), + Self::Custom { + max_output_tokens, .. + } => *max_output_tokens, + } + } + + pub fn supports_parallel_tool_calls(&self) -> bool { + match self { + Self::Grok2Vision + | Self::Grok3 + | Self::Grok3Mini + | Self::Grok3Fast + | Self::Grok3MiniFast + | Self::Grok4 => true, + Model::Custom { .. } => false, + } + } + + pub fn supports_tool(&self) -> bool { + match self { + Self::Grok2Vision + | Self::Grok3 + | Self::Grok3Mini + | Self::Grok3Fast + | Self::Grok3MiniFast + | Self::Grok4 => true, + Model::Custom { .. } => false, + } + } + + pub fn supports_images(&self) -> bool { + match self { + Self::Grok2Vision => true, + _ => false, + } + } +} diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index 3af1709b74af3d65539d80d4e39a9978a8da86d5..5997e43864bca1acfb8ae056165df011e58fa31b 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -2,7 +2,7 @@ description = "The fast, collaborative code editor." edition.workspace = true name = "zed" -version = "0.196.0" +version = "0.200.0" publish.workspace = true license = "GPL-3.0-or-later" authors = ["Zed Team <hi@zed.dev>"] @@ -45,6 +45,7 @@ collections.workspace = true command_palette.workspace = true component.workspace = true copilot.workspace = true +crashes.workspace = true dap_adapters.workspace = true db.workspace = true debug_adapter_extension.workspace = true @@ -56,6 +57,7 @@ env_logger.workspace = true extension.workspace = true extension_host.workspace = true extensions_ui.workspace = true +feature_flags.workspace = true feedback.workspace = true file_finder.workspace = true fs.workspace = true @@ -75,7 +77,7 @@ gpui_tokio.workspace = true http_client.workspace = true image_viewer.workspace = true indoc.workspace = true -inline_completion_button.workspace = true +edit_prediction_button.workspace = true inspector_ui.workspace = true install_cli.workspace = true jj_ui.workspace = true @@ -95,14 +97,17 @@ svg_preview.workspace = true menu.workspace = true migrator.workspace = true mimalloc = { version = "0.1", optional = true } +nc.workspace = true nix = { workspace = true, features = ["pthread", "signal"] } node_runtime.workspace = true notifications.workspace = true +onboarding.workspace = true outline.workspace = true outline_panel.workspace = true parking_lot.workspace = true paths.workspace = true picker.workspace = true +settings_profile_selector.workspace = true profiling.workspace = true project.workspace = true project_panel.workspace = true @@ -113,6 +118,7 @@ recent_projects.workspace = true release_channel.workspace = true remote.workspace = true repl.workspace = true +reqwest.workspace = true reqwest_client.workspace = true rope.workspace = true search.workspace = true diff --git a/crates/zed/build.rs b/crates/zed/build.rs index 0cfb3eba9fbecff024697bfffb542ae1b8b8d829..eb18617adde491908e690495917fd55974635642 100644 --- a/crates/zed/build.rs +++ b/crates/zed/build.rs @@ -50,12 +50,12 @@ fn main() { println!("cargo:rustc-link-arg=/stack:{}", 8 * 1024 * 1024); } - let release_channel = option_env!("RELEASE_CHANNEL").unwrap_or("nightly"); - + let release_channel = option_env!("RELEASE_CHANNEL").unwrap_or("dev"); let icon = match release_channel { "stable" => "resources/windows/app-icon.ico", "preview" => "resources/windows/app-icon-preview.ico", "nightly" => "resources/windows/app-icon-nightly.ico", + "dev" => "resources/windows/app-icon-dev.ico", _ => "resources/windows/app-icon-dev.ico", }; let icon = std::path::Path::new(icon); diff --git a/crates/zed/resources/app-icon-nightly.png b/crates/zed/resources/app-icon-nightly.png index 5f1304a6af8d57bbf7414d175611829989d122da..776cd06b1bca36c74257dafbc4bffebbbc8f55ad 100644 Binary files a/crates/zed/resources/app-icon-nightly.png and b/crates/zed/resources/app-icon-nightly.png differ diff --git a/crates/zed/resources/app-icon-nightly@2x.png b/crates/zed/resources/app-icon-nightly@2x.png index edb416ede489be0733bdbb00c467a1d828173e56..6d781594ac658d32e5fcff01f66543f7f4f70d93 100644 Binary files a/crates/zed/resources/app-icon-nightly@2x.png and b/crates/zed/resources/app-icon-nightly@2x.png differ diff --git a/crates/zed/resources/flatpak/manifest-template.json b/crates/zed/resources/flatpak/manifest-template.json index 1560027e9fefaf7ebd2cbcf3032f03d6123cc5e0..0a14a1c2b09054823147fc445f41e4d8eb134b37 100644 --- a/crates/zed/resources/flatpak/manifest-template.json +++ b/crates/zed/resources/flatpak/manifest-template.json @@ -38,7 +38,7 @@ }, "build-commands": [ "install -Dm644 $ICON_FILE.png /app/share/icons/hicolor/512x512/apps/$APP_ID.png", - "envsubst < zed.desktop.in > zed.desktop && install -Dm644 zed.desktop /app/share/applications/$APP_ID.desktop", + "envsubst < zed.desktop.in > zed.desktop && install -Dm755 zed.desktop /app/share/applications/$APP_ID.desktop", "envsubst < flatpak/zed.metainfo.xml.in > zed.metainfo.xml && install -Dm644 zed.metainfo.xml /app/share/metainfo/$APP_ID.metainfo.xml", "sed -i -e '/@release_info@/{r flatpak/release-info/$CHANNEL' -e 'd}' /app/share/metainfo/$APP_ID.metainfo.xml", "install -Dm755 bin/zed /app/bin/zed", diff --git a/crates/zed/resources/windows/app-icon-dev.ico b/crates/zed/resources/windows/app-icon-dev.ico index de92b6dd3c4a34164a72dec5f2427589e97536a6..1d6367b78853393ab14407748a823898c93e8667 100644 Binary files a/crates/zed/resources/windows/app-icon-dev.ico and b/crates/zed/resources/windows/app-icon-dev.ico differ diff --git a/crates/zed/resources/windows/app-icon-nightly.ico b/crates/zed/resources/windows/app-icon-nightly.ico index 7d5f53f2eec36306e1b423fa3cee0f0642501103..15e06a6e17631ddd9aed52679b328a46ffe482ff 100644 Binary files a/crates/zed/resources/windows/app-icon-nightly.ico and b/crates/zed/resources/windows/app-icon-nightly.ico differ diff --git a/crates/zed/resources/windows/app-icon-preview.ico b/crates/zed/resources/windows/app-icon-preview.ico index ebc0286e79c997a7bc2cf947a3cc362c38e55cd1..5c8601d314c8a1e1e777944d35f11245f4d31113 100644 Binary files a/crates/zed/resources/windows/app-icon-preview.ico and b/crates/zed/resources/windows/app-icon-preview.ico differ diff --git a/crates/zed/resources/windows/app-icon.ico b/crates/zed/resources/windows/app-icon.ico index 321e90fcfa15d8f84c2619b4d12af892ea5cda66..9c5761b9e9d25361ff30d15d08e524c7be93981e 100644 Binary files a/crates/zed/resources/windows/app-icon.ico and b/crates/zed/resources/windows/app-icon.ico differ diff --git a/crates/zed/resources/windows/zed.iss b/crates/zed/resources/windows/zed.iss index 9d104d1f1540acde9f7045ebbe103caf1ae5605a..2e76f35a0b7f1081d2bff988fad01bcb0e8cfd92 100644 --- a/crates/zed/resources/windows/zed.iss +++ b/crates/zed/resources/windows/zed.iss @@ -62,6 +62,7 @@ Source: "{#ResourcesDir}\Zed.exe"; DestDir: "{code:GetInstallDir}"; Flags: ignor Source: "{#ResourcesDir}\bin\*"; DestDir: "{code:GetInstallDir}\bin"; Flags: ignoreversion Source: "{#ResourcesDir}\tools\*"; DestDir: "{app}\tools"; Flags: ignoreversion Source: "{#ResourcesDir}\appx\*"; DestDir: "{app}\appx"; BeforeInstall: RemoveAppxPackage; AfterInstall: AddAppxPackage; Flags: ignoreversion; Check: IsWindows11OrLater +Source: "{#ResourcesDir}\amd_ags_x64.dll"; DestDir: "{app}"; Flags: ignoreversion [Icons] Name: "{group}\{#AppName}"; Filename: "{app}\{#AppExeName}.exe"; AppUserModelID: "{#AppUserId}" @@ -1245,16 +1246,6 @@ Root: HKCU; Subkey: "Software\Classes\zed\DefaultIcon"; ValueType: "string"; Val Root: HKCU; Subkey: "Software\Classes\zed\shell\open\command"; ValueType: "string"; ValueData: """{app}\Zed.exe"" ""%1""" [Code] -function InitializeSetup(): Boolean; -begin - Result := True; - - if not WizardSilent() and IsAdmin() then begin - MsgBox('This User Installer is not meant to be run as an Administrator.', mbError, MB_OK); - Result := False; - end; -end; - function WizardNotSilent(): Boolean; begin Result := not WizardSilent(); diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 59f432faafbef0c607dd75d8f4623cc2be7b959a..e4a14b5d326d63d04abf36a225baff8f680f8413 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -1,6 +1,7 @@ mod reliability; mod zed; +use agent_ui::AgentPanel; use anyhow::{Context as _, Result}; use clap::{Parser, command}; use cli::FORCE_CLI_MODE_ENV_VAR_NAME; @@ -14,7 +15,7 @@ use extension_host::ExtensionStore; use fs::{Fs, RealFs}; use futures::{StreamExt, channel::oneshot, future}; use git::GitHostingProviderRegistry; -use gpui::{App, AppContext as _, Application, AsyncApp, UpdateGlobal as _}; +use gpui::{App, AppContext as _, Application, AsyncApp, Focusable as _, UpdateGlobal as _}; use gpui_tokio::Tokio; use http_client::{Url, read_proxy_from_env}; @@ -41,7 +42,7 @@ use theme::{ ActiveTheme, IconThemeNotFoundError, SystemAppearance, ThemeNotFoundError, ThemeRegistry, ThemeSettings, }; -use util::{ConnectionResult, ResultExt, TryFutureExt, maybe}; +use util::{ResultExt, TryFutureExt, maybe}; use uuid::Uuid; use welcome::{FIRST_OPEN, show_welcome_view}; use workspace::{ @@ -50,11 +51,13 @@ use workspace::{ }; use zed::{ OpenListener, OpenRequest, RawOpenRequest, app_menus, build_window_options, - derive_paths_with_position, handle_cli_connection, handle_keymap_file_changes, - handle_settings_changed, handle_settings_file_changes, initialize_workspace, - inline_completion_registry, open_paths_with_positions, + derive_paths_with_position, edit_prediction_registry, handle_cli_connection, + handle_keymap_file_changes, handle_settings_changed, handle_settings_file_changes, + initialize_workspace, open_paths_with_positions, }; +use crate::zed::OpenRequestKind; + #[cfg(feature = "mimalloc")] #[global_allocator] static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; @@ -169,12 +172,29 @@ pub fn main() { let args = Args::parse(); + // `zed --crash-handler` Makes zed operate in minidump crash handler mode + if let Some(socket) = &args.crash_handler { + crashes::crash_server(socket.as_path()); + return; + } + // `zed --askpass` Makes zed operate in nc/netcat mode for use with askpass if let Some(socket) = &args.askpass { askpass::main(socket); return; } + // `zed --nc` Makes zed operate in nc/netcat mode for use with MCP + if let Some(socket) = &args.nc { + match nc::main(socket) { + Ok(()) => return, + Err(err) => { + eprintln!("Error: {}", err); + process::exit(1); + } + } + } + // `zed --printenv` Outputs environment variables as JSON to stdout if args.printenv { util::shell_env::print_env(); @@ -250,6 +270,9 @@ pub fn main() { let session_id = Uuid::new_v4().to_string(); let session = app.background_executor().block(Session::new()); + app.background_executor() + .spawn(crashes::init(session_id.clone())) + .detach(); reliability::init_panic_hook( app_version, app_commit_sha.clone(), @@ -540,15 +563,12 @@ pub fn main() { supermaven::init(app_state.client.clone(), cx); language_model::init(app_state.client.clone(), cx); language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx); + agent_settings::init(cx); agent_servers::init(cx); web_search::init(cx); web_search_providers::init(app_state.client.clone(), cx); snippet_provider::init(cx); - inline_completion_registry::init( - app_state.client.clone(), - app_state.user_store.clone(), - cx, - ); + edit_prediction_registry::init(app_state.client.clone(), app_state.user_store.clone(), cx); let prompt_builder = PromptBuilder::load(app_state.fs.clone(), stdout_is_a_pty(), cx); agent_ui::init( app_state.fs.clone(), @@ -598,6 +618,7 @@ pub fn main() { language_selector::init(cx); toolchain_selector::init(cx); theme_selector::init(cx); + settings_profile_selector::init(cx); language_tools::init(cx); call::init(app_state.client.clone(), app_state.user_store.clone(), cx); notifications::init(app_state.client.clone(), app_state.user_store.clone(), cx); @@ -608,6 +629,7 @@ pub fn main() { markdown_preview::init(cx); svg_preview::init(cx); welcome::init(cx); + onboarding::init(cx); settings_ui::init(cx); extensions_ui::init(cx); zeta::init(cx); @@ -665,17 +687,9 @@ pub fn main() { cx.spawn({ let client = app_state.client.clone(); - async move |cx| match authenticate(client, &cx).await { - ConnectionResult::Timeout => log::error!("Timeout during initial auth"), - ConnectionResult::ConnectionReset => { - log::error!("Connection reset during initial auth") - } - ConnectionResult::Result(r) => { - r.log_err(); - } - } + async move |cx| authenticate(client, &cx).await }) - .detach(); + .detach_and_log_err(cx); let urls: Vec<_> = args .paths_or_urls @@ -734,15 +748,46 @@ pub fn main() { } fn handle_open_request(request: OpenRequest, app_state: Arc<AppState>, cx: &mut App) { - if let Some(connection) = request.cli_connection { - let app_state = app_state.clone(); - cx.spawn(async move |cx| handle_cli_connection(connection, app_state, cx).await) - .detach(); - return; - } + if let Some(kind) = request.kind { + match kind { + OpenRequestKind::CliConnection(connection) => { + let app_state = app_state.clone(); + cx.spawn(async move |cx| handle_cli_connection(connection, app_state, cx).await) + .detach(); + } + OpenRequestKind::Extension { extension_id } => { + cx.spawn(async move |cx| { + let workspace = + workspace::get_any_active_workspace(app_state, cx.clone()).await?; + workspace.update(cx, |_, window, cx| { + window.dispatch_action( + Box::new(zed_actions::Extensions { + category_filter: None, + id: Some(extension_id), + }), + cx, + ); + }) + }) + .detach_and_log_err(cx); + } + OpenRequestKind::AgentPanel => { + cx.spawn(async move |cx| { + let workspace = + workspace::get_any_active_workspace(app_state, cx.clone()).await?; + workspace.update(cx, |workspace, window, cx| { + if let Some(panel) = workspace.panel::<AgentPanel>(cx) { + panel.focus_handle(cx).focus(window); + } + }) + }) + .detach_and_log_err(cx); + } + OpenRequestKind::DockMenuAction { index } => { + cx.perform_dock_menu_action(index); + } + } - if let Some(action_index) = request.dock_menu_action { - cx.perform_dock_menu_action(action_index); return; } @@ -794,15 +839,7 @@ fn handle_open_request(request: OpenRequest, app_state: Arc<AppState>, cx: &mut let client = app_state.client.clone(); // we continue even if authentication fails as join_channel/ open channel notes will // show a visible error message. - match authenticate(client, &cx).await { - ConnectionResult::Timeout => { - log::error!("Timeout during open request handling") - } - ConnectionResult::ConnectionReset => { - log::error!("Connection reset during open request handling") - } - ConnectionResult::Result(r) => r?, - }; + authenticate(client, &cx).await.log_err(); if let Some(channel_id) = request.join_channel { cx.update(|cx| { @@ -852,18 +889,18 @@ fn handle_open_request(request: OpenRequest, app_state: Arc<AppState>, cx: &mut } } -async fn authenticate(client: Arc<Client>, cx: &AsyncApp) -> ConnectionResult<()> { +async fn authenticate(client: Arc<Client>, cx: &AsyncApp) -> Result<()> { if stdout_is_a_pty() { if client::IMPERSONATE_LOGIN.is_some() { - return client.authenticate_and_connect(false, cx).await; + client.sign_in_with_optional_connect(false, cx).await?; } else if client.has_credentials(cx).await { - return client.authenticate_and_connect(true, cx).await; + client.sign_in_with_optional_connect(true, cx).await?; } } else if client.has_credentials(cx).await { - return client.authenticate_and_connect(true, cx).await; + client.sign_in_with_optional_connect(true, cx).await?; } - ConnectionResult::Result(Ok(())) + Ok(()) } async fn system_id() -> Result<IdType> { @@ -1097,6 +1134,7 @@ fn init_paths() -> HashMap<io::ErrorKind, Vec<&'static Path>> { paths::config_dir(), paths::extensions_dir(), paths::languages_dir(), + paths::debug_adapters_dir(), paths::database_dir(), paths::logs_dir(), paths::temp_dir(), @@ -1151,6 +1189,16 @@ struct Args { #[arg(long, hide = true)] askpass: Option<String>, + /// Used for the MCP Server, to remove the need for netcat as a dependency, + /// by having Zed act like netcat communicating over a Unix socket. + #[arg(long, hide = true)] + nc: Option<String>, + + /// Used for recording minidumps on crashes by having Zed run a separate + /// process communicating over a socket. + #[arg(long, hide = true)] + crash_handler: Option<PathBuf>, + /// Run zed in the foreground, only used on Windows, to match the behavior on macOS. #[arg(long)] #[cfg(target_os = "windows")] @@ -1391,7 +1439,6 @@ fn dump_all_gpui_actions() { documentation: Option<&'static str>, } let mut actions = gpui::generate_list_of_all_registered_actions() - .into_iter() .map(|action| ActionDef { name: action.name, human_name: command_palette::humanize_action_name(action.name), diff --git a/crates/zed/src/reliability.rs b/crates/zed/src/reliability.rs index ccbe57e7b3903e9e5b380ad0c0323be65864397d..53539699cc164173cbbe1bf9d1016d3f7de91a2b 100644 --- a/crates/zed/src/reliability.rs +++ b/crates/zed/src/reliability.rs @@ -2,21 +2,32 @@ use crate::stdout_is_a_pty; use anyhow::{Context as _, Result}; use backtrace::{self, Backtrace}; use chrono::Utc; -use client::{TelemetrySettings, telemetry}; +use client::{ + TelemetrySettings, + telemetry::{self, MINIDUMP_ENDPOINT}, +}; use db::kvp::KEY_VALUE_STORE; +use futures::AsyncReadExt; use gpui::{App, AppContext as _, SemanticVersion}; use http_client::{self, HttpClient, HttpClientWithUrl, HttpRequestExt, Method}; use paths::{crashes_dir, crashes_retired_dir}; use project::Project; use release_channel::{AppCommitSha, RELEASE_CHANNEL, ReleaseChannel}; +use reqwest::multipart::{Form, Part}; use settings::Settings; use smol::stream::StreamExt; use std::{ env, ffi::{OsStr, c_void}, - sync::{Arc, atomic::Ordering}, + fs, + io::Write, + panic, + sync::{ + Arc, + atomic::{AtomicU32, Ordering}, + }, + thread, }; -use std::{io::Write, panic, sync::atomic::AtomicU32, thread}; use telemetry_events::{LocationData, Panic, PanicRequest}; use url::Url; use util::ResultExt; @@ -37,9 +48,10 @@ pub fn init_panic_hook( if prior_panic_count > 0 { // Give the panic-ing thread time to write the panic file loop { - std::thread::yield_now(); + thread::yield_now(); } } + crashes::handle_panic(); let thread = thread::current(); let thread_name = thread.name().unwrap_or("<unnamed>"); @@ -63,7 +75,7 @@ pub fn init_panic_hook( location.column(), match app_commit_sha.as_ref() { Some(commit_sha) => format!( - "https://github.com/zed-industries/zed/blob/{}/src/{}#L{} \ + "https://github.com/zed-industries/zed/blob/{}/{}#L{} \ (may not be uploaded, line may be incorrect if files modified)\n", commit_sha.full(), location.file(), @@ -136,9 +148,9 @@ pub fn init_panic_hook( if let Some(panic_data_json) = serde_json::to_string(&panic_data).log_err() { let timestamp = chrono::Utc::now().format("%Y_%m_%d %H_%M_%S").to_string(); let panic_file_path = paths::logs_dir().join(format!("zed-{timestamp}.panic")); - let panic_file = std::fs::OpenOptions::new() - .append(true) - .create(true) + let panic_file = fs::OpenOptions::new() + .write(true) + .create_new(true) .open(&panic_file_path) .log_err(); if let Some(mut panic_file) = panic_file { @@ -205,27 +217,31 @@ pub fn init( if let Some(ssh_client) = project.ssh_client() { ssh_client.update(cx, |client, cx| { if TelemetrySettings::get_global(cx).diagnostics { - let request = client.proto_client().request(proto::GetPanicFiles {}); + let request = client.proto_client().request(proto::GetCrashFiles {}); cx.background_spawn(async move { - let panic_files = request.await?; - for file in panic_files.file_contents { - let panic: Option<Panic> = serde_json::from_str(&file) - .log_err() - .or_else(|| { - file.lines() - .next() - .and_then(|line| serde_json::from_str(line).ok()) - }) - .unwrap_or_else(|| { - log::error!("failed to deserialize panic file {:?}", file); - None - }); - - if let Some(mut panic) = panic { + let crash_files = request.await?; + for crash in crash_files.crashes { + let mut panic: Option<Panic> = crash + .panic_contents + .and_then(|s| serde_json::from_str(&s).log_err()); + + if let Some(panic) = panic.as_mut() { panic.session_id = session_id.clone(); panic.system_id = system_id.clone(); panic.installation_id = installation_id.clone(); + } + + if let Some(minidump) = crash.minidump_contents { + upload_minidump( + http_client.clone(), + minidump.clone(), + panic.as_ref(), + ) + .await + .log_err(); + } + if let Some(panic) = panic { upload_panic(&http_client, &panic_report_url, panic, &mut None) .await?; } @@ -510,6 +526,22 @@ async fn upload_previous_panics( }); if let Some(panic) = panic { + let minidump_path = paths::logs_dir() + .join(&panic.session_id) + .with_extension("dmp"); + if minidump_path.exists() { + let minidump = smol::fs::read(&minidump_path) + .await + .context("Failed to read minidump")?; + if upload_minidump(http.clone(), minidump, Some(&panic)) + .await + .log_err() + .is_some() + { + fs::remove_file(minidump_path).ok(); + } + } + if !upload_panic(&http, &panic_report_url, panic, &mut most_recent_panic).await? { continue; } @@ -517,13 +549,80 @@ async fn upload_previous_panics( } // We've done what we can, delete the file - std::fs::remove_file(child_path) + fs::remove_file(child_path) .context("error removing panic") .log_err(); } + + if MINIDUMP_ENDPOINT.is_none() { + return Ok(most_recent_panic); + } + + // loop back over the directory again to upload any minidumps that are missing panics + let mut children = smol::fs::read_dir(paths::logs_dir()).await?; + while let Some(child) = children.next().await { + let child = child?; + let child_path = child.path(); + if child_path.extension() != Some(OsStr::new("dmp")) { + continue; + } + if upload_minidump( + http.clone(), + smol::fs::read(&child_path) + .await + .context("Failed to read minidump")?, + None, + ) + .await + .log_err() + .is_some() + { + fs::remove_file(child_path).ok(); + } + } + Ok(most_recent_panic) } +async fn upload_minidump( + http: Arc<HttpClientWithUrl>, + minidump: Vec<u8>, + panic: Option<&Panic>, +) -> Result<()> { + let minidump_endpoint = MINIDUMP_ENDPOINT + .to_owned() + .ok_or_else(|| anyhow::anyhow!("Minidump endpoint not set"))?; + + let mut form = Form::new() + .part( + "upload_file_minidump", + Part::bytes(minidump) + .file_name("minidump.dmp") + .mime_str("application/octet-stream")?, + ) + .text("platform", "rust"); + if let Some(panic) = panic { + form = form + .text( + "sentry[release]", + format!("{}-{}", panic.release_channel, panic.app_version), + ) + .text("sentry[logentry][formatted]", panic.payload.clone()); + } + + let mut response_text = String::new(); + let mut response = http.send_multipart_form(&minidump_endpoint, form).await?; + response + .body_mut() + .read_to_string(&mut response_text) + .await?; + if !response.status().is_success() { + anyhow::bail!("failed to upload minidump: {response_text}"); + } + log::info!("Uploaded minidump. event id: {response_text}"); + Ok(()) +} + async fn upload_panic( http: &Arc<HttpClientWithUrl>, panic_report_url: &Url, diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index dc094a6c12fb1ba11642cc988f5d06d2cce01078..8c89a7d85a4a7be4c3fd9d905deaecde6670260d 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -1,6 +1,6 @@ mod app_menus; pub mod component_preview; -pub mod inline_completion_registry; +pub mod edit_prediction_registry; #[cfg(target_os = "macos")] pub(crate) mod mac_only_instance; mod migrate; @@ -19,6 +19,7 @@ use collections::VecDeque; use debugger_ui::debugger_panel::DebugPanel; use editor::ProposedChangesEditorToolbar; use editor::{Editor, MultiBuffer}; +use feature_flags::{FeatureFlagAppExt, PanicFeatureFlag}; use futures::future::Either; use futures::{StreamExt, channel::mpsc, select_biased}; use git_ui::git_panel::GitPanel; @@ -53,9 +54,12 @@ use settings::{ initial_local_debug_tasks_content, initial_project_settings_content, initial_tasks_content, update_settings_file, }; -use std::path::PathBuf; -use std::sync::atomic::{self, AtomicBool}; -use std::{borrow::Cow, path::Path, sync::Arc}; +use std::{ + borrow::Cow, + path::{Path, PathBuf}, + sync::Arc, + sync::atomic::{self, AtomicBool}, +}; use terminal_view::terminal_panel::{self, TerminalPanel}; use theme::{ActiveTheme, ThemeSettings}; use ui::{PopoverMenuHandle, prelude::*}; @@ -107,6 +111,8 @@ actions!( Zoom, /// Triggers a test panic for debugging. TestPanic, + /// Triggers a hard crash for debugging. + TestCrash, ] ); @@ -120,11 +126,28 @@ pub fn init(cx: &mut App) { cx.on_action(quit); cx.on_action(|_: &RestoreBanner, cx| title_bar::restore_banner(cx)); - - if ReleaseChannel::global(cx) == ReleaseChannel::Dev { - cx.on_action(test_panic); - } - + let flag = cx.wait_for_flag::<PanicFeatureFlag>(); + cx.spawn(async |cx| { + if cx + .update(|cx| ReleaseChannel::global(cx) == ReleaseChannel::Dev) + .unwrap_or_default() + || flag.await + { + cx.update(|cx| { + cx.on_action(|_: &TestPanic, _| panic!("Ran the TestPanic action")); + cx.on_action(|_: &TestCrash, _| { + unsafe extern "C" { + fn puts(s: *const i8); + } + unsafe { + puts(0xabad1d3a as *const i8); + } + }); + }) + .ok(); + }; + }) + .detach(); cx.on_action(|_: &OpenLog, cx| { with_active_or_new_workspace(cx, |workspace, window, cx| { open_log_file(workspace, window, cx); @@ -309,18 +332,18 @@ pub fn initialize_workspace( show_software_emulation_warning_if_needed(specs, window, cx); } - let inline_completion_menu_handle = PopoverMenuHandle::default(); + let edit_prediction_menu_handle = PopoverMenuHandle::default(); let edit_prediction_button = cx.new(|cx| { - inline_completion_button::InlineCompletionButton::new( + edit_prediction_button::EditPredictionButton::new( app_state.fs.clone(), app_state.user_store.clone(), - inline_completion_menu_handle.clone(), + edit_prediction_menu_handle.clone(), cx, ) }); workspace.register_action({ - move |_, _: &inline_completion_button::ToggleMenu, window, cx| { - inline_completion_menu_handle.toggle(window, cx); + move |_, _: &edit_prediction_button::ToggleMenu, window, cx| { + edit_prediction_menu_handle.toggle(window, cx); } }); @@ -987,10 +1010,6 @@ fn about( .detach(); } -fn test_panic(_: &TestPanic, _: &mut App) { - panic!("Ran the TestPanic action") -} - fn install_cli( _: &mut Workspace, _: &install_cli::Install, @@ -3957,6 +3976,7 @@ mod tests { language::init(cx); workspace::init(app_state.clone(), cx); welcome::init(cx); + onboarding::init(cx); Project::init_settings(cx); app_state }) @@ -4327,12 +4347,14 @@ mod tests { "jj", "journal", "keymap_editor", + "keystroke_input", "language_selector", "lsp_tool", "markdown", "menu", "notebook", "notification_panel", + "onboarding", "outline", "outline_panel", "pane", @@ -4345,6 +4367,7 @@ mod tests { "repl", "rules_library", "search", + "settings_profile_selector", "snippets", "supermaven", "svg", @@ -4416,7 +4439,7 @@ mod tests { }); for name in languages.language_names() { languages - .language_for_name(&name) + .language_for_name(name.as_ref()) .await .with_context(|| format!("language name {name}")) .unwrap(); diff --git a/crates/zed/src/zed/app_menus.rs b/crates/zed/src/zed/app_menus.rs index dac9f6495b83cca08257fb551e8f87667ea483de..15d5659f03dd2d42840c50001fcba061705e506e 100644 --- a/crates/zed/src/zed/app_menus.rs +++ b/crates/zed/src/zed/app_menus.rs @@ -1,5 +1,6 @@ use collab_ui::collab_panel; use gpui::{Menu, MenuItem, OsAction}; +use settings_ui::keybindings; use terminal_view::terminal_panel; pub fn app_menus() -> Vec<Menu> { @@ -16,13 +17,17 @@ pub fn app_menus() -> Vec<Menu> { name: "Settings".into(), items: vec![ MenuItem::action("Open Settings", super::OpenSettings), - MenuItem::action("Open Key Bindings", zed_actions::OpenKeymap), + MenuItem::action("Open Key Bindings", keybindings::OpenKeymapEditor), MenuItem::action("Open Default Settings", super::OpenDefaultSettings), MenuItem::action( "Open Default Key Bindings", zed_actions::OpenDefaultKeymap, ), MenuItem::action("Open Project Settings", super::OpenProjectSettings), + MenuItem::action( + "Select Settings Profile...", + zed_actions::settings_profile_selector::Toggle, + ), MenuItem::action( "Select Theme...", zed_actions::theme_selector::Toggle::default(), @@ -144,15 +149,15 @@ pub fn app_menus() -> Vec<Menu> { items: vec![ MenuItem::action( "Zoom In", - zed_actions::IncreaseBufferFontSize { persist: true }, + zed_actions::IncreaseBufferFontSize { persist: false }, ), MenuItem::action( "Zoom Out", - zed_actions::DecreaseBufferFontSize { persist: true }, + zed_actions::DecreaseBufferFontSize { persist: false }, ), MenuItem::action( "Reset Zoom", - zed_actions::ResetBufferFontSize { persist: true }, + zed_actions::ResetBufferFontSize { persist: false }, ), MenuItem::separator(), MenuItem::action("Toggle Left Dock", workspace::ToggleLeftDock), @@ -199,8 +204,11 @@ pub fn app_menus() -> Vec<Menu> { MenuItem::action("Go to Type Definition", editor::actions::GoToTypeDefinition), MenuItem::action("Find All References", editor::actions::FindAllReferences), MenuItem::separator(), - MenuItem::action("Next Problem", editor::actions::GoToDiagnostic), - MenuItem::action("Previous Problem", editor::actions::GoToPreviousDiagnostic), + MenuItem::action("Next Problem", editor::actions::GoToDiagnostic::default()), + MenuItem::action( + "Previous Problem", + editor::actions::GoToPreviousDiagnostic::default(), + ), ], }, Menu { diff --git a/crates/zed/src/zed/component_preview.rs b/crates/zed/src/zed/component_preview.rs index c32248cbe00f08d5982dbf394b806f3226c814ae..db75b544f611589377031b22aff5682875bb62dd 100644 --- a/crates/zed/src/zed/component_preview.rs +++ b/crates/zed/src/zed/component_preview.rs @@ -105,7 +105,9 @@ enum PreviewPage { struct ComponentPreview { active_page: PreviewPage, active_thread: Option<Entity<ActiveThread>>, + reset_key: usize, component_list: ListState, + entries: Vec<PreviewEntry>, component_map: HashMap<ComponentId, ComponentMetadata>, components: Vec<ComponentMetadata>, cursor_index: usize, @@ -138,8 +140,7 @@ impl ComponentPreview { let project_clone = project.clone(); cx.spawn_in(window, async move |entity, cx| { - let thread_store_future = - load_preview_thread_store(workspace_clone.clone(), project_clone.clone(), cx); + let thread_store_future = load_preview_thread_store(project_clone.clone(), cx); let text_thread_store_future = load_preview_text_thread_store(workspace_clone.clone(), project_clone.clone(), cx); @@ -172,23 +173,14 @@ impl ComponentPreview { sorted_components.len(), gpui::ListAlignment::Top, px(1500.0), - { - let this = cx.entity().downgrade(); - move |ix, window: &mut Window, cx: &mut App| { - this.update(cx, |this, cx| { - let component = this.get_component(ix); - this.render_preview(&component, window, cx) - .into_any_element() - }) - .unwrap() - } - }, ); let mut component_preview = Self { active_page, active_thread: None, + reset_key: 0, component_list, + entries: Vec::new(), component_map: component_registry.component_map(), components: sorted_components, cursor_index: selected_index, @@ -265,15 +257,16 @@ impl ComponentPreview { } fn set_active_page(&mut self, page: PreviewPage, cx: &mut Context<Self>) { - self.active_page = page; - cx.emit(ItemEvent::UpdateTab); + if self.active_page == page { + // Force the current preview page to render again + self.reset_key = self.reset_key.wrapping_add(1); + } else { + self.active_page = page; + cx.emit(ItemEvent::UpdateTab); + } cx.notify(); } - fn get_component(&self, ix: usize) -> ComponentMetadata { - self.components[ix].clone() - } - fn filtered_components(&self) -> Vec<ComponentMetadata> { if self.filter_text.is_empty() { return self.components.clone(); @@ -369,7 +362,6 @@ impl ComponentPreview { // Always show all components first entries.push(PreviewEntry::AllComponents); entries.push(PreviewEntry::ActiveThread); - entries.push(PreviewEntry::Separator); let mut scopes: Vec<_> = scope_groups .keys() @@ -382,7 +374,9 @@ impl ComponentPreview { for scope in scopes { if let Some(components) = scope_groups.remove(&scope) { if !components.is_empty() { + entries.push(PreviewEntry::Separator); entries.push(PreviewEntry::SectionHeader(scope.to_string().into())); + let mut sorted_components = components; sorted_components.sort_by_key(|(component, _)| component.sort_name()); @@ -413,7 +407,6 @@ impl ComponentPreview { fn update_component_list(&mut self, cx: &mut Context<Self>) { let entries = self.scope_ordered_entries(); let new_len = entries.len(); - let weak_entity = cx.entity().downgrade(); if new_len > 0 { self.nav_scroll_handle @@ -439,56 +432,9 @@ impl ComponentPreview { } } - self.component_list = ListState::new( - filtered_components.len(), - gpui::ListAlignment::Top, - px(1500.0), - { - let components = filtered_components.clone(); - let this = cx.entity().downgrade(); - move |ix, window: &mut Window, cx: &mut App| { - if ix >= components.len() { - return div().w_full().h_0().into_any_element(); - } - - this.update(cx, |this, cx| { - let component = &components[ix]; - this.render_preview(component, window, cx) - .into_any_element() - }) - .unwrap() - } - }, - ); + self.component_list = ListState::new(new_len, gpui::ListAlignment::Top, px(1500.0)); + self.entries = entries; - let new_list = ListState::new( - new_len, - gpui::ListAlignment::Top, - px(1500.0), - move |ix, window, cx| { - if ix >= entries.len() { - return div().w_full().h_0().into_any_element(); - } - - let entry = &entries[ix]; - - weak_entity - .update(cx, |this, cx| match entry { - PreviewEntry::Component(component, _) => this - .render_preview(component, window, cx) - .into_any_element(), - PreviewEntry::SectionHeader(shared_string) => this - .render_scope_header(ix, shared_string.clone(), window, cx) - .into_any_element(), - PreviewEntry::AllComponents => div().w_full().h_0().into_any_element(), - PreviewEntry::ActiveThread => div().w_full().h_0().into_any_element(), - PreviewEntry::Separator => div().w_full().h_0().into_any_element(), - }) - .unwrap() - }, - ); - - self.component_list = new_list; cx.emit(ItemEvent::UpdateTab); } @@ -515,16 +461,12 @@ impl ComponentPreview { Vec::new() }; if valid_positions.is_empty() { - Label::new(name.clone()) - .color(Color::Default) - .into_any_element() + Label::new(name.clone()).into_any_element() } else { HighlightedLabel::new(name.clone(), valid_positions).into_any_element() } } else { - Label::new(name.clone()) - .color(Color::Default) - .into_any_element() + Label::new(name.clone()).into_any_element() }) .selectable(true) .toggle_state(selected) @@ -542,7 +484,7 @@ impl ComponentPreview { let selected = self.active_page == PreviewPage::AllComponents; ListItem::new(ix) - .child(Label::new("All Components").color(Color::Default)) + .child(Label::new("All Components")) .selectable(true) .toggle_state(selected) .inset(true) @@ -555,7 +497,7 @@ impl ComponentPreview { let selected = self.active_page == PreviewPage::ActiveThread; ListItem::new(ix) - .child(Label::new("Active Thread").color(Color::Default)) + .child(Label::new("Active Thread")) .selectable(true) .toggle_state(selected) .inset(true) @@ -565,12 +507,8 @@ impl ComponentPreview { .into_any_element() } PreviewEntry::Separator => ListItem::new(ix) - .child( - h_flex() - .occlude() - .pt_3() - .child(Divider::horizontal_dashed()), - ) + .disabled(true) + .child(div().w_full().py_2().child(Divider::horizontal())) .into_any_element(), } } @@ -585,7 +523,6 @@ impl ComponentPreview { h_flex() .w_full() .h_10() - .items_center() .child(Headline::new(title).size(HeadlineSize::XSmall)) .child(Divider::horizontal()) } @@ -674,10 +611,35 @@ impl ComponentPreview { .child(format!("No components matching '{}'.", self.filter_text)) .into_any_element() } else { - list(self.component_list.clone()) - .flex_grow() - .with_sizing_behavior(gpui::ListSizingBehavior::Auto) - .into_any_element() + list( + self.component_list.clone(), + cx.processor(|this, ix, window, cx| { + if ix >= this.entries.len() { + return div().w_full().h_0().into_any_element(); + } + + let entry = &this.entries[ix]; + + match entry { + PreviewEntry::Component(component, _) => this + .render_preview(component, window, cx) + .into_any_element(), + PreviewEntry::SectionHeader(shared_string) => this + .render_scope_header(ix, shared_string.clone(), window, cx) + .into_any_element(), + PreviewEntry::AllComponents => { + div().w_full().h_0().into_any_element() + } + PreviewEntry::ActiveThread => { + div().w_full().h_0().into_any_element() + } + PreviewEntry::Separator => div().w_full().h_0().into_any_element(), + } + }), + ) + .flex_grow() + .with_sizing_behavior(gpui::ListSizingBehavior::Auto) + .into_any_element() }, ) } @@ -698,6 +660,7 @@ impl ComponentPreview { component.clone(), self.workspace.clone(), self.active_thread.clone(), + self.reset_key, )) .into_any_element() } else { @@ -798,7 +761,7 @@ impl Render for ComponentPreview { ) .track_scroll(self.nav_scroll_handle.clone()) .p_2p5() - .w(px(240.)) + .w(px(229.)) .h_full() .flex_1(), ) @@ -1049,6 +1012,7 @@ pub struct ComponentPreviewPage { component: ComponentMetadata, workspace: WeakEntity<Workspace>, active_thread: Option<Entity<ActiveThread>>, + reset_key: usize, } impl ComponentPreviewPage { @@ -1056,6 +1020,7 @@ impl ComponentPreviewPage { component: ComponentMetadata, workspace: WeakEntity<Workspace>, active_thread: Option<Entity<ActiveThread>>, + reset_key: usize, // languages: Arc<LanguageRegistry> ) -> Self { Self { @@ -1063,6 +1028,7 @@ impl ComponentPreviewPage { component, workspace, active_thread, + reset_key, } } @@ -1163,6 +1129,7 @@ impl ComponentPreviewPage { }; v_flex() + .id(("component-preview", self.reset_key)) .size_full() .flex_1() .px_12() diff --git a/crates/zed/src/zed/component_preview/preview_support/active_thread.rs b/crates/zed/src/zed/component_preview/preview_support/active_thread.rs index 825744572d6cc2546b23ef608841ff3bda6610e4..de98106faeb5b39dfeee5903db55be7fa61f0f3c 100644 --- a/crates/zed/src/zed/component_preview/preview_support/active_thread.rs +++ b/crates/zed/src/zed/component_preview/preview_support/active_thread.rs @@ -12,21 +12,19 @@ use ui::{App, Window}; use workspace::Workspace; pub fn load_preview_thread_store( - workspace: WeakEntity<Workspace>, project: Entity<Project>, cx: &mut AsyncApp, ) -> Task<Result<Entity<ThreadStore>>> { - workspace - .update(cx, |_, cx| { - ThreadStore::load( - project.clone(), - cx.new(|_| ToolWorkingSet::default()), - None, - Arc::new(PromptBuilder::new(None).unwrap()), - cx, - ) - }) - .unwrap_or(Task::ready(Err(anyhow!("workspace dropped")))) + cx.update(|cx| { + ThreadStore::load( + project.clone(), + cx.new(|_| ToolWorkingSet::default()), + None, + Arc::new(PromptBuilder::new(None).unwrap()), + cx, + ) + }) + .unwrap_or(Task::ready(Err(anyhow!("workspace dropped")))) } pub fn load_preview_text_thread_store( diff --git a/crates/zed/src/zed/inline_completion_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs similarity index 91% rename from crates/zed/src/zed/inline_completion_registry.rs rename to crates/zed/src/zed/edit_prediction_registry.rs index f2e9d21b96ad54462a27f80e0c9c352a215eacdf..b9f561c0e7884f3d7c8579cd7b65fa1c787b5291 100644 --- a/crates/zed/src/zed/inline_completion_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -11,7 +11,7 @@ use supermaven::{Supermaven, SupermavenCompletionProvider}; use ui::Window; use util::ResultExt; use workspace::Workspace; -use zeta::{ProviderDataCollection, ZetaInlineCompletionProvider}; +use zeta::{ProviderDataCollection, ZetaEditPredictionProvider}; pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) { let editors: Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>> = Rc::default(); @@ -90,10 +90,7 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) { let new_provider = all_language_settings(None, cx).edit_predictions.provider; if new_provider != provider { - let tos_accepted = user_store - .read(cx) - .current_user_has_accepted_terms() - .unwrap_or(false); + let tos_accepted = user_store.read(cx).has_accepted_terms_of_service(); telemetry::event!( "Edit Prediction Provider Changed", @@ -174,7 +171,7 @@ fn register_backward_compatible_actions(editor: &mut Editor, cx: &mut Context<Ed editor .register_action(cx.listener( |editor, _: &copilot::Suggest, window: &mut Window, cx: &mut Context<Editor>| { - editor.show_inline_completion(&Default::default(), window, cx); + editor.show_edit_prediction(&Default::default(), window, cx); }, )) .detach(); @@ -195,16 +192,6 @@ fn register_backward_compatible_actions(editor: &mut Editor, cx: &mut Context<Ed }, )) .detach(); - editor - .register_action(cx.listener( - |editor, - _: &editor::actions::AcceptPartialCopilotSuggestion, - window: &mut Window, - cx: &mut Context<Editor>| { - editor.accept_partial_inline_completion(&Default::default(), window, cx); - }, - )) - .detach(); } fn assign_edit_prediction_provider( @@ -220,7 +207,7 @@ fn assign_edit_prediction_provider( match provider { EditPredictionProvider::None => { - editor.set_edit_prediction_provider::<ZetaInlineCompletionProvider>(None, window, cx); + editor.set_edit_prediction_provider::<ZetaEditPredictionProvider>(None, window, cx); } EditPredictionProvider::Copilot => { if let Some(copilot) = Copilot::global(cx) { @@ -242,7 +229,7 @@ fn assign_edit_prediction_provider( } } EditPredictionProvider::Zed => { - if client.status().borrow().is_connected() { + if user_store.read(cx).current_user().is_some() { let mut worktree = None; if let Some(buffer) = &singleton_buffer { @@ -278,7 +265,7 @@ fn assign_edit_prediction_provider( ProviderDataCollection::new(zeta.clone(), singleton_buffer, cx); let provider = - cx.new(|_| zeta::ZetaInlineCompletionProvider::new(zeta, data_collection)); + cx.new(|_| zeta::ZetaEditPredictionProvider::new(zeta, data_collection)); editor.set_edit_prediction_provider(Some(provider), window, cx); } diff --git a/crates/zed/src/zed/open_listener.rs b/crates/zed/src/zed/open_listener.rs index 0fb08d1be5790557674ee91a08cd35d28ea0b062..2fd9b0a68c7c14fd6df0ba2a52c537e34cdd7ceb 100644 --- a/crates/zed/src/zed/open_listener.rs +++ b/crates/zed/src/zed/open_listener.rs @@ -12,7 +12,7 @@ use futures::channel::mpsc::{UnboundedReceiver, UnboundedSender}; use futures::channel::{mpsc, oneshot}; use futures::future::join_all; use futures::{FutureExt, SinkExt, StreamExt}; -use git_ui::diff_view::DiffView; +use git_ui::file_diff_view::FileDiffView; use gpui::{App, AsyncApp, Global, WindowHandle}; use language::Point; use recent_projects::{SshSettings, open_ssh_project}; @@ -30,13 +30,20 @@ use workspace::{AppState, OpenOptions, SerializedWorkspaceLocation, Workspace}; #[derive(Default, Debug)] pub struct OpenRequest { - pub cli_connection: Option<(mpsc::Receiver<CliRequest>, IpcSender<CliResponse>)>, + pub kind: Option<OpenRequestKind>, pub open_paths: Vec<String>, pub diff_paths: Vec<[String; 2]>, pub open_channel_notes: Vec<(u64, Option<String>)>, pub join_channel: Option<u64>, pub ssh_connection: Option<SshConnectionOptions>, - pub dock_menu_action: Option<usize>, +} + +#[derive(Debug)] +pub enum OpenRequestKind { + CliConnection((mpsc::Receiver<CliRequest>, IpcSender<CliResponse>)), + Extension { extension_id: String }, + AgentPanel, + DockMenuAction { index: usize }, } impl OpenRequest { @@ -44,9 +51,11 @@ impl OpenRequest { let mut this = Self::default(); for url in request.urls { if let Some(server_name) = url.strip_prefix("zed-cli://") { - this.cli_connection = Some(connect_to_cli(server_name)?); + this.kind = Some(OpenRequestKind::CliConnection(connect_to_cli(server_name)?)); } else if let Some(action_index) = url.strip_prefix("zed-dock-action://") { - this.dock_menu_action = Some(action_index.parse()?); + this.kind = Some(OpenRequestKind::DockMenuAction { + index: action_index.parse()?, + }); } else if let Some(file) = url.strip_prefix("file://") { this.parse_file_path(file) } else if let Some(file) = url.strip_prefix("zed://file") { @@ -54,6 +63,12 @@ impl OpenRequest { } else if let Some(file) = url.strip_prefix("zed://ssh") { let ssh_url = "ssh:/".to_string() + file; this.parse_ssh_file_path(&ssh_url, cx)? + } else if let Some(extension_id) = url.strip_prefix("zed://extension/") { + this.kind = Some(OpenRequestKind::Extension { + extension_id: extension_id.to_string(), + }); + } else if url == "zed://agent" { + this.kind = Some(OpenRequestKind::AgentPanel); } else if url.starts_with("ssh://") { this.parse_ssh_file_path(&url, cx)? } else if let Some(request_path) = parse_zed_link(&url, cx) { @@ -247,7 +262,7 @@ pub async fn open_paths_with_positions( let old_path = Path::new(&diff_pair[0]).canonicalize()?; let new_path = Path::new(&diff_pair[1]).canonicalize()?; if let Ok(diff_view) = workspace.update(cx, |workspace, window, cx| { - DiffView::open(old_path, new_path, workspace, window, cx) + FileDiffView::open(old_path, new_path, workspace, window, cx) }) { if let Some(diff_view) = diff_view.await.log_err() { items.push(Some(Ok(Box::new(diff_view)))) diff --git a/crates/zed/src/zed/quick_action_bar.rs b/crates/zed/src/zed/quick_action_bar.rs index 36d446579a1c5cc2da010579a8e78ea3d2ed7076..e76bef59a38004d42cc769e574ecfc4ac4621037 100644 --- a/crates/zed/src/zed/quick_action_bar.rs +++ b/crates/zed/src/zed/quick_action_bar.rs @@ -15,6 +15,7 @@ use gpui::{ FocusHandle, Focusable, InteractiveElement, ParentElement, Render, Styled, Subscription, WeakEntity, Window, anchored, deferred, point, }; +use project::DisableAiSettings; use project::project_settings::DiagnosticSeverity; use search::{BufferSearchBar, buffer_search}; use settings::{Settings, SettingsStore}; @@ -32,6 +33,7 @@ const MAX_CODE_ACTION_MENU_LINES: u32 = 16; pub struct QuickActionBar { _inlay_hints_enabled_subscription: Option<Subscription>, + _ai_settings_subscription: Subscription, active_item: Option<Box<dyn ItemHandle>>, buffer_search_bar: Entity<BufferSearchBar>, show: bool, @@ -46,8 +48,28 @@ impl QuickActionBar { workspace: &Workspace, cx: &mut Context<Self>, ) -> Self { + let mut was_ai_disabled = DisableAiSettings::get_global(cx).disable_ai; + let mut was_agent_enabled = AgentSettings::get_global(cx).enabled; + let mut was_agent_button = AgentSettings::get_global(cx).button; + + let ai_settings_subscription = cx.observe_global::<SettingsStore>(move |_, cx| { + let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai; + let agent_settings = AgentSettings::get_global(cx); + + if was_ai_disabled != is_ai_disabled + || was_agent_enabled != agent_settings.enabled + || was_agent_button != agent_settings.button + { + was_ai_disabled = is_ai_disabled; + was_agent_enabled = agent_settings.enabled; + was_agent_button = agent_settings.button; + cx.notify(); + } + }); + let mut this = Self { _inlay_hints_enabled_subscription: None, + _ai_settings_subscription: ai_settings_subscription, active_item: None, buffer_search_bar, show: true, @@ -170,7 +192,7 @@ impl Render for QuickActionBar { }; v_flex() .child( - IconButton::new("toggle_code_actions_icon", IconName::Bolt) + IconButton::new("toggle_code_actions_icon", IconName::BoltOutlined) .icon_size(IconSize::Small) .style(ButtonStyle::Subtle) .disabled(!has_available_code_actions) @@ -255,8 +277,11 @@ impl Render for QuickActionBar { .action("Go to Symbol", Box::new(ToggleOutline)) .action("Go to Line/Column", Box::new(ToggleGoToLine)) .separator() - .action("Next Problem", Box::new(GoToDiagnostic)) - .action("Previous Problem", Box::new(GoToPreviousDiagnostic)) + .action("Next Problem", Box::new(GoToDiagnostic::default())) + .action( + "Previous Problem", + Box::new(GoToPreviousDiagnostic::default()), + ) .separator() .action_disabled_when(!has_diff_hunks, "Next Hunk", Box::new(GoToHunk)) .action_disabled_when( @@ -356,7 +381,7 @@ impl Render for QuickActionBar { } if has_edit_prediction_provider { - let mut inline_completion_entry = ContextMenuEntry::new("Edit Predictions") + let mut edit_prediction_entry = ContextMenuEntry::new("Edit Predictions") .toggleable(IconPosition::Start, edit_predictions_enabled_at_cursor && show_edit_predictions) .disabled(!edit_predictions_enabled_at_cursor) .action( @@ -376,12 +401,12 @@ impl Render for QuickActionBar { } }); if !edit_predictions_enabled_at_cursor { - inline_completion_entry = inline_completion_entry.documentation_aside(DocumentationSide::Left, |_| { + edit_prediction_entry = edit_prediction_entry.documentation_aside(DocumentationSide::Left, |_| { Label::new("You can't toggle edit predictions for this file as it is within the excluded files list.").into_any_element() }); } - menu = menu.item(inline_completion_entry); + menu = menu.item(edit_prediction_entry); } menu = menu.separator(); @@ -572,7 +597,9 @@ impl Render for QuickActionBar { .children(self.render_preview_button(self.workspace.clone(), cx)) .children(search_button) .when( - AgentSettings::get_global(cx).enabled && AgentSettings::get_global(cx).button, + AgentSettings::get_global(cx).enabled + && AgentSettings::get_global(cx).button + && !DisableAiSettings::get_global(cx).disable_ai, |bar| bar.child(assistant_button), ) .children(code_actions_dropdown) diff --git a/crates/zed_actions/src/lib.rs b/crates/zed_actions/src/lib.rs index 06121a9de8e0b68316c8ffda1d4a393beedb217f..64891b6973bac04efdfb4cbadadd12cfcca3be10 100644 --- a/crates/zed_actions/src/lib.rs +++ b/crates/zed_actions/src/lib.rs @@ -76,6 +76,9 @@ pub struct Extensions { /// Filters the extensions page down to extensions that are in the specified category. #[serde(default)] pub category_filter: Option<ExtensionCategoryFilter>, + /// Focuses just the extension with the specified ID. + #[serde(default)] + pub id: Option<String>, } /// Decreases the font size in the editor buffer. @@ -257,14 +260,25 @@ pub mod icon_theme_selector { } } +pub mod settings_profile_selector { + use gpui::Action; + use schemars::JsonSchema; + use serde::Deserialize; + + #[derive(PartialEq, Clone, Default, Debug, Deserialize, JsonSchema, Action)] + #[action(namespace = settings_profile_selector)] + pub struct Toggle; +} + pub mod agent { use gpui::actions; actions!( agent, [ - /// Opens the agent configuration panel. - OpenConfiguration, + /// Opens the agent settings panel. + #[action(deprecated_aliases = ["agent::OpenConfiguration"])] + OpenSettings, /// Opens the agent onboarding modal. OpenOnboardingModal, /// Resets the agent onboarding state. @@ -274,7 +288,10 @@ pub mod agent { /// Displays the previous message in the history. PreviousHistoryMessage, /// Displays the next message in the history. - NextHistoryMessage + NextHistoryMessage, + /// Toggles the language model selector dropdown. + #[action(deprecated_aliases = ["assistant::ToggleModelSelector", "assistant2::ToggleModelSelector"])] + ToggleModelSelector ] ); } diff --git a/crates/zeta/Cargo.toml b/crates/zeta/Cargo.toml index 1609773339a57df929ce317ce2a793fb8b067bca..9f1d02b79003c57092221e9953b509636946b61f 100644 --- a/crates/zeta/Cargo.toml +++ b/crates/zeta/Cargo.toml @@ -17,11 +17,14 @@ doctest = false test-support = [] [dependencies] +ai_onboarding.workspace = true anyhow.workspace = true arrayvec.workspace = true client.workspace = true +cloud_llm_client.workspace = true collections.workspace = true command_palette_hooks.workspace = true +copilot.workspace = true db.workspace = true editor.workspace = true feature_flags.workspace = true @@ -30,16 +33,13 @@ futures.workspace = true gpui.workspace = true http_client.workspace = true indoc.workspace = true -inline_completion.workspace = true +edit_prediction.workspace = true language.workspace = true language_model.workspace = true log.workspace = true menu.workspace = true -migrator.workspace = true -paths.workspace = true postage.workspace = true project.workspace = true -proto.workspace = true regex.workspace = true release_channel.workspace = true serde.workspace = true @@ -52,16 +52,17 @@ thiserror.workspace = true ui.workspace = true util.workspace = true uuid.workspace = true +workspace-hack.workspace = true workspace.workspace = true worktree.workspace = true zed_actions.workspace = true -zed_llm_client.workspace = true -workspace-hack.workspace = true [dev-dependencies] -collections = { workspace = true, features = ["test-support"] } +call = { workspace = true, features = ["test-support"] } client = { workspace = true, features = ["test-support"] } clock = { workspace = true, features = ["test-support"] } +cloud_api_types.workspace = true +collections = { workspace = true, features = ["test-support"] } ctor.workspace = true editor = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] } @@ -77,5 +78,4 @@ tree-sitter-rust.workspace = true unindent.workspace = true workspace = { workspace = true, features = ["test-support"] } worktree = { workspace = true, features = ["test-support"] } -call = { workspace = true, features = ["test-support"] } zlog.workspace = true diff --git a/crates/zeta/src/completion_diff_element.rs b/crates/zeta/src/completion_diff_element.rs index 3b7355d797de4a07f5f84f0230774b003ce3bde4..73c3cb20cd7de5da92fbf6e5a32a8ca8d42a5933 100644 --- a/crates/zeta/src/completion_diff_element.rs +++ b/crates/zeta/src/completion_diff_element.rs @@ -1,6 +1,6 @@ use std::cmp; -use crate::InlineCompletion; +use crate::EditPrediction; use gpui::{ AnyElement, App, BorderStyle, Bounds, Corners, Edges, HighlightStyle, Hsla, StyledText, TextLayout, TextStyle, point, prelude::*, quad, size, @@ -17,7 +17,7 @@ pub struct CompletionDiffElement { } impl CompletionDiffElement { - pub fn new(completion: &InlineCompletion, cx: &App) -> Self { + pub fn new(completion: &EditPrediction, cx: &App) -> Self { let mut diff = completion .snapshot .text_for_range(completion.excerpt_range.clone()) diff --git a/crates/zeta/src/init.rs b/crates/zeta/src/init.rs index 6411e423a4d2e0b0f8b9e8b6e2e745a11e7864e6..a01e3a89a2bc0e365dd58a80e2377a215e303c64 100644 --- a/crates/zeta/src/init.rs +++ b/crates/zeta/src/init.rs @@ -4,7 +4,8 @@ use command_palette_hooks::CommandPaletteFilter; use feature_flags::{FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag}; use gpui::actions; use language::language_settings::{AllLanguageSettings, EditPredictionProvider}; -use settings::update_settings_file; +use project::DisableAiSettings; +use settings::{Settings, SettingsStore, update_settings_file}; use ui::App; use workspace::Workspace; @@ -21,6 +22,8 @@ actions!( ); pub fn init(cx: &mut App) { + feature_gate_predict_edits_actions(cx); + cx.observe_new(move |workspace: &mut Workspace, _, _cx| { workspace.register_action(|workspace, _: &RateCompletions, window, cx| { if cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>() { @@ -34,7 +37,6 @@ pub fn init(cx: &mut App) { workspace, workspace.user_store().clone(), workspace.client().clone(), - workspace.app_state().fs.clone(), window, cx, ) @@ -54,27 +56,57 @@ pub fn init(cx: &mut App) { }); }) .detach(); - - feature_gate_predict_edits_rating_actions(cx); } -fn feature_gate_predict_edits_rating_actions(cx: &mut App) { +fn feature_gate_predict_edits_actions(cx: &mut App) { let rate_completion_action_types = [TypeId::of::<RateCompletions>()]; + let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()]; + let zeta_all_action_types = [ + TypeId::of::<RateCompletions>(), + TypeId::of::<ResetOnboarding>(), + zed_actions::OpenZedPredictOnboarding.type_id(), + TypeId::of::<crate::ClearHistory>(), + TypeId::of::<crate::ThumbsUpActiveCompletion>(), + TypeId::of::<crate::ThumbsDownActiveCompletion>(), + TypeId::of::<crate::NextEdit>(), + TypeId::of::<crate::PreviousEdit>(), + ]; CommandPaletteFilter::update_global(cx, |filter, _cx| { filter.hide_action_types(&rate_completion_action_types); + filter.hide_action_types(&reset_onboarding_action_types); filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]); }); + cx.observe_global::<SettingsStore>(move |cx| { + let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai; + let has_feature_flag = cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>(); + + CommandPaletteFilter::update_global(cx, |filter, _cx| { + if is_ai_disabled { + filter.hide_action_types(&zeta_all_action_types); + } else { + if has_feature_flag { + filter.show_action_types(rate_completion_action_types.iter()); + } else { + filter.hide_action_types(&rate_completion_action_types); + } + } + }); + }) + .detach(); + cx.observe_flag::<PredictEditsRateCompletionsFeatureFlag, _>(move |is_enabled, cx| { - if is_enabled { - CommandPaletteFilter::update_global(cx, |filter, _cx| { - filter.show_action_types(rate_completion_action_types.iter()); - }); - } else { - CommandPaletteFilter::update_global(cx, |filter, _cx| { - filter.hide_action_types(&rate_completion_action_types); - }); + if !DisableAiSettings::get_global(cx).disable_ai { + if is_enabled { + CommandPaletteFilter::update_global(cx, |filter, _cx| { + filter.show_action_types(rate_completion_action_types.iter()); + }); + } else { + CommandPaletteFilter::update_global(cx, |filter, _cx| { + filter.hide_action_types(&rate_completion_action_types); + }); + } } }) .detach(); diff --git a/crates/zeta/src/onboarding_modal.rs b/crates/zeta/src/onboarding_modal.rs index c123d76c53c801fb8eb7eb95416b8f53fc3f58f6..1d59f36b0532429f8cc24f3fc6adcdd468279d33 100644 --- a/crates/zeta/src/onboarding_modal.rs +++ b/crates/zeta/src/onboarding_modal.rs @@ -1,40 +1,33 @@ -use std::{sync::Arc, time::Duration}; +use std::sync::Arc; -use crate::{ZED_PREDICT_DATA_COLLECTION_CHOICE, onboarding_event}; -use anyhow::Context as _; +use crate::{ZedPredictUpsell, onboarding_event}; +use ai_onboarding::EditPredictionOnboarding; use client::{Client, UserStore}; -use db::kvp::KEY_VALUE_STORE; +use db::kvp::Dismissable; use fs::Fs; use gpui::{ - Animation, AnimationExt as _, ClickEvent, DismissEvent, Entity, EventEmitter, FocusHandle, - Focusable, MouseDownEvent, Render, ease_in_out, svg, + ClickEvent, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, MouseDownEvent, Render, + linear_color_stop, linear_gradient, }; use language::language_settings::{AllLanguageSettings, EditPredictionProvider}; -use settings::{Settings, update_settings_file}; -use ui::{Checkbox, TintColor, prelude::*}; -use util::ResultExt; -use workspace::{ModalView, Workspace, notifications::NotifyTaskExt}; +use settings::update_settings_file; +use ui::{Vector, VectorName, prelude::*}; +use workspace::{ModalView, Workspace}; /// Introduces user to Zed's Edit Prediction feature and terms of service pub struct ZedPredictModal { - user_store: Entity<UserStore>, - client: Arc<Client>, - fs: Arc<dyn Fs>, + onboarding: Entity<EditPredictionOnboarding>, focus_handle: FocusHandle, - sign_in_status: SignInStatus, - terms_of_service: bool, - data_collection_expanded: bool, - data_collection_opted_in: bool, } -#[derive(PartialEq, Eq)] -enum SignInStatus { - /// Signed out or signed in but not from this modal - Idle, - /// Authentication triggered from this modal - Waiting, - /// Signed in after authentication from this modal - SignedIn, +pub(crate) fn set_edit_prediction_provider(provider: EditPredictionProvider, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + update_settings_file::<AllLanguageSettings>(fs, cx, move |settings, _| { + settings + .features + .get_or_insert(Default::default()) + .edit_prediction_provider = Some(provider); + }); } impl ZedPredictModal { @@ -42,127 +35,45 @@ impl ZedPredictModal { workspace: &mut Workspace, user_store: Entity<UserStore>, client: Arc<Client>, - fs: Arc<dyn Fs>, window: &mut Window, cx: &mut Context<Workspace>, ) { - workspace.toggle_modal(window, cx, |_window, cx| Self { - user_store, - client, - fs, - focus_handle: cx.focus_handle(), - sign_in_status: SignInStatus::Idle, - terms_of_service: false, - data_collection_expanded: false, - data_collection_opted_in: false, - }); - } - - fn view_terms(&mut self, _: &ClickEvent, _: &mut Window, cx: &mut Context<Self>) { - cx.open_url("https://zed.dev/terms-of-service"); - cx.notify(); - - onboarding_event!("ToS Link Clicked"); - } - - fn view_blog(&mut self, _: &ClickEvent, _: &mut Window, cx: &mut Context<Self>) { - cx.open_url("https://zed.dev/blog/edit-prediction"); - cx.notify(); - - onboarding_event!("Blog Link clicked"); - } - - fn inline_completions_doc(&mut self, _: &ClickEvent, _: &mut Window, cx: &mut Context<Self>) { - cx.open_url("https://zed.dev/docs/configuring-zed#disabled-globs"); - cx.notify(); - - onboarding_event!("Docs Link Clicked"); - } - - fn accept_and_enable(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) { - let task = self - .user_store - .update(cx, |this, cx| this.accept_terms_of_service(cx)); - let fs = self.fs.clone(); - - cx.spawn(async move |this, cx| { - task.await?; - - let mut data_collection_opted_in = false; - this.update(cx, |this, _cx| { - data_collection_opted_in = this.data_collection_opted_in; - }) - .ok(); - - KEY_VALUE_STORE - .write_kvp( - ZED_PREDICT_DATA_COLLECTION_CHOICE.into(), - data_collection_opted_in.to_string(), - ) - .await - .log_err(); - - // Make sure edit prediction provider setting is using the new key - let settings_path = paths::settings_file().as_path(); - let settings_path = fs.canonicalize(settings_path).await.with_context(|| { - format!("Failed to canonicalize settings path {:?}", settings_path) - })?; - - if let Some(settings) = fs.load(&settings_path).await.log_err() { - if let Some(new_settings) = - migrator::migrate_edit_prediction_provider_settings(&settings)? - { - fs.atomic_write(settings_path, new_settings).await?; - } + workspace.toggle_modal(window, cx, |_window, cx| { + let weak_entity = cx.weak_entity(); + Self { + onboarding: cx.new(|cx| { + EditPredictionOnboarding::new( + user_store.clone(), + client.clone(), + copilot::Copilot::global(cx) + .map_or(false, |copilot| copilot.read(cx).status().is_configured()), + Arc::new({ + let this = weak_entity.clone(); + move |_window, cx| { + ZedPredictUpsell::set_dismissed(true, cx); + set_edit_prediction_provider(EditPredictionProvider::Zed, cx); + this.update(cx, |_, cx| cx.emit(DismissEvent)).ok(); + } + }), + Arc::new({ + let this = weak_entity.clone(); + move |window, cx| { + ZedPredictUpsell::set_dismissed(true, cx); + set_edit_prediction_provider(EditPredictionProvider::Copilot, cx); + this.update(cx, |_, cx| cx.emit(DismissEvent)).ok(); + copilot::initiate_sign_in(window, cx); + } + }), + cx, + ) + }), + focus_handle: cx.focus_handle(), } - - this.update(cx, |this, cx| { - update_settings_file::<AllLanguageSettings>(this.fs.clone(), cx, move |file, _| { - file.features - .get_or_insert(Default::default()) - .edit_prediction_provider = Some(EditPredictionProvider::Zed); - }); - - cx.emit(DismissEvent); - }) - }) - .detach_and_notify_err(window, cx); - - onboarding_event!( - "Enable Clicked", - data_collection_opted_in = self.data_collection_opted_in, - ); - } - - fn sign_in(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) { - let client = self.client.clone(); - self.sign_in_status = SignInStatus::Waiting; - - cx.spawn(async move |this, cx| { - let result = client - .authenticate_and_connect(true, &cx) - .await - .into_response(); - - let status = match result { - Ok(_) => SignInStatus::SignedIn, - Err(_) => SignInStatus::Idle, - }; - - this.update(cx, |this, cx| { - this.sign_in_status = status; - onboarding_event!("Signed In"); - cx.notify() - })?; - - result - }) - .detach_and_notify_err(window, cx); - - onboarding_event!("Sign In Clicked"); + }); } fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) { + ZedPredictUpsell::set_dismissed(true, cx); cx.emit(DismissEvent); } } @@ -177,85 +88,12 @@ impl Focusable for ZedPredictModal { impl ModalView for ZedPredictModal {} -impl ZedPredictModal { - fn render_data_collection_explanation(&self, cx: &Context<Self>) -> impl IntoElement { - fn label_item(label_text: impl Into<SharedString>) -> impl Element { - Label::new(label_text).color(Color::Muted).into_element() - } - - fn info_item(label_text: impl Into<SharedString>) -> impl Element { - h_flex() - .items_start() - .gap_2() - .child( - div() - .mt_1p5() - .child(Icon::new(IconName::Check).size(IconSize::XSmall)), - ) - .child(div().w_full().child(label_item(label_text))) - } - - fn multiline_info_item<E1: Into<SharedString>, E2: IntoElement>( - first_line: E1, - second_line: E2, - ) -> impl Element { - v_flex() - .child(info_item(first_line)) - .child(div().pl_5().child(second_line)) - } - - v_flex() - .mt_2() - .p_2() - .rounded_sm() - .bg(cx.theme().colors().editor_background.opacity(0.5)) - .border_1() - .border_color(cx.theme().colors().border_variant) - .child( - div().child( - Label::new("To improve edit predictions, please consider contributing to our open dataset based on your interactions within open source repositories.") - .mb_1() - ) - ) - .child(info_item( - "We collect data exclusively from open source projects.", - )) - .child(info_item( - "Zed automatically detects if your project is open source.", - )) - .child(info_item("Toggle participation at any time via the status bar menu.")) - .child(multiline_info_item( - "If turned on, this setting applies for all open source repositories", - label_item("you open in Zed.") - )) - .child(multiline_info_item( - "Files with sensitive data, like `.env`, are excluded by default", - h_flex() - .w_full() - .flex_wrap() - .child(label_item("via the")) - .child( - Button::new("doc-link", "disabled_globs").on_click( - cx.listener(Self::inline_completions_doc), - ), - ) - .child(label_item("setting.")), - )) - } -} - impl Render for ZedPredictModal { fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { let window_height = window.viewport_size().height; let max_height = window_height - px(200.); - let has_subscription_period = self.user_store.read(cx).subscription_period().is_some(); - let plan = self.user_store.read(cx).current_plan().filter(|_| { - // Since the user might be on the legacy free plan we filter based on whether we have a subscription period. - has_subscription_period - }); - - let base = v_flex() + v_flex() .id("edit-prediction-onboarding") .key_context("ZedPredictModal") .relative() @@ -264,14 +102,9 @@ impl Render for ZedPredictModal { .max_h(max_height) .p_4() .gap_2() - .when(self.data_collection_expanded, |element| { - element.overflow_y_scroll() - }) - .when(!self.data_collection_expanded, |element| { - element.overflow_hidden() - }) .elevation_3(cx) .track_focus(&self.focus_handle(cx)) + .overflow_hidden() .on_action(cx.listener(Self::cancel)) .on_action(cx.listener(|_, _: &menu::Cancel, _window, cx| { onboarding_event!("Cancelled", trigger = "Action"); @@ -282,77 +115,30 @@ impl Render for ZedPredictModal { })) .child( div() - .p_1p5() + .opacity(0.5) .absolute() - .top_1() - .left_1() + .top(px(-8.0)) .right_0() - .h(px(200.)) + .w(px(400.)) + .h(px(92.)) .child( - svg() - .path("icons/zed_predict_bg.svg") - .text_color(cx.theme().colors().icon_disabled) - .w(px(530.)) - .h(px(128.)) - .overflow_hidden(), + Vector::new(VectorName::AiGrid, rems_from_px(400.), rems_from_px(92.)) + .color(Color::Custom(cx.theme().colors().text.alpha(0.32))), ), ) .child( - h_flex() - .w_full() - .mb_2() - .justify_between() - .child( - v_flex() - .gap_1() - .child( - Label::new("Introducing Zed AI's") - .size(LabelSize::Small) - .color(Color::Muted), - ) - .child(Headline::new("Edit Prediction").size(HeadlineSize::Large)), - ) - .child({ - let tab = |n: usize| { - let text_color = cx.theme().colors().text; - let border_color = cx.theme().colors().text_accent.opacity(0.4); - - h_flex().child( - h_flex() - .px_4() - .py_0p5() - .bg(cx.theme().colors().editor_background) - .border_1() - .border_color(border_color) - .rounded_sm() - .font(theme::ThemeSettings::get_global(cx).buffer_font.clone()) - .text_size(TextSize::XSmall.rems(cx)) - .text_color(text_color) - .child("tab") - .with_animation( - n, - Animation::new(Duration::from_secs(2)).repeat(), - move |tab, delta| { - let delta = (delta - 0.15 * n as f32) / 0.7; - let delta = 1.0 - (0.5 - delta).abs() * 2.; - let delta = ease_in_out(delta.clamp(0., 1.)); - let delta = 0.1 + 0.9 * delta; - - tab.border_color(border_color.opacity(delta)) - .text_color(text_color.opacity(delta)) - }, - ), - ) - }; - - v_flex() - .gap_2() - .items_center() - .pr_2p5() - .child(tab(0).ml_neg_20()) - .child(tab(1)) - .child(tab(2).ml_20()) - }), + div() + .absolute() + .top_0() + .right_0() + .w(px(660.)) + .h(px(401.)) + .overflow_hidden() + .bg(linear_gradient( + 75., + linear_color_stop(cx.theme().colors().panel_background.alpha(0.01), 1.0), + linear_color_stop(cx.theme().colors().panel_background, 0.45), + )), ) .child(h_flex().absolute().top_2().right_2().child( IconButton::new("cancel", IconName::X).on_click(cx.listener( @@ -361,148 +147,7 @@ impl Render for ZedPredictModal { cx.emit(DismissEvent); }, )), - )); - - let blog_post_button = Button::new("view-blog", "Read the Blog Post") - .full_width() - .icon(IconName::ArrowUpRight) - .icon_size(IconSize::Indicator) - .icon_color(Color::Muted) - .on_click(cx.listener(Self::view_blog)); - - if self.user_store.read(cx).current_user().is_some() { - let copy = match self.sign_in_status { - SignInStatus::Idle => { - "Zed can now predict your next edit on every keystroke. Powered by Zeta, our open-source, open-dataset language model." - } - SignInStatus::SignedIn => "Almost there! Ensure you:", - SignInStatus::Waiting => unreachable!(), - }; - - let accordion_icons = if self.data_collection_expanded { - (IconName::ChevronUp, IconName::ChevronDown) - } else { - (IconName::ChevronDown, IconName::ChevronUp) - }; - let plan = plan.unwrap_or(proto::Plan::Free); - - base.child(Label::new(copy).color(Color::Muted)) - .child( - h_flex().child( - Checkbox::new("plan", ToggleState::Selected) - .fill() - .disabled(true) - .label(format!( - "You get {} edit predictions through your {}.", - if plan == proto::Plan::Free { - "2,000" - } else { - "unlimited" - }, - match plan { - proto::Plan::Free => "Zed Free plan", - proto::Plan::ZedPro => "Zed Pro plan", - proto::Plan::ZedProTrial => "Zed Pro trial", - } - )), - ), - ) - .child( - h_flex() - .child( - Checkbox::new("tos-checkbox", self.terms_of_service.into()) - .fill() - .label("I have read and accept the") - .on_click(cx.listener(move |this, state, _window, cx| { - this.terms_of_service = *state == ToggleState::Selected; - cx.notify(); - })), - ) - .child( - Button::new("view-tos", "Terms of Service") - .icon(IconName::ArrowUpRight) - .icon_size(IconSize::Indicator) - .icon_color(Color::Muted) - .on_click(cx.listener(Self::view_terms)), - ), - ) - .child( - v_flex() - .child( - h_flex() - .flex_wrap() - .child( - Checkbox::new( - "training-data-checkbox", - self.data_collection_opted_in.into(), - ) - .label( - "Contribute to the open dataset when editing open source.", - ) - .fill() - .on_click(cx.listener( - move |this, state, _window, cx| { - this.data_collection_opted_in = - *state == ToggleState::Selected; - cx.notify() - }, - )), - ) - .child( - Button::new("learn-more", "Learn More") - .icon(accordion_icons.0) - .icon_size(IconSize::Indicator) - .icon_color(Color::Muted) - .on_click(cx.listener(|this, _, _, cx| { - this.data_collection_expanded = - !this.data_collection_expanded; - cx.notify(); - - if this.data_collection_expanded { - onboarding_event!( - "Data Collection Learn More Clicked" - ); - } - })), - ), - ) - .when(self.data_collection_expanded, |element| { - element.child(self.render_data_collection_explanation(cx)) - }), - ) - .child( - v_flex() - .mt_2() - .gap_2() - .w_full() - .child( - Button::new("accept-tos", "Enable Edit Prediction") - .disabled(!self.terms_of_service) - .style(ButtonStyle::Tinted(TintColor::Accent)) - .full_width() - .on_click(cx.listener(Self::accept_and_enable)), - ) - .child(blog_post_button), - ) - } else { - base.child( - Label::new("To set Zed as your edit prediction provider, please sign in.") - .color(Color::Muted), - ) - .child( - v_flex() - .mt_2() - .gap_2() - .w_full() - .child( - Button::new("accept-tos", "Sign in with GitHub") - .disabled(self.sign_in_status == SignInStatus::Waiting) - .style(ButtonStyle::Tinted(TintColor::Accent)) - .full_width() - .on_click(cx.listener(Self::sign_in)), - ) - .child(blog_post_button), - ) - } + )) + .child(self.onboarding.clone()) } } diff --git a/crates/zeta/src/rate_completion_modal.rs b/crates/zeta/src/rate_completion_modal.rs index 5a873fb8de70a42c8a8d0289a14e019f6ef3d0e5..ac7fcade9137d4a60b22e88273b0d625371028e1 100644 --- a/crates/zeta/src/rate_completion_modal.rs +++ b/crates/zeta/src/rate_completion_modal.rs @@ -1,4 +1,4 @@ -use crate::{CompletionDiffElement, InlineCompletion, InlineCompletionRating, Zeta}; +use crate::{CompletionDiffElement, EditPrediction, EditPredictionRating, Zeta}; use editor::Editor; use gpui::{App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, actions, prelude::*}; use language::language_settings; @@ -34,7 +34,7 @@ pub struct RateCompletionModal { } struct ActiveCompletion { - completion: InlineCompletion, + completion: EditPrediction, feedback_editor: Entity<Editor>, } @@ -157,7 +157,7 @@ impl RateCompletionModal { if let Some(active) = &self.active_completion { zeta.rate_completion( &active.completion, - InlineCompletionRating::Positive, + EditPredictionRating::Positive, active.feedback_editor.read(cx).text(cx), cx, ); @@ -189,7 +189,7 @@ impl RateCompletionModal { self.zeta.update(cx, |zeta, cx| { zeta.rate_completion( &active.completion, - InlineCompletionRating::Negative, + EditPredictionRating::Negative, active.feedback_editor.read(cx).text(cx), cx, ); @@ -250,7 +250,7 @@ impl RateCompletionModal { pub fn select_completion( &mut self, - completion: Option<InlineCompletion>, + completion: Option<EditPrediction>, focus: bool, window: &mut Window, cx: &mut Context<Self>, diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 12d3d4bfbc7aae1c5b2ed2d36c7a89dd1f526723..b1bd737dbf097adaee6cfd6e0f4d1a2275e343ca 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -7,10 +7,9 @@ mod onboarding_telemetry; mod rate_completion_modal; pub(crate) use completion_diff_element::*; -use db::kvp::KEY_VALUE_STORE; -use feature_flags::{FeatureFlagAppExt as _, ZedCloudFeatureFlag}; +use db::kvp::{Dismissable, KEY_VALUE_STORE}; +use edit_prediction::DataCollectionState; pub use init::*; -use inline_completion::DataCollectionState; use license_detection::LICENSE_FILES_TO_CHECK; pub use license_detection::is_license_eligible_for_data_collection; pub use rate_completion_modal::*; @@ -18,6 +17,10 @@ pub use rate_completion_modal::*; use anyhow::{Context as _, Result, anyhow}; use arrayvec::ArrayVec; use client::{Client, EditPredictionUsage, UserStore}; +use cloud_llm_client::{ + AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, + PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, ZED_VERSION_HEADER_NAME, +}; use collections::{HashMap, HashSet, VecDeque}; use futures::AsyncReadExt; use gpui::{ @@ -31,7 +34,7 @@ use language::{ }; use language_model::{LlmApiToken, RefreshLlmTokenListener}; use postage::watch; -use project::Project; +use project::{Project, ProjectPath}; use release_channel::AppVersion; use settings::WorktreeId; use std::str::FromStr; @@ -47,17 +50,13 @@ use std::{ sync::Arc, time::{Duration, Instant}, }; -use telemetry_events::InlineCompletionRating; +use telemetry_events::EditPredictionRating; use thiserror::Error; use util::ResultExt; use uuid::Uuid; use workspace::Workspace; use workspace::notifications::{ErrorMessagePrompt, NotificationId}; use worktree::Worktree; -use zed_llm_client::{ - AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, - PredictEditsBody, PredictEditsResponse, ZED_VERSION_HEADER_NAME, -}; const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>"; const START_OF_FILE_MARKER: &'static str = "<|start_of_file|>"; @@ -82,28 +81,61 @@ actions!( ); #[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)] -pub struct InlineCompletionId(Uuid); +pub struct EditPredictionId(Uuid); -impl From<InlineCompletionId> for gpui::ElementId { - fn from(value: InlineCompletionId) -> Self { +impl From<EditPredictionId> for gpui::ElementId { + fn from(value: EditPredictionId) -> Self { gpui::ElementId::Uuid(value.0) } } -impl std::fmt::Display for InlineCompletionId { +impl std::fmt::Display for EditPredictionId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0) } } +struct ZedPredictUpsell; + +impl Dismissable for ZedPredictUpsell { + const KEY: &'static str = "dismissed-edit-predict-upsell"; + + fn dismissed() -> bool { + // To make this backwards compatible with older versions of Zed, we + // check if the user has seen the previous Edit Prediction Onboarding + // before, by checking the data collection choice which was written to + // the database once the user clicked on "Accept and Enable" + if KEY_VALUE_STORE + .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE) + .log_err() + .map_or(false, |s| s.is_some()) + { + return true; + } + + KEY_VALUE_STORE + .read_kvp(Self::KEY) + .log_err() + .map_or(false, |s| s.is_some()) + } +} + +pub fn should_show_upsell_modal(user_store: &Entity<UserStore>, cx: &App) -> bool { + if user_store.read(cx).has_accepted_terms_of_service() { + !ZedPredictUpsell::dismissed() + } else { + true + } +} + #[derive(Clone)] struct ZetaGlobal(Entity<Zeta>); impl Global for ZetaGlobal {} #[derive(Clone)] -pub struct InlineCompletion { - id: InlineCompletionId, +pub struct EditPrediction { + id: EditPredictionId, path: Arc<Path>, excerpt_range: Range<usize>, cursor_offset: usize, @@ -114,14 +146,14 @@ pub struct InlineCompletion { input_events: Arc<str>, input_excerpt: Arc<str>, output_excerpt: Arc<str>, - request_sent_at: Instant, + buffer_snapshotted_at: Instant, response_received_at: Instant, } -impl InlineCompletion { +impl EditPrediction { fn latency(&self) -> Duration { self.response_received_at - .duration_since(self.request_sent_at) + .duration_since(self.buffer_snapshotted_at) } fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> { @@ -175,9 +207,9 @@ fn interpolate( if edits.is_empty() { None } else { Some(edits) } } -impl std::fmt::Debug for InlineCompletion { +impl std::fmt::Debug for EditPrediction { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("InlineCompletion") + f.debug_struct("EditPrediction") .field("id", &self.id) .field("path", &self.path) .field("edits", &self.edits) @@ -190,17 +222,14 @@ pub struct Zeta { client: Arc<Client>, events: VecDeque<Event>, registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>, - shown_completions: VecDeque<InlineCompletion>, - rated_completions: HashSet<InlineCompletionId>, + shown_completions: VecDeque<EditPrediction>, + rated_completions: HashSet<EditPredictionId>, data_collection_choice: Entity<DataCollectionChoice>, llm_token: LlmApiToken, _llm_token_subscription: Subscription, - /// Whether the terms of service have been accepted. - tos_accepted: bool, /// Whether an update to a newer version of Zed is required to continue using Zeta. update_required: bool, user_store: Entity<UserStore>, - _user_store_subscription: Subscription, license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>, } @@ -275,22 +304,7 @@ impl Zeta { .detach_and_log_err(cx); }, ), - tos_accepted: user_store - .read(cx) - .current_user_has_accepted_terms() - .unwrap_or(false), update_required: false, - _user_store_subscription: cx.subscribe(&user_store, |this, user_store, event, cx| { - match event { - client::user::Event::PrivateUserInfoUpdated => { - this.tos_accepted = user_store - .read(cx) - .current_user_has_accepted_terms() - .unwrap_or(false); - } - _ => {} - } - }), license_detection_watchers: HashMap::default(), user_store, } @@ -370,119 +384,70 @@ impl Zeta { can_collect_data: bool, cx: &mut Context<Self>, perform_predict_edits: F, - ) -> Task<Result<Option<InlineCompletion>>> + ) -> Task<Result<Option<EditPrediction>>> where F: FnOnce(PerformPredictEditsParams) -> R + 'static, R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> + Send + 'static, { + let buffer = buffer.clone(); + let buffer_snapshotted_at = Instant::now(); let snapshot = self.report_changes_for_buffer(&buffer, cx); - let diagnostic_groups = snapshot.diagnostic_groups(None); - let cursor_point = cursor.to_point(&snapshot); - let cursor_offset = cursor_point.to_offset(&snapshot); - let events = self.events.clone(); - let path: Arc<Path> = snapshot - .file() - .map(|f| Arc::from(f.full_path(cx).as_path())) - .unwrap_or_else(|| Arc::from(Path::new("untitled"))); - let zeta = cx.entity(); + let events = self.events.clone(); let client = self.client.clone(); let llm_token = self.llm_token.clone(); let app_version = AppVersion::global(cx); - let use_cloud = cx.has_flag::<ZedCloudFeatureFlag>(); - let buffer = buffer.clone(); - - let local_lsp_store = - project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local()); - let diagnostic_groups = if let Some(local_lsp_store) = local_lsp_store { - Some( - diagnostic_groups - .into_iter() - .filter_map(|(language_server_id, diagnostic_group)| { - let language_server = - local_lsp_store.running_language_server_for_id(language_server_id)?; - - Some(( - language_server.name(), - diagnostic_group.resolve::<usize>(&snapshot), - )) - }) - .collect::<Vec<_>>(), - ) + let git_info = if let (true, Some(project), Some(file)) = + (can_collect_data, project, snapshot.file()) + { + git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx) } else { None }; - cx.spawn(async move |this, cx| { - let request_sent_at = Instant::now(); - - struct BackgroundValues { - input_events: String, - input_excerpt: String, - speculated_output: String, - editable_range: Range<usize>, - input_outline: String, - } + let full_path: Arc<Path> = snapshot + .file() + .map(|f| Arc::from(f.full_path(cx).as_path())) + .unwrap_or_else(|| Arc::from(Path::new("untitled"))); + let full_path_str = full_path.to_string_lossy().to_string(); + let cursor_point = cursor.to_point(&snapshot); + let cursor_offset = cursor_point.to_offset(&snapshot); + let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS); + let gather_task = gather_context( + project, + full_path_str, + &snapshot, + cursor_point, + make_events_prompt, + can_collect_data, + git_info, + cx, + ); - let values = cx - .background_spawn({ - let snapshot = snapshot.clone(); - let path = path.clone(); - async move { - let path = path.to_string_lossy(); - let input_excerpt = excerpt_for_cursor_position( - cursor_point, - &path, - &snapshot, - MAX_REWRITE_TOKENS, - MAX_CONTEXT_TOKENS, - ); - let input_events = prompt_for_events(&events, MAX_EVENT_TOKENS); - let input_outline = prompt_for_outline(&snapshot); - - anyhow::Ok(BackgroundValues { - input_events, - input_excerpt: input_excerpt.prompt, - speculated_output: input_excerpt.speculated_output, - editable_range: input_excerpt.editable_range.to_offset(&snapshot), - input_outline, - }) - } - }) - .await?; + cx.spawn(async move |this, cx| { + let GatherContextOutput { + body, + editable_range, + } = gather_task.await?; log::debug!( "Events:\n{}\nExcerpt:\n{:?}", - values.input_events, - values.input_excerpt + body.input_events, + body.input_excerpt ); - let body = PredictEditsBody { - input_events: values.input_events.clone(), - input_excerpt: values.input_excerpt.clone(), - speculated_output: Some(values.speculated_output), - outline: Some(values.input_outline.clone()), - can_collect_data, - diagnostic_groups: diagnostic_groups.and_then(|diagnostic_groups| { - diagnostic_groups - .into_iter() - .map(|(name, diagnostic_group)| { - Ok((name.to_string(), serde_json::to_value(diagnostic_group)?)) - }) - .collect::<Result<Vec<_>>>() - .log_err() - }), - }; + let input_outline = body.outline.clone().unwrap_or_default(); + let input_events = body.input_events.clone(); + let input_excerpt = body.input_excerpt.clone(); let response = perform_predict_edits(PerformPredictEditsParams { client, llm_token, app_version, body, - use_cloud, }) .await; let (response, usage) = match response { @@ -534,13 +499,13 @@ impl Zeta { response, buffer, &snapshot, - values.editable_range, + editable_range, cursor_offset, - path, - values.input_outline, - values.input_events, - values.input_excerpt, - request_sent_at, + full_path, + input_outline, + input_events, + input_excerpt, + buffer_snapshotted_at, &cx, ) .await @@ -708,7 +673,7 @@ and then another position: language::Anchor, response: PredictEditsResponse, cx: &mut Context<Self>, - ) -> Task<Result<Option<InlineCompletion>>> { + ) -> Task<Result<Option<EditPrediction>>> { use std::future::ready; self.request_completion_impl(None, project, buffer, position, false, cx, |_params| { @@ -723,7 +688,7 @@ and then another position: language::Anchor, can_collect_data: bool, cx: &mut Context<Self>, - ) -> Task<Result<Option<InlineCompletion>>> { + ) -> Task<Result<Option<EditPrediction>>> { let workspace = self .workspace .as_ref() @@ -739,7 +704,7 @@ and then another ) } - fn perform_predict_edits( + pub fn perform_predict_edits( params: PerformPredictEditsParams, ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> { async move { @@ -748,7 +713,6 @@ and then another llm_token, app_version, body, - use_cloud, .. } = params; @@ -764,7 +728,7 @@ and then another } else { request_builder.uri( http_client - .build_zed_llm_url("/predict_edits/v2", &[], use_cloud)? + .build_zed_llm_url("/predict_edits/v2", &[])? .as_ref(), ) }; @@ -818,13 +782,12 @@ and then another fn accept_edit_prediction( &mut self, - request_id: InlineCompletionId, + request_id: EditPredictionId, cx: &mut Context<Self>, ) -> Task<Result<()>> { let client = self.client.clone(); let llm_token = self.llm_token.clone(); let app_version = AppVersion::global(cx); - let use_cloud = cx.has_flag::<ZedCloudFeatureFlag>(); cx.spawn(async move |this, cx| { let http_client = client.http_client(); let mut response = llm_token_retry(&llm_token, &client, |token| { @@ -835,7 +798,7 @@ and then another } else { request_builder.uri( http_client - .build_zed_llm_url("/predict_edits/accept", &[], use_cloud)? + .build_zed_llm_url("/predict_edits/accept", &[])? .as_ref(), ) }; @@ -896,9 +859,9 @@ and then another input_outline: String, input_events: String, input_excerpt: String, - request_sent_at: Instant, + buffer_snapshotted_at: Instant, cx: &AsyncApp, - ) -> Task<Result<Option<InlineCompletion>>> { + ) -> Task<Result<Option<EditPrediction>>> { let snapshot = snapshot.clone(); let request_id = prediction_response.request_id; let output_excerpt = prediction_response.output_excerpt; @@ -930,8 +893,8 @@ and then another let edit_preview = edit_preview.await; - Ok(Some(InlineCompletion { - id: InlineCompletionId(request_id), + Ok(Some(EditPrediction { + id: EditPredictionId(request_id), path, excerpt_range: editable_range, cursor_offset, @@ -942,7 +905,7 @@ and then another input_events: input_events.into(), input_excerpt: input_excerpt.into(), output_excerpt, - request_sent_at, + buffer_snapshotted_at, response_received_at: Instant::now(), })) }) @@ -1041,11 +1004,11 @@ and then another .collect() } - pub fn is_completion_rated(&self, completion_id: InlineCompletionId) -> bool { + pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool { self.rated_completions.contains(&completion_id) } - pub fn completion_shown(&mut self, completion: &InlineCompletion, cx: &mut Context<Self>) { + pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) { self.shown_completions.push_front(completion.clone()); if self.shown_completions.len() > 50 { let completion = self.shown_completions.pop_back().unwrap(); @@ -1056,8 +1019,8 @@ and then another pub fn rate_completion( &mut self, - completion: &InlineCompletion, - rating: InlineCompletionRating, + completion: &EditPrediction, + rating: EditPredictionRating, feedback: String, cx: &mut Context<Self>, ) { @@ -1075,7 +1038,7 @@ and then another cx.notify(); } - pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &InlineCompletion> { + pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> { self.shown_completions.iter() } @@ -1126,12 +1089,11 @@ and then another } } -struct PerformPredictEditsParams { +pub struct PerformPredictEditsParams { pub client: Arc<Client>, pub llm_token: LlmApiToken, pub app_version: SemanticVersion, pub body: PredictEditsBody, - pub use_cloud: bool, } #[derive(Error, Debug)] @@ -1202,6 +1164,108 @@ fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: .sum() } +fn git_info_for_file( + project: &Entity<Project>, + project_path: &ProjectPath, + cx: &App, +) -> Option<PredictEditsGitInfo> { + let git_store = project.read(cx).git_store().read(cx); + if let Some((repository, _repo_path)) = + git_store.repository_and_path_for_project_path(project_path, cx) + { + let repository = repository.read(cx); + let head_sha = repository + .head_commit + .as_ref() + .map(|head_commit| head_commit.sha.to_string()); + let remote_origin_url = repository.remote_origin_url.clone(); + let remote_upstream_url = repository.remote_upstream_url.clone(); + if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() { + return None; + } + Some(PredictEditsGitInfo { + head_sha, + remote_origin_url, + remote_upstream_url, + }) + } else { + None + } +} + +pub struct GatherContextOutput { + pub body: PredictEditsBody, + pub editable_range: Range<usize>, +} + +pub fn gather_context( + project: Option<&Entity<Project>>, + full_path_str: String, + snapshot: &BufferSnapshot, + cursor_point: language::Point, + make_events_prompt: impl FnOnce() -> String + Send + 'static, + can_collect_data: bool, + git_info: Option<PredictEditsGitInfo>, + cx: &App, +) -> Task<Result<GatherContextOutput>> { + let local_lsp_store = + project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local()); + let diagnostic_groups: Vec<(String, serde_json::Value)> = + if let Some(local_lsp_store) = local_lsp_store { + snapshot + .diagnostic_groups(None) + .into_iter() + .filter_map(|(language_server_id, diagnostic_group)| { + let language_server = + local_lsp_store.running_language_server_for_id(language_server_id)?; + let diagnostic_group = diagnostic_group.resolve::<usize>(&snapshot); + let language_server_name = language_server.name().to_string(); + let serialized = serde_json::to_value(diagnostic_group).unwrap(); + Some((language_server_name, serialized)) + }) + .collect::<Vec<_>>() + } else { + Vec::new() + }; + + cx.background_spawn({ + let snapshot = snapshot.clone(); + async move { + let diagnostic_groups = if diagnostic_groups.is_empty() { + None + } else { + Some(diagnostic_groups) + }; + + let input_excerpt = excerpt_for_cursor_position( + cursor_point, + &full_path_str, + &snapshot, + MAX_REWRITE_TOKENS, + MAX_CONTEXT_TOKENS, + ); + let input_events = make_events_prompt(); + let input_outline = prompt_for_outline(&snapshot); + let editable_range = input_excerpt.editable_range.to_offset(&snapshot); + + let body = PredictEditsBody { + input_events, + input_excerpt: input_excerpt.prompt, + speculated_output: Some(input_excerpt.speculated_output), + outline: Some(input_outline), + can_collect_data, + diagnostic_groups, + git_info, + }; + + Ok(GatherContextOutput { + body, + editable_range, + }) + } + }) +} + fn prompt_for_outline(snapshot: &BufferSnapshot) -> String { let mut input_outline = String::new(); @@ -1252,7 +1316,7 @@ struct RegisteredBuffer { } #[derive(Clone)] -enum Event { +pub enum Event { BufferChange { old_snapshot: BufferSnapshot, new_snapshot: BufferSnapshot, @@ -1299,12 +1363,12 @@ impl Event { } #[derive(Debug, Clone)] -struct CurrentInlineCompletion { +struct CurrentEditPrediction { buffer_id: EntityId, - completion: InlineCompletion, + completion: EditPrediction, } -impl CurrentInlineCompletion { +impl CurrentEditPrediction { fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool { if self.buffer_id != old_completion.buffer_id { return true; @@ -1473,17 +1537,17 @@ async fn llm_token_retry( } } -pub struct ZetaInlineCompletionProvider { +pub struct ZetaEditPredictionProvider { zeta: Entity<Zeta>, pending_completions: ArrayVec<PendingCompletion, 2>, next_pending_completion_id: usize, - current_completion: Option<CurrentInlineCompletion>, + current_completion: Option<CurrentEditPrediction>, /// None if this is entirely disabled for this provider provider_data_collection: ProviderDataCollection, last_request_timestamp: Instant, } -impl ZetaInlineCompletionProvider { +impl ZetaEditPredictionProvider { pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300); pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self { @@ -1498,7 +1562,7 @@ impl ZetaInlineCompletionProvider { } } -impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider { +impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { fn name() -> &'static str { "zed-predict" } @@ -1547,7 +1611,12 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider } fn needs_terms_acceptance(&self, cx: &App) -> bool { - !self.zeta.read(cx).tos_accepted + !self + .zeta + .read(cx) + .user_store + .read(cx) + .has_accepted_terms_of_service() } fn is_refreshing(&self) -> bool { @@ -1562,7 +1631,7 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider _debounce: bool, cx: &mut Context<Self>, ) { - if !self.zeta.read(cx).tos_accepted { + if self.needs_terms_acceptance(cx) { return; } @@ -1574,7 +1643,7 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider .zeta .read(cx) .user_store - .read_with(cx, |user_store, _| { + .read_with(cx, |user_store, _cx| { user_store.account_too_young() || user_store.has_overdue_invoices() }) { @@ -1621,7 +1690,7 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider Ok(completion_request) => { let completion_request = completion_request.await; completion_request.map(|c| { - c.map(|completion| CurrentInlineCompletion { + c.map(|completion| CurrentEditPrediction { buffer_id: buffer.entity_id(), completion, }) @@ -1694,7 +1763,7 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider &mut self, _buffer: Entity<Buffer>, _cursor_position: language::Anchor, - _direction: inline_completion::Direction, + _direction: edit_prediction::Direction, _cx: &mut Context<Self>, ) { // Right now we don't support cycling. @@ -1725,8 +1794,8 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider buffer: &Entity<Buffer>, cursor_position: language::Anchor, cx: &mut Context<Self>, - ) -> Option<inline_completion::InlineCompletion> { - let CurrentInlineCompletion { + ) -> Option<edit_prediction::EditPrediction> { + let CurrentEditPrediction { buffer_id, completion, .. @@ -1774,7 +1843,7 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider } } - Some(inline_completion::InlineCompletion { + Some(edit_prediction::EditPrediction { id: Some(completion.id.to_string().into()), edits: edits[edit_start_ix..edit_end_ix].to_vec(), edit_preview: Some(completion.edit_preview.clone()), @@ -1791,19 +1860,20 @@ fn tokens_for_bytes(bytes: usize) -> usize { #[cfg(test)] mod tests { + use client::UserStore; use client::test::FakeServer; use clock::FakeSystemClock; + use cloud_api_types::{CreateLlmTokenResponse, LlmToken}; use gpui::TestAppContext; use http_client::FakeHttpClient; use indoc::indoc; use language::Point; - use rpc::proto; use settings::SettingsStore; use super::*; #[gpui::test] - async fn test_inline_completion_basic_interpolation(cx: &mut TestAppContext) { + async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx)); let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| { to_completion_edits( @@ -1818,19 +1888,19 @@ mod tests { .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx)) .await; - let completion = InlineCompletion { + let completion = EditPrediction { edits, edit_preview, path: Path::new("").into(), snapshot: cx.read(|cx| buffer.read(cx).snapshot()), - id: InlineCompletionId(Uuid::new_v4()), + id: EditPredictionId(Uuid::new_v4()), excerpt_range: 0..0, cursor_offset: 0, input_outline: "".into(), input_events: "".into(), input_excerpt: "".into(), output_excerpt: "".into(), - request_sent_at: Instant::now(), + buffer_snapshotted_at: Instant::now(), response_received_at: Instant::now(), }; @@ -1984,7 +2054,7 @@ mod tests { } #[gpui::test] - async fn test_inline_completion_end_of_buffer(cx: &mut TestAppContext) { + async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) { cx.update(|cx| { let settings_store = SettingsStore::test(cx); cx.set_global(settings_store); @@ -2001,28 +2071,45 @@ mod tests { <|editable_region_end|> ```"}; - let http_client = FakeHttpClient::create(move |_| async move { - Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&PredictEditsResponse { - request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45") - .unwrap(), - output_excerpt: completion_response.to_string(), - }) - .unwrap() - .into(), - ) - .unwrap()) + let http_client = FakeHttpClient::create(move |req| async move { + match (req.method(), req.uri().path()) { + (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&CreateLlmTokenResponse { + token: LlmToken("the-llm-token".to_string()), + }) + .unwrap() + .into(), + ) + .unwrap()), + (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&PredictEditsResponse { + request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45") + .unwrap(), + output_excerpt: completion_response.to_string(), + }) + .unwrap() + .into(), + ) + .unwrap()), + _ => Ok(http_client::Response::builder() + .status(404) + .body("Not Found".into()) + .unwrap()), + } }); let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx)); cx.update(|cx| { RefreshLlmTokenListener::register(client.clone(), cx); }); - let server = FakeServer::for_client(42, &client, cx).await; + // Construct the fake server to authenticate. + let _server = FakeServer::for_client(42, &client, cx).await; let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx)); + let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx)); let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0))); @@ -2030,13 +2117,6 @@ mod tests { zeta.request_completion(None, &buffer, cursor, false, cx) }); - server.receive::<proto::GetUsers>().await.unwrap(); - let token_request = server.receive::<proto::GetLlmToken>().await.unwrap(); - server.respond( - token_request.receipt(), - proto::GetLlmTokenResponse { token: "".into() }, - ); - let completion = completion_task.await.unwrap().unwrap(); buffer.update(cx, |buffer, cx| { buffer.edit(completion.edits.iter().cloned(), None, cx) @@ -2053,20 +2133,36 @@ mod tests { cx: &mut TestAppContext, ) -> Vec<(Range<Point>, String)> { let completion_response = completion_response.to_string(); - let http_client = FakeHttpClient::create(move |_| { + let http_client = FakeHttpClient::create(move |req| { let completion = completion_response.clone(); async move { - Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&PredictEditsResponse { - request_id: Uuid::new_v4(), - output_excerpt: completion, - }) - .unwrap() - .into(), - ) - .unwrap()) + match (req.method(), req.uri().path()) { + (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&CreateLlmTokenResponse { + token: LlmToken("the-llm-token".to_string()), + }) + .unwrap() + .into(), + ) + .unwrap()), + (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&PredictEditsResponse { + request_id: Uuid::new_v4(), + output_excerpt: completion, + }) + .unwrap() + .into(), + ) + .unwrap()), + _ => Ok(http_client::Response::builder() + .status(404) + .body("Not Found".into()) + .unwrap()), + } } }); @@ -2074,9 +2170,10 @@ mod tests { cx.update(|cx| { RefreshLlmTokenListener::register(client.clone(), cx); }); - let server = FakeServer::for_client(42, &client, cx).await; + // Construct the fake server to authenticate. + let _server = FakeServer::for_client(42, &client, cx).await; let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx)); + let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx)); let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()); @@ -2085,13 +2182,6 @@ mod tests { zeta.request_completion(None, &buffer, cursor, false, cx) }); - server.receive::<proto::GetUsers>().await.unwrap(); - let token_request = server.receive::<proto::GetLlmToken>().await.unwrap(); - server.respond( - token_request.receipt(), - proto::GetLlmTokenResponse { token: "".into() }, - ); - let completion = completion_task.await.unwrap().unwrap(); completion .edits diff --git a/crates/zeta_cli/Cargo.toml b/crates/zeta_cli/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..e77351c219bac4425136e2a3f1752d73e76adbbf --- /dev/null +++ b/crates/zeta_cli/Cargo.toml @@ -0,0 +1,45 @@ +[package] +name = "zeta_cli" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[[bin]] +name = "zeta" +path = "src/main.rs" + +[dependencies] +anyhow.workspace = true +clap.workspace = true +client.workspace = true +debug_adapter_extension.workspace = true +extension.workspace = true +fs.workspace = true +futures.workspace = true +gpui.workspace = true +gpui_tokio.workspace = true +language.workspace = true +language_extension.workspace = true +language_model.workspace = true +language_models.workspace = true +languages = { workspace = true, features = ["load-grammars"] } +node_runtime.workspace = true +paths.workspace = true +project.workspace = true +prompt_store.workspace = true +release_channel.workspace = true +reqwest_client.workspace = true +serde.workspace = true +serde_json.workspace = true +settings.workspace = true +shellexpand.workspace = true +terminal_view.workspace = true +util.workspace = true +watch.workspace = true +workspace-hack.workspace = true +zeta.workspace = true +smol.workspace = true diff --git a/crates/zeta_cli/LICENSE-GPL b/crates/zeta_cli/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/zeta_cli/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/zeta_cli/build.rs b/crates/zeta_cli/build.rs new file mode 100644 index 0000000000000000000000000000000000000000..ccbb54c5b4e6db939a8adba8b09f7a2f1174a510 --- /dev/null +++ b/crates/zeta_cli/build.rs @@ -0,0 +1,14 @@ +fn main() { + let cargo_toml = + std::fs::read_to_string("../zed/Cargo.toml").expect("Failed to read Cargo.toml"); + let version = cargo_toml + .lines() + .find(|line| line.starts_with("version = ")) + .expect("Version not found in crates/zed/Cargo.toml") + .split('=') + .nth(1) + .expect("Invalid version format") + .trim() + .trim_matches('"'); + println!("cargo:rustc-env=ZED_PKG_VERSION={}", version); +} diff --git a/crates/zeta_cli/src/headless.rs b/crates/zeta_cli/src/headless.rs new file mode 100644 index 0000000000000000000000000000000000000000..959bb91a8f17b816c233c9143fe4ecdcd2449540 --- /dev/null +++ b/crates/zeta_cli/src/headless.rs @@ -0,0 +1,128 @@ +use client::{Client, ProxySettings, UserStore}; +use extension::ExtensionHostProxy; +use fs::RealFs; +use gpui::http_client::read_proxy_from_env; +use gpui::{App, AppContext, Entity}; +use gpui_tokio::Tokio; +use language::LanguageRegistry; +use language_extension::LspAccess; +use node_runtime::{NodeBinaryOptions, NodeRuntime}; +use project::Project; +use project::project_settings::ProjectSettings; +use release_channel::AppVersion; +use reqwest_client::ReqwestClient; +use settings::{Settings, SettingsStore}; +use std::path::PathBuf; +use std::sync::Arc; +use util::ResultExt as _; + +/// Headless subset of `workspace::AppState`. +pub struct ZetaCliAppState { + pub languages: Arc<LanguageRegistry>, + pub client: Arc<Client>, + pub user_store: Entity<UserStore>, + pub fs: Arc<dyn fs::Fs>, + pub node_runtime: NodeRuntime, +} + +// TODO: dedupe with crates/eval/src/eval.rs +pub fn init(cx: &mut App) -> ZetaCliAppState { + let app_version = AppVersion::load(env!("ZED_PKG_VERSION")); + release_channel::init(app_version, cx); + gpui_tokio::init(cx); + + let mut settings_store = SettingsStore::new(cx); + settings_store + .set_default_settings(settings::default_settings().as_ref(), cx) + .unwrap(); + cx.set_global(settings_store); + client::init_settings(cx); + + // Set User-Agent so we can download language servers from GitHub + let user_agent = format!( + "Zed/{} ({}; {})", + app_version, + std::env::consts::OS, + std::env::consts::ARCH + ); + let proxy_str = ProxySettings::get_global(cx).proxy.to_owned(); + let proxy_url = proxy_str + .as_ref() + .and_then(|input| input.parse().ok()) + .or_else(read_proxy_from_env); + let http = { + let _guard = Tokio::handle(cx).enter(); + + ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent) + .expect("could not start HTTP client") + }; + cx.set_http_client(Arc::new(http)); + + Project::init_settings(cx); + + let client = Client::production(cx); + cx.set_http_client(client.http_client()); + + let git_binary_path = None; + let fs = Arc::new(RealFs::new( + git_binary_path, + cx.background_executor().clone(), + )); + + let mut languages = LanguageRegistry::new(cx.background_executor().clone()); + languages.set_language_server_download_dir(paths::languages_dir().clone()); + let languages = Arc::new(languages); + + let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + + extension::init(cx); + + let (mut tx, rx) = watch::channel(None); + cx.observe_global::<SettingsStore>(move |cx| { + let settings = &ProjectSettings::get_global(cx).node; + let options = NodeBinaryOptions { + allow_path_lookup: !settings.ignore_system_version, + allow_binary_download: true, + use_paths: settings.path.as_ref().map(|node_path| { + let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref()); + let npm_path = settings + .npm_path + .as_ref() + .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref())); + ( + node_path.clone(), + npm_path.unwrap_or_else(|| { + let base_path = PathBuf::new(); + node_path.parent().unwrap_or(&base_path).join("npm") + }), + ) + }), + }; + tx.send(Some(options)).log_err(); + }) + .detach(); + let node_runtime = NodeRuntime::new(client.http_client(), None, rx); + + let extension_host_proxy = ExtensionHostProxy::global(cx); + + language::init(cx); + debug_adapter_extension::init(extension_host_proxy.clone(), cx); + language_extension::init( + LspAccess::Noop, + extension_host_proxy.clone(), + languages.clone(), + ); + language_model::init(client.clone(), cx); + language_models::init(user_store.clone(), client.clone(), cx); + languages::init(languages.clone(), node_runtime.clone(), cx); + prompt_store::init(cx); + terminal_view::init(cx); + + ZetaCliAppState { + languages, + client, + user_store, + fs, + node_runtime, + } +} diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs new file mode 100644 index 0000000000000000000000000000000000000000..adf768315267348bfba00a79a82402873b979f93 --- /dev/null +++ b/crates/zeta_cli/src/main.rs @@ -0,0 +1,378 @@ +mod headless; + +use anyhow::{Result, anyhow}; +use clap::{Args, Parser, Subcommand}; +use futures::channel::mpsc; +use futures::{FutureExt as _, StreamExt as _}; +use gpui::{AppContext, Application, AsyncApp}; +use gpui::{Entity, Task}; +use language::Bias; +use language::Buffer; +use language::Point; +use language_model::LlmApiToken; +use project::{Project, ProjectPath}; +use release_channel::AppVersion; +use reqwest_client::ReqwestClient; +use std::path::{Path, PathBuf}; +use std::process::exit; +use std::str::FromStr; +use std::sync::Arc; +use std::time::Duration; +use zeta::{GatherContextOutput, PerformPredictEditsParams, Zeta, gather_context}; + +use crate::headless::ZetaCliAppState; + +#[derive(Parser, Debug)] +#[command(name = "zeta")] +struct ZetaCliArgs { + #[command(subcommand)] + command: Commands, +} + +#[derive(Subcommand, Debug)] +enum Commands { + Context(ContextArgs), + Predict { + #[arg(long)] + predict_edits_body: Option<FileOrStdin>, + #[clap(flatten)] + context_args: Option<ContextArgs>, + }, +} + +#[derive(Debug, Args)] +#[group(requires = "worktree")] +struct ContextArgs { + #[arg(long)] + worktree: PathBuf, + #[arg(long)] + cursor: CursorPosition, + #[arg(long)] + use_language_server: bool, + #[arg(long)] + events: Option<FileOrStdin>, +} + +#[derive(Debug, Clone)] +enum FileOrStdin { + File(PathBuf), + Stdin, +} + +impl FileOrStdin { + async fn read_to_string(&self) -> Result<String, std::io::Error> { + match self { + FileOrStdin::File(path) => smol::fs::read_to_string(path).await, + FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await, + } + } +} + +impl FromStr for FileOrStdin { + type Err = <PathBuf as FromStr>::Err; + + fn from_str(s: &str) -> Result<Self, Self::Err> { + match s { + "-" => Ok(Self::Stdin), + _ => Ok(Self::File(PathBuf::from_str(s)?)), + } + } +} + +#[derive(Debug, Clone)] +struct CursorPosition { + path: PathBuf, + point: Point, +} + +impl FromStr for CursorPosition { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result<Self> { + let parts: Vec<&str> = s.split(':').collect(); + if parts.len() != 3 { + return Err(anyhow!( + "Invalid cursor format. Expected 'file.rs:line:column', got '{}'", + s + )); + } + + let path = PathBuf::from(parts[0]); + let line: u32 = parts[1] + .parse() + .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?; + let column: u32 = parts[2] + .parse() + .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?; + + // Convert from 1-based to 0-based indexing + let point = Point::new(line.saturating_sub(1), column.saturating_sub(1)); + + Ok(CursorPosition { path, point }) + } +} + +async fn get_context( + args: ContextArgs, + app_state: &Arc<ZetaCliAppState>, + cx: &mut AsyncApp, +) -> Result<GatherContextOutput> { + let ContextArgs { + worktree: worktree_path, + cursor, + use_language_server, + events, + } = args; + + let worktree_path = worktree_path.canonicalize()?; + if cursor.path.is_absolute() { + return Err(anyhow!("Absolute paths are not supported in --cursor")); + } + + let (project, _lsp_open_handle, buffer) = if use_language_server { + let (project, lsp_open_handle, buffer) = + open_buffer_with_language_server(&worktree_path, &cursor.path, &app_state, cx).await?; + (Some(project), Some(lsp_open_handle), buffer) + } else { + let abs_path = worktree_path.join(&cursor.path); + let content = smol::fs::read_to_string(&abs_path).await?; + let buffer = cx.new(|cx| Buffer::local(content, cx))?; + (None, None, buffer) + }; + + let worktree_name = worktree_path + .file_name() + .ok_or_else(|| anyhow!("--worktree path must end with a folder name"))?; + let full_path_str = PathBuf::from(worktree_name) + .join(&cursor.path) + .to_string_lossy() + .to_string(); + + let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?; + let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left); + if clipped_cursor != cursor.point { + let max_row = snapshot.max_point().row; + if cursor.point.row < max_row { + return Err(anyhow!( + "Cursor position {:?} is out of bounds (line length is {})", + cursor.point, + snapshot.line_len(cursor.point.row) + )); + } else { + return Err(anyhow!( + "Cursor position {:?} is out of bounds (max row is {})", + cursor.point, + max_row + )); + } + } + + let events = match events { + Some(events) => events.read_to_string().await?, + None => String::new(), + }; + let can_collect_data = false; + let git_info = None; + cx.update(|cx| { + gather_context( + project.as_ref(), + full_path_str, + &snapshot, + clipped_cursor, + move || events, + can_collect_data, + git_info, + cx, + ) + })? + .await +} + +pub async fn open_buffer_with_language_server( + worktree_path: &Path, + path: &Path, + app_state: &Arc<ZetaCliAppState>, + cx: &mut AsyncApp, +) -> Result<(Entity<Project>, Entity<Entity<Buffer>>, Entity<Buffer>)> { + let project = cx.update(|cx| { + Project::local( + app_state.client.clone(), + app_state.node_runtime.clone(), + app_state.user_store.clone(), + app_state.languages.clone(), + app_state.fs.clone(), + None, + cx, + ) + })?; + + let worktree = project + .update(cx, |project, cx| { + project.create_worktree(worktree_path, true, cx) + })? + .await?; + + let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath { + worktree_id: worktree.id(), + path: path.to_path_buf().into(), + })?; + + let buffer = project + .update(cx, |project, cx| project.open_buffer(project_path, cx))? + .await?; + + let lsp_open_handle = project.update(cx, |project, cx| { + project.register_buffer_with_language_servers(&buffer, cx) + })?; + + let log_prefix = path.to_string_lossy().to_string(); + wait_for_lang_server(&project, &buffer, log_prefix, cx).await?; + + Ok((project, lsp_open_handle, buffer)) +} + +// TODO: Dedupe with similar function in crates/eval/src/instance.rs +pub fn wait_for_lang_server( + project: &Entity<Project>, + buffer: &Entity<Buffer>, + log_prefix: String, + cx: &mut AsyncApp, +) -> Task<Result<()>> { + println!("{}⏵ Waiting for language server", log_prefix); + + let (mut tx, mut rx) = mpsc::channel(1); + + let lsp_store = project + .read_with(cx, |project, _| project.lsp_store()) + .unwrap(); + + let has_lang_server = buffer + .update(cx, |buffer, cx| { + lsp_store.update(cx, |lsp_store, cx| { + lsp_store + .language_servers_for_local_buffer(&buffer, cx) + .next() + .is_some() + }) + }) + .unwrap_or(false); + + if has_lang_server { + project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) + .unwrap() + .detach(); + } + + let subscriptions = [ + cx.subscribe(&lsp_store, { + let log_prefix = log_prefix.clone(); + move |_, event, _| match event { + project::LspStoreEvent::LanguageServerUpdate { + message: + client::proto::update_language_server::Variant::WorkProgress( + client::proto::LspWorkProgress { + message: Some(message), + .. + }, + ), + .. + } => println!("{}⟲ {message}", log_prefix), + _ => {} + } + }), + cx.subscribe(&project, { + let buffer = buffer.clone(); + move |project, event, cx| match event { + project::Event::LanguageServerAdded(_, _, _) => { + let buffer = buffer.clone(); + project + .update(cx, |project, cx| project.save_buffer(buffer, cx)) + .detach(); + } + project::Event::DiskBasedDiagnosticsFinished { .. } => { + tx.try_send(()).ok(); + } + _ => {} + } + }), + ]; + + cx.spawn(async move |cx| { + let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0)); + let result = futures::select! { + _ = rx.next() => { + println!("{}⚑ Language server idle", log_prefix); + anyhow::Ok(()) + }, + _ = timeout.fuse() => { + anyhow::bail!("LSP wait timed out after 5 minutes"); + } + }; + drop(subscriptions); + result + }) +} + +fn main() { + let args = ZetaCliArgs::parse(); + let http_client = Arc::new(ReqwestClient::new()); + let app = Application::headless().with_http_client(http_client); + + app.run(move |cx| { + let app_state = Arc::new(headless::init(cx)); + cx.spawn(async move |cx| { + let result = match args.command { + Commands::Context(context_args) => get_context(context_args, &app_state, cx) + .await + .map(|output| serde_json::to_string_pretty(&output.body).unwrap()), + Commands::Predict { + predict_edits_body, + context_args, + } => { + cx.spawn(async move |cx| { + let app_version = cx.update(|cx| AppVersion::global(cx))?; + app_state.client.sign_in(true, cx).await?; + let llm_token = LlmApiToken::default(); + llm_token.refresh(&app_state.client).await?; + + let predict_edits_body = + if let Some(predict_edits_body) = predict_edits_body { + serde_json::from_str(&predict_edits_body.read_to_string().await?)? + } else if let Some(context_args) = context_args { + get_context(context_args, &app_state, cx).await?.body + } else { + return Err(anyhow!( + "Expected either --predict-edits-body-file \ + or the required args of the `context` command." + )); + }; + + let (response, _usage) = + Zeta::perform_predict_edits(PerformPredictEditsParams { + client: app_state.client.clone(), + llm_token, + app_version, + body: predict_edits_body, + }) + .await?; + + Ok(response.output_excerpt) + }) + .await + } + }; + match result { + Ok(output) => { + println!("{}", output); + let _ = cx.update(|cx| cx.quit()); + } + Err(e) => { + eprintln!("Failed: {:?}", e); + exit(1); + } + } + }) + .detach(); + }); +} diff --git a/crates/zlog/src/sink.rs b/crates/zlog/src/sink.rs index acf0469c775ec89135dfd87813ee20a9351781f5..17aa08026e6dea4bbc98946e044ce3828e5aa28f 100644 --- a/crates/zlog/src/sink.rs +++ b/crates/zlog/src/sink.rs @@ -21,6 +21,8 @@ const ANSI_MAGENTA: &str = "\x1b[35m"; /// Whether stdout output is enabled. static mut ENABLED_SINKS_STDOUT: bool = false; +/// Whether stderr output is enabled. +static mut ENABLED_SINKS_STDERR: bool = false; /// Is Some(file) if file output is enabled. static ENABLED_SINKS_FILE: Mutex<Option<std::fs::File>> = Mutex::new(None); @@ -45,6 +47,12 @@ pub fn init_output_stdout() { } } +pub fn init_output_stderr() { + unsafe { + ENABLED_SINKS_STDERR = true; + } +} + pub fn init_output_file( path: &'static PathBuf, path_rotate: Option<&'static PathBuf>, @@ -115,6 +123,21 @@ pub fn submit(record: Record) { }, record.message ); + } else if unsafe { ENABLED_SINKS_STDERR } { + let mut stdout = std::io::stderr().lock(); + _ = writeln!( + &mut stdout, + "{} {ANSI_BOLD}{}{}{ANSI_RESET} {} {}", + chrono::Local::now().format("%Y-%m-%dT%H:%M:%S%:z"), + LEVEL_ANSI_COLORS[record.level as usize], + LEVEL_OUTPUT_STRINGS[record.level as usize], + SourceFmt { + scope: record.scope, + module_path: record.module_path, + ansi: true, + }, + record.message + ); } let mut file = ENABLED_SINKS_FILE.lock().unwrap_or_else(|handle| { ENABLED_SINKS_FILE.clear_poison(); diff --git a/crates/zlog/src/zlog.rs b/crates/zlog/src/zlog.rs index 570c82314c5d1a56e03610a2740d35833ef07d69..5b40278f3fb0adbafe1815608765aa4ab3d44e57 100644 --- a/crates/zlog/src/zlog.rs +++ b/crates/zlog/src/zlog.rs @@ -5,7 +5,7 @@ mod env_config; pub mod filter; pub mod sink; -pub use sink::{flush, init_output_file, init_output_stdout}; +pub use sink::{flush, init_output_file, init_output_stderr, init_output_stdout}; pub const SCOPE_DEPTH_MAX: usize = 4; diff --git a/docs/README.md b/docs/README.md index 55993c9e36e9a78a9271dbc509ef9129d8c91422..a225903674966b142d5f35845018c98ce9770258 100644 --- a/docs/README.md +++ b/docs/README.md @@ -69,3 +69,64 @@ Templates are just functions that modify the source of the docs pages (usually w - Template Trait: crates/docs_preprocessor/src/templates.rs - Example template: crates/docs_preprocessor/src/templates/keybinding.rs - Client-side plugins: docs/theme/plugins.js + +## Postprocessor + +A postprocessor is implemented as a sub-command of `docs_preprocessor` that wraps the builtin `html` renderer and applies post-processing to the `html` files, to add support for page-specific title and meta description values. + +An example of the syntax can be found in `git.md`, as well as below + +```md +--- +title: Some more detailed title for this page +description: A page-specific description +--- + +# Editor +``` + +The above will be transformed into (with non-relevant tags removed) + +```html +<head> + <title>Editor | Some more detailed title for this page + + + +

Editor

+ +``` + +If no front-matter is provided, or If one or both keys aren't provided, the title and description will be set based on the `default-title` and `default-description` keys in `book.toml` respectively. + +### Implementation details + +Unfortunately, `mdbook` does not support post-processing like it does pre-processing, and only supports defining one description to put in the meta tag per book rather than per file. So in order to apply post-processing (necessary to modify the html head tags) the global book description is set to a marker value `#description#` and the html renderer is replaced with a sub-command of `docs_preprocessor` that wraps the builtin `html` renderer and applies post-processing to the `html` files, replacing the marker value and the `(.*)` with the contents of the front-matter if there is one. + +### Known limitations + +The front-matter parsing is extremely simple, which avoids needing to take on an additional dependency, or implement full yaml parsing. + +- Double quotes and multi-line values are not supported, i.e. Keys and values must be entirely on the same line, with no double quotes around the value. + +The following will not work: + +```md +--- +title: Some + Multi-line + Title +--- +``` + +And neither will: + +```md +--- +title: "Some title" +--- +``` + +- The front-matter must be at the top of the file, with only white-space preceding it + +- The contents of the title and description will not be html-escaped. They should be simple ascii text with no unicode or emoji characters diff --git a/docs/book.toml b/docs/book.toml index d04447d90f846ff33b7437b22d4dc82bbf586c7e..60ddc5ac515cb73f7b0b4f2f8c2c193bdddf228b 100644 --- a/docs/book.toml +++ b/docs/book.toml @@ -6,38 +6,88 @@ src = "src" title = "Zed" site-url = "/docs/" -[output.html] +[build] +extra-watch-dirs = ["../crates/docs_preprocessor"] + +# zed-html is a "custom" renderer that just wraps the +# builtin mdbook html renderer, and applies post-processing +# as post-processing is not possible with mdbook in the same way +# pre-processing is +# The config is passed directly to the html renderer, so all config +# options that apply to html apply to zed-html +[output.zed-html] +command = "cargo run -p docs_preprocessor -- postprocess" +# Set here instead of above as we only use it replace the `#description#` we set in the template +# when no front-matter is provided value +default-description = "Learn how to use and customize Zed, the fast, collaborative code editor. Official docs on features, configuration, AI tools, and workflows." +default-title = "Zed Code Editor Documentation" no-section-label = true preferred-dark-theme = "dark" additional-css = ["theme/page-toc.css", "theme/plugins.css", "theme/highlight.css"] additional-js = ["theme/page-toc.js", "theme/plugins.js"] -[output.html.print] +[output.zed-html.print] enable = false -[output.html.redirect] -"/elixir.html" = "/docs/languages/elixir.html" -"/javascript.html" = "/docs/languages/javascript.html" -"/ruby.html" = "/docs/languages/ruby.html" -"/python.html" = "/docs/languages/python.html" -"/adding-new-languages.html" = "/docs/extensions/languages.html" -"/language-model-integration.html" = "/docs/assistant/assistant.html" -"/assistant.html" = "/docs/assistant/assistant.html" -"/developing-zed.html" = "/docs/development.html" -"/conversations.html" = "/community-links" +# Redirects for `/docs` pages. +# +# All of the source URLs are interpreted relative to mdBook, so they must: +# 1. Not start with `/docs` +# 2. End in `.html` +# +# The destination URLs are interpreted relative to `https://zed.dev`. +# - Redirects to other docs pages should end in `.html` +# - You can link to pages on the Zed site by omitting the `/docs` in front of it. +[output.zed-html.redirect] +# AI "/ai.html" = "/docs/ai/overview.html" +"/assistant-panel.html" = "/docs/ai/agent-panel.html" +"/assistant.html" = "/docs/assistant/assistant.html" +"/assistant/assistant-panel.html" = "/docs/ai/agent-panel.html" "/assistant/assistant.html" = "/docs/ai/overview.html" +"/assistant/commands.html" = "/docs/ai/text-threads.html" "/assistant/configuration.html" = "/docs/ai/configuration.html" -"/assistant/assistant-panel.html" = "/docs/ai/agent-panel.html" +"/assistant/context-servers.html" = "/docs/ai/mcp.html" "/assistant/contexts.html" = "/docs/ai/text-threads.html" "/assistant/inline-assistant.html" = "/docs/ai/inline-assistant.html" -"/assistant/commands.html" = "/docs/ai/text-threads.html" -"/assistant/prompting.html" = "/docs/ai/rules.html" -"/assistant/context-servers.html" = "/docs/ai/mcp.html" "/assistant/model-context-protocol.html" = "/docs/ai/mcp.html" +"/assistant/prompting.html" = "/docs/ai/rules.html" +"/language-model-integration.html" = "/docs/assistant/assistant.html" "/model-improvement.html" = "/docs/ai/ai-improvement.html" +"/ai/temperature.html" = "/docs/ai/agent-settings.html#model-temperature" + +# Community +"/community/feedback.html" = "/community-links" +"/conversations.html" = "/community-links" + +# Debugger +"/debuggers.html" = "/docs/debugger.html" + +# MCP +"/assistant/model-context-protocolCitedby.html" = "/docs/ai/mcp.html" +"/context-servers.html" = "/docs/ai/mcp.html" "/extensions/context-servers.html" = "/docs/extensions/mcp-extensions.html" +# Languages +"/adding-new-languages.html" = "/docs/extensions/languages.html" +"/elixir.html" = "/docs/languages/elixir.html" +"/javascript.html" = "/docs/languages/javascript.html" +"/languages/languages/html.html" = "/docs/languages/html.html" +"/languages/languages/javascript.html" = "/docs/languages/javascript.html" +"/languages/languages/makefile.html" = "/docs/languages/makefile.html" +"/languages/languages/nim.html" = "/docs/languages/nim.html" +"/languages/languages/ruby.html" = "/docs/languages/ruby.html" +"/languages/languages/scala.html" = "/docs/languages/scala.html" +"/python.html" = "/docs/languages/python.html" +"/ruby.html" = "/docs/languages/ruby.html" + +# Zed development +"/contribute-to-zed.html" = "/docs/development.html#contributor-links" +"/contributing.html" = "/docs/development.html#contributor-links" +"/developing-zed.html" = "/docs/development.html" +"/development/development/linux.html" = "/docs/development/linux.html" +"/development/development/macos.html" = "/docs/development/macos.html" +"/development/development/windows.html" = "/docs/development/windows.html" # Our custom preprocessor for expanding commands like `{#kb action::ActionName}`, # and other docs-related functions. diff --git a/docs/src/SUMMARY.md b/docs/src/SUMMARY.md index 1d43872547a366e03136876475004918d9b827b9..fc936d6bd0cfda980df10c4b6c41768ac02486de 100644 --- a/docs/src/SUMMARY.md +++ b/docs/src/SUMMARY.md @@ -45,13 +45,14 @@ - [Overview](./ai/overview.md) - [Agent Panel](./ai/agent-panel.md) - [Tools](./ai/tools.md) - - [Model Temperature](./ai/temperature.md) - [Inline Assistant](./ai/inline-assistant.md) - [Edit Prediction](./ai/edit-prediction.md) - [Text Threads](./ai/text-threads.md) - [Rules](./ai/rules.md) - [Model Context Protocol](./ai/mcp.md) - [Configuration](./ai/configuration.md) + - [LLM Providers](./ai/llm-providers.md) + - [Agent Settings](./ai/agent-settings.md) - [Subscription](./ai/subscription.md) - [Plans and Usage](./ai/plans-and-usage.md) - [Billing](./ai/billing.md) diff --git a/docs/src/accounts.md b/docs/src/accounts.md index c13c98ad9aadd77ccf60c1f1cc3c33e6fac7690d..1ce23cf902dc558de4163621d4ec886d2b719e15 100644 --- a/docs/src/accounts.md +++ b/docs/src/accounts.md @@ -5,7 +5,7 @@ Signing in to Zed is not a requirement. You can use most features you'd expect i ## What Features Require Signing In? 1. All real-time [collaboration features](./collaboration.md). -2. [LLM-powered features](./ai/overview.md), if you are using Zed as the provider of your LLM models. Alternatively, you can [bring and configure your own API keys](./ai/configuration.md#use-your-own-keys) if you'd prefer, and avoid having to sign in. +2. [LLM-powered features](./ai/overview.md), if you are using Zed as the provider of your LLM models. Alternatively, you can [bring and configure your own API keys](./ai/llm-providers.md#use-your-own-keys) if you'd prefer, and avoid having to sign in. ## Signing In diff --git a/docs/src/ai/agent-panel.md b/docs/src/ai/agent-panel.md index 3c04ae5c43f87ee54e96a253300aa20524d6d844..f944eb88b06c8be21002ff319a972ff1843de39d 100644 --- a/docs/src/ai/agent-panel.md +++ b/docs/src/ai/agent-panel.md @@ -1,18 +1,21 @@ # Agent Panel -The Agent Panel provides you with a way to interact with LLMs. -You can use it for various tasks, such as generating code, asking questions about your code base, and general inquiries such as emails and documentation. +The Agent Panel provides you with a surface to interact with LLMs, enabling various types of tasks, such as generating code, asking questions about your codebase, and general inquiries like emails, documentation, and more. -To open the Agent Panel, use the `agent: new thread` action in [the Command Palette](../getting-started.md#command-palette) or click the ✨ (sparkles) icon in the status bar. +To open it, use the `agent: new thread` action in [the Command Palette](../getting-started.md#command-palette) or click the ✨ (sparkles) icon in the status bar. -If you're using the Agent Panel for the first time, you'll need to [configure at least one LLM provider](./configuration.md). +If you're using the Agent Panel for the first time, you need to have at least one LLM provider configured. +You can do that by: + +1. [subscribing to our Pro plan](https://zed.dev/pricing), so you have access to our hosted models +2. or by [bringing your own API keys](./llm-providers.md#use-your-own-keys) for your desired provider ## Overview {#overview} After you've configured one or more LLM providers, type at the message editor and hit `enter` to submit your prompt. If you need extra room to type, you can expand the message editor with {#kb agent::ExpandMessageEditor}. -You should start to see the responses stream in with indications of [which tools](./tools.md) the AI is using to fulfill your prompt. +You should start to see the responses stream in with indications of [which tools](./tools.md) the model is using to fulfill your prompt. ### Editing Messages {#editing-messages} @@ -21,13 +24,13 @@ You can click on the card that contains your message and re-submit it with an ad ### Checkpoints {#checkpoints} -Every time the AI performs an edit, you should see a "Restore Checkpoint" button to the top of your message, allowing you to return your codebase to the state it was in prior to that message. +Every time the AI performs an edit, you should see a "Restore Checkpoint" button to the top of your message, allowing you to return your code base to the state it was in prior to that message. The checkpoint button appears even if you interrupt the thread midway through an edit attempt, as this is likely a moment when you've identified that the agent is not heading in the right direction and you want to revert back. ### Navigating History {#navigating-history} -To quickly navigate through recently opened threads, use the {#kb agent::ToggleNavigationMenu} binding, when focused on the panel's editor, or click the hamburger icon button at the top left of the panel to open the dropdown that shows you the six most recent threads. +To quickly navigate through recently opened threads, use the {#kb agent::ToggleNavigationMenu} binding, when focused on the panel's editor, or click the menu icon button at the top left of the panel to open the dropdown that shows you the six most recent threads. The items in this menu function similarly to tabs, and closing them doesn’t delete the thread; instead, it simply removes them from the recent list. @@ -39,6 +42,8 @@ Zed is built with collaboration natively integrated. This approach extends to collaboration with AI as well. To follow the agent reading through your codebase and performing edits, click on the "crosshair" icon button at the bottom left of the panel. +You can also do that with the keyboard by pressing the `cmd`/`ctrl` modifier with `enter` when submitting a message. + ### Get Notified {#get-notified} If you send a prompt to the Agent and then move elsewhere, thus putting Zed in the background, you can be notified of whether its response is finished either via: @@ -63,12 +68,12 @@ So, if your active tab had edits made by the AI, you'll see diffs with the same ## Adding Context {#adding-context} -Although Zed's agent is very efficient at reading through your codebase to autonomously pick up relevant files, directories, and other context, manually adding context is still encouraged as a way to speed up and improve the AI's response quality. +Although Zed's agent is very efficient at reading through your code base to autonomously pick up relevant files, directories, and other context, manually adding context is still encouraged as a way to speed up and improve the AI's response quality. -If you have a tab open when opening the Agent Panel, that tab appears as a suggested context in form of a dashed button. +If you have a tab open while using the Agent Panel, that tab appears as a suggested context in form of a dashed button. You can also add other forms of context by either mentioning them with `@` or hitting the `+` icon button. -You can even add previous threads as context by mentioning them with `@thread`, or by selecting the "New From Summary" option from the top-right menu to continue a longer conversation, keeping it within the context window. +You can even add previous threads as context by mentioning them with `@thread`, or by selecting the "New From Summary" option from the `+` menu to continue a longer conversation, keeping it within the context window. Pasting images as context is also supported by the Agent Panel. @@ -82,7 +87,7 @@ You can also do this at any time with an ongoing thread via the "Agent Options" ## Changing Models {#changing-models} -After you've configured your LLM providers—either via [a custom API key](./configuration.md#use-your-own-keys) or through [Zed's hosted models](./models.md)—you can switch between them by clicking on the model selector on the message editor or by using the {#kb agent::ToggleModelSelector} keybinding. +After you've configured your LLM providers—either via [a custom API key](./llm-providers.md#use-your-own-keys) or through [Zed's hosted models](./models.md)—you can switch between them by clicking on the model selector on the message editor or by using the {#kb agent::ToggleModelSelector} keybinding. ## Using Tools {#using-tools} @@ -116,6 +121,12 @@ Zed will store this profile in your settings using the same profile name as the All custom profiles can be edited via the UI or by hand under the `assistant.profiles` key in your `settings.json` file. +### Tool Approval + +Zed's Agent Panel surfaces the `agent.always_allow_tool_actions` setting that, if turned to `false`, will require you to give permission to any editing attempt as well as tool calls coming from MCP servers. + +You can change that by setting this key to `true` in either your `settings.json` or via the Agent Panel's settings view. + ### Model Support {#model-support} Tool calling needs to be individually supported by each model and model provider. @@ -141,24 +152,17 @@ You can remove and edit responses from the LLM, swap roles, and include more con For users who have been with us for some time, you'll notice that text threads are our original assistant panel—users love it for the control it offers. We do not plan to deprecate text threads, but it should be noted that if you want the AI to write to your code base autonomously, that's only available in the newer, and now default, "Threads". -### Text Thread History {#text-thread-history} - -Content from text thread are saved to your file system. -Visit [the dedicated docs](./text-threads.md#history) for more info. - ## Errors and Debugging {#errors-and-debugging} In case of any error or strange LLM response behavior, the best way to help the Zed team debug is by reaching for the `agent: open thread as markdown` action and attaching that data as part of your issue on GitHub. -This action exposes the entire thread in the form of Markdown and allows for deeper understanding of what each tool call was doing. - You can also open threads as Markdown by clicking on the file icon button, to the right of the thumbs down button, when focused on the panel's editor. ## Feedback {#feedback} -Every change we make to Zed's system prompt and tool set, needs to be backed by an eval with good scores. +Every change we make to Zed's system prompt and tool set, needs to be backed by a thorough eval with good scores. -Every time the LLM performs a weird change or investigates a certain topic in your codebase completely incorrectly, it's an indication that there's an improvement opportunity. +Every time the LLM performs a weird change or investigates a certain topic in your code base incorrectly, it's an indication that there's an improvement opportunity. > Note that rating responses will send your data related to that response to Zed's servers. > See [AI Improvement](./ai-improvement.md) and [Privacy and Security](./privacy-and-security.md) for more information about Zed's approach to AI improvement, privacy, and security. diff --git a/docs/src/ai/agent-settings.md b/docs/src/ai/agent-settings.md new file mode 100644 index 0000000000000000000000000000000000000000..ff97bcb8eeb941d2c072b95dbcc8089da927df42 --- /dev/null +++ b/docs/src/ai/agent-settings.md @@ -0,0 +1,226 @@ +# Agent Settings + +Learn about all the settings you can customize in Zed's Agent Panel. + +## Model Settings {#model-settings} + +### Default Model {#default-model} + +If you're using [Zed's hosted LLM service](./plans-and-usage.md), it sets `claude-sonnet-4` as the default model. +But if you're not subscribed to it or simply just want to change it, you can do it so either via the model dropdown in the Agent Panel's bottom-right corner or by manually editing the `default_model` object in your settings: + +```json +{ + "agent": { + "default_model": { + "provider": "zed.dev", + "model": "gpt-4o" + } + } +} +``` + +### Feature-specific Models {#feature-specific-models} + +Assign distinct and specific models for the following AI-powered features in Zed: + +- Thread summary model: Used for generating thread summaries +- Inline assistant model: Used for the inline assistant feature +- Commit message model: Used for generating Git commit messages + +```json +{ + "agent": { + "default_model": { + "provider": "zed.dev", + "model": "claude-sonnet-4" + }, + "inline_assistant_model": { + "provider": "anthropic", + "model": "claude-3-5-sonnet" + }, + "commit_message_model": { + "provider": "openai", + "model": "gpt-4o-mini" + }, + "thread_summary_model": { + "provider": "google", + "model": "gemini-2.0-flash" + } + } +} +``` + +> If a custom model isn't set for one of these features, they automatically fall back to using the default model. + +### Alternative Models for Inline Assists {#alternative-assists} + +The Inline Assist feature in particular has the capacity to perform multiple generations in parallel using different models. +That is possible by assigning more than one model to it, taking the configuration shown above one step further. + +When configured, the inline assist UI will surface controls to cycle between the outputs generated by each model. + +The models you specify here are always used in _addition_ to your [default model](#default-model). + +For example, the following configuration will generate two outputs for every assist. +One with Claude Sonnet 4 (the default model), and one with GPT-4o. + +```json +{ + "agent": { + "default_model": { + "provider": "zed.dev", + "model": "claude-sonnet-4" + }, + "inline_alternatives": [ + { + "provider": "zed.dev", + "model": "gpt-4o" + } + ] + } +} +``` + +### Model Temperature + +Specify a custom temperature for a provider and/or model: + +```json +"model_parameters": [ + // To set parameters for all requests to OpenAI models: + { + "provider": "openai", + "temperature": 0.5 + }, + // To set parameters for all requests in general: + { + "temperature": 0 + }, + // To set parameters for a specific provider and model: + { + "provider": "zed.dev", + "model": "claude-sonnet-4", + "temperature": 1.0 + } +], +``` + +## Agent Panel Settings {#agent-panel-settings} + +Note that some of these settings are also surfaced in the Agent Panel's settings UI, which you can access either via the `agent: open settings` action or by the dropdown menu on the top-right corner of the panel. + +### Default View + +Use the `default_view` setting to change the default view of the Agent Panel. +You can choose between `thread` (the default) and `text_thread`: + +```json +{ + "agent": { + "default_view": "text_thread" + } +} +``` + +### Auto-run Commands + +Control whether you want to allow the agent to run commands without asking you for permission. +The default value is `false`. + +```json +{ + "agent": { + "always_allow_tool_actions": "true" + } +} +``` + +> This setting is available via the Agent Panel's settings UI. + +### Single-file Review + +Control whether you want to see review actions (accept & reject) in single buffers after the agent is done performing edits. +The default value is `false`. + +```json +{ + "agent": { + "single_file_review": "true" + } +} +``` + +When set to false, these controls are only available in the multibuffer review tab. + +> This setting is available via the Agent Panel's settings UI. + +### Sound Notification + +Control whether you want to hear a notification sound when the agent is done generating changes or needs your input. +The default value is `false`. + +```json +{ + "agent": { + "play_sound_when_agent_done": "true" + } +} +``` + +> This setting is available via the Agent Panel's settings UI. + +### Modifier to Send + +Make a modifier (`cmd` on macOS, `ctrl` on Linux) required to send messages. +This is encouraged for more thoughtful prompt crafting. +The default value is `false`. + +```json +{ + "agent": { + "use_modifier_to_send": "true" + } +} +``` + +> This setting is available via the Agent Panel's settings UI. + +### Edit Card + +Use the `expand_edit_card` setting to control whether edit cards show the full diff in the Agent Panel. +It is set to `true` by default, but if set to false, the card's height is capped to a certain number of lines, requiring a click to be expanded. + +```json +{ + "agent": { + "expand_edit_card": "false" + } +} +``` + +### Terminal Card + +Use the `expand_terminal_card` setting to control whether terminal cards show the command output in the Agent Panel. +It is set to `true` by default, but if set to false, the card will be fully collapsed even while the command is running, requiring a click to be expanded. + +```json +{ + "agent": { + "expand_terminal_card": "false" + } +} +``` + +### Feedback Controls + +Control whether you want to see the thumbs up/down buttons to give Zed feedback about the agent's performance. +The default value is `true`. + +```json +{ + "agent": { + "enable_feedback": "false" + } +} +``` diff --git a/docs/src/ai/billing.md b/docs/src/ai/billing.md index c49bacd8831c2c4df384339b303bc332bb2165cd..d519b136aeea8c505979cde224d406ac995b65f0 100644 --- a/docs/src/ai/billing.md +++ b/docs/src/ai/billing.md @@ -1,7 +1,7 @@ # Billing We use Stripe as our billing and payments provider. All Pro plans require payment via credit card. -For invoice-based billing, a Business plan is required. Contact sales@zed.dev for more information. +For invoice-based billing, a Business plan is required. Contact [sales@zed.dev](mailto:sales@zed.dev) for more information. ## Settings {#settings} @@ -12,7 +12,8 @@ Clicking the button under Account Settings will navigate you to Stripe’s secur Zed is billed on a monthly basis based on the date you initially subscribe. -We’ll also bill in-month for additional prompts used beyond your plan’s prompt limit, if usage exceeds $20 before month end. See [usage-based pricing](./plans-and-usage.md#ubp) for more. +We’ll also bill in-month for additional prompts used beyond your plan’s prompt limit, if usage exceeds $20 before month end. +See [usage-based pricing](./plans-and-usage.md#ubp) for more. ## Invoice History {#invoice-history} @@ -25,3 +26,12 @@ From Stripe’s secure portal, you can download all current and historical invoi You can update your payment method, company name, address, and tax information through the billing portal. Please note that changes to billing information will **only** affect future invoices — **we cannot modify historical invoices**. + +## Sales Tax {#sales-tax} + +Zed partners with [Sphere](https://www.getsphere.com/) to calculate indirect tax rate for invoices, based on customer location and the product being sold. Tax is listed as a separate line item on invoices, based preferentially on your billing address, followed by the card issue country known to Stripe. + +If you have a VAT/GST ID, you can add it at [zed.dev/account](https://zed.dev/account) by clicking "Manage" on your subscription. Check the box that denotes you as a business. + +Please note that changes to VAT/GST IDs and address will **only** affect future invoices — **we cannot modify historical invoices**. +Questions or issues can be directed to [billing-support@zed.dev](mailto:billing-support@zed.dev). diff --git a/docs/src/ai/configuration.md b/docs/src/ai/configuration.md index ade1ae672f51944949c47e9f098c60a9a8198423..d28a7e8ed006b1c788cc0f649362bae41879a99b 100644 --- a/docs/src/ai/configuration.md +++ b/docs/src/ai/configuration.md @@ -1,682 +1,20 @@ # Configuration -There are various aspects about the Agent Panel that you can customize. -All of them can be seen by either visiting [the Configuring Zed page](../configuring-zed.md#agent) or by running the `zed: open default settings` action and searching for `"agent"`. +When using AI in Zed, you can customize several aspects: -Alternatively, you can also visit the panel's Settings view by running the `agent: open configuration` action or going to the top-right menu and hitting "Settings". +1. Which [LLM providers](./llm-providers.md) you can use +2. [Model parameters and usage](./agent-settings.md#model-settings) +3. [Interactions with the Agent Panel](./agent-settings.md#agent-panel-settings) -## LLM Providers +## Turning AI Off Entirely -Zed supports multiple large language model providers. -Here's an overview of the supported providers and tool call support: - -| Provider | Tool Use Supported | -| ----------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| [Amazon Bedrock](#amazon-bedrock) | Depends on the model | -| [Anthropic](#anthropic) | ✅ | -| [DeepSeek](#deepseek) | ✅ | -| [GitHub Copilot Chat](#github-copilot-chat) | For some models ([link](https://github.com/zed-industries/zed/blob/9e0330ba7d848755c9734bf456c716bddf0973f3/crates/language_models/src/provider/copilot_chat.rs#L189-L198)) | -| [Google AI](#google-ai) | ✅ | -| [LM Studio](#lmstudio) | ✅ | -| [Mistral](#mistral) | ✅ | -| [Ollama](#ollama) | ✅ | -| [OpenAI](#openai) | ✅ | -| [OpenAI API Compatible](#openai-api-compatible) | 🚫 | -| [OpenRouter](#openrouter) | ✅ | - -## Use Your Own Keys {#use-your-own-keys} - -While Zed offers hosted versions of models through [our various plans](./plans-and-usage.md), we're always happy to support users wanting to supply their own API keys. -Below, you can learn how to do that for each provider. - -> Using your own API keys is _free_—you do not need to subscribe to a Zed plan to use our AI features with your own keys. - -### Amazon Bedrock {#amazon-bedrock} - -> ✅ Supports tool use with models that support streaming tool use. -> More details can be found in the [Amazon Bedrock's Tool Use documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-supported-models-features.html). - -To use Amazon Bedrock's models, an AWS authentication is required. -Ensure your credentials have the following permissions set up: - -- `bedrock:InvokeModelWithResponseStream` -- `bedrock:InvokeModel` -- `bedrock:ConverseStream` - -Your IAM policy should look similar to: - -```json -{ - "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Allow", - "Action": [ - "bedrock:InvokeModel", - "bedrock:InvokeModelWithResponseStream", - "bedrock:ConverseStream" - ], - "Resource": "*" - } - ] -} -``` - -With that done, choose one of the two authentication methods: - -#### Authentication via Named Profile (Recommended) - -1. Ensure you have the AWS CLI installed and configured with a named profile -2. Open your `settings.json` (`zed: open settings`) and include the `bedrock` key under `language_models` with the following settings: - ```json - { - "language_models": { - "bedrock": { - "authentication_method": "named_profile", - "region": "your-aws-region", - "profile": "your-profile-name" - } - } - } - ``` - -#### Authentication via Static Credentials - -While it's possible to configure through the Agent Panel settings UI by entering your AWS access key and secret directly, we recommend using named profiles instead for better security practices. -To do this: - -1. Create an IAM User that you can assume in the [IAM Console](https://us-east-1.console.aws.amazon.com/iam/home?region=us-east-1#/users). -2. Create security credentials for that User, save them and keep them secure. -3. Open the Agent Configuration with (`agent: open configuration`) and go to the Amazon Bedrock section -4. Copy the credentials from Step 2 into the respective **Access Key ID**, **Secret Access Key**, and **Region** fields. - -#### Cross-Region Inference - -The Zed implementation of Amazon Bedrock uses [Cross-Region inference](https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference.html) for all the models and region combinations that support it. -With Cross-Region inference, you can distribute traffic across multiple AWS Regions, enabling higher throughput. - -For example, if you use `Claude Sonnet 3.7 Thinking` from `us-east-1`, it may be processed across the US regions, namely: `us-east-1`, `us-east-2`, or `us-west-2`. -Cross-Region inference requests are kept within the AWS Regions that are part of the geography where the data originally resides. -For example, a request made within the US is kept within the AWS Regions in the US. - -Although the data remains stored only in the source Region, your input prompts and output results might move outside of your source Region during cross-Region inference. -All data will be transmitted encrypted across Amazon's secure network. - -We will support Cross-Region inference for each of the models on a best-effort basis, please refer to the [Cross-Region Inference method Code](https://github.com/zed-industries/zed/blob/main/crates/bedrock/src/models.rs#L297). - -For the most up-to-date supported regions and models, refer to the [Supported Models and Regions for Cross Region inference](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html). - -### Anthropic {#anthropic} - -> ✅ Supports tool use - -You can use Anthropic models by choosing it via the model dropdown in the Agent Panel. - -1. Sign up for Anthropic and [create an API key](https://console.anthropic.com/settings/keys) -2. Make sure that your Anthropic account has credits -3. Open the settings view (`agent: open configuration`) and go to the Anthropic section -4. Enter your Anthropic API key - -Even if you pay for Claude Pro, you will still have to [pay for additional credits](https://console.anthropic.com/settings/plans) to use it via the API. - -Zed will also use the `ANTHROPIC_API_KEY` environment variable if it's defined. - -#### Custom Models {#anthropic-custom-models} - -You can add custom models to the Anthropic provider by adding the following to your Zed `settings.json`: - -```json -{ - "language_models": { - "anthropic": { - "available_models": [ - { - "name": "claude-3-5-sonnet-20240620", - "display_name": "Sonnet 2024-June", - "max_tokens": 128000, - "max_output_tokens": 2560, - "cache_configuration": { - "max_cache_anchors": 10, - "min_total_token": 10000, - "should_speculate": false - }, - "tool_override": "some-model-that-supports-toolcalling" - } - ] - } - } -} -``` - -Custom models will be listed in the model dropdown in the Agent Panel. - -You can configure a model to use [extended thinking](https://docs.anthropic.com/en/docs/about-claude/models/extended-thinking-models) (if it supports it) by changing the mode in your model's configuration to `thinking`, for example: - -```json -{ - "name": "claude-sonnet-4-latest", - "display_name": "claude-sonnet-4-thinking", - "max_tokens": 200000, - "mode": { - "type": "thinking", - "budget_tokens": 4_096 - } -} -``` - -### DeepSeek {#deepseek} - -> ✅ Supports tool use - -1. Visit the DeepSeek platform and [create an API key](https://platform.deepseek.com/api_keys) -2. Open the settings view (`agent: open configuration`) and go to the DeepSeek section -3. Enter your DeepSeek API key - -The DeepSeek API key will be saved in your keychain. - -Zed will also use the `DEEPSEEK_API_KEY` environment variable if it's defined. - -#### Custom Models {#deepseek-custom-models} - -The Zed agent comes pre-configured to use the latest version for common models (DeepSeek Chat, DeepSeek Reasoner). -If you wish to use alternate models or customize the API endpoint, you can do so by adding the following to your Zed `settings.json`: - -```json -{ - "language_models": { - "deepseek": { - "api_url": "https://api.deepseek.com", - "available_models": [ - { - "name": "deepseek-chat", - "display_name": "DeepSeek Chat", - "max_tokens": 64000 - }, - { - "name": "deepseek-reasoner", - "display_name": "DeepSeek Reasoner", - "max_tokens": 64000, - "max_output_tokens": 4096 - } - ] - } - } -} -``` - -Custom models will be listed in the model dropdown in the Agent Panel. -You can also modify the `api_url` to use a custom endpoint if needed. - -### GitHub Copilot Chat {#github-copilot-chat} - -> ✅ Supports tool use in some cases. -> Visit [the Copilot Chat code](https://github.com/zed-industries/zed/blob/9e0330ba7d848755c9734bf456c716bddf0973f3/crates/language_models/src/provider/copilot_chat.rs#L189-L198) for the supported subset. - -You can use GitHub Copilot Chat with the Zed agent by choosing it via the model dropdown in the Agent Panel. - -1. Open the settings view (`agent: open configuration`) and go to the GitHub Copilot Chat section -2. Click on `Sign in to use GitHub Copilot`, follow the steps shown in the modal. - -Alternatively, you can provide an OAuth token via the `GH_COPILOT_TOKEN` environment variable. - -> **Note**: If you don't see specific models in the dropdown, you may need to enable them in your [GitHub Copilot settings](https://github.com/settings/copilot/features). - -To use Copilot Enterprise with Zed (for both agent and inline completions), you must configure your enterprise endpoint as described in [Configuring GitHub Copilot Enterprise](./edit-prediction.md#github-copilot-enterprise). - -### Google AI {#google-ai} - -> ✅ Supports tool use - -You can use Gemini models with the Zed agent by choosing it via the model dropdown in the Agent Panel. - -1. Go to the Google AI Studio site and [create an API key](https://aistudio.google.com/app/apikey). -2. Open the settings view (`agent: open configuration`) and go to the Google AI section -3. Enter your Google AI API key and press enter. - -The Google AI API key will be saved in your keychain. - -Zed will also use the `GOOGLE_AI_API_KEY` environment variable if it's defined. - -#### Custom Models {#google-ai-custom-models} - -By default, Zed will use `stable` versions of models, but you can use specific versions of models, including [experimental models](https://ai.google.dev/gemini-api/docs/models/experimental-models). You can configure a model to use [thinking mode](https://ai.google.dev/gemini-api/docs/thinking) (if it supports it) by adding a `mode` configuration to your model. This is useful for controlling reasoning token usage and response speed. If not specified, Gemini will automatically choose the thinking budget. - -Here is an example of a custom Google AI model you could add to your Zed `settings.json`: - -```json -{ - "language_models": { - "google": { - "available_models": [ - { - "name": "gemini-2.5-flash-preview-05-20", - "display_name": "Gemini 2.5 Flash (Thinking)", - "max_tokens": 1000000, - "mode": { - "type": "thinking", - "budget_tokens": 24000 - } - } - ] - } - } -} -``` - -Custom models will be listed in the model dropdown in the Agent Panel. - -### LM Studio {#lmstudio} - -> ✅ Supports tool use - -1. Download and install [the latest version of LM Studio](https://lmstudio.ai/download) -2. In the app press `cmd/ctrl-shift-m` and download at least one model (e.g., qwen2.5-coder-7b). Alternatively, you can get models via the LM Studio CLI: - - ```sh - lms get qwen2.5-coder-7b - ``` - -3. Make sure the LM Studio API server is running by executing: - - ```sh - lms server start - ``` - -Tip: Set [LM Studio as a login item](https://lmstudio.ai/docs/advanced/headless#run-the-llm-service-on-machine-login) to automate running the LM Studio server. - -### Mistral {#mistral} - -> ✅ Supports tool use - -1. Visit the Mistral platform and [create an API key](https://console.mistral.ai/api-keys/) -2. Open the configuration view (`agent: open configuration`) and navigate to the Mistral section -3. Enter your Mistral API key - -The Mistral API key will be saved in your keychain. - -Zed will also use the `MISTRAL_API_KEY` environment variable if it's defined. - -#### Custom Models {#mistral-custom-models} - -The Zed agent comes pre-configured with several Mistral models (codestral-latest, mistral-large-latest, mistral-medium-latest, mistral-small-latest, open-mistral-nemo, and open-codestral-mamba). -All the default models support tool use. -If you wish to use alternate models or customize their parameters, you can do so by adding the following to your Zed `settings.json`: - -```json -{ - "language_models": { - "mistral": { - "api_url": "https://api.mistral.ai/v1", - "available_models": [ - { - "name": "mistral-tiny-latest", - "display_name": "Mistral Tiny", - "max_tokens": 32000, - "max_output_tokens": 4096, - "max_completion_tokens": 1024, - "supports_tools": true, - "supports_images": false - } - ] - } - } -} -``` - -Custom models will be listed in the model dropdown in the Agent Panel. - -### Ollama {#ollama} - -> ✅ Supports tool use - -Download and install Ollama from [ollama.com/download](https://ollama.com/download) (Linux or macOS) and ensure it's running with `ollama --version`. - -1. Download one of the [available models](https://ollama.com/models), for example, for `mistral`: - - ```sh - ollama pull mistral - ``` - -2. Make sure that the Ollama server is running. You can start it either via running Ollama.app (macOS) or launching: - - ```sh - ollama serve - ``` - -3. In the Agent Panel, select one of the Ollama models using the model dropdown. - -#### Ollama Context Length {#ollama-context} - -Zed has pre-configured maximum context lengths (`max_tokens`) to match the capabilities of common models. -Zed API requests to Ollama include this as the `num_ctx` parameter, but the default values do not exceed `16384` so users with ~16GB of RAM are able to use most models out of the box. - -See [get_max_tokens in ollama.rs](https://github.com/zed-industries/zed/blob/main/crates/ollama/src/ollama.rs) for a complete set of defaults. - -> **Note**: Token counts displayed in the Agent Panel are only estimates and will differ from the model's native tokenizer. - -Depending on your hardware or use-case you may wish to limit or increase the context length for a specific model via settings.json: - -```json -{ - "language_models": { - "ollama": { - "api_url": "http://localhost:11434", - "available_models": [ - { - "name": "qwen2.5-coder", - "display_name": "qwen 2.5 coder 32K", - "max_tokens": 32768, - "supports_tools": true, - "supports_thinking": true, - "supports_images": true - } - ] - } - } -} -``` - -If you specify a context length that is too large for your hardware, Ollama will log an error. -You can watch these logs by running: `tail -f ~/.ollama/logs/ollama.log` (macOS) or `journalctl -u ollama -f` (Linux). -Depending on the memory available on your machine, you may need to adjust the context length to a smaller value. - -You may also optionally specify a value for `keep_alive` for each available model. -This can be an integer (seconds) or alternatively a string duration like "5m", "10m", "1h", "1d", etc. -For example, `"keep_alive": "120s"` will allow the remote server to unload the model (freeing up GPU VRAM) after 120 seconds. - -The `supports_tools` option controls whether the model will use additional tools. -If the model is tagged with `tools` in the Ollama catalog, this option should be supplied, and the built-in profiles `Ask` and `Write` can be used. -If the model is not tagged with `tools` in the Ollama catalog, this option can still be supplied with the value `true`; however, be aware that only the `Minimal` built-in profile will work. - -The `supports_thinking` option controls whether the model will perform an explicit "thinking" (reasoning) pass before producing its final answer. -If the model is tagged with `thinking` in the Ollama catalog, set this option and you can use it in Zed. - -The `supports_images` option enables the model's vision capabilities, allowing it to process images included in the conversation context. -If the model is tagged with `vision` in the Ollama catalog, set this option and you can use it in Zed. - -### OpenAI {#openai} - -> ✅ Supports tool use - -1. Visit the OpenAI platform and [create an API key](https://platform.openai.com/account/api-keys) -2. Make sure that your OpenAI account has credits -3. Open the settings view (`agent: open configuration`) and go to the OpenAI section -4. Enter your OpenAI API key - -The OpenAI API key will be saved in your keychain. - -Zed will also use the `OPENAI_API_KEY` environment variable if it's defined. - -#### Custom Models {#openai-custom-models} - -The Zed agent comes pre-configured to use the latest version for common models (GPT-3.5 Turbo, GPT-4, GPT-4 Turbo, GPT-4o, GPT-4o mini). -To use alternate models, perhaps a preview release or a dated model release, or if you wish to control the request parameters, you can do so by adding the following to your Zed `settings.json`: - -```json -{ - "language_models": { - "openai": { - "available_models": [ - { - "name": "gpt-4o-2024-08-06", - "display_name": "GPT 4o Summer 2024", - "max_tokens": 128000 - }, - { - "name": "o1-mini", - "display_name": "o1-mini", - "max_tokens": 128000, - "max_completion_tokens": 20000 - } - ], - "version": "1" - } - } -} -``` - -You must provide the model's context window in the `max_tokens` parameter; this can be found in the [OpenAI model documentation](https://platform.openai.com/docs/models). - -OpenAI `o1` models should set `max_completion_tokens` as well to avoid incurring high reasoning token costs. -Custom models will be listed in the model dropdown in the Agent Panel. - -### OpenAI API Compatible {#openai-api-compatible} - -Zed supports using OpenAI compatible APIs by specifying a custom `endpoint` and `available_models` for the OpenAI provider. - -You can add a custom API URL for OpenAI either via the UI or by editing your `settings.json`. -Here are a few model examples you can plug in by using this feature: - -#### X.ai Grok - -Example configuration for using X.ai Grok with Zed: - -```json - "language_models": { - "openai": { - "api_url": "https://api.x.ai/v1", - "available_models": [ - { - "name": "grok-beta", - "display_name": "X.ai Grok (Beta)", - "max_tokens": 131072 - } - ], - "version": "1" - }, - } -``` - -### OpenRouter {#openrouter} - -> ✅ Supports tool use - -OpenRouter provides access to multiple AI models through a single API. It supports tool use for compatible models. - -1. Visit [OpenRouter](https://openrouter.ai) and create an account -2. Generate an API key from your [OpenRouter keys page](https://openrouter.ai/keys) -3. Open the settings view (`agent: open configuration`) and go to the OpenRouter section -4. Enter your OpenRouter API key - -The OpenRouter API key will be saved in your keychain. - -Zed will also use the `OPENROUTER_API_KEY` environment variable if it's defined. - -#### Custom Models {#openrouter-custom-models} - -You can add custom models to the OpenRouter provider by adding the following to your Zed `settings.json`: - -```json -{ - "language_models": { - "open_router": { - "api_url": "https://openrouter.ai/api/v1", - "available_models": [ - { - "name": "google/gemini-2.0-flash-thinking-exp", - "display_name": "Gemini 2.0 Flash (Thinking)", - "max_tokens": 200000, - "max_output_tokens": 8192, - "supports_tools": true, - "supports_images": true, - "mode": { - "type": "thinking", - "budget_tokens": 8000 - } - } - ] - } - } -} -``` - -The available configuration options for each model are: - -- `name` (required): The model identifier used by OpenRouter -- `display_name` (optional): A human-readable name shown in the UI -- `max_tokens` (required): The model's context window size -- `max_output_tokens` (optional): Maximum tokens the model can generate -- `max_completion_tokens` (optional): Maximum completion tokens -- `supports_tools` (optional): Whether the model supports tool/function calling -- `supports_images` (optional): Whether the model supports image inputs -- `mode` (optional): Special mode configuration for thinking models - -You can find available models and their specifications on the [OpenRouter models page](https://openrouter.ai/models). - -Custom models will be listed in the model dropdown in the Agent Panel. - -### Vercel v0 - -[Vercel v0](https://vercel.com/docs/v0/api) is an expert model for generating full-stack apps, with framework-aware completions optimized for modern stacks like Next.js and Vercel. -It supports text and image inputs and provides fast streaming responses. - -The v0 models are [OpenAI-compatible models](/#openai-api-compatible), but Vercel is listed as first-class provider in the panel's settings view. - -To start using it with Zed, ensure you have first created a [v0 API key](https://v0.dev/chat/settings/keys). -Once you have it, paste it directly into the Vercel provider section in the panel's settings view. - -You should then find it as `v0-1.5-md` in the model dropdown in the Agent Panel. - -## Advanced Configuration {#advanced-configuration} - -### Custom Provider Endpoints {#custom-provider-endpoint} - -You can use a custom API endpoint for different providers, as long as it's compatible with the provider's API structure. -To do so, add the following to your `settings.json`: - -```json -{ - "language_models": { - "some-provider": { - "api_url": "http://localhost:11434" - } - } -} -``` - -Where `some-provider` can be any of the following values: `anthropic`, `google`, `ollama`, `openai`. - -### Default Model {#default-model} - -Zed's hosted LLM service sets `claude-sonnet-4` as the default model. -However, you can change it either via the model dropdown in the Agent Panel's bottom-right corner or by manually editing the `default_model` object in your settings: - -```json -{ - "agent": { - "version": "2", - "default_model": { - "provider": "zed.dev", - "model": "gpt-4o" - } - } -} -``` - -### Feature-specific Models {#feature-specific-models} - -If a feature-specific model is not set, it will fall back to using the default model, which is the one you set on the Agent Panel. - -You can configure the following feature-specific models: - -- Thread summary model: Used for generating thread summaries -- Inline assistant model: Used for the inline assistant feature -- Commit message model: Used for generating Git commit messages - -Example configuration: - -```json -{ - "agent": { - "version": "2", - "default_model": { - "provider": "zed.dev", - "model": "claude-sonnet-4" - }, - "inline_assistant_model": { - "provider": "anthropic", - "model": "claude-3-5-sonnet" - }, - "commit_message_model": { - "provider": "openai", - "model": "gpt-4o-mini" - }, - "thread_summary_model": { - "provider": "google", - "model": "gemini-2.0-flash" - } - } -} -``` - -### Alternative Models for Inline Assists {#alternative-assists} - -You can configure additional models that will be used to perform inline assists in parallel. -When you do this, the inline assist UI will surface controls to cycle between the alternatives generated by each model. - -The models you specify here are always used in _addition_ to your [default model](#default-model). -For example, the following configuration will generate two outputs for every assist. -One with Claude 3.7 Sonnet, and one with GPT-4o. - -```json -{ - "agent": { - "default_model": { - "provider": "zed.dev", - "model": "claude-sonnet-4" - }, - "inline_alternatives": [ - { - "provider": "zed.dev", - "model": "gpt-4o" - } - ], - "version": "2" - } -} -``` - -### Default View - -Use the `default_view` setting to set change the default view of the Agent Panel. -You can choose between `thread` (the default) and `text_thread`: - -```json -{ - "agent": { - "default_view": "text_thread" - } -} -``` - -### Edit Card - -Use the `expand_edit_card` setting to control whether edit cards show the full diff in the Agent Panel. -It is set to `true` by default, but if set to false, the card's height is capped to a certain number of lines, requiring a click to be expanded. - -```json -{ - "agent": { - "expand_edit_card": "false" - } -} -``` - -This setting is currently only available in Preview. -It should be up in Stable by the next release. - -### Terminal Card - -Use the `expand_terminal_card` setting to control whether terminal cards show the command output in the Agent Panel. -It is set to `true` by default, but if set to false, the card will be fully collapsed even while the command is running, requiring a click to be expanded. +We want to respect users who want to use Zed without interacting with AI whatsoever. +To do that, add the following key to your `settings.json`: ```json { - "agent": { - "expand_terminal_card": "false" - } + "disable_ai": true } ``` -This setting is currently only available in Preview. -It should be up in Stable by the next release. +Read [the following blog post](https://zed.dev/blog/disable-ai-features) to learn more about our motivation to promote this, as much as we also encourage users to explore AI-assisted programming. diff --git a/docs/src/ai/inline-assistant.md b/docs/src/ai/inline-assistant.md index cd0ace3ce67876990f02f2618ec53aea4c391a03..da894e2cd87faf6ce8afa9c54a5f2d55bcd07827 100644 --- a/docs/src/ai/inline-assistant.md +++ b/docs/src/ai/inline-assistant.md @@ -12,7 +12,7 @@ You can also perform multiple generation requests in parallel by pressing `ctrl- Give the Inline Assistant context the same way you can in [the Agent Panel](./agent-panel.md), allowing you to provide additional instructions or rules for code transformations with @-mentions. -A useful pattern here is to create a thread in the Agent Panel, and then use the mention that thread with `@thread` in the Inline Assistant to include it as context. +A useful pattern here is to create a thread in the Agent Panel, and then mention that thread with `@thread` in the Inline Assistant to include it as context. > The Inline Assistant is limited to normal mode context windows ([see Models](./models.md) for more). diff --git a/docs/src/ai/llm-providers.md b/docs/src/ai/llm-providers.md new file mode 100644 index 0000000000000000000000000000000000000000..8fdb7ea325db5daea3949b8cb25d85c09487f3cc --- /dev/null +++ b/docs/src/ai/llm-providers.md @@ -0,0 +1,583 @@ +# LLM Providers + +To use AI in Zed, you need to have at least one large language model provider set up. + +You can do that by either subscribing to [one of Zed's plans](./plans-and-usage.md), or by using API keys you already have for the supported providers. + +## Use Your Own Keys {#use-your-own-keys} + +If you already have an API key for an existing LLM provider—say Anthropic or OpenAI, for example—you can insert them in Zed and use the Agent Panel **_for free_**. + +You can add your API key to a given provider either via the Agent Panel's settings UI or directly via the `settings.json` through the `language_models` key. + +## Supported Providers + +Here's all the supported LLM providers for which you can use your own API keys: + +| Provider | +| ----------------------------------------------- | +| [Amazon Bedrock](#amazon-bedrock) | +| [Anthropic](#anthropic) | +| [DeepSeek](#deepseek) | +| [GitHub Copilot Chat](#github-copilot-chat) | +| [Google AI](#google-ai) | +| [LM Studio](#lmstudio) | +| [Mistral](#mistral) | +| [Ollama](#ollama) | +| [OpenAI](#openai) | +| [OpenAI API Compatible](#openai-api-compatible) | +| [OpenRouter](#openrouter) | +| [Vercel](#vercel-v0) | +| [xAI](#xai) | + +### Amazon Bedrock {#amazon-bedrock} + +> Supports tool use with models that support streaming tool use. +> More details can be found in the [Amazon Bedrock's Tool Use documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-supported-models-features.html). + +To use Amazon Bedrock's models, an AWS authentication is required. +Ensure your credentials have the following permissions set up: + +- `bedrock:InvokeModelWithResponseStream` +- `bedrock:InvokeModel` +- `bedrock:ConverseStream` + +Your IAM policy should look similar to: + +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "bedrock:InvokeModel", + "bedrock:InvokeModelWithResponseStream", + "bedrock:ConverseStream" + ], + "Resource": "*" + } + ] +} +``` + +With that done, choose one of the two authentication methods: + +#### Authentication via Named Profile (Recommended) + +1. Ensure you have the AWS CLI installed and configured with a named profile +2. Open your `settings.json` (`zed: open settings`) and include the `bedrock` key under `language_models` with the following settings: + ```json + { + "language_models": { + "bedrock": { + "authentication_method": "named_profile", + "region": "your-aws-region", + "profile": "your-profile-name" + } + } + } + ``` + +#### Authentication via Static Credentials + +While it's possible to configure through the Agent Panel settings UI by entering your AWS access key and secret directly, we recommend using named profiles instead for better security practices. +To do this: + +1. Create an IAM User that you can assume in the [IAM Console](https://us-east-1.console.aws.amazon.com/iam/home?region=us-east-1#/users). +2. Create security credentials for that User, save them and keep them secure. +3. Open the Agent Configuration with (`agent: open settings`) and go to the Amazon Bedrock section +4. Copy the credentials from Step 2 into the respective **Access Key ID**, **Secret Access Key**, and **Region** fields. + +#### Cross-Region Inference + +The Zed implementation of Amazon Bedrock uses [Cross-Region inference](https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference.html) for all the models and region combinations that support it. +With Cross-Region inference, you can distribute traffic across multiple AWS Regions, enabling higher throughput. + +For example, if you use `Claude Sonnet 3.7 Thinking` from `us-east-1`, it may be processed across the US regions, namely: `us-east-1`, `us-east-2`, or `us-west-2`. +Cross-Region inference requests are kept within the AWS Regions that are part of the geography where the data originally resides. +For example, a request made within the US is kept within the AWS Regions in the US. + +Although the data remains stored only in the source Region, your input prompts and output results might move outside of your source Region during cross-Region inference. +All data will be transmitted encrypted across Amazon's secure network. + +We will support Cross-Region inference for each of the models on a best-effort basis, please refer to the [Cross-Region Inference method Code](https://github.com/zed-industries/zed/blob/main/crates/bedrock/src/models.rs#L297). + +For the most up-to-date supported regions and models, refer to the [Supported Models and Regions for Cross Region inference](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html). + +### Anthropic {#anthropic} + +You can use Anthropic models by choosing them via the model dropdown in the Agent Panel. + +1. Sign up for Anthropic and [create an API key](https://console.anthropic.com/settings/keys) +2. Make sure that your Anthropic account has credits +3. Open the settings view (`agent: open settings`) and go to the Anthropic section +4. Enter your Anthropic API key + +Even if you pay for Claude Pro, you will still have to [pay for additional credits](https://console.anthropic.com/settings/plans) to use it via the API. + +Zed will also use the `ANTHROPIC_API_KEY` environment variable if it's defined. + +#### Custom Models {#anthropic-custom-models} + +You can add custom models to the Anthropic provider by adding the following to your Zed `settings.json`: + +```json +{ + "language_models": { + "anthropic": { + "available_models": [ + { + "name": "claude-3-5-sonnet-20240620", + "display_name": "Sonnet 2024-June", + "max_tokens": 128000, + "max_output_tokens": 2560, + "cache_configuration": { + "max_cache_anchors": 10, + "min_total_token": 10000, + "should_speculate": false + }, + "tool_override": "some-model-that-supports-toolcalling" + } + ] + } + } +} +``` + +Custom models will be listed in the model dropdown in the Agent Panel. + +You can configure a model to use [extended thinking](https://docs.anthropic.com/en/docs/about-claude/models/extended-thinking-models) (if it supports it) by changing the mode in your model's configuration to `thinking`, for example: + +```json +{ + "name": "claude-sonnet-4-latest", + "display_name": "claude-sonnet-4-thinking", + "max_tokens": 200000, + "mode": { + "type": "thinking", + "budget_tokens": 4_096 + } +} +``` + +### DeepSeek {#deepseek} + +1. Visit the DeepSeek platform and [create an API key](https://platform.deepseek.com/api_keys) +2. Open the settings view (`agent: open settings`) and go to the DeepSeek section +3. Enter your DeepSeek API key + +The DeepSeek API key will be saved in your keychain. + +Zed will also use the `DEEPSEEK_API_KEY` environment variable if it's defined. + +#### Custom Models {#deepseek-custom-models} + +The Zed agent comes pre-configured to use the latest version for common models (DeepSeek Chat, DeepSeek Reasoner). +If you wish to use alternate models or customize the API endpoint, you can do so by adding the following to your Zed `settings.json`: + +```json +{ + "language_models": { + "deepseek": { + "api_url": "https://api.deepseek.com", + "available_models": [ + { + "name": "deepseek-chat", + "display_name": "DeepSeek Chat", + "max_tokens": 64000 + }, + { + "name": "deepseek-reasoner", + "display_name": "DeepSeek Reasoner", + "max_tokens": 64000, + "max_output_tokens": 4096 + } + ] + } + } +} +``` + +Custom models will be listed in the model dropdown in the Agent Panel. +You can also modify the `api_url` to use a custom endpoint if needed. + +### GitHub Copilot Chat {#github-copilot-chat} + +You can use GitHub Copilot Chat with the Zed agent by choosing it via the model dropdown in the Agent Panel. + +1. Open the settings view (`agent: open settings`) and go to the GitHub Copilot Chat section +2. Click on `Sign in to use GitHub Copilot`, follow the steps shown in the modal. + +Alternatively, you can provide an OAuth token via the `GH_COPILOT_TOKEN` environment variable. + +> **Note**: If you don't see specific models in the dropdown, you may need to enable them in your [GitHub Copilot settings](https://github.com/settings/copilot/features). + +To use Copilot Enterprise with Zed (for both agent and completions), you must configure your enterprise endpoint as described in [Configuring GitHub Copilot Enterprise](./edit-prediction.md#github-copilot-enterprise). + +### Google AI {#google-ai} + +You can use Gemini models with the Zed agent by choosing it via the model dropdown in the Agent Panel. + +1. Go to the Google AI Studio site and [create an API key](https://aistudio.google.com/app/apikey). +2. Open the settings view (`agent: open settings`) and go to the Google AI section +3. Enter your Google AI API key and press enter. + +The Google AI API key will be saved in your keychain. + +Zed will also use the `GEMINI_API_KEY` environment variable if it's defined. See [Using Gemini API keys](https://ai.google.dev/gemini-api/docs/api-key) in the Gemini docs for more. + +#### Custom Models {#google-ai-custom-models} + +By default, Zed will use `stable` versions of models, but you can use specific versions of models, including [experimental models](https://ai.google.dev/gemini-api/docs/models/experimental-models). You can configure a model to use [thinking mode](https://ai.google.dev/gemini-api/docs/thinking) (if it supports it) by adding a `mode` configuration to your model. This is useful for controlling reasoning token usage and response speed. If not specified, Gemini will automatically choose the thinking budget. + +Here is an example of a custom Google AI model you could add to your Zed `settings.json`: + +```json +{ + "language_models": { + "google": { + "available_models": [ + { + "name": "gemini-2.5-flash-preview-05-20", + "display_name": "Gemini 2.5 Flash (Thinking)", + "max_tokens": 1000000, + "mode": { + "type": "thinking", + "budget_tokens": 24000 + } + } + ] + } + } +} +``` + +Custom models will be listed in the model dropdown in the Agent Panel. + +### LM Studio {#lmstudio} + +1. Download and install [the latest version of LM Studio](https://lmstudio.ai/download) +2. In the app press `cmd/ctrl-shift-m` and download at least one model (e.g., qwen2.5-coder-7b). Alternatively, you can get models via the LM Studio CLI: + + ```sh + lms get qwen2.5-coder-7b + ``` + +3. Make sure the LM Studio API server is running by executing: + + ```sh + lms server start + ``` + +Tip: Set [LM Studio as a login item](https://lmstudio.ai/docs/advanced/headless#run-the-llm-service-on-machine-login) to automate running the LM Studio server. + +### Mistral {#mistral} + +1. Visit the Mistral platform and [create an API key](https://console.mistral.ai/api-keys/) +2. Open the configuration view (`agent: open settings`) and navigate to the Mistral section +3. Enter your Mistral API key + +The Mistral API key will be saved in your keychain. + +Zed will also use the `MISTRAL_API_KEY` environment variable if it's defined. + +#### Custom Models {#mistral-custom-models} + +The Zed agent comes pre-configured with several Mistral models (codestral-latest, mistral-large-latest, mistral-medium-latest, mistral-small-latest, open-mistral-nemo, and open-codestral-mamba). +All the default models support tool use. +If you wish to use alternate models or customize their parameters, you can do so by adding the following to your Zed `settings.json`: + +```json +{ + "language_models": { + "mistral": { + "api_url": "https://api.mistral.ai/v1", + "available_models": [ + { + "name": "mistral-tiny-latest", + "display_name": "Mistral Tiny", + "max_tokens": 32000, + "max_output_tokens": 4096, + "max_completion_tokens": 1024, + "supports_tools": true, + "supports_images": false + } + ] + } + } +} +``` + +Custom models will be listed in the model dropdown in the Agent Panel. + +### Ollama {#ollama} + +Download and install Ollama from [ollama.com/download](https://ollama.com/download) (Linux or macOS) and ensure it's running with `ollama --version`. + +1. Download one of the [available models](https://ollama.com/models), for example, for `mistral`: + + ```sh + ollama pull mistral + ``` + +2. Make sure that the Ollama server is running. You can start it either via running Ollama.app (macOS) or launching: + + ```sh + ollama serve + ``` + +3. In the Agent Panel, select one of the Ollama models using the model dropdown. + +#### Ollama Context Length {#ollama-context} + +Zed has pre-configured maximum context lengths (`max_tokens`) to match the capabilities of common models. +Zed API requests to Ollama include this as the `num_ctx` parameter, but the default values do not exceed `16384` so users with ~16GB of RAM are able to use most models out of the box. + +See [get_max_tokens in ollama.rs](https://github.com/zed-industries/zed/blob/main/crates/ollama/src/ollama.rs) for a complete set of defaults. + +> **Note**: Token counts displayed in the Agent Panel are only estimates and will differ from the model's native tokenizer. + +Depending on your hardware or use-case you may wish to limit or increase the context length for a specific model via settings.json: + +```json +{ + "language_models": { + "ollama": { + "api_url": "http://localhost:11434", + "available_models": [ + { + "name": "qwen2.5-coder", + "display_name": "qwen 2.5 coder 32K", + "max_tokens": 32768, + "supports_tools": true, + "supports_thinking": true, + "supports_images": true + } + ] + } + } +} +``` + +If you specify a context length that is too large for your hardware, Ollama will log an error. +You can watch these logs by running: `tail -f ~/.ollama/logs/ollama.log` (macOS) or `journalctl -u ollama -f` (Linux). +Depending on the memory available on your machine, you may need to adjust the context length to a smaller value. + +You may also optionally specify a value for `keep_alive` for each available model. +This can be an integer (seconds) or alternatively a string duration like "5m", "10m", "1h", "1d", etc. +For example, `"keep_alive": "120s"` will allow the remote server to unload the model (freeing up GPU VRAM) after 120 seconds. + +The `supports_tools` option controls whether the model will use additional tools. +If the model is tagged with `tools` in the Ollama catalog, this option should be supplied, and the built-in profiles `Ask` and `Write` can be used. +If the model is not tagged with `tools` in the Ollama catalog, this option can still be supplied with the value `true`; however, be aware that only the `Minimal` built-in profile will work. + +The `supports_thinking` option controls whether the model will perform an explicit "thinking" (reasoning) pass before producing its final answer. +If the model is tagged with `thinking` in the Ollama catalog, set this option and you can use it in Zed. + +The `supports_images` option enables the model's vision capabilities, allowing it to process images included in the conversation context. +If the model is tagged with `vision` in the Ollama catalog, set this option and you can use it in Zed. + +### OpenAI {#openai} + +1. Visit the OpenAI platform and [create an API key](https://platform.openai.com/account/api-keys) +2. Make sure that your OpenAI account has credits +3. Open the settings view (`agent: open settings`) and go to the OpenAI section +4. Enter your OpenAI API key + +The OpenAI API key will be saved in your keychain. + +Zed will also use the `OPENAI_API_KEY` environment variable if it's defined. + +#### Custom Models {#openai-custom-models} + +The Zed agent comes pre-configured to use the latest version for common models (GPT-3.5 Turbo, GPT-4, GPT-4 Turbo, GPT-4o, GPT-4o mini). +To use alternate models, perhaps a preview release or a dated model release, or if you wish to control the request parameters, you can do so by adding the following to your Zed `settings.json`: + +```json +{ + "language_models": { + "openai": { + "available_models": [ + { + "name": "gpt-4o-2024-08-06", + "display_name": "GPT 4o Summer 2024", + "max_tokens": 128000 + }, + { + "name": "o1-mini", + "display_name": "o1-mini", + "max_tokens": 128000, + "max_completion_tokens": 20000 + } + ], + "version": "1" + } + } +} +``` + +You must provide the model's context window in the `max_tokens` parameter; this can be found in the [OpenAI model documentation](https://platform.openai.com/docs/models). + +OpenAI `o1` models should set `max_completion_tokens` as well to avoid incurring high reasoning token costs. +Custom models will be listed in the model dropdown in the Agent Panel. + +### OpenAI API Compatible {#openai-api-compatible} + +Zed supports using [OpenAI compatible APIs](https://platform.openai.com/docs/api-reference/chat) by specifying a custom `api_url` and `available_models` for the OpenAI provider. +This is useful for connecting to other hosted services (like Together AI, Anyscale, etc.) or local models. + +You can add a custom, OpenAI-compatible model via either via the UI or by editing your `settings.json`. + +To do it via the UI, go to the Agent Panel settings (`agent: open settings`) and look for the "Add Provider" button to the right of the "LLM Providers" section title. +Then, fill up the input fields available in the modal. + +To do it via your `settings.json`, add the following snippet under `language_models`: + +```json +{ + "language_models": { + "openai": { + "api_url": "https://api.together.xyz/v1", // Using Together AI as an example + "available_models": [ + { + "name": "mistralai/Mixtral-8x7B-Instruct-v0.1", + "display_name": "Together Mixtral 8x7B", + "max_tokens": 32768 + } + ] + } + } +} +``` + +Note that LLM API keys aren't stored in your settings file. +So, ensure you have it set in your environment variables (`OPENAI_API_KEY=`) so your settings can pick it up. + +### OpenRouter {#openrouter} + +OpenRouter provides access to multiple AI models through a single API. It supports tool use for compatible models. + +1. Visit [OpenRouter](https://openrouter.ai) and create an account +2. Generate an API key from your [OpenRouter keys page](https://openrouter.ai/keys) +3. Open the settings view (`agent: open settings`) and go to the OpenRouter section +4. Enter your OpenRouter API key + +The OpenRouter API key will be saved in your keychain. + +Zed will also use the `OPENROUTER_API_KEY` environment variable if it's defined. + +#### Custom Models {#openrouter-custom-models} + +You can add custom models to the OpenRouter provider by adding the following to your Zed `settings.json`: + +```json +{ + "language_models": { + "open_router": { + "api_url": "https://openrouter.ai/api/v1", + "available_models": [ + { + "name": "google/gemini-2.0-flash-thinking-exp", + "display_name": "Gemini 2.0 Flash (Thinking)", + "max_tokens": 200000, + "max_output_tokens": 8192, + "supports_tools": true, + "supports_images": true, + "mode": { + "type": "thinking", + "budget_tokens": 8000 + } + } + ] + } + } +} +``` + +The available configuration options for each model are: + +- `name` (required): The model identifier used by OpenRouter +- `display_name` (optional): A human-readable name shown in the UI +- `max_tokens` (required): The model's context window size +- `max_output_tokens` (optional): Maximum tokens the model can generate +- `max_completion_tokens` (optional): Maximum completion tokens +- `supports_tools` (optional): Whether the model supports tool/function calling +- `supports_images` (optional): Whether the model supports image inputs +- `mode` (optional): Special mode configuration for thinking models + +You can find available models and their specifications on the [OpenRouter models page](https://openrouter.ai/models). + +Custom models will be listed in the model dropdown in the Agent Panel. + +### Vercel v0 {#vercel-v0} + +[Vercel v0](https://vercel.com/docs/v0/api) is an expert model for generating full-stack apps, with framework-aware completions optimized for modern stacks like Next.js and Vercel. +It supports text and image inputs and provides fast streaming responses. + +The v0 models are [OpenAI-compatible models](/#openai-api-compatible), but Vercel is listed as first-class provider in the panel's settings view. + +To start using it with Zed, ensure you have first created a [v0 API key](https://v0.dev/chat/settings/keys). +Once you have it, paste it directly into the Vercel provider section in the panel's settings view. + +You should then find it as `v0-1.5-md` in the model dropdown in the Agent Panel. + +### xAI {#xai} + +Zed has first-class support for [xAI](https://x.ai/) models. You can use your own API key to access Grok models. + +1. [Create an API key in the xAI Console](https://console.x.ai/team/default/api-keys) +2. Open the settings view (`agent: open settings`) and go to the **xAI** section +3. Enter your xAI API key + +The xAI API key will be saved in your keychain. Zed will also use the `XAI_API_KEY` environment variable if it's defined. + +> **Note:** While the xAI API is OpenAI-compatible, Zed has first-class support for it as a dedicated provider. For the best experience, we recommend using the dedicated `x_ai` provider configuration instead of the [OpenAI API Compatible](#openai-api-compatible) method. + +#### Custom Models {#xai-custom-models} + +The Zed agent comes pre-configured with common Grok models. If you wish to use alternate models or customize their parameters, you can do so by adding the following to your Zed `settings.json`: + +```json +{ + "language_models": { + "x_ai": { + "api_url": "https://api.x.ai/v1", + "available_models": [ + { + "name": "grok-1.5", + "display_name": "Grok 1.5", + "max_tokens": 131072, + "max_output_tokens": 8192 + }, + { + "name": "grok-1.5v", + "display_name": "Grok 1.5V (Vision)", + "max_tokens": 131072, + "max_output_tokens": 8192, + "supports_images": true + } + ] + } + } +} +``` + +## Custom Provider Endpoints {#custom-provider-endpoint} + +You can use a custom API endpoint for different providers, as long as it's compatible with the provider's API structure. +To do so, add the following to your `settings.json`: + +```json +{ + "language_models": { + "some-provider": { + "api_url": "http://localhost:11434" + } + } +} +``` + +Currently, `some-provider` can be any of the following values: `anthropic`, `google`, `ollama`, `openai`. + +This is the same infrastructure that powers models that are, for example, [OpenAI-compatible](#openai-api-compatible). diff --git a/docs/src/ai/mcp.md b/docs/src/ai/mcp.md index 202b14102209ae8d3dbf338ff11bb8b443432cf9..dfe3e4bdb904b911fd247b6f4eb7cde09d46a9a9 100644 --- a/docs/src/ai/mcp.md +++ b/docs/src/ai/mcp.md @@ -4,35 +4,35 @@ Zed uses the [Model Context Protocol](https://modelcontextprotocol.io/) to inter > The Model Context Protocol (MCP) is an open protocol that enables seamless integration between LLM applications and external data sources and tools. Whether you're building an AI-powered IDE, enhancing a chat interface, or creating custom AI workflows, MCP provides a standardized way to connect LLMs with the context they need. -Check out the [Anthropic news post](https://www.anthropic.com/news/model-context-protocol) and the [Zed blog post](https://zed.dev/blog/mcp) for an introduction to MCP. +Check out the [Anthropic news post](https://www.anthropic.com/news/model-context-protocol) and the [Zed blog post](https://zed.dev/blog/mcp) for a general intro to MCP. -## MCP Servers as Extensions +## Installing MCP Servers -One of the ways you can use MCP servers in Zed is by exposing them as an extension. -To learn how to do that, check out the [MCP Server Extensions](../extensions/mcp-extensions.md) page for more details. +### As Extensions -### Available extensions +One of the ways you can use MCP servers in Zed is by exposing them as an extension. +To learn how to create your own, check out the [MCP Server Extensions](../extensions/mcp-extensions.md) page for more details. -Many MCP servers have been exposed as extensions already, thanks to Zed's awesome community. -Check which ones are already available in Zed's extension store via any of these routes: +Thanks to our awesome community, many MCP servers have already been added as extensions. +You can check which ones are available via any of these routes: 1. [the Zed website](https://zed.dev/extensions?filter=context-servers) -2. in the app, run the `zed: extensions` action +2. in the app, open the Command Palette and run the `zed: extensions` action 3. in the app, go to the Agent Panel's top-right menu and look for the "View Server Extensions" menu item In any case, here are some of the ones available: -- [Postgres](https://github.com/zed-extensions/postgres-context-server) -- [GitHub](https://github.com/LoamStudios/zed-mcp-server-github) -- [Puppeteer](https://github.com/zed-extensions/mcp-server-puppeteer) -- [BrowserTools](https://github.com/mirageN1349/browser-tools-context-server) -- [Brave Search](https://github.com/zed-extensions/mcp-server-brave-search) +- [Context7](https://zed.dev/extensions/context7-mcp-server) +- [GitHub](https://zed.dev/extensions/github-mcp-server) +- [Puppeteer](https://zed.dev/extensions/puppeteer-mcp-server) +- [Gem](https://zed.dev/extensions/gem) +- [Brave Search](https://zed.dev/extensions/brave-search-mcp-server) - [Prisma](https://github.com/aqrln/prisma-mcp-zed) -- [Framelink Figma](https://github.com/LoamStudios/zed-mcp-server-figma) -- [Linear](https://github.com/LoamStudios/zed-mcp-server-linear) -- [Resend](https://github.com/danilo-leal/zed-resend-mcp-server) +- [Framelink Figma](https://zed.dev/extensions/framelink-figma-mcp-server) +- [Linear](https://zed.dev/extensions/linear-mcp-server) +- [Resend](https://zed.dev/extensions/resend-mcp-server) -## Add your own MCP server +### As Custom Servers Creating an extension is not the only way to use MCP servers in Zed. You can connect them by adding their commands directly to your `settings.json`, like so: @@ -50,5 +50,78 @@ You can connect them by adding their commands directly to your `settings.json`, } ``` -Alternatively, you can also add a custom server by accessing the Agent Panel's Settings view (also accessible via the `agent: open configuration` action). -From there, you can add it through the modal that appears when clicking the "Add Custom Server" button. +Alternatively, you can also add a custom server by accessing the Agent Panel's Settings view (also accessible via the `agent: open settings` action). +From there, you can add it through the modal that appears when you click the "Add Custom Server" button. + +## Using MCP Servers + +### Installation Check + +Regardless of whether you're using MCP servers as an extension or adding them directly, most servers out there need some sort of configuration as part of the set up process. + +In the case of extensions, Zed will show a modal displaying what is required for you to properly set up a given server. +For example, the GitHub MCP extension requires you to add a [Personal Access Token](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens). + +In the case of custom servers, make sure you check the provider documentation to determine what type of command, arguments, and environment variables need to be added to the JSON. + +To check whether your MCP server is properly installed, go to the Agent Panel's settings view and watch the indicator dot next to its name. +If they're running correctly, the indicator will be green and its tooltip will say "Server is active". +If not, other colors and tooltip messages will indicate what is happening. + +### Using in the Agent Panel + +Once installation is complete, you can return to the Agent Panel and start prompting. +Mentioning your MCP server by name helps the agent pick it up. + +If you want to ensure a given server will be used, you can create [a custom profile](./agent-panel.md#custom-profiles) by turning off the built-in tools (either all of them or the ones that would cause conflicts) and turning on only the tools coming from the MCP server. + +As an example, [the Dagger team suggests](https://container-use.com/agent-integrations#add-container-use-agent-profile-optional) doing that with their [Container Use MCP server](https://zed.dev/extensions/mcp-server-container-use): + +```json +"agent": { + "profiles": { + "container-use": { + "name": "Container Use", + "tools": { + "fetch": true, + "thinking": true, + "copy_path": false, + "find_path": false, + "delete_path": false, + "create_directory": false, + "list_directory": false, + "diagnostics": false, + "read_file": false, + "open": false, + "move_path": false, + "grep": false, + "edit_file": false, + "terminal": false + }, + "enable_all_context_servers": false, + "context_servers": { + "container-use": { + "tools": { + "environment_create": true, + "environment_add_service": true, + "environment_update": true, + "environment_run_cmd": true, + "environment_open": true, + "environment_file_write": true, + "environment_file_read": true, + "environment_file_list": true, + "environment_file_delete": true, + "environment_checkpoint": true + } + } + } + } + } +} +``` + +### Tool Approval + +Zed's Agent Panel includes the `agent.always_allow_tool_actions` setting that, if set to `false`, will require you to give permission for any editing attempt as well as tool calls coming from MCP servers. + +You can change this by setting this key to `true` in either your `settings.json` or through the Agent Panel's settings view. diff --git a/docs/src/ai/overview.md b/docs/src/ai/overview.md index f437b24ba6ecf3eb52523ea56b33e16c8f14c01d..6f081cb243ffcfb77b4304373df67865cc71ee10 100644 --- a/docs/src/ai/overview.md +++ b/docs/src/ai/overview.md @@ -1,15 +1,12 @@ # AI -Zed smoothly integrates LLMs in multiple ways across the editor. -Learn how to get started with AI on Zed and all its capabilities. +Learn how to get started using AI with Zed and all its capabilities. ## Setting up AI in Zed - [Configuration](./configuration.md): Learn how to set up different language model providers like Anthropic, OpenAI, Ollama, Google AI, and more. -- [Models](./models.md): Learn about the various language models available in Zed. - -- [Subscription](./subscription.md): Learn about Zed's subscriptions and other billing-related information. +- [Subscription](./subscription.md): Learn about Zed's hosted model service and other billing-related information. - [Privacy and Security](./privacy-and-security.md): Understand how Zed handles privacy and security with AI features. diff --git a/docs/src/ai/plans-and-usage.md b/docs/src/ai/plans-and-usage.md index a1da17f50de5057740e8f0d52d87f94afe3c13e9..1e6616c79b80489b91e4f92c13b9c5fe39ff1af5 100644 --- a/docs/src/ai/plans-and-usage.md +++ b/docs/src/ai/plans-and-usage.md @@ -11,7 +11,7 @@ Please note that if you’re interested in just using Zed as the world’s faste ## Usage {#usage} -- A `prompt` in Zed is an input from the user, initiated on pressing enter, composed of one or many `requests`. A `prompt` can be initiated from the Agent Panel, or via Inline Assist. +- A `prompt` in Zed is an input from the user, initiated by pressing enter, composed of one or many `requests`. A `prompt` can be initiated from the Agent Panel, or via Inline Assist. - A `request` in Zed is a response to a `prompt`, plus any tool calls that are initiated as part of that response. There may be one `request` per `prompt`, or many. Most models offered by Zed are metered per-prompt. diff --git a/docs/src/ai/rules.md b/docs/src/ai/rules.md index ed916874cadb957ca45d02af00d3a4047ebd3246..653b907a7d6463025b5a4f2f8b91debfe6d749e1 100644 --- a/docs/src/ai/rules.md +++ b/docs/src/ai/rules.md @@ -5,7 +5,7 @@ Currently, Zed supports `.rules` files at the directory's root and the Rules Lib ## `.rules` files -Zed supports including `.rules` files at the top level of worktrees, and act as project-level instructions that are included in all of your interactions with the Agent Panel. +Zed supports including `.rules` files at the top level of worktrees, and they act as project-level instructions that are included in all of your interactions with the Agent Panel. Other names for this file are also supported for compatibility with other agents, but note that the first file which matches in this list will be used: - `.rules` diff --git a/docs/src/ai/temperature.md b/docs/src/ai/temperature.md deleted file mode 100644 index bb0cef6b517e73712b3531fda45d7d1c6cd6f788..0000000000000000000000000000000000000000 --- a/docs/src/ai/temperature.md +++ /dev/null @@ -1,23 +0,0 @@ -# Model Temperature - -Zed's settings allow you to specify a custom temperature for a provider and/or model: - -```json -"model_parameters": [ - // To set parameters for all requests to OpenAI models: - { - "provider": "openai", - "temperature": 0.5 - }, - // To set parameters for all requests in general: - { - "temperature": 0 - }, - // To set parameters for a specific provider and model: - { - "provider": "zed.dev", - "model": "claude-sonnet-4", - "temperature": 1.0 - } - ], -``` diff --git a/docs/src/configuring-languages.md b/docs/src/configuring-languages.md index 42128cad6f6e2ad3879d570051a5aeded47520ea..52b7a3f7b82aeb3f2f19dcd63ef64c34251f1cd8 100644 --- a/docs/src/configuring-languages.md +++ b/docs/src/configuring-languages.md @@ -221,11 +221,11 @@ Most of the servers would rely on this way of configuring only. Apart of the LSP-related server configuration options, certain servers in Zed allow configuring the way binary is launched by Zed. -Languages mention in the documentation, whether they support it or not and their defaults for the configuration values: +Language servers are automatically downloaded or launched if found in your path, if you wish to specify an explicit alternate binary you can specify that in settings: ```json - "languages": { - "Markdown": { + "lsp": { + "rust-analyzer": { "binary": { // Whether to fetch the binary from the internet, or attempt to find locally. "ignore_system_version": false, diff --git a/docs/src/configuring-zed.md b/docs/src/configuring-zed.md index eec9da60dd96d21652d78a8c4d0c0dfca17c207a..5fd27abad67574e509e8e3fc17fdc2ca4cf487e1 100644 --- a/docs/src/configuring-zed.md +++ b/docs/src/configuring-zed.md @@ -639,6 +639,12 @@ List of `string` values "snippet_sort_order": "bottom" ``` +4. Do not show snippets in the completion list at all: + +```json +"snippet_sort_order": "none" +``` + ## Editor Scrollbar - Description: Whether or not to show the editor scrollbar and various elements in it. @@ -2582,6 +2588,7 @@ List of `integer` column numbers "font_features": null, "font_size": null, "line_height": "comfortable", + "minimum_contrast": 45, "option_as_meta": false, "button": true, "shell": "system", @@ -2692,6 +2699,54 @@ List of `integer` column numbers } ``` +### Terminal: Cursor Shape + +- Description: Whether or not selecting text in the terminal will automatically copy to the system clipboard. +- Setting: `cursor_shape` +- Default: `null` (defaults to block) + +**Options** + +1. A block that surrounds the following character + +```json +{ + "terminal": { + "cursor_shape": "block" + } +} +``` + +2. A vertical bar + +```json +{ + "terminal": { + "cursor_shape": "bar" + } +} +``` + +3. An underline / underscore that runs along the following character + +```json +{ + "terminal": { + "cursor_shape": "underline" + } +} +``` + +4. A box drawn around the following character + +```json +{ + "terminal": { + "cursor_shape": "hollow" + } +} +``` + ### Terminal: Keep Selection On Copy - Description: Whether or not to keep the selection in the terminal after copying text. @@ -2829,6 +2884,30 @@ See Buffer Font Features } ``` +### Terminal: Minimum Contrast + +- Description: Controls the minimum contrast between foreground and background colors in the terminal. Uses the APCA (Accessible Perceptual Contrast Algorithm) for color adjustments. Set this to 0 to disable this feature. +- Setting: `minimum_contrast` +- Default: `45` + +**Options** + +`integer` values from 0 to 106. Common recommended values: + +- `0`: No contrast adjustment +- `45`: Minimum for large fluent text (default) +- `60`: Minimum for other content text +- `75`: Minimum for body text +- `90`: Preferred for body text + +```json +{ + "terminal": { + "minimum_contrast": 45 + } +} +``` + ### Terminal: Option As Meta - Description: Re-interprets the option keys to act like a 'meta' key, like in Emacs. @@ -3336,26 +3415,7 @@ Run the `theme selector: toggle` action in the command palette to see a current ## Agent -- Description: Customize agent behavior -- Setting: `agent` -- Default: - -```json -"agent": { - "version": "2", - "enabled": true, - "button": true, - "dock": "right", - "default_width": 640, - "default_height": 320, - "default_view": "thread", - "default_model": { - "provider": "zed.dev", - "model": "claude-sonnet-4" - }, - "single_file_review": true, -} -``` +Visit [the Configuration page](./ai/configuration.md) under the AI section to learn more about all the agent-related settings. ## Outline Panel diff --git a/docs/src/debugger.md b/docs/src/debugger.md index f10461a1603baeb76b30f60a36f945be0b4895df..7cfbf63cd8266f7865e948d7da1997c1d81a1f95 100644 --- a/docs/src/debugger.md +++ b/docs/src/debugger.md @@ -8,9 +8,6 @@ Zed implements the client side of the protocol, and various _debug adapters_ imp This protocol enables features like setting breakpoints, stepping through code, inspecting variables, and more, in a consistent manner across different programming languages and runtime environments. -> We currently offer onboarding support for users. We are eager to hear from you if you encounter any issues or have suggestions for improvement for our debugging experience. -> You can schedule a call via [Cal.com](https://cal.com/team/zed-research/debugger) - ## Supported Languages To debug code written in a specific language, Zed needs to find a debug adapter for that language. Some debug adapters are provided by Zed without additional setup, and some are provided by [language extensions](./extensions/debugger-extensions.md). The following languages currently have debug adapters available: @@ -180,8 +177,8 @@ The debug adapter will then stop whenever an exception of a given kind occurs. W ### Stepping granularity - Description: The Step granularity that the debugger will use -- Default: line -- Setting: debugger.stepping_granularity +- Default: `line` +- Setting: `debugger.stepping_granularity` **Options** @@ -220,8 +217,8 @@ The debug adapter will then stop whenever an exception of a given kind occurs. W ### Save Breakpoints - Description: Whether the breakpoints should be saved across Zed sessions. -- Default: true -- Setting: debugger.save_breakpoints +- Default: `true` +- Setting: `debugger.save_breakpoints` **Options** @@ -238,8 +235,8 @@ The debug adapter will then stop whenever an exception of a given kind occurs. W ### Button - Description: Whether the button should be displayed in the debugger toolbar. -- Default: true -- Setting: debugger.show_button +- Default: `true` +- Setting: `debugger.show_button` **Options** @@ -256,8 +253,8 @@ The debug adapter will then stop whenever an exception of a given kind occurs. W ### Timeout - Description: Time in milliseconds until timeout error when connecting to a TCP debug adapter. -- Default: 2000 -- Setting: debugger.timeout +- Default: `2000` +- Setting: `debugger.timeout` **Options** @@ -271,6 +268,24 @@ The debug adapter will then stop whenever an exception of a given kind occurs. W } ``` +### Inline Values + +- Description: Whether to enable editor inlay hints showing the values of variables in your code during debugging sessions. +- Default: `true` +- Setting: `inlay_hints.show_value_hints` + +**Options** + +```json +{ + "inlay_hints": { + "show_value_hints": false + } +} +``` + +Inline value hints can also be toggled from the Editor Controls menu in the editor toolbar. + ### Log Dap Communications - Description: Whether to log messages between active debug adapters and Zed. (Used for DAP development) diff --git a/docs/src/development.md b/docs/src/development.md index 980b47aa4d98bd639dacae39293d7d3a94560380..046d515fede061160eff9c4a4bcb7cd1cd63b09e 100644 --- a/docs/src/development.md +++ b/docs/src/development.md @@ -37,6 +37,48 @@ development build, run Zed with the following environment variable set: ZED_DEVELOPMENT_USE_KEYCHAIN=1 ``` +## Performance Measurements + +Zed includes a frame time measurement system that can be used to profile how long it takes to render each frame. This is particularly useful when comparing rendering performance between different versions or when optimizing frame rendering code. + +### Using ZED_MEASUREMENTS + +To enable performance measurements, set the `ZED_MEASUREMENTS` environment variable: + +```sh +export ZED_MEASUREMENTS=1 +``` + +When enabled, Zed will print frame rendering timing information to stderr, showing how long each frame takes to render. + +### Performance Comparison Workflow + +Here's a typical workflow for comparing frame rendering performance between different versions: + +1. **Enable measurements:** + + ```sh + export ZED_MEASUREMENTS=1 + ``` + +2. **Test the first version:** + + - Checkout the commit you want to measure + - Run Zed in release mode and use it for 5-10 seconds: `cargo run --release &> version-a` + +3. **Test the second version:** + + - Checkout another commit you want to compare + - Run Zed in release mode and use it for 5-10 seconds: `cargo run --release &> version-b` + +4. **Generate comparison:** + + ```sh + script/histogram version-a version-b + ``` + +The `script/histogram` tool can accept as many measurement files as you like and will generate a histogram visualization comparing the frame rendering performance data between the provided versions. + ## Contributor links - [CONTRIBUTING.md](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md) diff --git a/docs/src/development/debugging-crashes.md b/docs/src/development/debugging-crashes.md index d08ab961cc927cc5c0922f3836706678fe09c61f..ed0a5807a32946531bb85893c6bd3ce87204b1d6 100644 --- a/docs/src/development/debugging-crashes.md +++ b/docs/src/development/debugging-crashes.md @@ -6,6 +6,7 @@ When an app crashes, - macOS creates a `.ips` file in `~/Library/Logs/DiagnosticReports`. You can view these using the built in Console app (`cmd-space Console`) under "Crash Reports". - Linux creates a core dump. See the [man pages](https://man7.org/linux/man-pages/man5/core.5.html) for pointers to how your system might be configured to manage core dumps. +- Windows doesn't create crash reports by default, but can be configured to create "minidump" memory dumps upon applications crashing. If you have enabled Zed's telemetry these will be uploaded to us when you restart the app. They end up in a [Slack channel (internal only)](https://zed-industries.slack.com/archives/C04S6T1T7TQ). diff --git a/docs/src/development/linux.md b/docs/src/development/linux.md index 08ac7f116bdcb0000bcc734fe8c6a170992af1cf..d7b586be34250ff83205257df3d147513bc9e401 100644 --- a/docs/src/development/linux.md +++ b/docs/src/development/linux.md @@ -16,20 +16,9 @@ Clone down the [Zed repository](https://github.com/zed-industries/zed). If you prefer to install the system libraries manually, you can find the list of required packages in the `script/linux` file. -## Backend dependencies +### Backend Dependencies (optional) {#backend-dependencies} -> This section is still in development. The instructions are not yet complete. - -If you are developing collaborative features of Zed, you'll need to install the dependencies of zed's `collab` server: - -- Install [Postgres](https://www.postgresql.org/download/linux/) -- Install [Livekit](https://github.com/livekit/livekit-cli) and [Foreman](https://theforeman.org/manuals/3.9/quickstart_guide.html) - -Alternatively, if you have [Docker](https://www.docker.com/) installed you can bring up all the `collab` dependencies using Docker Compose: - -```sh -docker compose up -d -``` +If you are looking to develop Zed collaboration features using a local collabortation server, please see: [Local Collaboration](./local-collaboration.md) docs. ## Building from source @@ -102,7 +91,7 @@ Zed has two main binaries: - You will need to build `crates/cli` and make its binary available in `$PATH` with the name `zed`. - You will need to build `crates/zed` and put it at `$PATH/to/cli/../../libexec/zed-editor`. For example, if you are going to put the cli at `~/.local/bin/zed` put zed at `~/.local/libexec/zed-editor`. As some linux distributions (notably Arch) discourage the use of `libexec`, you can also put this binary at `$PATH/to/cli/../../lib/zed/zed-editor` (e.g. `~/.local/lib/zed/zed-editor`) instead. -- If you are going to provide a `.desktop` file you can find a template in `crates/zed/resources/zed.desktop.in`, and use `envsubst` to populate it with the values required. This file should also be renamed to `$APP_ID.desktop` so that the file [follows the FreeDesktop standards](https://github.com/zed-industries/zed/issues/12707#issuecomment-2168742761). +- If you are going to provide a `.desktop` file you can find a template in `crates/zed/resources/zed.desktop.in`, and use `envsubst` to populate it with the values required. This file should also be renamed to `$APP_ID.desktop` so that the file [follows the FreeDesktop standards](https://github.com/zed-industries/zed/issues/12707#issuecomment-2168742761). You should also make this desktop file executable (`chmod 755`). - You will need to ensure that the necessary libraries are installed. You can get the current list by [inspecting the built binary](https://github.com/zed-industries/zed/blob/935cf542aebf55122ce6ed1c91d0fe8711970c82/script/bundle-linux#L65-L67) on your system. - For an example of a complete build script, see [script/bundle-linux](https://github.com/zed-industries/zed/blob/935cf542aebf55122ce6ed1c91d0fe8711970c82/script/bundle-linux). - You can disable Zed's auto updates and provide instructions for users who try to update Zed manually by building (or running) Zed with the environment variable `ZED_UPDATE_EXPLANATION`. For example: `ZED_UPDATE_EXPLANATION="Please use flatpak to update zed."`. diff --git a/docs/src/development/local-collaboration.md b/docs/src/development/local-collaboration.md index 6c96c342a872868a0c86495d5772168cc8772f40..eb7f3dfc43dc29ee3d25de3dbc373f5f925ba2af 100644 --- a/docs/src/development/local-collaboration.md +++ b/docs/src/development/local-collaboration.md @@ -1,18 +1,96 @@ # Local Collaboration -First, make sure you've installed Zed's backend dependencies for your platform: +1. Ensure you have access to our cloud infrastructure. If you don't have access, you can't collaborate locally at this time. -- [macOS](./macos.md#backend-dependencies) -- [Linux](./linux.md#backend-dependencies) -- [Windows](./windows.md#backend-dependencies) +2. Make sure you've installed Zed's dependencies for your platform: + +- [macOS](#macos) +- [Linux](#linux) +- [Windows](#backend-windows) Note that `collab` can be compiled only with MSVC toolchain on Windows +3. Clone down our cloud repository and follow the instructions in the cloud README + +4. Setup the local database for your platform: + +- [macOS & Linux](#database-unix) +- [Windows](#database-windows) + +5. Run collab: + +- [macOS & Linux](#run-collab-unix) +- [Windows](#run-collab-windows) + +## Backend Dependencies + +If you are developing collaborative features of Zed, you'll need to install the dependencies of zed's `collab` server: + +- PostgreSQL +- LiveKit +- Foreman + +You can install these dependencies natively or run them under Docker. + +### macOS + +1. Install [Postgres.app](https://postgresapp.com) or [postgresql via homebrew](https://formulae.brew.sh/formula/postgresql@15): + + ```sh + brew install postgresql@15 + ``` + +2. Install [Livekit](https://formulae.brew.sh/formula/livekit) and [Foreman](https://formulae.brew.sh/formula/foreman) + + ```sh + brew install livekit foreman + ``` + +- Follow the steps in the [collab README](https://github.com/zed-industries/zed/blob/main/crates/collab/README.md) to configure the Postgres database for integration tests + +Alternatively, if you have [Docker](https://www.docker.com/) installed you can bring up all the `collab` dependencies using Docker Compose: + +### Linux + +1. Install [Postgres](https://www.postgresql.org/download/linux/) + + ```sh + sudo apt-get install postgresql postgresql # Ubuntu/Debian + sudo pacman -S postgresql # Arch Linux + sudo dnf install postgresql postgresql-server # RHEL/Fedora + sudo zypper install postgresql postgresql-server # OpenSUSE + ``` + +2. Install [Livekit](https://github.com/livekit/livekit-cli) + + ```sh + curl -sSL https://get.livekit.io/cli | bash + ``` + +3. Install [Foreman](https://theforeman.org/manuals/3.15/quickstart_guide.html) + +### Windows {#backend-windows} + +> This section is still in development. The instructions are not yet complete. + +- Install [Postgres](https://www.postgresql.org/download/windows/) +- Install [Livekit](https://github.com/livekit/livekit), optionally you can add the `livekit-server` binary to your `PATH`. + +Alternatively, if you have [Docker](https://www.docker.com/) installed you can bring up all the `collab` dependencies using Docker Compose. + +### Docker {#Docker} + +If you have docker or podman available, you can run the backend dependencies inside containers with Docker Compose: + +```sh +docker compose up -d +``` + ## Database setup Before you can run the `collab` server locally, you'll need to set up a `zed` Postgres database. -### On macOS and Linux +### On macOS and Linux {#database-unix} ```sh script/bootstrap @@ -35,7 +113,7 @@ To use a different set of admin users, you can create your own version of that j } ``` -### On Windows +### On Windows {#database-windows} ```powershell .\script\bootstrap.ps1 @@ -43,7 +121,7 @@ To use a different set of admin users, you can create your own version of that j ## Testing collaborative features locally -### On macOS and Linux +### On macOS and Linux {#run-collab-unix} Ensure that Postgres is configured and running, then run Zed's collaboration server and the `livekit` dev server: @@ -53,12 +131,16 @@ foreman start docker compose up ``` -Alternatively, if you're not testing voice and screenshare, you can just run `collab`, and not the `livekit` dev server: +Alternatively, if you're not testing voice and screenshare, you can just run `collab` and `cloud`, and not the `livekit` dev server: ```sh cargo run -p collab -- serve all ``` +```sh +cd ../cloud; cargo make dev +``` + In a new terminal, run two or more instances of Zed. ```sh @@ -67,7 +149,7 @@ script/zed-local -3 This script starts one to four instances of Zed, depending on the `-2`, `-3` or `-4` flags. Each instance will be connected to the local `collab` server, signed in as a different user from `.admins.json` or `.admins.default.json`. -### On Windows +### On Windows {#run-collab-windows} Since `foreman` is not available on Windows, you can run the following commands in separate terminals: @@ -87,6 +169,12 @@ Otherwise, .\path\to\livekit-serve.exe --dev ``` +You'll also need to start the cloud server: + +```powershell +cd ..\cloud; cargo make dev +``` + In a new terminal, run two or more instances of Zed. ```powershell @@ -97,7 +185,10 @@ Note that this requires `node.exe` to be in your `PATH`. ## Running a local collab server -If you want to run your own version of the zed collaboration service, you can, but note that this is still under development, and there is no good support for authentication nor extensions. +> [!NOTE] +> Because of recent changes to our authentication system, Zed will not be able to authenticate itself with, and therefore use, a local collab server. + +If you want to run your own version of the zed collaboration service, you can, but note that this is still under development, and there is no support for authentication nor extensions. Configuration is done through environment variables. By default it will read the configuration from [`.env.toml`](https://github.com/zed-industries/zed/blob/main/crates/collab/.env.toml) and you should use that as a guide for setting this up. diff --git a/docs/src/development/macos.md b/docs/src/development/macos.md index 91adf7819386b8e60306a692fed38bf142ccc26c..f081f0b5f12b11e57ee9c82e38be03c16292311e 100644 --- a/docs/src/development/macos.md +++ b/docs/src/development/macos.md @@ -31,6 +31,10 @@ Clone down the [Zed repository](https://github.com/zed-industries/zed). brew install cmake ``` +### Backend Dependencies (optional) {#backend-dependencies} + +If you are looking to develop Zed collaboration features using a local collabortation server, please see: [Local Collaboration](./local-collaboration.md) docs. + ## Building Zed from Source Once you have the dependencies installed, you can build Zed using [Cargo](https://doc.rust-lang.org/cargo/). @@ -53,25 +57,6 @@ And to run the tests: cargo test --workspace ``` -## Backend Dependencies - -If you are developing collaborative features of Zed, you'll need to install the dependencies of zed's `collab` server: - -- Install [Postgres](https://postgresapp.com) -- Install [Livekit](https://formulae.brew.sh/formula/livekit) and [Foreman](https://formulae.brew.sh/formula/foreman) - - ```sh - brew install livekit foreman - ``` - -- Follow the steps in the [collab README](https://github.com/zed-industries/zed/blob/main/crates/collab/README.md) to configure the Postgres database for integration tests - -Alternatively, if you have [Docker](https://www.docker.com/) installed you can bring up all the `collab` dependencies using Docker Compose: - -```sh -docker compose up -d -``` - ## Troubleshooting ### Error compiling metal shaders diff --git a/docs/src/development/windows.md b/docs/src/development/windows.md index 6d67500aab9f7c65f5a746c1eedc6a004fd3d1aa..ac38e4d7d699b55c5722d2e9c56e527eea0e3620 100644 --- a/docs/src/development/windows.md +++ b/docs/src/development/windows.md @@ -66,20 +66,9 @@ The list can be obtained as follows: - Click on `More` in the `Installed` tab - Click on `Export configuration` -## Backend dependencies +### Backend Dependencies (optional) {#backend-dependencies} -> This section is still in development. The instructions are not yet complete. - -If you are developing collaborative features of Zed, you'll need to install the dependencies of zed's `collab` server: - -- Install [Postgres](https://www.postgresql.org/download/windows/) -- Install [Livekit](https://github.com/livekit/livekit), optionally you can add the `livekit-server` binary to your `PATH`. - -Alternatively, if you have [Docker](https://www.docker.com/) installed you can bring up all the `collab` dependencies using Docker Compose: - -```sh -docker compose up -d -``` +If you are looking to develop Zed collaboration features using a local collabortation server, please see: [Local Collaboration](./local-collaboration.md) docs. ### Notes diff --git a/docs/src/extensions/installing-extensions.md b/docs/src/extensions/installing-extensions.md index aed8bef4288d58fa9892235704f1eb160320ddeb..801fe5c55c0f47530e2656cd831619d1457ba13e 100644 --- a/docs/src/extensions/installing-extensions.md +++ b/docs/src/extensions/installing-extensions.md @@ -1,6 +1,6 @@ # Installing Extensions -You can search for extensions by launching the Zed Extension Gallery by pressing `cmd-shift-x` (macOS) or `ctrl-shift-x` (Linux), opening the command palette and selecting `zed: extensions` or by selecting "Zed > Extensions" from the menu bar. +You can search for extensions by launching the Zed Extension Gallery by pressing {#kb zed::Extensions} , opening the command palette and selecting {#action zed::Extensions} or by selecting "Zed > Extensions" from the menu bar. Here you can view the extensions that you currently have installed or search and install new ones. diff --git a/docs/src/extensions/languages.md b/docs/src/extensions/languages.md index 44c673e3e131dc433f4598ff69b43e9fe46d28e0..6756cb8a2309153a95edb23a24838295b1030266 100644 --- a/docs/src/extensions/languages.md +++ b/docs/src/extensions/languages.md @@ -402,11 +402,10 @@ If your language server supports additional languages, you can use `language_ids [language-servers.my-language-server] name = "Whatever LSP" -languages = ["JavaScript", "JSX", "HTML", "CSS"] +languages = ["JavaScript", "HTML", "CSS"] [language-servers.my-language-server.language_ids] "JavaScript" = "javascript" -"JSX" = "javascriptreact" "TSX" = "typescriptreact" "HTML" = "html" "CSS" = "css" diff --git a/docs/src/getting-started.md b/docs/src/getting-started.md index 5940c74b219da806aa4d6ceea5d855ea86f463b7..22af3b36d733f9d7eccb72cc622d6d07c942ca20 100644 --- a/docs/src/getting-started.md +++ b/docs/src/getting-started.md @@ -83,6 +83,6 @@ Visit [the AI overview page](./ai/overview.md) to learn how to quickly get start ## Set up your key bindings -To open your custom keymap to add your key bindings, use the {#kb zed::OpenKeymap} keybinding. +To edit your custom keymap and add or remap bindings, you can either use {#kb zed::OpenKeymapEditor} to spawn the Zed Keymap Editor ({#action zed::OpenKeymapEditor}) or you can directly open your Zed Keymap json (`~/.config/zed/keymap.json`) with {#action zed::OpenKeymap}. To access the default key binding set, open the Command Palette with {#kb command_palette::Toggle} and search for "zed: open default keymap". See [Key Bindings](./key-bindings.md) for more info. diff --git a/docs/src/git.md b/docs/src/git.md index 642861c7b0d7ae449f74b5891ff2f3d548635dc1..cccbad9b2e37ba55dc45f1f100883437759727f0 100644 --- a/docs/src/git.md +++ b/docs/src/git.md @@ -1,3 +1,8 @@ +--- +description: Zed is a text editor that supports lots of Git features +title: Zed Editor Git integration documentation +--- + # Git Zed currently offers a set of fundamental Git features, with support coming in the future for more advanced ones, like conflict resolution tools, line by line staging, and more. @@ -76,7 +81,7 @@ You can ask AI to generate a commit message by focusing on the message editor wi > Note that you need to have an LLM provider configured. Visit [the AI configuration page](./ai/configuration.md) to learn how to do so. -You can specify your preferred model to use by providing a `commit_message_model` agent setting. See [Feature-specific models](./ai/configuration.md#feature-specific-models) for more information. +You can specify your preferred model to use by providing a `commit_message_model` agent setting. See [Feature-specific models](./ai/agent-settings.md#feature-specific-models) for more information. ```json { @@ -151,3 +156,17 @@ When viewing files with changes, Zed displays diff hunks that can be expanded or | {#action editor::ToggleSelectedDiffHunks} | {#kb editor::ToggleSelectedDiffHunks} | > Not all actions have default keybindings, but can be bound by [customizing your keymap](./key-bindings.md#user-keymaps). + +## Git CLI Configuration + +If you would like to also use Zed for your [git commit message editor](https://git-scm.com/book/en/v2/Customizing-Git-Git-Configuration#_core_editor) when committing from the command line you can use `zed --wait`: + +```sh +git config --global core.editor "zed --wait" +``` + +Or add the following to your shell environment (in `~/.zshrc`, `~/.bashrc`, etc): + +```sh +export GIT_EDITOR="zed --wait" +``` diff --git a/docs/src/key-bindings.md b/docs/src/key-bindings.md index 8a956b518591720bde777a68c0ac587cef712ce2..feed9127879758a44b4db6e2164093a267138ab7 100644 --- a/docs/src/key-bindings.md +++ b/docs/src/key-bindings.md @@ -18,13 +18,13 @@ You can also enable `vim_mode`, which adds vim bindings too. ## User keymaps -Zed reads your keymap from `~/.config/zed/keymap.json`. You can open the file within Zed with {#kb zed::OpenKeymap}, or via `zed: Open Keymap` in the command palette. +Zed reads your keymap from `~/.config/zed/keymap.json`. You can open the file within Zed with {#action zed::OpenKeymap} from the command palette or to spawn the Zed Keymap Editor ({#action zed::OpenKeymapEditor}) use {#kb zed::OpenKeymapEditor}. The file contains a JSON array of objects with `"bindings"`. If no `"context"` is set the bindings are always active. If it is set the binding is only active when the [context matches](#contexts). Within each binding section a [key sequence](#keybinding-syntax) is mapped to an [action](#actions). If conflicts are detected they are resolved as [described below](#precedence). -If you are using a non-QWERTY, Latin-character keyboard, you may want to set `use_layout_keys` to `true`. See [Non-QWERTY keyboards](#non-qwerty-keyboards) for more information. +If you are using a non-QWERTY, Latin-character keyboard, you may want to set `use_key_equivalents` to `true`. See [Non-QWERTY keyboards](#non-qwerty-keyboards) for more information. For example: @@ -87,15 +87,13 @@ If a binding group has a `"context"` key it will be matched against the currentl Zed's contexts make up a tree, with the root being `Workspace`. Workspaces contain Panes and Panels, and Panes contain Editors, etc. The easiest way to see what contexts are active at a given moment is the key context view, which you can get to with `dev: Open Key Context View` in the command palette. -Contexts can contain extra attributes in addition to the name, so that you can (for example) match only in markdown files with `"context": "Editor && extension==md"`. It's worth noting that you can only use attributes at the level they are defined. - For example: ``` # in an editor, it might look like this: Workspace os=macos keyboard_layout=com.apple.keylayout.QWERTY Pane - Editor mode=full extension=md inline_completion vim_mode=insert + Editor mode=full extension=md vim_mode=insert # in the project panel Workspace os=macos @@ -106,9 +104,20 @@ Workspace os=macos Context expressions can contain the following syntax: - `X && Y`, `X || Y` to and/or two conditions -- `!X` to negate a condition +- `!X` to check that a condition is false - `(X)` for grouping -- `X > Y` to match if a parent in the tree matches X and this layer matches Y. +- `X > Y` to match if an ancestor in the tree matches X and this layer matches Y. + +For example: + +- `"context": "Editor"` - matches any editor (including inline inputs) +- `"context": "Editor && mode=full"` - matches the main editors used for editing code +- `"context": "!Editor && !Terminal"` - matches anywhere except where an Editor or Terminal is focused +- `"context": "os=macos > Editor"` - matches any editor on macOS. + +It's worth noting that attributes are only available on the node they are defined on. This means that if you want to (for example) only enable a keybinding when the debugger is stopped in vim normal mode, you need to do `debugger_stopped > vim_mode == normal`. + +Note: Before Zed v0.197.x, the ! operator only looked at one node at a time, and `>` meant "parent" not "ancestor". This meant that `!Editor` would match the context `Workspace > Pane > Editor`, because (confusingly) the Pane matches `!Editor`, and that `os=macos > Editor` did not match the context `Workspace > Pane > Editor` because of the intermediate `Pane` node. If you're using Vim mode, we have information on how [vim modes influence the context](./vim.md#contexts) @@ -136,17 +145,17 @@ When this happens, and both bindings are active in the current context, Zed will ### Non-QWERTY keyboards -As of Zed 0.162.0, Zed has some support for non-QWERTY keyboards on macOS. Better support for non-QWERTY keyboards on Linux is planned. +Zed's support for non-QWERTY keyboards is still a work in progress. -There are roughly three categories of keyboard to consider: +If your keyboard can type the full ASCII ranges (DVORAK, COLEMAK, etc.) then shortcuts should work as you expect. -Keyboards that support full ASCII (QWERTY, DVORAK, COLEMAK, etc.). On these keyboards bindings are resolved based on the character that would be generated by the key. So to type `cmd-[`, find the key labeled `[` and press it with command. +Otherwise, read on... -Keyboards that are mostly non-ASCII, but support full ASCII when the command key is pressed. For example Cyrillic keyboards, Armenian, Hebrew, etc. On these keyboards bindings are resolved based on the character that would be generated by typing the key with command pressed. So to type `ctrl-a`, find the key that generates `cmd-a`. For these keyboards, keyboard shortcuts are displayed in the app using their ASCII equivalents. If the ASCII-equivalents are not printed on your keyboard, you can use the macOS keyboard viewer and holding down the `cmd` key to find things (though often the ASCII equivalents are in a QWERTY layout). +#### macOS -Finally keyboards that support extended Latin alphabets (usually ISO keyboards) require the most support. For example French AZERTY, German QWERTZ, etc. On these keyboards it is often not possible to type the entire ASCII range without option. To ensure that shortcuts _can_ be typed without option, keyboard shortcuts are mapped to "key equivalents" in the same way as [macOS](). This mapping is defined per layout, and is a compromise between leaving keyboard shortcuts triggered by the same character they are defined with, keeping shortcuts in the same place as a QWERTY layout, and moving shortcuts out of the way of system shortcuts. +On Cyrillic, Hebrew, Armenian, and other keyboards that are mostly non-ASCII; macOS automatically maps keys to the ASCII range when `cmd` is held. Zed takes this a step further and it can always match key-presses against either the ASCII layout, or the real layout regardless of modifiers, and regardless of the `use_key_equivalents` setting. For example in Thai, pressing `ctrl-ๆ` will match bindings associated with `ctrl-q` or `ctrl-ๆ` -For example on a German QWERTZ keyboard, the `cmd->` shortcut is moved to `cmd-:` because `cmd->` is the system window switcher and this is where that shortcut is typed on a QWERTY keyboard. `cmd-+` stays the same because + is still typeable without option, and as a result, `cmd-[` and `cmd-]` become `cmd-ö` and `cmd-ä`, moving out of the way of the `+` key. +On keyboards that support extended Latin alphabets (French AZERTY, German QWERTZ, etc.) it is often not possible to type the entire ASCII range without `option`. This introduces an ambiguity, `option-2` produces `@`. To ensure that all the builtin keyboard shortcuts can still be typed on these keyboards we move key-bindings around. For example, shortcuts bound to `@` on QWERTY are moved to `"` on a Spanish layout. This mapping is based on the macOS system defaults and can be seen by running `dev: Open Key Context View` from the command palette. If you are defining shortcuts in your personal keymap, you can opt into the key equivalent mapping by setting `use_key_equivalents` to `true` in your keymap: @@ -161,6 +170,12 @@ If you are defining shortcuts in your personal keymap, you can opt into the key ] ``` +### Linux + +Since v0.196.0 on Linux if the key that you type doesn't produce an ASCII character then we use the QWERTY-layout equivalent key for keyboard shortcuts. This means that many shortcuts can be typed on many layouts. + +We do not yet move shortcuts around to ensure that all the builtin shortcuts can be typed on every layout; so if there are some ASCII characters that cannot be typed, and your keyboard layout has different ASCII characters on the same keys as would be needed to type them, you may need to add custom key bindings to make this work. We do intend to fix this at some point, and help is very much wanted! + ## Tips and tricks ### Disabling a binding diff --git a/docs/src/languages/c.md b/docs/src/languages/c.md index 14a11c0d665e9e5b3a9499284d36b133945ad866..8db1bb671257397f0bcf668af374d700142db658 100644 --- a/docs/src/languages/c.md +++ b/docs/src/languages/c.md @@ -77,7 +77,7 @@ You can use CodeLLDB or GDB to debug native binaries. (Make sure that your build "command": "make", "args": ["-j8"], "cwd": "$ZED_WORKTREE_ROOT" - } + }, "program": "$ZED_WORKTREE_ROOT/build/prog", "request": "launch", "adapter": "CodeLLDB" diff --git a/docs/src/languages/cpp.md b/docs/src/languages/cpp.md index 1273bce2ac0b6a92cbda8e63cd0f477965500a11..e84bb6ea507f264240a40e986f41c5cd3a23610d 100644 --- a/docs/src/languages/cpp.md +++ b/docs/src/languages/cpp.md @@ -127,7 +127,7 @@ You can use CodeLLDB or GDB to debug native binaries. (Make sure that your build "command": "make", "args": ["-j8"], "cwd": "$ZED_WORKTREE_ROOT" - } + }, "program": "$ZED_WORKTREE_ROOT/build/prog", "request": "launch", "adapter": "CodeLLDB" diff --git a/docs/src/languages/deno.md b/docs/src/languages/deno.md index c18b112326ef36cc8fdf535f6ce785b0a9e43275..c40b6531e62142de6a9597528ba1e6a4879c16e3 100644 --- a/docs/src/languages/deno.md +++ b/docs/src/languages/deno.md @@ -57,6 +57,40 @@ See [Configuring supported languages](../configuring-languages.md) in the Zed do TBD: Deno Typescript REPL instructions [docs/repl#typescript-deno](../repl.md#typescript-deno) --> +## DAP support + +To debug deno programs, add this to `.zed/debug.json` + +```json +[ + { + "adapter": "JavaScript", + "label": "Deno", + "request": "launch", + "type": "pwa-node", + "cwd": "$ZED_WORKTREE_ROOT", + "program": "$ZED_FILE", + "runtimeExecutable": "deno", + "runtimeArgs": ["run", "--allow-all", "--inspect-wait"], + "attachSimplePort": 9229 + } +] +``` + +## Runnable support + +To run deno tasks like tests from the ui, add this to `.zed/tasks.json` + +```json +[ + { + "label": "deno test", + "command": "deno test -A --filter '/^$ZED_CUSTOM_DENO_TEST_NAME$/' $ZED_FILE", + "tags": ["js-test"] + } +] +``` + ## See also: - [TypeScript](./typescript.md) diff --git a/docs/src/languages/java.md b/docs/src/languages/java.md index 70bafab476d36c83660077bfc6a25199445411de..0312cb3bd7e8b14ccedee7aacded456cc3e06e97 100644 --- a/docs/src/languages/java.md +++ b/docs/src/languages/java.md @@ -1,12 +1,8 @@ # Java -There are two extensions that provide Java language support for Zed: - -- Zed Java: [zed-extensions/java](https://github.com/zed-extensions/java) and -- Java with Eclipse JDTLS: [zed-java-eclipse-jdtls](https://github.com/ABckh/zed-java-eclipse-jdtls). - -Both use: +Java language support in Zed is provided by: +- Zed Java: [zed-extensions/java](https://github.com/zed-extensions/java) - Tree-sitter: [tree-sitter/tree-sitter-java](https://github.com/tree-sitter/tree-sitter-java) - Language Server: [eclipse-jdtls/eclipse.jdt.ls](https://github.com/eclipse-jdtls/eclipse.jdt.ls) @@ -25,11 +21,9 @@ Or manually download and install [OpenJDK 23](https://jdk.java.net/23/). You can install either by opening {#action zed::Extensions}({#kb zed::Extensions}) and searching for `java`. -We recommend you install one or the other and not both. - ## Settings / Initialization Options -Both extensions will automatically download the language server, see: [Manual JDTLS Install](#manual-jdts-install) below if you'd prefer to manage that yourself. +The extension will automatically download the language server, see: [Manual JDTLS Install](#manual-jdts-install) below if you'd prefer to manage that yourself. For available `initialization_options` please see the [Initialize Request section of the Eclipse.jdt.ls Wiki](https://github.com/eclipse-jdtls/eclipse.jdt.ls/wiki/Running-the-JAVA-LS-server-from-the-command-line#initialize-request). @@ -47,21 +41,25 @@ You can add these customizations to your Zed Settings by launching {#action zed: } ``` -### Java with Eclipse JDTLS settings +## Example Configs + +### JDTLS Binary + +By default, zed will look in your `PATH` for a `jdtls` binary, if you wish to specify an explicit binary you can do so via settings: ```json -{ "lsp": { - "java": { - "settings": {}, - "initialization_options": {} + "jdtls": { + "binary": { + "path": "/path/to/java/bin/jdtls", + // "arguments": [], + // "env": {}, + "ignore_system_version": true + } } } -} ``` -## Example Configs - ### Zed Java Initialization Options There are also many more options you can pass directly to the language server, for example: @@ -152,27 +150,9 @@ There are also many more options you can pass directly to the language server, f } ``` -### Java with Eclipse JTDLS Configuration {#zed-java-eclipse-configuration} - -Configuration options match those provided in the [redhat-developer/vscode-java extension](https://github.com/redhat-developer/vscode-java#supported-vs-code-settings). - -For example, to enable [Lombok Support](https://github.com/redhat-developer/vscode-java/wiki/Lombok-support): - -```json -{ - "lsp": { - "java": { - "settings": { - "java.jdt.ls.lombokSupport.enabled:": true - } - } - } -} -``` - ## Manual JDTLS Install -If you prefer, you can install JDTLS yourself and both extensions can be configured to use that instead. +If you prefer, you can install JDTLS yourself and the extension can be configured to use that instead. - MacOS: `brew install jdtls` - Arch: [`jdtls` from AUR](https://aur.archlinux.org/packages/jdtls) @@ -184,12 +164,5 @@ Or manually download install: ## See also -- [Zed Java Readme](https://github.com/zed-extensions/java) -- [Java with Eclipse JDTLS Readme](https://github.com/ABckh/zed-java-eclipse-jdtls) - -## Support - -If you have issues with either of these plugins, please open issues on their respective repositories: - +- [Zed Java Repo](https://github.com/zed-extensions/java) - [Zed Java Issues](https://github.com/zed-extensions/java/issues) -- [Java with Eclipse JDTLS Issues](https://github.com/ABckh/zed-java-eclipse-jdtls/issues) diff --git a/docs/src/languages/php.md b/docs/src/languages/php.md index 2ddb93c8d5b9465f7a68fb4f59ce8cb8225410ac..4e94c134467c5a3484ede7a2146f2f09c172e859 100644 --- a/docs/src/languages/php.md +++ b/docs/src/languages/php.md @@ -15,7 +15,7 @@ The PHP extension offers both `phpactor` and `intelephense` language server supp ## Phpactor -The Zed PHP Extension can install `phpactor` automatically but requires `php` to installed and available in your path: +The Zed PHP Extension can install `phpactor` automatically but requires `php` to be installed and available in your path: ```sh # brew install php # macOS @@ -27,7 +27,7 @@ which php ## Intelephense -[Intelephense](https://intelephense.com/) is a [proprietary](https://github.com/bmewburn/vscode-intelephense/blob/master/LICENSE.txt#L29) language server for PHP operating under a freemium model. Certain features require purchase of a [premium license](https://intelephense.com/). To use these features you must place your [licence.txt file](https://intelephense.com/faq.html) at `~/intelephense/licence.txt` inside your home directory. +[Intelephense](https://intelephense.com/) is a [proprietary](https://github.com/bmewburn/vscode-intelephense/blob/master/LICENSE.txt#L29) language server for PHP operating under a freemium model. Certain features require purchase of a [premium license](https://intelephense.com/). To switch to `intelephense`, add the following to your `settings.json`: @@ -41,6 +41,20 @@ To switch to `intelephense`, add the following to your `settings.json`: } ``` +To use the premium features, you can place your [licence.txt file](https://intelephense.com/faq.html) at `~/intelephense/licence.txt` inside your home directory. Alternatively, you can pass the licence key or a path to a file containing the licence key as an initialization option for the `intelephense` language server. To do this, add the following to your `settings.json`: + +```json +{ + "lsp": { + "intelephense": { + "initialization_options": { + "licenceKey": "/path/to/licence.txt" + } + } + } +} +``` + ## PHPDoc Zed supports syntax highlighting for PHPDoc comments. diff --git a/docs/src/languages/ruby.md b/docs/src/languages/ruby.md index b7856b2cd07ab15bb1b5fd402908c162a543f86d..6f530433bd0e15d2ed659dc2e1f0055ad5711cb5 100644 --- a/docs/src/languages/ruby.md +++ b/docs/src/languages/ruby.md @@ -127,7 +127,7 @@ Solargraph reads its configuration from a file called `.solargraph.yml` in the r ## Setting up `ruby-lsp` -Ruby LSP uses pull-based diagnostics which Zed doesn't support yet. We can tell Zed to disable it by adding the following to your `settings.json`: +You can pass Ruby LSP configuration to `initialization_options`, e.g. ```json { @@ -140,8 +140,7 @@ Ruby LSP uses pull-based diagnostics which Zed doesn't support yet. We can tell "ruby-lsp": { "initialization_options": { "enabledFeatures": { - // This disables diagnostics - "diagnostics": false + // "someFeature": false } } } diff --git a/docs/src/linux.md b/docs/src/linux.md index 896bfdaf3ff9e525896576717f215df7e54ba0de..309354de6d1b6e3c8f0936350708c161132fd803 100644 --- a/docs/src/linux.md +++ b/docs/src/linux.md @@ -148,7 +148,7 @@ On some systems the file `/etc/prime-discrete` can be used to enforce the use of On others, you may be able to the environment variable `DRI_PRIME=1` when running Zed to force the use of the discrete GPU. -If you're using an AMD GPU and Zed crashes when selecting long lines, try setting the `ZED_SAMPLE_COUNT=0` environment variable. (See [#26143](https://github.com/zed-industries/zed/issues/26143)) +If you're using an AMD GPU and Zed crashes when selecting long lines, try setting the `ZED_PATH_SAMPLE_COUNT=0` environment variable. (See [#26143](https://github.com/zed-industries/zed/issues/26143)) If you're using an AMD GPU, you might get a 'Broken Pipe' error. Try using the RADV or Mesa drivers. (See [#13880](https://github.com/zed-industries/zed/issues/13880)) @@ -294,3 +294,78 @@ If your system uses PipeWire: ``` 3. **Restart your system** + +### Forcing X11 scale factor + +On X11 systems, Zed automatically detects the appropriate scale factor for high-DPI displays. The scale factor is determined using the following priority order: + +1. `GPUI_X11_SCALE_FACTOR` environment variable (if set) +2. `Xft.dpi` from X resources database (xrdb) +3. Automatic detection via RandR based on monitor resolution and physical size + +If you want to customize the scale factor beyond what Zed detects automatically, you have several options: + +#### Check your current scale factor + +You can verify if you have `Xft.dpi` set: + +```sh +xrdb -query | grep Xft.dpi +``` + +If this command returns no output, Zed is using RandR (X11's monitor management extension) to automatically calculate the scale factor based on your monitor's reported resolution and physical dimensions. + +#### Option 1: Set Xft.dpi (X Resources Database) + +`Xft.dpi` is a standard X11 setting that many applications use for consistent font and UI scaling. Setting this ensures Zed scales the same way as other X11 applications that respect this setting. + +Edit or create the `~/.Xresources` file: + +```sh +vim ~/.Xresources +``` + +Add this line with your desired DPI: + +```sh +Xft.dpi: 96 +``` + +Common DPI values: + +- `96` for standard 1x scaling +- `144` for 1.5x scaling +- `192` for 2x scaling +- `288` for 3x scaling + +Load the configuration: + +```sh +xrdb -merge ~/.Xresources +``` + +Restart Zed for the changes to take effect. + +#### Option 2: Use the GPUI_X11_SCALE_FACTOR environment variable + +This Zed-specific environment variable directly sets the scale factor, bypassing all automatic detection. + +```sh +GPUI_X11_SCALE_FACTOR=1.5 zed +``` + +You can use decimal values (e.g., `1.25`, `1.5`, `2.0`) or set `GPUI_X11_SCALE_FACTOR=randr` to force RandR-based detection even when `Xft.dpi` is set. + +To make this permanent, add it to your shell profile or desktop entry. + +#### Option 3: Adjust system-wide RandR DPI + +This changes the reported DPI for your entire X11 session, affecting how RandR calculates scaling for all applications that use it. + +Add this to your `.xprofile` or `.xinitrc`: + +```sh +xrandr --dpi 192 +``` + +Replace `192` with your desired DPI value. This affects the system globally and will be used by Zed's automatic RandR detection when `Xft.dpi` is not set. diff --git a/docs/src/telemetry.md b/docs/src/telemetry.md index 20018b920a4c5f95285f6a84d881ad20de932789..107aef5a96a8e39e90b109158f4252fda5556ab5 100644 --- a/docs/src/telemetry.md +++ b/docs/src/telemetry.md @@ -21,19 +21,20 @@ The telemetry settings can also be configured via the welcome screen, which can Telemetry is sent from the application to our servers. Data is proxied through our servers to enable us to easily switch analytics services. We currently use: -- [Axiom](https://axiom.co): Cloud-monitoring service - stores diagnostic events -- [Snowflake](https://snowflake.com): Business Intelligence platform - stores both diagnostic and metric events -- [Metabase](https://www.metabase.com): Dashboards - dashboards built around data pulled from Snowflake +- [Sentry](https://sentry.io): Crash-monitoring service - stores diagnostic events +- [Snowflake](https://snowflake.com): Data warehouse - stores both diagnostic and metric events +- [Hex](https://www.hex.tech): Dashboards and data exploration - accesses data stored in Snowflake +- [Amplitude](https://www.amplitude.com): Dashboards and data exploration - accesses data stored in Snowflake ## Types of Telemetry ### Diagnostics -Diagnostic events include debug information (stack traces) from crash reports. Reports are sent on the first application launch after the crash occurred. We've built dashboards that allow us to visualize the frequency and severity of issues experienced by users. Having these reports sent automatically allows us to begin implementing fixes without the user needing to file a report in our issue tracker. The plots in the dashboards also give us an informal measurement of the stability of Zed. +Crash reports consist of a [minidump](https://learn.microsoft.com/en-us/windows/win32/debug/minidump-files) and some extra debug information. Reports are sent on the first application launch after the crash occurred. We've built dashboards that allow us to visualize the frequency and severity of issues experienced by users. Having these reports sent automatically allows us to begin implementing fixes without the user needing to file a report in our issue tracker. The plots in the dashboards also give us an informal measurement of the stability of Zed. -You can see what data is sent when a panic occurs by inspecting the `Panic` struct in [crates/telemetry_events/src/telemetry_events.rs](https://github.com/zed-industries/zed/blob/main/crates/telemetry_events/src/telemetry_events.rs) in the Zed repo. You can find additional information in the [Debugging Crashes](./development/debugging-crashes.md) documentation. +You can see what extra data is sent alongside the minidump in the `Panic` struct in [crates/telemetry_events/src/telemetry_events.rs](https://github.com/zed-industries/zed/blob/main/crates/telemetry_events/src/telemetry_events.rs) in the Zed repo. You can find additional information in the [Debugging Crashes](./development/debugging-crashes.md) documentation. -### Usage Data (Metrics) {#metrics} +### Client-Side Usage Data {#client-metrics} To improve Zed and understand how it is being used in the wild, Zed optionally collects usage data like the following: @@ -50,6 +51,12 @@ You can audit the metrics data that Zed has reported by running the command {#ac You can see the full list of the event types and exactly the data sent for each by inspecting the `Event` enum and the associated structs in [crates/telemetry_events/src/telemetry_events.rs](https://github.com/zed-industries/zed/blob/main/crates/telemetry_events/src/telemetry_events.rs) in the Zed repository. +### Server-Side Usage Data {#metrics} + +When using Zed's hosted services, we may collect, generate, and Process data to allow us to support users and improve our hosted offering. Examples include metadata around rate limiting and billing metrics/token usage. Zed does not persistently store user content or use user content to evaluate and/or improve our AI features, unless it is explicitly shared with Zed, and we have a zero-data retention agreement with Anthropic. + +You can see more about our stance on data collection (and that any prompt data shared with Zed is explicitly opt-in) at [AI Improvement](./ai/ai-improvement.md). + ## Concerns and Questions If you have concerns about telemetry, please feel free to [open an issue](https://github.com/zed-industries/zed/issues/new/choose). diff --git a/docs/src/visual-customization.md b/docs/src/visual-customization.md index 636e0f9c4e242e57dd96304d877534ddcbd83632..8b307d97d5851861b0a94e1834c67d1f85166afe 100644 --- a/docs/src/visual-customization.md +++ b/docs/src/visual-customization.md @@ -267,7 +267,7 @@ TBD: Centered layout related settings "display_in": "active_editor", // Where to show (active_editor, all_editor) "thumb": "always", // When to show thumb (always, hover) "thumb_border": "left_open", // Thumb border (left_open, right_open, full, none) - "max_width_columns": 80 // Maximum width of minimap + "max_width_columns": 80, // Maximum width of minimap "current_line_highlight": null // Highlight current line (null, line, gutter) }, @@ -317,7 +317,7 @@ TBD: Centered layout related settings ### Editor Completions, Snippets, Actions, Diagnostics {#editor-lsp} ```json - "snippet_sort_order": "inline", // Snippets completions: top, inline, bottom + "snippet_sort_order": "inline", // Snippets completions: top, inline, bottom, none "show_completions_on_input": true, // Show completions while typing "show_completion_documentation": true, // Show documentation in completions "auto_signature_help": false, // Show method signatures inside parentheses @@ -448,7 +448,7 @@ See [Zed AI Documentation](./ai/overview.md) for additional non-visual AI settin // Set the cursor blinking behavior in the terminal (on, off, terminal_controlled) "blinking": "terminal_controlled", - // Default cursor shape for the terminal (block, bar, underline, hollow) + // Default cursor shape for the terminal cursor (block, bar, underline, hollow) "cursor_shape": "block", // Environment variables to add to terminal's process environment diff --git a/docs/theme/index.hbs b/docs/theme/index.hbs index 8ab4f21cf167668ef3fb2f905dafb6b57496b88e..4339a02d1722d0d64e67b35de66889d9a849e9a4 100644 --- a/docs/theme/index.hbs +++ b/docs/theme/index.hbs @@ -15,7 +15,7 @@ {{> head}} - + diff --git a/extensions/emmet/Cargo.toml b/extensions/emmet/Cargo.toml index db8aaaae41ea353fc8623be485e6e48ecdfcfab1..9d72a6c5c4df38cfe1e9203641a3a2273a40084a 100644 --- a/extensions/emmet/Cargo.toml +++ b/extensions/emmet/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zed_emmet" -version = "0.0.3" +version = "0.0.4" edition.workspace = true publish.workspace = true license = "Apache-2.0" diff --git a/extensions/glsl/languages/glsl/config.toml b/extensions/glsl/languages/glsl/config.toml index 0144e981cc4d446192c4e433c6c5cc2c3929bb4a..0c71419c91e40f4b5fc65c10c882ac5c542a080c 100644 --- a/extensions/glsl/languages/glsl/config.toml +++ b/extensions/glsl/languages/glsl/config.toml @@ -12,7 +12,7 @@ path_suffixes = [ ] first_line_pattern = '^#version \d+' line_comments = ["// "] -block_comment = ["/* ", " */"] +block_comment = { start = "/* ", prefix = "* ", end = "*/", tab_size = 1 } brackets = [ { start = "{", end = "}", close = true, newline = true }, { start = "[", end = "]", close = true, newline = true }, diff --git a/extensions/html/languages/html/config.toml b/extensions/html/languages/html/config.toml index 6f52cc8f65e85bb0ec4ab0c8a32ba2f89bf41361..f74db2888eb71e6e9f9afcbb1b41ab98e232a7a7 100644 --- a/extensions/html/languages/html/config.toml +++ b/extensions/html/languages/html/config.toml @@ -2,7 +2,7 @@ name = "HTML" grammar = "html" path_suffixes = ["html", "htm", "shtml"] autoclose_before = ">})" -block_comment = [""] +block_comment = { start = "", tab_size = 0 } brackets = [ { start = "{", end = "}", close = true, newline = true }, { start = "[", end = "]", close = true, newline = true }, diff --git a/extensions/ruff/Cargo.toml b/extensions/ruff/Cargo.toml index 830897279aa8179315b620f8936de2bfc54cfb41..24616f963b0258bfaca01b7399a4b466eb5d9709 100644 --- a/extensions/ruff/Cargo.toml +++ b/extensions/ruff/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zed_ruff" -version = "0.1.0" +version = "0.1.1" edition.workspace = true publish.workspace = true license = "Apache-2.0" diff --git a/extensions/ruff/extension.toml b/extensions/ruff/extension.toml index 63929fc19119719638f29c6c7f7275fa43f2a113..1f5a7314f4a477679ce606626c9dc06fa6fea7b5 100644 --- a/extensions/ruff/extension.toml +++ b/extensions/ruff/extension.toml @@ -1,7 +1,7 @@ id = "ruff" name = "Ruff" description = "Support for Ruff, the Python linter and formatter" -version = "0.1.0" +version = "0.1.1" schema_version = 1 authors = [] repository = "https://github.com/zed-industries/zed" diff --git a/flake.lock b/flake.lock index fa0d51d90de9a6a9929241f6be212ea32e1432a2..80022f7b555900ad78dca230d37faeb04dd09c7d 100644 --- a/flake.lock +++ b/flake.lock @@ -2,11 +2,11 @@ "nodes": { "crane": { "locked": { - "lastModified": 1750266157, - "narHash": "sha256-tL42YoNg9y30u7zAqtoGDNdTyXTi8EALDeCB13FtbQA=", + "lastModified": 1754269165, + "narHash": "sha256-0tcS8FHd4QjbCVoxN9jI+PjHgA4vc/IjkUSp+N3zy0U=", "owner": "ipetkov", "repo": "crane", - "rev": "e37c943371b73ed87faf33f7583860f81f1d5a48", + "rev": "444e81206df3f7d92780680e45858e31d2f07a08", "type": "github" }, "original": { @@ -33,10 +33,10 @@ "nixpkgs": { "locked": { "lastModified": 315532800, - "narHash": "sha256-j+zO+IHQ7VwEam0pjPExdbLT2rVioyVS3iq4bLO3GEc=", - "rev": "61c0f513911459945e2cb8bf333dc849f1b976ff", + "narHash": "sha256-5VYevX3GccubYeccRGAXvCPA1ktrGmIX1IFC0icX07g=", + "rev": "a683adc19ff5228af548c6539dbc3440509bfed3", "type": "tarball", - "url": "https://releases.nixos.org/nixpkgs/nixpkgs-25.11pre821324.61c0f5139114/nixexprs.tar.xz" + "url": "https://releases.nixos.org/nixpkgs/nixpkgs-25.11pre840248.a683adc19ff5/nixexprs.tar.xz" }, "original": { "type": "tarball", @@ -58,11 +58,11 @@ ] }, "locked": { - "lastModified": 1750964660, - "narHash": "sha256-YQ6EyFetjH1uy5JhdhRdPe6cuNXlYpMAQePFfZj4W7M=", + "lastModified": 1754575663, + "narHash": "sha256-afOx8AG0KYtw7mlt6s6ahBBy7eEHZwws3iCRoiuRQS4=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "04f0fcfb1a50c63529805a798b4b5c21610ff390", + "rev": "6db0fb0e9cec2e9729dc52bf4898e6c135bb8a0f", "type": "github" }, "original": { diff --git a/nix/build.nix b/nix/build.nix index 873431a42768d86b28ce43f0202da713dae5ef52..70b4f76932fe3f0330b0c53163b668e3ff2c9d66 100644 --- a/nix/build.nix +++ b/nix/build.nix @@ -298,6 +298,7 @@ craneLib.buildPackage ( export APP_ARGS="%U" mkdir -p "$out/share/applications" ${lib.getExe envsubst} < "crates/zed/resources/zed.desktop.in" > "$out/share/applications/dev.zed.Zed-Nightly.desktop" + chmod +x "$out/share/applications/dev.zed.Zed-Nightly.desktop" ) runHook postInstall diff --git a/rust-toolchain.toml b/rust-toolchain.toml index f80eab8fbcbd78e5bbf3bf4e8757bc6872146e1b..3d87025a27e2a159cfd7e7e83138f0893296e939 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] -channel = "1.88" +channel = "1.89" profile = "minimal" components = [ "rustfmt", "clippy" ] targets = [ diff --git a/script/bundle-freebsd b/script/bundle-freebsd index 7222a0625692c06606aad64d1420fbfa7daae943..87c9459ffb1086e416aaa40618c4a3f0f261373e 100755 --- a/script/bundle-freebsd +++ b/script/bundle-freebsd @@ -138,6 +138,7 @@ fi # mkdir -p "${zed_dir}/share/applications" # envsubst <"crates/zed/resources/zed.desktop.in" >"${zed_dir}/share/applications/zed$suffix.desktop" +# chmod +x "${zed_dir}/share/applications/zed$suffix.desktop" # Copy generated licenses so they'll end up in archive too # cp "assets/licenses.md" "${zed_dir}/licenses.md" diff --git a/script/bundle-linux b/script/bundle-linux index c52312015bed92cad021a187c59146ae8aeb9800..ad67b7a0f75f8c3e5d22e1a12e175ed248ecaf57 100755 --- a/script/bundle-linux +++ b/script/bundle-linux @@ -83,6 +83,23 @@ if [[ "$remote_server_triple" == "$musl_triple" ]]; then fi cargo build --release --target "${remote_server_triple}" --package remote_server +# Upload debug info to sentry.io +if ! command -v sentry-cli >/dev/null 2>&1; then + echo "sentry-cli not found. skipping sentry upload." + echo "install with: 'curl -sL https://sentry.io/get-cli | bash'" +else + if [[ -n "${SENTRY_AUTH_TOKEN:-}" ]]; then + echo "Uploading zed debug symbols to sentry..." + # note: this uploads the unstripped binary which is needed because it contains + # .eh_frame data for stack unwinindg. see https://github.com/getsentry/symbolic/issues/783 + sentry-cli debug-files upload --include-sources --wait -p zed -o zed-dev \ + "${target_dir}/${target_triple}"/release/zed \ + "${target_dir}/${remote_server_triple}"/release/remote_server + else + echo "missing SENTRY_AUTH_TOKEN. skipping sentry upload." + fi +fi + # Strip debug symbols and save them for upload to DigitalOcean objcopy --only-keep-debug "${target_dir}/${target_triple}/release/zed" "${target_dir}/${target_triple}/release/zed.dbg" objcopy --only-keep-debug "${target_dir}/${remote_server_triple}/release/remote_server" "${target_dir}/${remote_server_triple}/release/remote_server.dbg" @@ -162,6 +179,7 @@ fi mkdir -p "${zed_dir}/share/applications" envsubst < "crates/zed/resources/zed.desktop.in" > "${zed_dir}/share/applications/zed$suffix.desktop" +chmod +x "${zed_dir}/share/applications/zed$suffix.desktop" # Copy generated licenses so they'll end up in archive too cp "assets/licenses.md" "${zed_dir}/licenses.md" diff --git a/script/bundle-mac b/script/bundle-mac index 18dfe90815243c0c948e66fab0ad6d1b5d78d44c..b2be5732355c352bbbfa2ef248acdaf63c74193d 100755 --- a/script/bundle-mac +++ b/script/bundle-mac @@ -366,3 +366,20 @@ else gzip -f --stdout --best target/x86_64-apple-darwin/release/remote_server > target/zed-remote-server-macos-x86_64.gz gzip -f --stdout --best target/aarch64-apple-darwin/release/remote_server > target/zed-remote-server-macos-aarch64.gz fi + +# Upload debug info to sentry.io +if ! command -v sentry-cli >/dev/null 2>&1; then + echo "sentry-cli not found. skipping sentry upload." + echo "install with: 'curl -sL https://sentry.io/get-cli | bash'" +else + if [[ -n "${SENTRY_AUTH_TOKEN:-}" ]]; then + echo "Uploading zed debug symbols to sentry..." + # note: this uploads the unstripped binary which is needed because it contains + # .eh_frame data for stack unwinindg. see https://github.com/getsentry/symbolic/issues/783 + sentry-cli debug-files upload --include-sources --wait -p zed -o zed-dev \ + "target/x86_64-apple-darwin/${target_dir}/" \ + "target/aarch64-apple-darwin/${target_dir}/" + else + echo "missing SENTRY_AUTH_TOKEN. skipping sentry upload." + fi +fi diff --git a/script/bundle-windows.ps1 b/script/bundle-windows.ps1 index 6eaf98cd231ba4eadfaf508eb9f4ee1f6b38cc40..8ae02124918a2f7f47a1c6204f5199f6eb4e6056 100644 --- a/script/bundle-windows.ps1 +++ b/script/bundle-windows.ps1 @@ -26,6 +26,7 @@ if ($Help) { Push-Location -Path crates/zed $channel = Get-Content "RELEASE_CHANNEL" $env:ZED_RELEASE_CHANNEL = $channel +$env:RELEASE_CHANNEL = $channel Pop-Location function CheckEnvironmentVariables { @@ -56,6 +57,13 @@ function PrepareForBundle { New-Item -Path "$innoDir\tools" -ItemType Directory -Force } +function GenerateLicenses { + $oldErrorActionPreference = $ErrorActionPreference + $ErrorActionPreference = 'Continue' + . $PSScriptRoot/generate-licenses.ps1 + $ErrorActionPreference = $oldErrorActionPreference +} + function BuildZedAndItsFriends { Write-Output "Building Zed and its friends, for channel: $channel" # Build zed.exe, cli.exe and auto_update_helper.exe @@ -89,6 +97,21 @@ function ZipZedAndItsFriendsDebug { Compress-Archive -Path $items -DestinationPath ".\target\release\zed-$env:RELEASE_VERSION-$env:ZED_RELEASE_CHANNEL.dbg.zip" -Force } + +function UploadToSentry { + if (-not (Get-Command "sentry-cli" -ErrorAction SilentlyContinue)) { + Write-Output "sentry-cli not found. skipping sentry upload." + Write-Output "install with: 'winget install -e --id=Sentry.sentry-cli'" + return + } + if (-not (Test-Path "env:SENTRY_AUTH_TOKEN")) { + Write-Output "missing SENTRY_AUTH_TOKEN. skipping sentry upload." + return + } + Write-Output "Uploading zed debug symbols to sentry..." + sentry-cli debug-files upload --include-sources --wait -p zed -o zed-dev .\target\release\ +} + function MakeAppx { switch ($channel) { "stable" { @@ -113,11 +136,22 @@ function SignZedAndItsFriends { & "$innoDir\sign.ps1" $files } +function DownloadAMDGpuServices { + # If you update the AGS SDK version, please also update the version in `crates/gpui/src/platform/windows/directx_renderer.rs` + $url = "https://codeload.github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/zip/refs/tags/v6.3.0" + $zipPath = ".\AGS_SDK_v6.3.0.zip" + # Download the AGS SDK zip file + Invoke-WebRequest -Uri $url -OutFile $zipPath + # Extract the AGS SDK zip file + Expand-Archive -Path $zipPath -DestinationPath "." -Force +} + function CollectFiles { Move-Item -Path "$innoDir\zed_explorer_command_injector.appx" -Destination "$innoDir\appx\zed_explorer_command_injector.appx" -Force Move-Item -Path "$innoDir\zed_explorer_command_injector.dll" -Destination "$innoDir\appx\zed_explorer_command_injector.dll" -Force Move-Item -Path "$innoDir\cli.exe" -Destination "$innoDir\bin\zed.exe" -Force Move-Item -Path "$innoDir\auto_update_helper.exe" -Destination "$innoDir\tools\auto_update_helper.exe" -Force + Move-Item -Path ".\AGS_SDK-6.3.0\ags_lib\lib\amd_ags_x64.dll" -Destination "$innoDir\amd_ags_x64.dll" -Force } function BuildInstaller { @@ -167,7 +201,7 @@ function BuildInstaller { } "dev" { $appId = "{{8357632E-24A4-4F32-BA97-E575B4D1FE5D}" - $appIconName = "app-icon-nightly" + $appIconName = "app-icon-dev" $appName = "Zed Dev" $appDisplayName = "Zed Dev" $appSetupName = "ZedEditorUserSetup-x64-$env:RELEASE_VERSION-dev" @@ -188,7 +222,6 @@ function BuildInstaller { # Windows runner 2022 default has iscc in PATH, https://github.com/actions/runner-images/blob/main/images/windows/Windows2022-Readme.md # Currently, we are using Windows 2022 runner. # Windows runner 2025 doesn't have iscc in PATH for now, https://github.com/actions/runner-images/issues/11228 - # $innoSetupPath = "iscc.exe" $innoSetupPath = "C:\Program Files (x86)\Inno Setup 6\ISCC.exe" $definitions = @{ @@ -235,19 +268,22 @@ function BuildInstaller { ParseZedWorkspace $innoDir = "$env:ZED_WORKSPACE\inno" +$debugArchive = ".\target\release\zed-$env:RELEASE_VERSION-$env:ZED_RELEASE_CHANNEL.dbg.zip" +$debugStoreKey = "$env:ZED_RELEASE_CHANNEL/zed-$env:RELEASE_VERSION-$env:ZED_RELEASE_CHANNEL.dbg.zip" CheckEnvironmentVariables PrepareForBundle +GenerateLicenses BuildZedAndItsFriends MakeAppx SignZedAndItsFriends ZipZedAndItsFriendsDebug +DownloadAMDGpuServices CollectFiles BuildInstaller -$debugArchive = ".\target\release\zed-$env:RELEASE_VERSION-$env:ZED_RELEASE_CHANNEL.dbg.zip" -$debugStoreKey = "$env:ZED_RELEASE_CHANNEL/zed-$env:RELEASE_VERSION-$env:ZED_RELEASE_CHANNEL.dbg.zip" UploadToBlobStorePublic -BucketName "zed-debug-symbols" -FileToUpload $debugArchive -BlobStoreKey $debugStoreKey +UploadToSentry if ($buildSuccess) { Write-Output "Build successful" diff --git a/script/generate-licenses.ps1 b/script/generate-licenses.ps1 new file mode 100644 index 0000000000000000000000000000000000000000..52a6fe0118b9979be23d0c584bc0facdd4ce8f1e --- /dev/null +++ b/script/generate-licenses.ps1 @@ -0,0 +1,44 @@ +$CARGO_ABOUT_VERSION="0.7" +$outputFile=$args[0] ? $args[0] : "$(Get-Location)/assets/licenses.md" +$templateFile="script/licenses/template.md.hbs" + +New-Item -Path "$outputFile" -ItemType File -Value "" -Force + +@( + "# ###### THEME LICENSES ######\n" + Get-Content assets/themes/LICENSES + "\n# ###### ICON LICENSES ######\n" + Get-Content assets/icons/LICENSES + "\n# ###### CODE LICENSES ######\n" +) | Add-Content -Path $outputFile + +$versionOutput = cargo about --version +if (-not ($versionOutput -match "cargo-about $CARGO_ABOUT_VERSION")) { + Write-Host "Installing cargo-about@^$CARGO_ABOUT_VERSION..." + cargo install "cargo-about@^$CARGO_ABOUT_VERSION" +} else { + Write-Host "cargo-about@^$CARGO_ABOUT_VERSION" is already installed +} + +Write-Host "Generating cargo licenses" + +$failFlag = $env:ALLOW_MISSING_LICENSES ? "--fail" : "" +$args = @('about', 'generate', $failFlag, '-c', 'script/licenses/zed-licenses.toml', $templateFile, '-o', $outputFile) | Where-Object { $_ } +cargo @args + +Write-Host "Applying replacements" +$replacements = @{ + '"' = '"' + ''' = "'" + '=' = '=' + '`' = '`' + '<' = '<' + '>' = '>' +} +$content = Get-Content $outputFile +foreach ($find in $replacements.keys) { + $content = $content -replace $find, $replacements[$find] +} +$content | Set-Content $outputFile + +Write-Host "generate-licenses completed. See $outputFile" diff --git a/script/linux b/script/linux index bc46291023c43467fcf5a623de241277e157a7bd..f1e5a20eae9a32af2e56f3d169a2ec6be038fba5 100755 --- a/script/linux +++ b/script/linux @@ -148,6 +148,7 @@ if [[ -n $zyp ]]; then gzip jq libvulkan1 + libx11-devel libxcb-devel libxkbcommon-devel libxkbcommon-x11-devel diff --git a/script/new-crate b/script/new-crate index df574981e739a465f3f4f92d8a05c8df7cffdb82..52ee900b30837cbf77fa1e3145e0282fa5e19b7c 100755 --- a/script/new-crate +++ b/script/new-crate @@ -39,7 +39,7 @@ CRATE_PATH="crates/$CRATE_NAME" mkdir -p "$CRATE_PATH/src" # Symlink the license -ln -sf "../../../$LICENSE_FILE" "$CRATE_PATH/$LICENSE_FILE" +ln -sf "../../$LICENSE_FILE" "$CRATE_PATH/$LICENSE_FILE" CARGO_TOML_TEMPLATE=$(cat << 'EOF' [package] diff --git a/script/zed-local b/script/zed-local index 256893124668c2e89a0860cd1934f841ab0692e4..99d93082326af5f5af159a28276beb45b381e735 100755 --- a/script/zed-local +++ b/script/zed-local @@ -213,7 +213,7 @@ setTimeout(() => { platform === "win32" ? "http://127.0.0.1:8080/rpc" : "http://localhost:8080/rpc", - ZED_ADMIN_API_TOKEN: "secret", + ZED_ADMIN_API_TOKEN: "internal-api-key-secret", ZED_WINDOW_SIZE: size, ZED_CLIENT_CHECKSUM_SEED: "development-checksum-seed", RUST_LOG: process.env.RUST_LOG || "info", diff --git a/tooling/workspace-hack/Cargo.toml b/tooling/workspace-hack/Cargo.toml index d4019ab85440a1ed5b44f064aa20e01dac872518..338985ed9592156af9a88d240ab2dc6a41c808ef 100644 --- a/tooling/workspace-hack/Cargo.toml +++ b/tooling/workspace-hack/Cargo.toml @@ -82,6 +82,7 @@ lyon = { version = "1", default-features = false, features = ["extra"] } lyon_path = { version = "1" } md-5 = { version = "0.10" } memchr = { version = "2" } +mime_guess = { version = "2" } miniz_oxide = { version = "0.8", features = ["simd"] } nom = { version = "7" } num-bigint = { version = "0.4" } @@ -107,7 +108,7 @@ rustc-hash = { version = "1" } rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", default-features = false, features = ["fs", "net", "std"] } rustls = { version = "0.23", features = ["ring"] } rustls-webpki = { version = "0.103", default-features = false, features = ["aws-lc-rs", "ring", "std"] } -schemars = { version = "1", features = ["chrono04", "indexmap2"] } +schemars = { version = "1", features = ["chrono04", "indexmap2", "semver1"] } sea-orm = { version = "1", features = ["runtime-tokio-rustls", "sqlx-postgres", "sqlx-sqlite"] } sea-query-binder = { version = "0.7", default-features = false, features = ["postgres-array", "sqlx-postgres", "sqlx-sqlite", "with-bigdecimal", "with-chrono", "with-json", "with-rust_decimal", "with-time", "with-uuid"] } semver = { version = "1", features = ["serde"] } @@ -212,6 +213,7 @@ lyon = { version = "1", default-features = false, features = ["extra"] } lyon_path = { version = "1" } md-5 = { version = "0.10" } memchr = { version = "2" } +mime_guess = { version = "2" } miniz_oxide = { version = "0.8", features = ["simd"] } nom = { version = "7" } num-bigint = { version = "0.4" } @@ -240,7 +242,7 @@ rustc-hash = { version = "1" } rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", default-features = false, features = ["fs", "net", "std"] } rustls = { version = "0.23", features = ["ring"] } rustls-webpki = { version = "0.103", default-features = false, features = ["aws-lc-rs", "ring", "std"] } -schemars = { version = "1", features = ["chrono04", "indexmap2"] } +schemars = { version = "1", features = ["chrono04", "indexmap2", "semver1"] } sea-orm = { version = "1", features = ["runtime-tokio-rustls", "sqlx-postgres", "sqlx-sqlite"] } sea-query-binder = { version = "0.7", default-features = false, features = ["postgres-array", "sqlx-postgres", "sqlx-sqlite", "with-bigdecimal", "with-chrono", "with-json", "with-rust_decimal", "with-time", "with-uuid"] } semver = { version = "1", features = ["serde"] } @@ -284,14 +286,13 @@ winnow = { version = "0.7", features = ["simd"] } codespan-reporting = { version = "0.12" } core-foundation = { version = "0.9" } core-foundation-sys = { version = "0.8" } -coreaudio-sys = { version = "0.2", default-features = false, features = ["audio_toolbox", "audio_unit", "core_audio", "core_midi", "open_al"] } foldhash = { version = "0.1", default-features = false, features = ["std"] } getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-features = false, features = ["std"] } gimli = { version = "0.31", default-features = false, features = ["read", "std", "write"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } naga = { version = "25", features = ["msl-out", "wgsl-in"] } -nix = { version = "0.29", features = ["fs", "pthread", "signal", "user"] } +nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "user"] } objc2 = { version = "0.6" } objc2-core-foundation = { version = "0.3", default-features = false, features = ["CFArray", "CFCGTypes", "CFData", "CFDate", "CFDictionary", "CFRunLoop", "CFString", "CFURL", "objc2", "std"] } objc2-foundation = { version = "0.3", default-features = false, features = ["NSArray", "NSAttributedString", "NSBundle", "NSCoder", "NSData", "NSDate", "NSDictionary", "NSEnumerator", "NSError", "NSGeometry", "NSNotification", "NSNull", "NSObjCRuntime", "NSObject", "NSProcessInfo", "NSRange", "NSRunLoop", "NSString", "NSURL", "NSUndoManager", "NSValue", "objc2-core-foundation", "std"] } @@ -304,24 +305,22 @@ scopeguard = { version = "1" } security-framework = { version = "3", features = ["OSX_10_14"] } security-framework-sys = { version = "2", features = ["OSX_10_14"] } sync_wrapper = { version = "1", default-features = false, features = ["futures"] } -tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } +tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } tower = { version = "0.5", default-features = false, features = ["timeout", "util"] } [target.x86_64-apple-darwin.build-dependencies] -clang-sys = { version = "1", default-features = false, features = ["clang_11_0", "runtime"] } codespan-reporting = { version = "0.12" } core-foundation = { version = "0.9" } core-foundation-sys = { version = "0.8" } -coreaudio-sys = { version = "0.2", default-features = false, features = ["audio_toolbox", "audio_unit", "core_audio", "core_midi", "open_al"] } foldhash = { version = "0.1", default-features = false, features = ["std"] } getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-features = false, features = ["std"] } gimli = { version = "0.31", default-features = false, features = ["read", "std", "write"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } naga = { version = "25", features = ["msl-out", "wgsl-in"] } -nix = { version = "0.29", features = ["fs", "pthread", "signal", "user"] } +nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "user"] } objc2 = { version = "0.6" } objc2-core-foundation = { version = "0.3", default-features = false, features = ["CFArray", "CFCGTypes", "CFData", "CFDate", "CFDictionary", "CFRunLoop", "CFString", "CFURL", "objc2", "std"] } objc2-foundation = { version = "0.3", default-features = false, features = ["NSArray", "NSAttributedString", "NSBundle", "NSCoder", "NSData", "NSDate", "NSDictionary", "NSEnumerator", "NSError", "NSGeometry", "NSNotification", "NSNull", "NSObjCRuntime", "NSObject", "NSProcessInfo", "NSRange", "NSRunLoop", "NSString", "NSURL", "NSUndoManager", "NSValue", "objc2-core-foundation", "std"] } @@ -335,7 +334,7 @@ scopeguard = { version = "1" } security-framework = { version = "3", features = ["OSX_10_14"] } security-framework-sys = { version = "2", features = ["OSX_10_14"] } sync_wrapper = { version = "1", default-features = false, features = ["futures"] } -tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } +tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } tower = { version = "0.5", default-features = false, features = ["timeout", "util"] } @@ -344,14 +343,13 @@ tower = { version = "0.5", default-features = false, features = ["timeout", "uti codespan-reporting = { version = "0.12" } core-foundation = { version = "0.9" } core-foundation-sys = { version = "0.8" } -coreaudio-sys = { version = "0.2", default-features = false, features = ["audio_toolbox", "audio_unit", "core_audio", "core_midi", "open_al"] } foldhash = { version = "0.1", default-features = false, features = ["std"] } getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-features = false, features = ["std"] } gimli = { version = "0.31", default-features = false, features = ["read", "std", "write"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } naga = { version = "25", features = ["msl-out", "wgsl-in"] } -nix = { version = "0.29", features = ["fs", "pthread", "signal", "user"] } +nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "user"] } objc2 = { version = "0.6" } objc2-core-foundation = { version = "0.3", default-features = false, features = ["CFArray", "CFCGTypes", "CFData", "CFDate", "CFDictionary", "CFRunLoop", "CFString", "CFURL", "objc2", "std"] } objc2-foundation = { version = "0.3", default-features = false, features = ["NSArray", "NSAttributedString", "NSBundle", "NSCoder", "NSData", "NSDate", "NSDictionary", "NSEnumerator", "NSError", "NSGeometry", "NSNotification", "NSNull", "NSObjCRuntime", "NSObject", "NSProcessInfo", "NSRange", "NSRunLoop", "NSString", "NSURL", "NSUndoManager", "NSValue", "objc2-core-foundation", "std"] } @@ -364,24 +362,22 @@ scopeguard = { version = "1" } security-framework = { version = "3", features = ["OSX_10_14"] } security-framework-sys = { version = "2", features = ["OSX_10_14"] } sync_wrapper = { version = "1", default-features = false, features = ["futures"] } -tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } +tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } tower = { version = "0.5", default-features = false, features = ["timeout", "util"] } [target.aarch64-apple-darwin.build-dependencies] -clang-sys = { version = "1", default-features = false, features = ["clang_11_0", "runtime"] } codespan-reporting = { version = "0.12" } core-foundation = { version = "0.9" } core-foundation-sys = { version = "0.8" } -coreaudio-sys = { version = "0.2", default-features = false, features = ["audio_toolbox", "audio_unit", "core_audio", "core_midi", "open_al"] } foldhash = { version = "0.1", default-features = false, features = ["std"] } getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-features = false, features = ["std"] } gimli = { version = "0.31", default-features = false, features = ["read", "std", "write"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } naga = { version = "25", features = ["msl-out", "wgsl-in"] } -nix = { version = "0.29", features = ["fs", "pthread", "signal", "user"] } +nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "user"] } objc2 = { version = "0.6" } objc2-core-foundation = { version = "0.3", default-features = false, features = ["CFArray", "CFCGTypes", "CFData", "CFDate", "CFDictionary", "CFRunLoop", "CFString", "CFURL", "objc2", "std"] } objc2-foundation = { version = "0.3", default-features = false, features = ["NSArray", "NSAttributedString", "NSBundle", "NSCoder", "NSData", "NSDate", "NSDictionary", "NSEnumerator", "NSError", "NSGeometry", "NSNotification", "NSNull", "NSObjCRuntime", "NSObject", "NSProcessInfo", "NSRange", "NSRunLoop", "NSString", "NSURL", "NSUndoManager", "NSValue", "objc2-core-foundation", "std"] } @@ -395,7 +391,7 @@ scopeguard = { version = "1" } security-framework = { version = "3", features = ["OSX_10_14"] } security-framework-sys = { version = "2", features = ["OSX_10_14"] } sync_wrapper = { version = "1", default-features = false, features = ["futures"] } -tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } +tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } tower = { version = "0.5", default-features = false, features = ["timeout", "util"] } @@ -420,7 +416,8 @@ linux-raw-sys-274715c4dabd11b0 = { package = "linux-raw-sys", version = "0.9", d linux-raw-sys-9fbad63c4bcf4a8f = { package = "linux-raw-sys", version = "0.4", default-features = false, features = ["elf", "errno", "general", "if_ether", "ioctl", "net", "netlink", "no_std", "prctl", "system", "xdp"] } mio = { version = "1", features = ["net", "os-ext"] } naga = { version = "25", features = ["spv-out", "wgsl-in"] } -nix = { version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } +nix-1f5adca70f036a62 = { package = "nix", version = "0.28", features = ["fs", "mman", "ptrace", "signal", "term", "user"] } +nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } num-bigint-dig = { version = "0.8", features = ["i128", "prime", "zeroize"] } object = { version = "0.36", default-features = false, features = ["archive", "read_core", "unaligned", "write"] } proc-macro2 = { version = "1", features = ["span-locations"] } @@ -432,7 +429,7 @@ rustix-dff4ba8e3ae991db = { package = "rustix", version = "1", features = ["fs", scopeguard = { version = "1" } syn-f595c2ba2a3f28df = { package = "syn", version = "2", features = ["extra-traits", "fold", "full", "visit", "visit-mut"] } sync_wrapper = { version = "1", default-features = false, features = ["futures"] } -tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } +tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } toml_datetime = { version = "0.6", default-features = false, features = ["serde"] } @@ -460,7 +457,8 @@ linux-raw-sys-274715c4dabd11b0 = { package = "linux-raw-sys", version = "0.9", d linux-raw-sys-9fbad63c4bcf4a8f = { package = "linux-raw-sys", version = "0.4", default-features = false, features = ["elf", "errno", "general", "if_ether", "ioctl", "net", "netlink", "no_std", "prctl", "system", "xdp"] } mio = { version = "1", features = ["net", "os-ext"] } naga = { version = "25", features = ["spv-out", "wgsl-in"] } -nix = { version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } +nix-1f5adca70f036a62 = { package = "nix", version = "0.28", features = ["fs", "mman", "ptrace", "signal", "term", "user"] } +nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } num-bigint-dig = { version = "0.8", features = ["i128", "prime", "zeroize"] } object = { version = "0.36", default-features = false, features = ["archive", "read_core", "unaligned", "write"] } proc-macro2 = { version = "1", default-features = false, features = ["span-locations"] } @@ -470,7 +468,7 @@ rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", features = ["ev rustix-dff4ba8e3ae991db = { package = "rustix", version = "1", features = ["fs", "net", "process", "termios", "time"] } scopeguard = { version = "1" } sync_wrapper = { version = "1", default-features = false, features = ["futures"] } -tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } +tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } toml_datetime = { version = "0.6", default-features = false, features = ["serde"] } @@ -498,7 +496,8 @@ linux-raw-sys-274715c4dabd11b0 = { package = "linux-raw-sys", version = "0.9", d linux-raw-sys-9fbad63c4bcf4a8f = { package = "linux-raw-sys", version = "0.4", default-features = false, features = ["elf", "errno", "general", "if_ether", "ioctl", "net", "netlink", "no_std", "prctl", "system", "xdp"] } mio = { version = "1", features = ["net", "os-ext"] } naga = { version = "25", features = ["spv-out", "wgsl-in"] } -nix = { version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } +nix-1f5adca70f036a62 = { package = "nix", version = "0.28", features = ["fs", "mman", "ptrace", "signal", "term", "user"] } +nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } num-bigint-dig = { version = "0.8", features = ["i128", "prime", "zeroize"] } object = { version = "0.36", default-features = false, features = ["archive", "read_core", "unaligned", "write"] } proc-macro2 = { version = "1", features = ["span-locations"] } @@ -510,7 +509,7 @@ rustix-dff4ba8e3ae991db = { package = "rustix", version = "1", features = ["fs", scopeguard = { version = "1" } syn-f595c2ba2a3f28df = { package = "syn", version = "2", features = ["extra-traits", "fold", "full", "visit", "visit-mut"] } sync_wrapper = { version = "1", default-features = false, features = ["futures"] } -tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } +tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } toml_datetime = { version = "0.6", default-features = false, features = ["serde"] } @@ -538,7 +537,8 @@ linux-raw-sys-274715c4dabd11b0 = { package = "linux-raw-sys", version = "0.9", d linux-raw-sys-9fbad63c4bcf4a8f = { package = "linux-raw-sys", version = "0.4", default-features = false, features = ["elf", "errno", "general", "if_ether", "ioctl", "net", "netlink", "no_std", "prctl", "system", "xdp"] } mio = { version = "1", features = ["net", "os-ext"] } naga = { version = "25", features = ["spv-out", "wgsl-in"] } -nix = { version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } +nix-1f5adca70f036a62 = { package = "nix", version = "0.28", features = ["fs", "mman", "ptrace", "signal", "term", "user"] } +nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } num-bigint-dig = { version = "0.8", features = ["i128", "prime", "zeroize"] } object = { version = "0.36", default-features = false, features = ["archive", "read_core", "unaligned", "write"] } proc-macro2 = { version = "1", default-features = false, features = ["span-locations"] } @@ -548,7 +548,7 @@ rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", features = ["ev rustix-dff4ba8e3ae991db = { package = "rustix", version = "1", features = ["fs", "net", "process", "termios", "time"] } scopeguard = { version = "1" } sync_wrapper = { version = "1", default-features = false, features = ["futures"] } -tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } +tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } toml_datetime = { version = "0.6", default-features = false, features = ["serde"] } @@ -564,23 +564,20 @@ getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-f getrandom-6f8ce4dd05d13bba = { package = "getrandom", version = "0.2", default-features = false, features = ["js", "rdrand"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } -naga = { version = "25", features = ["spv-out", "wgsl-in"] } ring = { version = "0.17", features = ["std"] } rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", features = ["event"] } scopeguard = { version = "1" } sync_wrapper = { version = "1", default-features = false, features = ["futures"] } -tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } +tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } tower = { version = "0.5", default-features = false, features = ["timeout", "util"] } winapi = { version = "0.3", default-features = false, features = ["cfg", "commapi", "consoleapi", "errhandlingapi", "evntrace", "fileapi", "handleapi", "impl-debug", "impl-default", "in6addr", "inaddr", "ioapiset", "knownfolders", "minwinbase", "minwindef", "namedpipeapi", "ntsecapi", "objbase", "processenv", "processthreadsapi", "shlobj", "std", "synchapi", "sysinfoapi", "timezoneapi", "winbase", "windef", "winerror", "winioctl", "winnt", "winreg", "winsock2", "winuser"] } -windows = { version = "0.61", features = ["Foundation_Metadata", "Foundation_Numerics", "Graphics_Capture", "Graphics_DirectX_Direct3D11", "Graphics_Imaging", "Media_Core", "Media_MediaProperties", "Media_Transcoding", "Security_Cryptography", "Storage_Search", "Storage_Streams", "System_Threading", "UI_ViewManagement", "Wdk_System_SystemServices", "Win32_Devices_Display", "Win32_Globalization", "Win32_Graphics_Direct2D_Common", "Win32_Graphics_Direct3D", "Win32_Graphics_Direct3D11", "Win32_Graphics_DirectWrite", "Win32_Graphics_Dwm", "Win32_Graphics_Dxgi_Common", "Win32_Graphics_Gdi", "Win32_Graphics_Imaging_D2D", "Win32_Networking_WinSock", "Win32_Security_Credentials", "Win32_Storage_FileSystem", "Win32_System_Com_StructuredStorage", "Win32_System_Console", "Win32_System_DataExchange", "Win32_System_IO", "Win32_System_LibraryLoader", "Win32_System_Memory", "Win32_System_Ole", "Win32_System_Pipes", "Win32_System_SystemInformation", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_Variant", "Win32_System_WinRT_Direct3D11", "Win32_System_WinRT_Graphics_Capture", "Win32_UI_Controls", "Win32_UI_HiDpi", "Win32_UI_Input_Ime", "Win32_UI_Input_KeyboardAndMouse", "Win32_UI_Shell_Common", "Win32_UI_Shell_PropertiesSystem", "Win32_UI_WindowsAndMessaging"] } windows-core = { version = "0.61" } -windows-future = { version = "0.2" } windows-numerics = { version = "0.2" } windows-sys-73dcd821b1037cfd = { package = "windows-sys", version = "0.59", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Win32_Globalization", "Win32_NetworkManagement_IpHelper", "Win32_Networking_WinSock", "Win32_Security_Authentication_Identity", "Win32_Security_Credentials", "Win32_Security_Cryptography", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Console", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Ioctl", "Win32_System_Kernel", "Win32_System_LibraryLoader", "Win32_System_Memory", "Win32_System_Performance", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_SystemInformation", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_Time", "Win32_System_WindowsProgramming", "Win32_UI_Input_KeyboardAndMouse", "Win32_UI_Shell", "Win32_UI_WindowsAndMessaging"] } windows-sys-b21d60becc0929df = { package = "windows-sys", version = "0.52", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Wdk_System_IO", "Win32_Foundation", "Win32_Networking_WinSock", "Win32_Security_Authorization", "Win32_Storage_FileSystem", "Win32_System_Console", "Win32_System_IO", "Win32_System_Memory", "Win32_System_Pipes", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_WindowsProgramming"] } -windows-sys-c8eced492e86ede7 = { package = "windows-sys", version = "0.48", features = ["Win32_Foundation", "Win32_Globalization", "Win32_Networking_WinSock", "Win32_Security", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_Threading", "Win32_System_Time", "Win32_UI_Shell"] } +windows-sys-c8eced492e86ede7 = { package = "windows-sys", version = "0.48", features = ["Win32_Foundation", "Win32_Globalization", "Win32_Networking_WinSock", "Win32_Security", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_Threading", "Win32_System_Time", "Win32_System_WindowsProgramming", "Win32_UI_Shell"] } [target.x86_64-pc-windows-msvc.build-dependencies] codespan-reporting = { version = "0.12" } @@ -590,24 +587,21 @@ getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-f getrandom-6f8ce4dd05d13bba = { package = "getrandom", version = "0.2", default-features = false, features = ["js", "rdrand"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } -naga = { version = "25", features = ["spv-out", "wgsl-in"] } proc-macro2 = { version = "1", default-features = false, features = ["span-locations"] } ring = { version = "0.17", features = ["std"] } rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", features = ["event"] } scopeguard = { version = "1" } sync_wrapper = { version = "1", default-features = false, features = ["futures"] } -tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } +tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } tower = { version = "0.5", default-features = false, features = ["timeout", "util"] } winapi = { version = "0.3", default-features = false, features = ["cfg", "commapi", "consoleapi", "errhandlingapi", "evntrace", "fileapi", "handleapi", "impl-debug", "impl-default", "in6addr", "inaddr", "ioapiset", "knownfolders", "minwinbase", "minwindef", "namedpipeapi", "ntsecapi", "objbase", "processenv", "processthreadsapi", "shlobj", "std", "synchapi", "sysinfoapi", "timezoneapi", "winbase", "windef", "winerror", "winioctl", "winnt", "winreg", "winsock2", "winuser"] } -windows = { version = "0.61", features = ["Foundation_Metadata", "Foundation_Numerics", "Graphics_Capture", "Graphics_DirectX_Direct3D11", "Graphics_Imaging", "Media_Core", "Media_MediaProperties", "Media_Transcoding", "Security_Cryptography", "Storage_Search", "Storage_Streams", "System_Threading", "UI_ViewManagement", "Wdk_System_SystemServices", "Win32_Devices_Display", "Win32_Globalization", "Win32_Graphics_Direct2D_Common", "Win32_Graphics_Direct3D", "Win32_Graphics_Direct3D11", "Win32_Graphics_DirectWrite", "Win32_Graphics_Dwm", "Win32_Graphics_Dxgi_Common", "Win32_Graphics_Gdi", "Win32_Graphics_Imaging_D2D", "Win32_Networking_WinSock", "Win32_Security_Credentials", "Win32_Storage_FileSystem", "Win32_System_Com_StructuredStorage", "Win32_System_Console", "Win32_System_DataExchange", "Win32_System_IO", "Win32_System_LibraryLoader", "Win32_System_Memory", "Win32_System_Ole", "Win32_System_Pipes", "Win32_System_SystemInformation", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_Variant", "Win32_System_WinRT_Direct3D11", "Win32_System_WinRT_Graphics_Capture", "Win32_UI_Controls", "Win32_UI_HiDpi", "Win32_UI_Input_Ime", "Win32_UI_Input_KeyboardAndMouse", "Win32_UI_Shell_Common", "Win32_UI_Shell_PropertiesSystem", "Win32_UI_WindowsAndMessaging"] } windows-core = { version = "0.61" } -windows-future = { version = "0.2" } windows-numerics = { version = "0.2" } windows-sys-73dcd821b1037cfd = { package = "windows-sys", version = "0.59", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Win32_Globalization", "Win32_NetworkManagement_IpHelper", "Win32_Networking_WinSock", "Win32_Security_Authentication_Identity", "Win32_Security_Credentials", "Win32_Security_Cryptography", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Console", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Ioctl", "Win32_System_Kernel", "Win32_System_LibraryLoader", "Win32_System_Memory", "Win32_System_Performance", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_SystemInformation", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_Time", "Win32_System_WindowsProgramming", "Win32_UI_Input_KeyboardAndMouse", "Win32_UI_Shell", "Win32_UI_WindowsAndMessaging"] } windows-sys-b21d60becc0929df = { package = "windows-sys", version = "0.52", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Wdk_System_IO", "Win32_Foundation", "Win32_Networking_WinSock", "Win32_Security_Authorization", "Win32_Storage_FileSystem", "Win32_System_Console", "Win32_System_IO", "Win32_System_Memory", "Win32_System_Pipes", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_WindowsProgramming"] } -windows-sys-c8eced492e86ede7 = { package = "windows-sys", version = "0.48", features = ["Win32_Foundation", "Win32_Globalization", "Win32_Networking_WinSock", "Win32_Security", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_Threading", "Win32_System_Time", "Win32_UI_Shell"] } +windows-sys-c8eced492e86ede7 = { package = "windows-sys", version = "0.48", features = ["Win32_Foundation", "Win32_Globalization", "Win32_Networking_WinSock", "Win32_Security", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_Threading", "Win32_System_Time", "Win32_System_WindowsProgramming", "Win32_UI_Shell"] } [target.x86_64-unknown-linux-musl.dependencies] aes = { version = "0.8", default-features = false, features = ["zeroize"] } @@ -629,7 +623,8 @@ linux-raw-sys-274715c4dabd11b0 = { package = "linux-raw-sys", version = "0.9", d linux-raw-sys-9fbad63c4bcf4a8f = { package = "linux-raw-sys", version = "0.4", default-features = false, features = ["elf", "errno", "general", "if_ether", "ioctl", "net", "netlink", "no_std", "prctl", "system", "xdp"] } mio = { version = "1", features = ["net", "os-ext"] } naga = { version = "25", features = ["spv-out", "wgsl-in"] } -nix = { version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } +nix-1f5adca70f036a62 = { package = "nix", version = "0.28", features = ["fs", "mman", "ptrace", "signal", "term", "user"] } +nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } num-bigint-dig = { version = "0.8", features = ["i128", "prime", "zeroize"] } object = { version = "0.36", default-features = false, features = ["archive", "read_core", "unaligned", "write"] } proc-macro2 = { version = "1", features = ["span-locations"] } @@ -641,7 +636,7 @@ rustix-dff4ba8e3ae991db = { package = "rustix", version = "1", features = ["fs", scopeguard = { version = "1" } syn-f595c2ba2a3f28df = { package = "syn", version = "2", features = ["extra-traits", "fold", "full", "visit", "visit-mut"] } sync_wrapper = { version = "1", default-features = false, features = ["futures"] } -tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } +tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } toml_datetime = { version = "0.6", default-features = false, features = ["serde"] } @@ -669,7 +664,8 @@ linux-raw-sys-274715c4dabd11b0 = { package = "linux-raw-sys", version = "0.9", d linux-raw-sys-9fbad63c4bcf4a8f = { package = "linux-raw-sys", version = "0.4", default-features = false, features = ["elf", "errno", "general", "if_ether", "ioctl", "net", "netlink", "no_std", "prctl", "system", "xdp"] } mio = { version = "1", features = ["net", "os-ext"] } naga = { version = "25", features = ["spv-out", "wgsl-in"] } -nix = { version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } +nix-1f5adca70f036a62 = { package = "nix", version = "0.28", features = ["fs", "mman", "ptrace", "signal", "term", "user"] } +nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } num-bigint-dig = { version = "0.8", features = ["i128", "prime", "zeroize"] } object = { version = "0.36", default-features = false, features = ["archive", "read_core", "unaligned", "write"] } proc-macro2 = { version = "1", default-features = false, features = ["span-locations"] } @@ -679,7 +675,7 @@ rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", features = ["ev rustix-dff4ba8e3ae991db = { package = "rustix", version = "1", features = ["fs", "net", "process", "termios", "time"] } scopeguard = { version = "1" } sync_wrapper = { version = "1", default-features = false, features = ["futures"] } -tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } +tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } toml_datetime = { version = "0.6", default-features = false, features = ["serde"] } diff --git a/typos.toml b/typos.toml index 7f1c6e04f12867f4b3b88d64d6b2cad06dd9d509..336a829a44e6ff7a36e7f8f27f8a5ddc6f3a3f87 100644 --- a/typos.toml +++ b/typos.toml @@ -71,6 +71,10 @@ extend-ignore-re = [ # Not an actual typo but an intentionally invalid color, in `color_extractor` "#fof", # Stripped version of reserved keyword `type` - "typ" + "typ", + # AMD GPU Services + "ags", + # AMD GPU Services + "AGS" ] check-filename = true