diff --git a/.config/nextest.toml b/.config/nextest.toml index ab03abd839600e1a84ebd5eea9709f60cea1c7f0..b18a3f31e4a75af0636b4d8d8fdd81f48d8d93e6 100644 --- a/.config/nextest.toml +++ b/.config/nextest.toml @@ -42,3 +42,7 @@ slow-timeout = { period = "300s", terminate-after = 1 } [[profile.default.overrides]] filter = 'package(editor) and test(test_random_split_editor)' slow-timeout = { period = "300s", terminate-after = 1 } + +[[profile.default.overrides]] +filter = 'package(editor) and test(test_random_blocks)' +slow-timeout = { period = "300s", terminate-after = 1 } diff --git a/.github/ISSUE_TEMPLATE/10_bug_report.yml b/.github/ISSUE_TEMPLATE/10_bug_report.yml index 13e43219dd65a78af4afec479330bbc5fd85fe42..5eb8e8a6299c5189384b6d060e12cd61a2249a3c 100644 --- a/.github/ISSUE_TEMPLATE/10_bug_report.yml +++ b/.github/ISSUE_TEMPLATE/10_bug_report.yml @@ -100,7 +100,7 @@ body: label: (for AI issues) Model provider details placeholder: | - Provider: (Anthropic via ZedPro, Anthropic via API key, Copilot Chat, Mistral, OpenAI, etc.) - - Model Name: (Claude Sonnet 4.5, Gemini 3 Pro, GPT-5) + - Model Name: (Claude Sonnet 4.5, Gemini 3.1 Pro, GPT-5) - Mode: (Agent Panel, Inline Assistant, Terminal Assistant or Text Threads) - Other details (ACPs, MCPs, other settings, etc.): validations: diff --git a/.github/workflows/add_commented_closed_issue_to_project.yml b/.github/workflows/add_commented_closed_issue_to_project.yml index 5871f5ae0e61f97557ce926c4a2627841f50560d..bd84eaa9446e57c5482ab818df3dbcfe587e040e 100644 --- a/.github/workflows/add_commented_closed_issue_to_project.yml +++ b/.github/workflows/add_commented_closed_issue_to_project.yml @@ -63,13 +63,18 @@ jobs: } - if: steps.is-post-close-comment.outputs.result == 'true' && steps.check-staff.outputs.result == 'true' + env: + ISSUE_NUMBER: ${{ github.event.issue.number }} run: | - echo "::notice::Skipping issue #${{ github.event.issue.number }} - commenter is staff member" + echo "::notice::Skipping issue #$ISSUE_NUMBER - commenter is staff member" # github-script outputs are JSON strings, so we compare against 'false' (string) - if: steps.is-post-close-comment.outputs.result == 'true' && steps.check-staff.outputs.result == 'false' + env: + ISSUE_NUMBER: ${{ github.event.issue.number }} + COMMENT_USER_LOGIN: ${{ github.event.comment.user.login }} run: | - echo "::notice::Adding issue #${{ github.event.issue.number }} to project (comment by ${{ github.event.comment.user.login }})" + echo "::notice::Adding issue #$ISSUE_NUMBER to project (comment by $COMMENT_USER_LOGIN)" - if: steps.is-post-close-comment.outputs.result == 'true' && steps.check-staff.outputs.result == 'false' uses: actions/add-to-project@244f685bbc3b7adfa8466e08b698b5577571133e # v1.0.2 diff --git a/.github/workflows/after_release.yml b/.github/workflows/after_release.yml index 9582e3f1956b3ecda383fc03efdb3d7ff67eaa68..95229f9f46bbd34ffe02832114b2b39da1b7e090 100644 --- a/.github/workflows/after_release.yml +++ b/.github/workflows/after_release.yml @@ -76,7 +76,7 @@ jobs: "X-GitHub-Api-Version" = "2022-11-28" } $body = @{ branch = "master" } | ConvertTo-Json - $uri = "https://api.github.com/repos/${{ github.repository_owner }}/winget-pkgs/merge-upstream" + $uri = "https://api.github.com/repos/$env:GITHUB_REPOSITORY_OWNER/winget-pkgs/merge-upstream" try { Invoke-RestMethod -Uri $uri -Method Post -Headers $headers -Body $body -ContentType "application/json" Write-Host "Successfully synced winget-pkgs fork" @@ -131,11 +131,10 @@ jobs: runs-on: namespace-profile-2x4-ubuntu-2404 steps: - name: release::send_slack_message - run: | - curl -X POST -H 'Content-type: application/json'\ - --data '{"text":"❌ ${{ github.workflow }} failed: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' "$SLACK_WEBHOOK" + run: 'curl -X POST -H ''Content-type: application/json'' --data "$(jq -n --arg text "$SLACK_MESSAGE" ''{"text": $text}'')" "$SLACK_WEBHOOK"' env: SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK_WORKFLOW_FAILURES }} + SLACK_MESSAGE: '❌ ${{ github.workflow }} failed: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}' defaults: run: shell: bash -euxo pipefail {0} diff --git a/.github/workflows/autofix_pr.yml b/.github/workflows/autofix_pr.yml index 60cc66294af2cf65e17aaad530a9df511ec61503..1fa271d168a8c3d1744439647ff50b793a854d1d 100644 --- a/.github/workflows/autofix_pr.yml +++ b/.github/workflows/autofix_pr.yml @@ -22,8 +22,9 @@ jobs: with: clean: false - name: autofix_pr::run_autofix::checkout_pr - run: gh pr checkout ${{ inputs.pr_number }} + run: gh pr checkout "$PR_NUMBER" env: + PR_NUMBER: ${{ inputs.pr_number }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: steps::setup_cargo_config run: | @@ -104,8 +105,9 @@ jobs: clean: false token: ${{ steps.get-app-token.outputs.token }} - name: autofix_pr::commit_changes::checkout_pr - run: gh pr checkout ${{ inputs.pr_number }} + run: gh pr checkout "$PR_NUMBER" env: + PR_NUMBER: ${{ inputs.pr_number }} GITHUB_TOKEN: ${{ steps.get-app-token.outputs.token }} - name: autofix_pr::download_patch_artifact uses: actions/download-artifact@018cc2cf5baa6db3ef3c5f8a56943fffe632ef53 diff --git a/.github/workflows/catch_blank_issues.yml b/.github/workflows/catch_blank_issues.yml index dd425afc886e86c1217a94e90eabced013f66bf0..c6f595ef2e0890ce107829f3e91490332567368a 100644 --- a/.github/workflows/catch_blank_issues.yml +++ b/.github/workflows/catch_blank_issues.yml @@ -42,8 +42,10 @@ jobs: } - if: steps.check-staff.outputs.result == 'true' + env: + ISSUE_NUMBER: ${{ github.event.issue.number }} run: | - echo "::notice::Skipping issue #${{ github.event.issue.number }} - actor is staff member" + echo "::notice::Skipping issue #$ISSUE_NUMBER - actor is staff member" - if: steps.check-staff.outputs.result == 'false' id: add-label diff --git a/.github/workflows/cherry_pick.yml b/.github/workflows/cherry_pick.yml index 9d46f300b509347b2853c00575c4e82fd9a2863c..ee0c1d35d0f9825d7c39b81fba0fe35901de2611 100644 --- a/.github/workflows/cherry_pick.yml +++ b/.github/workflows/cherry_pick.yml @@ -36,8 +36,11 @@ jobs: app-id: ${{ secrets.ZED_ZIPPY_APP_ID }} private-key: ${{ secrets.ZED_ZIPPY_APP_PRIVATE_KEY }} - name: cherry_pick::run_cherry_pick::cherry_pick - run: ./script/cherry-pick ${{ inputs.branch }} ${{ inputs.commit }} ${{ inputs.channel }} + run: ./script/cherry-pick "$BRANCH" "$COMMIT" "$CHANNEL" env: + BRANCH: ${{ inputs.branch }} + COMMIT: ${{ inputs.commit }} + CHANNEL: ${{ inputs.channel }} GIT_COMMITTER_NAME: Zed Zippy GIT_COMMITTER_EMAIL: hi@zed.dev GITHUB_TOKEN: ${{ steps.get-app-token.outputs.token }} diff --git a/.github/workflows/community_update_all_top_ranking_issues.yml b/.github/workflows/community_update_all_top_ranking_issues.yml index 59926f35563a4b21e3486ecbd454a4ccf951461e..ef3b4fc39ddb5f0db9b09c5e861547ae8cd7eb08 100644 --- a/.github/workflows/community_update_all_top_ranking_issues.yml +++ b/.github/workflows/community_update_all_top_ranking_issues.yml @@ -22,4 +22,6 @@ jobs: - name: Install dependencies run: uv sync --project script/update_top_ranking_issues -p 3.13 - name: Run script - run: uv run --project script/update_top_ranking_issues script/update_top_ranking_issues/main.py --github-token ${{ secrets.GITHUB_TOKEN }} --issue-reference-number 5393 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: uv run --project script/update_top_ranking_issues script/update_top_ranking_issues/main.py --github-token "$GITHUB_TOKEN" --issue-reference-number 5393 diff --git a/.github/workflows/community_update_weekly_top_ranking_issues.yml b/.github/workflows/community_update_weekly_top_ranking_issues.yml index 75ba66b934b5861bd51aef4238a1a4188dddefc3..53b548f2bb4286e5de86d3823e67d75c0413a1cb 100644 --- a/.github/workflows/community_update_weekly_top_ranking_issues.yml +++ b/.github/workflows/community_update_weekly_top_ranking_issues.yml @@ -22,4 +22,6 @@ jobs: - name: Install dependencies run: uv sync --project script/update_top_ranking_issues -p 3.13 - name: Run script - run: uv run --project script/update_top_ranking_issues script/update_top_ranking_issues/main.py --github-token ${{ secrets.GITHUB_TOKEN }} --issue-reference-number 6952 --query-day-interval 7 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: uv run --project script/update_top_ranking_issues script/update_top_ranking_issues/main.py --github-token "$GITHUB_TOKEN" --issue-reference-number 6952 --query-day-interval 7 diff --git a/.github/workflows/compare_perf.yml b/.github/workflows/compare_perf.yml index e5a2d4f9c928eac2d1b1cf54ed374f8b0cca5d25..f7d78dbbf6a6d04bc47212b6842f894850288fcc 100644 --- a/.github/workflows/compare_perf.yml +++ b/.github/workflows/compare_perf.yml @@ -37,27 +37,40 @@ jobs: - name: compare_perf::run_perf::install_hyperfine uses: taiki-e/install-action@hyperfine - name: steps::git_checkout - run: git fetch origin ${{ inputs.base }} && git checkout ${{ inputs.base }} + run: git fetch origin "$REF_NAME" && git checkout "$REF_NAME" + env: + REF_NAME: ${{ inputs.base }} - name: compare_perf::run_perf::cargo_perf_test run: |2- - if [ -n "${{ inputs.crate_name }}" ]; then - cargo perf-test -p ${{ inputs.crate_name }} -- --json=${{ inputs.base }}; + if [ -n "$CRATE_NAME" ]; then + cargo perf-test -p "$CRATE_NAME" -- --json="$REF_NAME"; else - cargo perf-test -p vim -- --json=${{ inputs.base }}; + cargo perf-test -p vim -- --json="$REF_NAME"; fi + env: + REF_NAME: ${{ inputs.base }} + CRATE_NAME: ${{ inputs.crate_name }} - name: steps::git_checkout - run: git fetch origin ${{ inputs.head }} && git checkout ${{ inputs.head }} + run: git fetch origin "$REF_NAME" && git checkout "$REF_NAME" + env: + REF_NAME: ${{ inputs.head }} - name: compare_perf::run_perf::cargo_perf_test run: |2- - if [ -n "${{ inputs.crate_name }}" ]; then - cargo perf-test -p ${{ inputs.crate_name }} -- --json=${{ inputs.head }}; + if [ -n "$CRATE_NAME" ]; then + cargo perf-test -p "$CRATE_NAME" -- --json="$REF_NAME"; else - cargo perf-test -p vim -- --json=${{ inputs.head }}; + cargo perf-test -p vim -- --json="$REF_NAME"; fi + env: + REF_NAME: ${{ inputs.head }} + CRATE_NAME: ${{ inputs.crate_name }} - name: compare_perf::run_perf::compare_runs - run: cargo perf-compare --save=results.md ${{ inputs.base }} ${{ inputs.head }} + run: cargo perf-compare --save=results.md "$BASE" "$HEAD" + env: + BASE: ${{ inputs.base }} + HEAD: ${{ inputs.head }} - name: '@actions/upload-artifact results.md' uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 with: diff --git a/.github/workflows/deploy_cloudflare.yml b/.github/workflows/deploy_cloudflare.yml index cb0dfc2187a06cf62255b049b7e5fe74b10c505a..37f23b20d2825e9f3d26c456903962a10c2d0081 100644 --- a/.github/workflows/deploy_cloudflare.yml +++ b/.github/workflows/deploy_cloudflare.yml @@ -26,6 +26,7 @@ jobs: CC: clang CXX: clang++ DOCS_AMPLITUDE_API_KEY: ${{ secrets.DOCS_AMPLITUDE_API_KEY }} + DOCS_CONSENT_IO_INSTANCE: ${{ secrets.DOCS_CONSENT_IO_INSTANCE }} - name: Deploy Docs uses: cloudflare/wrangler-action@da0e0dfe58b7a431659754fdf3f186c529afbe65 # v3 diff --git a/.github/workflows/deploy_collab.yml b/.github/workflows/deploy_collab.yml index b1bdaf61979452a73380226ce1935b43eb05c32b..89fb6980b65f2d09a6571f140ab016a710be230f 100644 --- a/.github/workflows/deploy_collab.yml +++ b/.github/workflows/deploy_collab.yml @@ -119,8 +119,9 @@ jobs: with: token: ${{ secrets.DIGITALOCEAN_ACCESS_TOKEN }} - name: deploy_collab::deploy::sign_into_kubernetes - run: | - doctl kubernetes cluster kubeconfig save --expiry-seconds 600 ${{ secrets.CLUSTER_NAME }} + run: doctl kubernetes cluster kubeconfig save --expiry-seconds 600 "$CLUSTER_NAME" + env: + CLUSTER_NAME: ${{ secrets.CLUSTER_NAME }} - name: deploy_collab::deploy::start_rollout run: | set -eu @@ -140,7 +141,7 @@ jobs: echo "Deploying collab:$GITHUB_SHA to $ZED_KUBE_NAMESPACE" source script/lib/deploy-helpers.sh - export_vars_for_environment $ZED_KUBE_NAMESPACE + export_vars_for_environment "$ZED_KUBE_NAMESPACE" ZED_DO_CERTIFICATE_ID="$(doctl compute certificate list --format ID --no-header)" export ZED_DO_CERTIFICATE_ID @@ -150,14 +151,14 @@ jobs: export ZED_LOAD_BALANCER_SIZE_UNIT=$ZED_COLLAB_LOAD_BALANCER_SIZE_UNIT export DATABASE_MAX_CONNECTIONS=850 envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f - - kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch + kubectl -n "$ZED_KUBE_NAMESPACE" rollout status "deployment/$ZED_SERVICE_NAME" --watch echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}" export ZED_SERVICE_NAME=api export ZED_LOAD_BALANCER_SIZE_UNIT=$ZED_API_LOAD_BALANCER_SIZE_UNIT export DATABASE_MAX_CONNECTIONS=60 envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f - - kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch + kubectl -n "$ZED_KUBE_NAMESPACE" rollout status "deployment/$ZED_SERVICE_NAME" --watch echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}" defaults: run: diff --git a/.github/workflows/extension_bump.yml b/.github/workflows/extension_bump.yml index b7bb78363ce4ff97680b2a53967938280c3de902..9cc53741e8007a1b3ddd02ad07b191b3ce171cc8 100644 --- a/.github/workflows/extension_bump.yml +++ b/.github/workflows/extension_bump.yml @@ -39,7 +39,7 @@ jobs: run: | CURRENT_VERSION="$(sed -n 's/^version = \"\(.*\)\"/\1/p' < extension.toml | tr -d '[:space:]')" - if [[ "${{ github.event_name }}" == "pull_request" ]]; then + if [[ "$GITHUB_EVENT_NAME" == "pull_request" ]]; then PR_FORK_POINT="$(git merge-base origin/main HEAD)" git checkout "$PR_FORK_POINT" elif BRANCH_PARENT_SHA="$(git merge-base origin/main origin/zed-zippy-autobump)"; then @@ -82,8 +82,6 @@ jobs: - id: bump-version name: extension_bump::bump_version run: | - OLD_VERSION="${{ needs.check_version_changed.outputs.current_version }}" - BUMP_FILES=("extension.toml") if [[ -f "Cargo.toml" ]]; then BUMP_FILES+=("Cargo.toml") @@ -93,7 +91,7 @@ jobs: --search "version = \"{current_version}"\" \ --replace "version = \"{new_version}"\" \ --current-version "$OLD_VERSION" \ - --no-configured-files ${{ inputs.bump-type }} "${BUMP_FILES[@]}" + --no-configured-files "$BUMP_TYPE" "${BUMP_FILES[@]}" if [[ -f "Cargo.toml" ]]; then cargo update --workspace @@ -102,6 +100,9 @@ jobs: NEW_VERSION="$(sed -n 's/^version = \"\(.*\)\"/\1/p' < extension.toml | tr -d '[:space:]')" echo "new_version=${NEW_VERSION}" >> "$GITHUB_OUTPUT" + env: + OLD_VERSION: ${{ needs.check_version_changed.outputs.current_version }} + BUMP_TYPE: ${{ inputs.bump-type }} - name: extension_bump::create_pull_request uses: peter-evans/create-pull-request@v7 with: diff --git a/.github/workflows/extension_tests.yml b/.github/workflows/extension_tests.yml index 5160aba2869b1a3234c686a6508460784b0536b1..53de373c1b79dc3ca9a3637642e10998c781580a 100644 --- a/.github/workflows/extension_tests.yml +++ b/.github/workflows/extension_tests.yml @@ -32,7 +32,7 @@ jobs: git fetch origin "$GITHUB_BASE_REF" --depth=350 COMPARE_REV="$(git merge-base "origin/${GITHUB_BASE_REF}" HEAD)" fi - CHANGED_FILES="$(git diff --name-only "$COMPARE_REV" ${{ github.sha }})" + CHANGED_FILES="$(git diff --name-only "$COMPARE_REV" "$GITHUB_SHA")" check_pattern() { local output_name="$1" @@ -129,7 +129,7 @@ jobs: run: | CURRENT_VERSION="$(sed -n 's/^version = \"\(.*\)\"/\1/p' < extension.toml | tr -d '[:space:]')" - if [[ "${{ github.event_name }}" == "pull_request" ]]; then + if [[ "$GITHUB_EVENT_NAME" == "pull_request" ]]; then PR_FORK_POINT="$(git merge-base origin/main HEAD)" git checkout "$PR_FORK_POINT" elif BRANCH_PARENT_SHA="$(git merge-base origin/main origin/zed-zippy-autobump)"; then @@ -147,11 +147,14 @@ jobs: echo "current_version=${CURRENT_VERSION}" >> "$GITHUB_OUTPUT" - name: extension_tests::verify_version_did_not_change run: | - if [[ ${{ steps.compare-versions-check.outputs.version_changed }} == "true" && "${{ github.event_name }}" == "pull_request" && "${{ github.event.pull_request.user.login }}" != "zed-zippy[bot]" ]] ; then + if [[ "$VERSION_CHANGED" == "true" && "$GITHUB_EVENT_NAME" == "pull_request" && "$PR_USER_LOGIN" != "zed-zippy[bot]" ]] ; then echo "Version change detected in your change!" echo "Version changes happen in separate PRs and will be performed by the zed-zippy bot" exit 42 fi + env: + VERSION_CHANGED: ${{ steps.compare-versions-check.outputs.version_changed }} + PR_USER_LOGIN: ${{ github.event.pull_request.user.login }} timeout-minutes: 6 tests_pass: needs: @@ -171,11 +174,15 @@ jobs: if [[ "$2" != "skipped" && "$2" != "success" ]]; then EXIT_CODE=1; fi } - check_result "orchestrate" "${{ needs.orchestrate.result }}" - check_result "check_rust" "${{ needs.check_rust.result }}" - check_result "check_extension" "${{ needs.check_extension.result }}" + check_result "orchestrate" "$RESULT_ORCHESTRATE" + check_result "check_rust" "$RESULT_CHECK_RUST" + check_result "check_extension" "$RESULT_CHECK_EXTENSION" exit $EXIT_CODE + env: + RESULT_ORCHESTRATE: ${{ needs.orchestrate.result }} + RESULT_CHECK_RUST: ${{ needs.check_rust.result }} + RESULT_CHECK_EXTENSION: ${{ needs.check_extension.result }} concurrency: group: ${{ github.workflow }}-${{ github.ref_name }}-${{ github.ref_name == 'main' && github.sha || 'anysha' }} cancel-in-progress: true diff --git a/.github/workflows/extension_workflow_rollout.yml b/.github/workflows/extension_workflow_rollout.yml index 709956fc1bc0b25190638d9f1b5d4cd3cadd7ba2..9bfac06d4527985553ba3d04e64c656ee5bf85e4 100644 --- a/.github/workflows/extension_workflow_rollout.yml +++ b/.github/workflows/extension_workflow_rollout.yml @@ -80,9 +80,7 @@ jobs: - id: calc-changes name: extension_workflow_rollout::rollout_workflows_to_extension::get_removed_files run: | - PREV_COMMIT="${{ steps.prev-tag.outputs.prev_commit }}" - - if [ "${{ matrix.repo }}" = "workflows" ]; then + if [ "$MATRIX_REPO" = "workflows" ]; then WORKFLOW_DIR="extensions/workflows" else WORKFLOW_DIR="extensions/workflows/shared" @@ -101,11 +99,12 @@ jobs: echo "Files to remove: $REMOVED_FILES" echo "removed_files=$REMOVED_FILES" >> "$GITHUB_OUTPUT" + env: + PREV_COMMIT: ${{ steps.prev-tag.outputs.prev_commit }} + MATRIX_REPO: ${{ matrix.repo }} working-directory: zed - name: extension_workflow_rollout::rollout_workflows_to_extension::sync_workflow_files run: | - REMOVED_FILES="${{ steps.calc-changes.outputs.removed_files }}" - mkdir -p extension/.github/workflows cd extension/.github/workflows @@ -119,11 +118,14 @@ jobs: cd - > /dev/null - if [ "${{ matrix.repo }}" = "workflows" ]; then + if [ "$MATRIX_REPO" = "workflows" ]; then cp zed/extensions/workflows/*.yml extension/.github/workflows/ else cp zed/extensions/workflows/shared/*.yml extension/.github/workflows/ fi + env: + REMOVED_FILES: ${{ steps.calc-changes.outputs.removed_files }} + MATRIX_REPO: ${{ matrix.repo }} - id: short-sha name: extension_workflow_rollout::rollout_workflows_to_extension::get_short_sha run: | @@ -148,13 +150,13 @@ jobs: sign-commits: true - name: extension_workflow_rollout::rollout_workflows_to_extension::enable_auto_merge run: | - PR_NUMBER="${{ steps.create-pr.outputs.pull-request-number }}" if [ -n "$PR_NUMBER" ]; then cd extension gh pr merge "$PR_NUMBER" --auto --squash fi env: GH_TOKEN: ${{ steps.generate-token.outputs.token }} + PR_NUMBER: ${{ steps.create-pr.outputs.pull-request-number }} timeout-minutes: 10 create_rollout_tag: needs: diff --git a/.github/workflows/publish_extension_cli.yml b/.github/workflows/publish_extension_cli.yml index 391baac1cb3aa9da76c4fde39aa6909525541a58..75f1b16b007e33d0c4f346a33a1403648f1cd6c6 100644 --- a/.github/workflows/publish_extension_cli.yml +++ b/.github/workflows/publish_extension_cli.yml @@ -27,7 +27,7 @@ jobs: - name: publish_extension_cli::publish_job::build_extension_cli run: cargo build --release --package extension_cli - name: publish_extension_cli::publish_job::upload_binary - run: script/upload-extension-cli ${{ github.sha }} + run: script/upload-extension-cli "$GITHUB_SHA" env: DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} @@ -55,10 +55,10 @@ jobs: - id: short-sha name: publish_extension_cli::get_short_sha run: | - echo "sha_short=$(echo "${{ github.sha }}" | cut -c1-7)" >> "$GITHUB_OUTPUT" + echo "sha_short=$(echo "$GITHUB_SHA" | cut -c1-7)" >> "$GITHUB_OUTPUT" - name: publish_extension_cli::update_sha_in_zed::replace_sha run: | - sed -i "s/ZED_EXTENSION_CLI_SHA: &str = \"[a-f0-9]*\"/ZED_EXTENSION_CLI_SHA: \&str = \"${{ github.sha }}\"/" \ + sed -i "s/ZED_EXTENSION_CLI_SHA: &str = \"[a-f0-9]*\"/ZED_EXTENSION_CLI_SHA: \&str = \"$GITHUB_SHA\"/" \ tooling/xtask/src/tasks/workflows/extension_tests.rs - name: publish_extension_cli::update_sha_in_zed::regenerate_workflows run: cargo xtask workflows @@ -97,7 +97,7 @@ jobs: - id: short-sha name: publish_extension_cli::get_short_sha run: | - echo "sha_short=$(echo "${{ github.sha }}" | cut -c1-7)" >> "$GITHUB_OUTPUT" + echo "sha_short=$(echo "$GITHUB_SHA" | cut -c1-7)" >> "$GITHUB_OUTPUT" - name: publish_extension_cli::update_sha_in_extensions::checkout_extensions_repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 with: @@ -105,7 +105,7 @@ jobs: token: ${{ steps.generate-token.outputs.token }} - name: publish_extension_cli::update_sha_in_extensions::replace_sha run: | - sed -i "s/ZED_EXTENSION_CLI_SHA: [a-f0-9]*/ZED_EXTENSION_CLI_SHA: ${{ github.sha }}/" \ + sed -i "s/ZED_EXTENSION_CLI_SHA: [a-f0-9]*/ZED_EXTENSION_CLI_SHA: $GITHUB_SHA/" \ .github/workflows/ci.yml - name: publish_extension_cli::create_pull_request_extensions uses: peter-evans/create-pull-request@v7 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 8ac5eeb998f5102d5af9b2775a82093b6ea29858..8adad5cfba278dc68dd227b86455510278c7a1ae 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -257,8 +257,14 @@ jobs: name: run_tests::check_scripts::download_actionlint run: bash <(curl https://raw.githubusercontent.com/rhysd/actionlint/main/scripts/download-actionlint.bash) - name: run_tests::check_scripts::run_actionlint - run: | - ${{ steps.get_actionlint.outputs.executable }} -color + run: '"$ACTIONLINT_BIN" -color' + env: + ACTIONLINT_BIN: ${{ steps.get_actionlint.outputs.executable }} + - name: steps::cache_rust_dependencies_namespace + uses: namespacelabs/nscloud-cache-action@v1 + with: + cache: rust + path: ~/.rustup - name: run_tests::check_scripts::check_xtask_workflows run: | cargo xtask workflows @@ -654,12 +660,7 @@ jobs: - id: generate-webhook-message name: release::generate_slack_message run: | - MESSAGE=$(DRAFT_RESULT="${{ needs.create_draft_release.result }}" - UPLOAD_RESULT="${{ needs.upload_release_assets.result }}" - VALIDATE_RESULT="${{ needs.validate_release_assets.result }}" - AUTO_RELEASE_RESULT="${{ needs.auto_release_preview.result }}" - TAG="$GITHUB_REF_NAME" - RUN_URL="${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}" + MESSAGE=$(TAG="$GITHUB_REF_NAME" if [ "$DRAFT_RESULT" == "failure" ]; then echo "❌ Draft release creation failed for $TAG: $RUN_URL" @@ -669,19 +670,19 @@ jobs: echo "❌ Release asset upload failed for $TAG: $RELEASE_URL" elif [ "$UPLOAD_RESULT" == "cancelled" ] || [ "$UPLOAD_RESULT" == "skipped" ]; then FAILED_JOBS="" - if [ "${{ needs.run_tests_mac.result }}" == "failure" ];then FAILED_JOBS="$FAILED_JOBS run_tests_mac"; fi - if [ "${{ needs.run_tests_linux.result }}" == "failure" ];then FAILED_JOBS="$FAILED_JOBS run_tests_linux"; fi - if [ "${{ needs.run_tests_windows.result }}" == "failure" ];then FAILED_JOBS="$FAILED_JOBS run_tests_windows"; fi - if [ "${{ needs.clippy_mac.result }}" == "failure" ];then FAILED_JOBS="$FAILED_JOBS clippy_mac"; fi - if [ "${{ needs.clippy_linux.result }}" == "failure" ];then FAILED_JOBS="$FAILED_JOBS clippy_linux"; fi - if [ "${{ needs.clippy_windows.result }}" == "failure" ];then FAILED_JOBS="$FAILED_JOBS clippy_windows"; fi - if [ "${{ needs.check_scripts.result }}" == "failure" ];then FAILED_JOBS="$FAILED_JOBS check_scripts"; fi - if [ "${{ needs.bundle_linux_aarch64.result }}" == "failure" ];then FAILED_JOBS="$FAILED_JOBS bundle_linux_aarch64"; fi - if [ "${{ needs.bundle_linux_x86_64.result }}" == "failure" ];then FAILED_JOBS="$FAILED_JOBS bundle_linux_x86_64"; fi - if [ "${{ needs.bundle_mac_aarch64.result }}" == "failure" ];then FAILED_JOBS="$FAILED_JOBS bundle_mac_aarch64"; fi - if [ "${{ needs.bundle_mac_x86_64.result }}" == "failure" ];then FAILED_JOBS="$FAILED_JOBS bundle_mac_x86_64"; fi - if [ "${{ needs.bundle_windows_aarch64.result }}" == "failure" ];then FAILED_JOBS="$FAILED_JOBS bundle_windows_aarch64"; fi - if [ "${{ needs.bundle_windows_x86_64.result }}" == "failure" ];then FAILED_JOBS="$FAILED_JOBS bundle_windows_x86_64"; fi + if [ "$RESULT_RUN_TESTS_MAC" == "failure" ];then FAILED_JOBS="$FAILED_JOBS run_tests_mac"; fi + if [ "$RESULT_RUN_TESTS_LINUX" == "failure" ];then FAILED_JOBS="$FAILED_JOBS run_tests_linux"; fi + if [ "$RESULT_RUN_TESTS_WINDOWS" == "failure" ];then FAILED_JOBS="$FAILED_JOBS run_tests_windows"; fi + if [ "$RESULT_CLIPPY_MAC" == "failure" ];then FAILED_JOBS="$FAILED_JOBS clippy_mac"; fi + if [ "$RESULT_CLIPPY_LINUX" == "failure" ];then FAILED_JOBS="$FAILED_JOBS clippy_linux"; fi + if [ "$RESULT_CLIPPY_WINDOWS" == "failure" ];then FAILED_JOBS="$FAILED_JOBS clippy_windows"; fi + if [ "$RESULT_CHECK_SCRIPTS" == "failure" ];then FAILED_JOBS="$FAILED_JOBS check_scripts"; fi + if [ "$RESULT_BUNDLE_LINUX_AARCH64" == "failure" ];then FAILED_JOBS="$FAILED_JOBS bundle_linux_aarch64"; fi + if [ "$RESULT_BUNDLE_LINUX_X86_64" == "failure" ];then FAILED_JOBS="$FAILED_JOBS bundle_linux_x86_64"; fi + if [ "$RESULT_BUNDLE_MAC_AARCH64" == "failure" ];then FAILED_JOBS="$FAILED_JOBS bundle_mac_aarch64"; fi + if [ "$RESULT_BUNDLE_MAC_X86_64" == "failure" ];then FAILED_JOBS="$FAILED_JOBS bundle_mac_x86_64"; fi + if [ "$RESULT_BUNDLE_WINDOWS_AARCH64" == "failure" ];then FAILED_JOBS="$FAILED_JOBS bundle_windows_aarch64"; fi + if [ "$RESULT_BUNDLE_WINDOWS_X86_64" == "failure" ];then FAILED_JOBS="$FAILED_JOBS bundle_windows_x86_64"; fi FAILED_JOBS=$(echo "$FAILED_JOBS" | xargs) if [ "$UPLOAD_RESULT" == "cancelled" ]; then if [ -n "$FAILED_JOBS" ]; then @@ -710,12 +711,29 @@ jobs: echo "message=$MESSAGE" >> "$GITHUB_OUTPUT" env: GH_TOKEN: ${{ github.token }} + DRAFT_RESULT: ${{ needs.create_draft_release.result }} + UPLOAD_RESULT: ${{ needs.upload_release_assets.result }} + VALIDATE_RESULT: ${{ needs.validate_release_assets.result }} + AUTO_RELEASE_RESULT: ${{ needs.auto_release_preview.result }} + RUN_URL: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} + RESULT_RUN_TESTS_MAC: ${{ needs.run_tests_mac.result }} + RESULT_RUN_TESTS_LINUX: ${{ needs.run_tests_linux.result }} + RESULT_RUN_TESTS_WINDOWS: ${{ needs.run_tests_windows.result }} + RESULT_CLIPPY_MAC: ${{ needs.clippy_mac.result }} + RESULT_CLIPPY_LINUX: ${{ needs.clippy_linux.result }} + RESULT_CLIPPY_WINDOWS: ${{ needs.clippy_windows.result }} + RESULT_CHECK_SCRIPTS: ${{ needs.check_scripts.result }} + RESULT_BUNDLE_LINUX_AARCH64: ${{ needs.bundle_linux_aarch64.result }} + RESULT_BUNDLE_LINUX_X86_64: ${{ needs.bundle_linux_x86_64.result }} + RESULT_BUNDLE_MAC_AARCH64: ${{ needs.bundle_mac_aarch64.result }} + RESULT_BUNDLE_MAC_X86_64: ${{ needs.bundle_mac_x86_64.result }} + RESULT_BUNDLE_WINDOWS_AARCH64: ${{ needs.bundle_windows_aarch64.result }} + RESULT_BUNDLE_WINDOWS_X86_64: ${{ needs.bundle_windows_x86_64.result }} - name: release::send_slack_message - run: | - curl -X POST -H 'Content-type: application/json'\ - --data '{"text":"${{ steps.generate-webhook-message.outputs.message }}"}' "$SLACK_WEBHOOK" + run: 'curl -X POST -H ''Content-type: application/json'' --data "$(jq -n --arg text "$SLACK_MESSAGE" ''{"text": $text}'')" "$SLACK_WEBHOOK"' env: SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK_WORKFLOW_FAILURES }} + SLACK_MESSAGE: ${{ steps.generate-webhook-message.outputs.message }} concurrency: group: ${{ github.workflow }}-${{ github.ref_name }}-${{ github.ref_name == 'main' && github.sha || 'anysha' }} cancel-in-progress: true diff --git a/.github/workflows/release_nightly.yml b/.github/workflows/release_nightly.yml index 7f243411b4f540d6c7bc611df4883f5341d6a83b..46d8732b08ea658275e1fb21117a09b9e0668933 100644 --- a/.github/workflows/release_nightly.yml +++ b/.github/workflows/release_nightly.yml @@ -554,11 +554,10 @@ jobs: runs-on: namespace-profile-2x4-ubuntu-2404 steps: - name: release::send_slack_message - run: | - curl -X POST -H 'Content-type: application/json'\ - --data '{"text":"❌ ${{ github.workflow }} failed: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' "$SLACK_WEBHOOK" + run: 'curl -X POST -H ''Content-type: application/json'' --data "$(jq -n --arg text "$SLACK_MESSAGE" ''{"text": $text}'')" "$SLACK_WEBHOOK"' env: SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK_WORKFLOW_FAILURES }} + SLACK_MESSAGE: '❌ ${{ github.workflow }} failed: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}' defaults: run: shell: bash -euxo pipefail {0} diff --git a/.github/workflows/run_cron_unit_evals.yml b/.github/workflows/run_cron_unit_evals.yml index e57b54e4f2249b92630b2d3636ce2316a0814625..2a204a9d40d78bf52f38825b4db060216e348a87 100644 --- a/.github/workflows/run_cron_unit_evals.yml +++ b/.github/workflows/run_cron_unit_evals.yml @@ -16,7 +16,7 @@ jobs: model: - anthropic/claude-sonnet-4-5-latest - anthropic/claude-opus-4-5-latest - - google/gemini-3-pro + - google/gemini-3.1-pro - openai/gpt-5 fail-fast: false steps: diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 29f888cbb596593052c6adebe2341171eac9055d..00d69639a53868386157e67aeab5ce7383d32426 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -35,7 +35,7 @@ jobs: git fetch origin "$GITHUB_BASE_REF" --depth=350 COMPARE_REV="$(git merge-base "origin/${GITHUB_BASE_REF}" HEAD)" fi - CHANGED_FILES="$(git diff --name-only "$COMPARE_REV" ${{ github.sha }})" + CHANGED_FILES="$(git diff --name-only "$COMPARE_REV" "$GITHUB_SHA")" check_pattern() { local output_name="$1" @@ -653,8 +653,14 @@ jobs: name: run_tests::check_scripts::download_actionlint run: bash <(curl https://raw.githubusercontent.com/rhysd/actionlint/main/scripts/download-actionlint.bash) - name: run_tests::check_scripts::run_actionlint - run: | - ${{ steps.get_actionlint.outputs.executable }} -color + run: '"$ACTIONLINT_BIN" -color' + env: + ACTIONLINT_BIN: ${{ steps.get_actionlint.outputs.executable }} + - name: steps::cache_rust_dependencies_namespace + uses: namespacelabs/nscloud-cache-action@v1 + with: + cache: rust + path: ~/.rustup - name: run_tests::check_scripts::check_xtask_workflows run: | cargo xtask workflows @@ -735,23 +741,39 @@ jobs: if [[ "$2" != "skipped" && "$2" != "success" ]]; then EXIT_CODE=1; fi } - check_result "orchestrate" "${{ needs.orchestrate.result }}" - check_result "check_style" "${{ needs.check_style.result }}" - check_result "clippy_windows" "${{ needs.clippy_windows.result }}" - check_result "clippy_linux" "${{ needs.clippy_linux.result }}" - check_result "clippy_mac" "${{ needs.clippy_mac.result }}" - check_result "run_tests_windows" "${{ needs.run_tests_windows.result }}" - check_result "run_tests_linux" "${{ needs.run_tests_linux.result }}" - check_result "run_tests_mac" "${{ needs.run_tests_mac.result }}" - check_result "doctests" "${{ needs.doctests.result }}" - check_result "check_workspace_binaries" "${{ needs.check_workspace_binaries.result }}" - check_result "check_wasm" "${{ needs.check_wasm.result }}" - check_result "check_dependencies" "${{ needs.check_dependencies.result }}" - check_result "check_docs" "${{ needs.check_docs.result }}" - check_result "check_licenses" "${{ needs.check_licenses.result }}" - check_result "check_scripts" "${{ needs.check_scripts.result }}" + check_result "orchestrate" "$RESULT_ORCHESTRATE" + check_result "check_style" "$RESULT_CHECK_STYLE" + check_result "clippy_windows" "$RESULT_CLIPPY_WINDOWS" + check_result "clippy_linux" "$RESULT_CLIPPY_LINUX" + check_result "clippy_mac" "$RESULT_CLIPPY_MAC" + check_result "run_tests_windows" "$RESULT_RUN_TESTS_WINDOWS" + check_result "run_tests_linux" "$RESULT_RUN_TESTS_LINUX" + check_result "run_tests_mac" "$RESULT_RUN_TESTS_MAC" + check_result "doctests" "$RESULT_DOCTESTS" + check_result "check_workspace_binaries" "$RESULT_CHECK_WORKSPACE_BINARIES" + check_result "check_wasm" "$RESULT_CHECK_WASM" + check_result "check_dependencies" "$RESULT_CHECK_DEPENDENCIES" + check_result "check_docs" "$RESULT_CHECK_DOCS" + check_result "check_licenses" "$RESULT_CHECK_LICENSES" + check_result "check_scripts" "$RESULT_CHECK_SCRIPTS" exit $EXIT_CODE + env: + RESULT_ORCHESTRATE: ${{ needs.orchestrate.result }} + RESULT_CHECK_STYLE: ${{ needs.check_style.result }} + RESULT_CLIPPY_WINDOWS: ${{ needs.clippy_windows.result }} + RESULT_CLIPPY_LINUX: ${{ needs.clippy_linux.result }} + RESULT_CLIPPY_MAC: ${{ needs.clippy_mac.result }} + RESULT_RUN_TESTS_WINDOWS: ${{ needs.run_tests_windows.result }} + RESULT_RUN_TESTS_LINUX: ${{ needs.run_tests_linux.result }} + RESULT_RUN_TESTS_MAC: ${{ needs.run_tests_mac.result }} + RESULT_DOCTESTS: ${{ needs.doctests.result }} + RESULT_CHECK_WORKSPACE_BINARIES: ${{ needs.check_workspace_binaries.result }} + RESULT_CHECK_WASM: ${{ needs.check_wasm.result }} + RESULT_CHECK_DEPENDENCIES: ${{ needs.check_dependencies.result }} + RESULT_CHECK_DOCS: ${{ needs.check_docs.result }} + RESULT_CHECK_LICENSES: ${{ needs.check_licenses.result }} + RESULT_CHECK_SCRIPTS: ${{ needs.check_scripts.result }} concurrency: group: ${{ github.workflow }}-${{ github.ref_name }}-${{ github.ref_name == 'main' && github.sha || 'anysha' }} cancel-in-progress: true diff --git a/.github/workflows/slack_notify_first_responders.yml b/.github/workflows/slack_notify_first_responders.yml index a6f2d557a574778aea6c2a90f9721b5a41bd0724..538d02b582f18db627693b62e439f4142ea29056 100644 --- a/.github/workflows/slack_notify_first_responders.yml +++ b/.github/workflows/slack_notify_first_responders.yml @@ -17,8 +17,9 @@ jobs: id: check-label env: LABEL_NAME: ${{ github.event.label.name }} + FIRST_RESPONDER_LABELS: ${{ env.FIRST_RESPONDER_LABELS }} run: | - if echo '${{ env.FIRST_RESPONDER_LABELS }}' | jq -e --arg label "$LABEL_NAME" 'index($label) != null' > /dev/null; then + if echo "$FIRST_RESPONDER_LABELS" | jq -e --arg label "$LABEL_NAME" 'index($label) != null' > /dev/null; then echo "should_notify=true" >> "$GITHUB_OUTPUT" echo "Label '$LABEL_NAME' requires first responder notification" else diff --git a/.github/workflows/update_duplicate_magnets.yml b/.github/workflows/update_duplicate_magnets.yml index 1c6c5a562532891eb97ceb11f44b81f35612c026..c3832b7bdbec13f74a8136cb1120a682f6e53920 100644 --- a/.github/workflows/update_duplicate_magnets.yml +++ b/.github/workflows/update_duplicate_magnets.yml @@ -21,7 +21,9 @@ jobs: run: pip install requests - name: Update duplicate magnets issue + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | python script/github-find-top-duplicated-bugs.py \ - --github-token ${{ secrets.GITHUB_TOKEN }} \ + --github-token "$GITHUB_TOKEN" \ --issue-number 46355 diff --git a/Cargo.lock b/Cargo.lock index c813e6a4f2c9facdc68cc526c7ea8bb33a4ccf14..f5fe136c8f62fb14b5ebd1e29b636e82a3193c38 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -78,6 +78,7 @@ dependencies = [ "clock", "collections", "ctor", + "fs", "futures 0.3.31", "gpui", "indoc", @@ -171,7 +172,7 @@ dependencies = [ "context_server", "ctor", "db", - "derive_more 0.99.20", + "derive_more", "editor", "env_logger 0.11.8", "eval_utils", @@ -243,7 +244,7 @@ dependencies = [ "anyhow", "async-broadcast", "async-trait", - "derive_more 2.0.1", + "derive_more", "futures 0.3.31", "log", "serde", @@ -257,7 +258,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "44bc1fef9c32f03bce2ab44af35b6f483bfd169bf55cc59beeb2e3b1a00ae4d1" dependencies = [ "anyhow", - "derive_more 2.0.1", + "derive_more", "schemars", "serde", "serde_json", @@ -370,6 +371,7 @@ dependencies = [ "fs", "futures 0.3.31", "fuzzy", + "git", "gpui", "gpui_tokio", "html_to_markdown", @@ -603,6 +605,17 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" +[[package]] +name = "annotate-snippets" +version = "0.12.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c86cd1c51b95d71dde52bca69ed225008f6ff4c8cc825b08042aa1ef823e1980" +dependencies = [ + "anstyle", + "memchr", + "unicode-width", +] + [[package]] name = "anstream" version = "0.6.21" @@ -815,7 +828,7 @@ dependencies = [ "anyhow", "async-trait", "collections", - "derive_more 0.99.20", + "derive_more", "extension", "futures 0.3.31", "gpui", @@ -1353,6 +1366,7 @@ version = "0.1.0" dependencies = [ "anyhow", "log", + "scopeguard", "simplelog", "tempfile", "windows 0.61.3", @@ -3001,7 +3015,7 @@ dependencies = [ "cloud_llm_client", "collections", "credentials_provider", - "derive_more 0.99.20", + "derive_more", "feature_flags", "fs", "futures 0.3.31", @@ -3439,7 +3453,7 @@ name = "command_palette_hooks" version = "0.1.0" dependencies = [ "collections", - "derive_more 0.99.20", + "derive_more", "gpui", "workspace", ] @@ -3615,15 +3629,18 @@ dependencies = [ [[package]] name = "convert_case" -version = "0.4.0" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" +checksum = "baaaa0ecca5b51987b9423ccdc971514dd8b0bb7b4060b983d3664dad3f1f89f" +dependencies = [ + "unicode-segmentation", +] [[package]] name = "convert_case" -version = "0.8.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baaaa0ecca5b51987b9423ccdc971514dd8b0bb7b4060b983d3664dad3f1f89f" +checksum = "633458d4ef8c78b72454de2d54fd6ab2e60f9e02be22f3c6104cdc8a4e0fceb9" dependencies = [ "unicode-segmentation", ] @@ -4122,13 +4139,13 @@ dependencies = [ name = "crashes" version = "0.1.0" dependencies = [ - "bincode", "cfg-if", "crash-handler", "futures 0.3.31", "log", "mach2 0.5.0", "minidumper", + "parking_lot", "paths", "release_channel", "serde", @@ -4342,6 +4359,20 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "csv_preview" +version = "0.1.0" +dependencies = [ + "anyhow", + "editor", + "feature_flags", + "gpui", + "log", + "text", + "ui", + "workspace", +] + [[package]] name = "ctor" version = "0.4.3" @@ -4779,34 +4810,23 @@ dependencies = [ [[package]] name = "derive_more" -version = "0.99.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6edb4b64a43d977b8e99788fe3a04d483834fba1215a7e02caa415b626497f7f" -dependencies = [ - "convert_case 0.4.0", - "proc-macro2", - "quote", - "rustc_version", - "syn 2.0.106", -] - -[[package]] -name = "derive_more" -version = "2.0.1" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "093242cf7570c207c83073cf82f79706fe7b8317e98620a47d5be7c3d8497678" +checksum = "d751e9e49156b02b44f9c1815bcb94b984cdcc4396ecc32521c739452808b134" dependencies = [ "derive_more-impl", ] [[package]] name = "derive_more-impl" -version = "2.0.1" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda628edc44c4bb645fbe0f758797143e4e07926f7ebf4e9bdfbd3d2ce621df3" +checksum = "799a97264921d8623a957f6c3b9011f3b5492f557bbb7a5a19b7fa6d06ba8dcb" dependencies = [ + "convert_case 0.10.0", "proc-macro2", "quote", + "rustc_version", "syn 2.0.106", "unicode-xid", ] @@ -7115,7 +7135,7 @@ version = "0.8.0" source = "git+https://github.com/zed-industries/gh-workflow?rev=c9eac0ed361583e1072860d96776fa52775b82ac#c9eac0ed361583e1072860d96776fa52775b82ac" dependencies = [ "async-trait", - "derive_more 2.0.1", + "derive_more", "derive_setters", "gh-workflow-macros", "indexmap", @@ -7184,7 +7204,7 @@ dependencies = [ "askpass", "async-trait", "collections", - "derive_more 0.99.20", + "derive_more", "futures 0.3.31", "git2", "gpui", @@ -7563,7 +7583,7 @@ dependencies = [ "core-text", "core-video", "ctor", - "derive_more 0.99.20", + "derive_more", "embed-resource", "env_logger 0.11.8", "etagere", @@ -7584,7 +7604,7 @@ dependencies = [ "mach2 0.5.0", "media", "metal", - "naga", + "naga 28.0.0", "num_cpus", "objc", "objc2", @@ -7691,7 +7711,7 @@ dependencies = [ "core-text", "core-video", "ctor", - "derive_more 0.99.20", + "derive_more", "dispatch2", "etagere", "foreign-types 0.5.0", @@ -8249,7 +8269,7 @@ dependencies = [ "async-fs", "async-tar", "bytes 1.11.1", - "derive_more 0.99.20", + "derive_more", "futures 0.3.31", "http 1.3.1", "http-body 1.0.1", @@ -9126,9 +9146,9 @@ dependencies = [ [[package]] name = "jupyter-protocol" -version = "1.2.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c75a69caf8b8e781224badfb76c4a8da4d49856de36ce72ae3cf5d4a1c94e42" +checksum = "4649647741f9794a7a02e3be976f1b248ba28a37dbfc626d5089316fd4fbf4c8" dependencies = [ "async-trait", "bytes 1.11.1", @@ -10018,6 +10038,7 @@ dependencies = [ "ctor", "futures 0.3.31", "gpui", + "gpui_util", "log", "lsp-types", "parking_lot", @@ -10688,6 +10709,30 @@ name = "naga" version = "28.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "618f667225063219ddfc61251087db8a9aec3c3f0950c916b614e403486f1135" +dependencies = [ + "arrayvec", + "bit-set", + "bitflags 2.10.0", + "cfg-if", + "cfg_aliases 0.2.1", + "codespan-reporting 0.12.0", + "half", + "hashbrown 0.16.1", + "hexf-parse", + "indexmap", + "libm", + "log", + "num-traits", + "once_cell", + "rustc-hash 1.1.0", + "thiserror 2.0.17", + "unicode-ident", +] + +[[package]] +name = "naga" +version = "28.0.1" +source = "git+https://github.com/zed-industries/wgpu?rev=9459e95113c5bd116b2cc2c87e8424b28059e17c#9459e95113c5bd116b2cc2c87e8424b28059e17c" dependencies = [ "arrayvec", "bit-set", @@ -10746,9 +10791,9 @@ dependencies = [ [[package]] name = "nbformat" -version = "1.1.0" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b10a89a2d910233ec3fca4de359b16ebe95e833c8b2162643ef98c6053a0549d" +checksum = "d4983a40792c45e8639f77ef8e4461c55679cbc618f4b9e83830e8c7e79c8383" dependencies = [ "anyhow", "chrono", @@ -14609,9 +14654,9 @@ dependencies = [ [[package]] name = "runtimelib" -version = "1.2.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d80685459e1e5fa5603182058351ae91c98ca458dfef4e85f0a37be4f7cf1e6c" +checksum = "fa84884e45ed4a1e663120cef3fc11f14d1a2a1933776e1c31599f7bd2dd0c9e" dependencies = [ "async-dispatcher", "async-std", @@ -15517,7 +15562,7 @@ version = "0.1.0" dependencies = [ "anyhow", "collections", - "derive_more 0.99.20", + "derive_more", "gpui", "log", "schemars", @@ -17300,7 +17345,7 @@ version = "0.1.0" dependencies = [ "anyhow", "collections", - "derive_more 0.99.20", + "derive_more", "fs", "futures 0.3.31", "gpui", @@ -19812,6 +19857,7 @@ version = "0.1.0" dependencies = [ "anyhow", "client", + "cloud_api_types", "cloud_llm_client", "futures 0.3.31", "gpui", @@ -19876,9 +19922,8 @@ checksum = "a751b3277700db47d3e574514de2eced5e54dc8a5436a3bf7a0b248b2cee16f3" [[package]] name = "wgpu" -version = "28.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9cb534d5ffd109c7d1135f34cdae29e60eab94855a625dcfe1705f8bc7ad79f" +version = "28.0.1" +source = "git+https://github.com/zed-industries/wgpu?rev=9459e95113c5bd116b2cc2c87e8424b28059e17c#9459e95113c5bd116b2cc2c87e8424b28059e17c" dependencies = [ "arrayvec", "bitflags 2.10.0", @@ -19889,7 +19934,7 @@ dependencies = [ "hashbrown 0.16.1", "js-sys", "log", - "naga", + "naga 28.0.1", "parking_lot", "portable-atomic", "profiling", @@ -19906,9 +19951,8 @@ dependencies = [ [[package]] name = "wgpu-core" -version = "28.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bb4c8b5db5f00e56f1f08869d870a0dff7c8bc7ebc01091fec140b0cf0211a9" +version = "28.0.1" +source = "git+https://github.com/zed-industries/wgpu?rev=9459e95113c5bd116b2cc2c87e8424b28059e17c#9459e95113c5bd116b2cc2c87e8424b28059e17c" dependencies = [ "arrayvec", "bit-set", @@ -19920,7 +19964,7 @@ dependencies = [ "hashbrown 0.16.1", "indexmap", "log", - "naga", + "naga 28.0.1", "once_cell", "parking_lot", "portable-atomic", @@ -19938,36 +19982,32 @@ dependencies = [ [[package]] name = "wgpu-core-deps-apple" -version = "28.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87b7b696b918f337c486bf93142454080a32a37832ba8a31e4f48221890047da" +version = "28.0.1" +source = "git+https://github.com/zed-industries/wgpu?rev=9459e95113c5bd116b2cc2c87e8424b28059e17c#9459e95113c5bd116b2cc2c87e8424b28059e17c" dependencies = [ "wgpu-hal", ] [[package]] name = "wgpu-core-deps-emscripten" -version = "28.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34b251c331f84feac147de3c4aa3aa45112622a95dd7ee1b74384fa0458dbd79" +version = "28.0.1" +source = "git+https://github.com/zed-industries/wgpu?rev=9459e95113c5bd116b2cc2c87e8424b28059e17c#9459e95113c5bd116b2cc2c87e8424b28059e17c" dependencies = [ "wgpu-hal", ] [[package]] name = "wgpu-core-deps-windows-linux-android" -version = "28.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68ca976e72b2c9964eb243e281f6ce7f14a514e409920920dcda12ae40febaae" +version = "28.0.1" +source = "git+https://github.com/zed-industries/wgpu?rev=9459e95113c5bd116b2cc2c87e8424b28059e17c#9459e95113c5bd116b2cc2c87e8424b28059e17c" dependencies = [ "wgpu-hal", ] [[package]] name = "wgpu-hal" -version = "28.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "293080d77fdd14d6b08a67c5487dfddbf874534bb7921526db56a7b75d7e3bef" +version = "28.0.1" +source = "git+https://github.com/zed-industries/wgpu?rev=9459e95113c5bd116b2cc2c87e8424b28059e17c#9459e95113c5bd116b2cc2c87e8424b28059e17c" dependencies = [ "android_system_properties", "arrayvec", @@ -19990,7 +20030,7 @@ dependencies = [ "libloading", "log", "metal", - "naga", + "naga 28.0.1", "ndk-sys", "objc", "once_cell", @@ -20013,9 +20053,8 @@ dependencies = [ [[package]] name = "wgpu-types" -version = "28.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e18308757e594ed2cd27dddbb16a139c42a683819d32a2e0b1b0167552f5840c" +version = "28.0.1" +source = "git+https://github.com/zed-industries/wgpu?rev=9459e95113c5bd116b2cc2c87e8424b28059e17c#9459e95113c5bd116b2cc2c87e8424b28059e17c" dependencies = [ "bitflags 2.10.0", "bytemuck", @@ -21239,7 +21278,6 @@ dependencies = [ "any_vec", "anyhow", "async-recursion", - "call", "chrono", "client", "clock", @@ -21504,6 +21542,7 @@ checksum = "ec7a2a501ed189703dba8b08142f057e887dfc4b2cc4db2d343ac6376ba3e0b9" name = "xtask" version = "0.1.0" dependencies = [ + "annotate-snippets", "anyhow", "backtrace", "cargo_metadata", @@ -21512,8 +21551,12 @@ dependencies = [ "gh-workflow", "indexmap", "indoc", + "itertools 0.14.0", + "regex", "serde", "serde_json", + "serde_yaml", + "strum 0.27.2", "toml 0.8.23", "toml_edit 0.22.27", ] @@ -21692,7 +21735,7 @@ dependencies = [ [[package]] name = "zed" -version = "0.227.0" +version = "0.228.0" dependencies = [ "acp_thread", "acp_tools", @@ -21710,7 +21753,6 @@ dependencies = [ "audio", "auto_update", "auto_update_ui", - "bincode", "breadcrumbs", "call", "channel", @@ -21729,6 +21771,7 @@ dependencies = [ "copilot_chat", "copilot_ui", "crashes", + "csv_preview", "dap", "dap_adapters", "db", diff --git a/Cargo.toml b/Cargo.toml index 98fccfaeb21bc6107323378605c8299d5bd5838f..b8e57bda7e46ea45451fedd6759268235c7d71ab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,6 +45,7 @@ members = [ "crates/copilot_chat", "crates/crashes", "crates/credentials_provider", + "crates/csv_preview", "crates/dap", "crates/dap_adapters", "crates/db", @@ -298,6 +299,7 @@ copilot_ui = { path = "crates/copilot_ui" } crashes = { path = "crates/crashes" } credentials_provider = { path = "crates/credentials_provider" } crossbeam = "0.8.4" +csv_preview = { path = "crates/csv_preview"} dap = { path = "crates/dap" } dap_adapters = { path = "crates/dap_adapters" } db = { path = "crates/db" } @@ -536,7 +538,16 @@ criterion = { version = "0.5", features = ["html_reports"] } ctor = "0.4.0" dap-types = { git = "https://github.com/zed-industries/dap-types", rev = "1b461b310481d01e02b2603c16d7144b926339f8" } dashmap = "6.0" -derive_more = "0.99.17" +derive_more = { version = "2.1.1", features = [ + "add", + "add_assign", + "deref", + "deref_mut", + "from_str", + "mul", + "mul_assign", + "not", +] } dirs = "4.0" documented = "0.9.1" dotenvy = "0.15.0" @@ -572,7 +583,7 @@ itertools = "0.14.0" json_dotpath = "1.1" jsonschema = "0.37.0" jsonwebtoken = "10.0" -jupyter-protocol = "1.2.0" +jupyter-protocol = "1.4.0" jupyter-websocket-client = "1.0.0" libc = "0.2" libsqlite3-sys = { version = "0.30.1", features = ["bundled"] } @@ -588,7 +599,7 @@ minidumper = "0.8" moka = { version = "0.12.10", features = ["sync"] } naga = { version = "28.0", features = ["wgsl-in"] } nanoid = "0.4" -nbformat = "1.1.0" +nbformat = "1.2.0" nix = "0.29" num-format = "0.4.4" objc = "0.2" @@ -658,7 +669,7 @@ reqwest = { git = "https://github.com/zed-industries/reqwest.git", rev = "c15662 "stream", ], package = "zed-reqwest", version = "0.12.15-zed" } rsa = "0.9.6" -runtimelib = { version = "1.2.0", default-features = false, features = [ +runtimelib = { version = "1.4.0", default-features = false, features = [ "async-dispatcher-runtime", "aws-lc-rs" ] } rust-embed = { version = "8.4", features = ["include-exclude"] } @@ -768,7 +779,7 @@ wax = "0.7" which = "6.0.0" wasm-bindgen = "0.2.113" web-time = "1.1.0" -wgpu = "28.0" +wgpu = { git = "https://github.com/zed-industries/wgpu", rev = "9459e95113c5bd116b2cc2c87e8424b28059e17c" } windows-core = "0.61" yawc = "0.2.5" zeroize = "1.8" @@ -813,6 +824,7 @@ features = [ "Win32_System_Ole", "Win32_System_Performance", "Win32_System_Pipes", + "Win32_System_RestartManager", "Win32_System_SystemInformation", "Win32_System_SystemServices", "Win32_System_Threading", diff --git a/assets/icons/file_icons/gitlab.svg b/assets/icons/file_icons/gitlab.svg new file mode 100644 index 0000000000000000000000000000000000000000..f0faf570b125c7764e769ae60f7a6ce6f7825ceb --- /dev/null +++ b/assets/icons/file_icons/gitlab.svg @@ -0,0 +1 @@ + diff --git a/assets/icons/file_icons/helm.svg b/assets/icons/file_icons/helm.svg new file mode 100644 index 0000000000000000000000000000000000000000..03e702f2d5081c4e96ff4db7ba7428817b08748f --- /dev/null +++ b/assets/icons/file_icons/helm.svg @@ -0,0 +1 @@ + diff --git a/assets/icons/file_icons/yaml.svg b/assets/icons/file_icons/yaml.svg new file mode 100644 index 0000000000000000000000000000000000000000..2c3efd46cd45ff67d6c46d84476d563dd5ac3a73 --- /dev/null +++ b/assets/icons/file_icons/yaml.svg @@ -0,0 +1 @@ + diff --git a/assets/keymaps/default-linux.json b/assets/keymaps/default-linux.json index 9b8f2d337b1f1073bca818cf0b9c66773a3ce4e9..7e01245ec62b2590a1c88fef5946b7d06463968d 100644 --- a/assets/keymaps/default-linux.json +++ b/assets/keymaps/default-linux.json @@ -204,6 +204,7 @@ { "context": "Editor && editor_agent_diff", "bindings": { + "alt-y": "agent::Keep", "ctrl-alt-y": "agent::Keep", "ctrl-alt-z": "agent::Reject", "shift-alt-y": "agent::KeepAll", @@ -214,6 +215,7 @@ { "context": "AgentDiff", "bindings": { + "alt-y": "agent::Keep", "ctrl-alt-y": "agent::Keep", "ctrl-alt-z": "agent::Reject", "shift-alt-y": "agent::KeepAll", @@ -1310,6 +1312,7 @@ "bindings": { "ctrl-shift-space": "git::WorktreeFromDefaultOnWindow", "ctrl-space": "git::WorktreeFromDefault", + "ctrl-shift-backspace": "git::DeleteWorktree", }, }, { diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json index 5f210cb4da35f9909767035c941289ee24a2ee3f..43d6419575fc698110cd5a033c01127ac6543f9a 100644 --- a/assets/keymaps/default-macos.json +++ b/assets/keymaps/default-macos.json @@ -242,6 +242,7 @@ "context": "AgentDiff", "use_key_equivalents": true, "bindings": { + "cmd-y": "agent::Keep", "cmd-alt-y": "agent::Keep", "cmd-alt-z": "agent::Reject", "shift-alt-y": "agent::KeepAll", @@ -252,6 +253,7 @@ "context": "Editor && editor_agent_diff", "use_key_equivalents": true, "bindings": { + "cmd-y": "agent::Keep", "cmd-alt-y": "agent::Keep", "cmd-alt-z": "agent::Reject", "shift-alt-y": "agent::KeepAll", @@ -448,6 +450,13 @@ "down": "search::NextHistoryQuery", }, }, + { + "context": "BufferSearchBar || ProjectSearchBar", + "use_key_equivalents": true, + "bindings": { + "ctrl-enter": "editor::Newline", + }, + }, { "context": "ProjectSearchBar", "use_key_equivalents": true, @@ -1408,6 +1417,7 @@ "bindings": { "ctrl-shift-space": "git::WorktreeFromDefaultOnWindow", "ctrl-space": "git::WorktreeFromDefault", + "cmd-shift-backspace": "git::DeleteWorktree", }, }, { diff --git a/assets/keymaps/default-windows.json b/assets/keymaps/default-windows.json index 19f75f858cd45192c4cf30dd6bd0799046c26268..22541368cecfc6a645e2b8b7ce55a6711491a012 100644 --- a/assets/keymaps/default-windows.json +++ b/assets/keymaps/default-windows.json @@ -203,6 +203,7 @@ "context": "Editor && editor_agent_diff", "use_key_equivalents": true, "bindings": { + "alt-y": "agent::Keep", "ctrl-alt-y": "agent::Keep", "ctrl-alt-z": "agent::Reject", "shift-alt-y": "agent::KeepAll", @@ -214,6 +215,7 @@ "context": "AgentDiff", "use_key_equivalents": true, "bindings": { + "alt-y": "agent::Keep", "ctrl-alt-y": "agent::Keep", "ctrl-alt-z": "agent::Reject", "shift-alt-y": "agent::KeepAll", @@ -1331,6 +1333,7 @@ "bindings": { "ctrl-shift-space": "git::WorktreeFromDefaultOnWindow", "ctrl-space": "git::WorktreeFromDefault", + "ctrl-shift-backspace": "git::DeleteWorktree", }, }, { diff --git a/assets/settings/default.json b/assets/settings/default.json index b193c0f60d0087972381f4f85f2b864b52fdbc7d..0a824bbe93a0d68a23d934a63eb1fdab1e2f1b02 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -361,8 +361,11 @@ // bracket, brace, single or double quote characters. // For example, when you select text and type '(', Zed will surround the text with (). "use_auto_surround": true, - // Whether indentation should be adjusted based on the context whilst typing. - "auto_indent": true, + // Controls automatic indentation behavior when typing. + // - "syntax_aware": Adjusts indentation based on syntax context (default) + // - "preserve_indent": Preserves current line's indentation on new lines + // - "none": No automatic indentation + "auto_indent": "syntax_aware", // Whether indentation of pasted content should be adjusted based on the context. "auto_indent_on_paste": true, // Controls how the editor handles the autoclosed characters. @@ -1831,8 +1834,8 @@ " (", " # multi-char path: first char (not opening delimiter, space, or box drawing char)", " [^({\\[<\"'`\\ \\u2500-\\u257F]", - " # middle chars: non-space, and colon/paren only if not followed by digit/paren", - " ([^\\ :(]|[:(][^0-9()])*", + " # middle chars: non-space, and colon/paren only if not followed by digit/paren/space", + " ([^\\ :(]|[:(][^0-9()\\ ])*", " # last char: not closing delimiter or colon", " [^()}\\]>\"'`.,;:\\ ]", " |", diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index f57ce1f4d188e260624bd90187a21890379fe6b6..1b9271918884dc020986577926d9578e3a6f049c 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -972,6 +972,8 @@ pub struct AcpThread { had_error: bool, /// The user's unsent prompt text, persisted so it can be restored when reloading the thread. draft_prompt: Option>, + /// The initial scroll position for the thread view, set during session registration. + ui_scroll_position: Option, } impl From<&AcpThread> for ActionLogTelemetry { @@ -1210,6 +1212,7 @@ impl AcpThread { pending_terminal_exit: HashMap::default(), had_error: false, draft_prompt: None, + ui_scroll_position: None, } } @@ -1229,6 +1232,14 @@ impl AcpThread { self.draft_prompt = prompt; } + pub fn ui_scroll_position(&self) -> Option { + self.ui_scroll_position + } + + pub fn set_ui_scroll_position(&mut self, position: Option) { + self.ui_scroll_position = position; + } + pub fn connection(&self) -> &Rc { &self.connection } diff --git a/crates/action_log/Cargo.toml b/crates/action_log/Cargo.toml index 8488df691e40ea3bcfc04f4f6f74964fba7863dd..b1a1bf824fb770b8378e596fd0c799a7cf98b13d 100644 --- a/crates/action_log/Cargo.toml +++ b/crates/action_log/Cargo.toml @@ -20,6 +20,7 @@ buffer_diff.workspace = true log.workspace = true clock.workspace = true collections.workspace = true +fs.workspace = true futures.workspace = true gpui.workspace = true language.workspace = true diff --git a/crates/action_log/src/action_log.rs b/crates/action_log/src/action_log.rs index 5f8a639c0559c10546fc5640dc240aeba9dde487..5679f3c58fe52057f7a4a0faa24d5b5db2b5e497 100644 --- a/crates/action_log/src/action_log.rs +++ b/crates/action_log/src/action_log.rs @@ -1,14 +1,20 @@ use anyhow::{Context as _, Result}; use buffer_diff::BufferDiff; use clock; -use collections::BTreeMap; +use collections::{BTreeMap, HashMap}; +use fs::MTime; use futures::{FutureExt, StreamExt, channel::mpsc}; use gpui::{ App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity, }; use language::{Anchor, Buffer, BufferEvent, Point, ToOffset, ToPoint}; use project::{Project, ProjectItem, lsp_store::OpenLspBufferHandle}; -use std::{cmp, ops::Range, sync::Arc}; +use std::{ + cmp, + ops::Range, + path::{Path, PathBuf}, + sync::Arc, +}; use text::{Edit, Patch, Rope}; use util::{RangeExt, ResultExt as _}; @@ -54,6 +60,8 @@ pub struct ActionLog { linked_action_log: Option>, /// Stores undo information for the most recent reject operation last_reject_undo: Option, + /// Tracks the last time files were read by the agent, to detect external modifications + file_read_times: HashMap, } impl ActionLog { @@ -64,6 +72,7 @@ impl ActionLog { project, linked_action_log: None, last_reject_undo: None, + file_read_times: HashMap::default(), } } @@ -76,6 +85,32 @@ impl ActionLog { &self.project } + pub fn file_read_time(&self, path: &Path) -> Option { + self.file_read_times.get(path).copied() + } + + fn update_file_read_time(&mut self, buffer: &Entity, cx: &App) { + let buffer = buffer.read(cx); + if let Some(file) = buffer.file() { + if let Some(local_file) = file.as_local() { + if let Some(mtime) = file.disk_state().mtime() { + let abs_path = local_file.abs_path(cx); + self.file_read_times.insert(abs_path, mtime); + } + } + } + } + + fn remove_file_read_time(&mut self, buffer: &Entity, cx: &App) { + let buffer = buffer.read(cx); + if let Some(file) = buffer.file() { + if let Some(local_file) = file.as_local() { + let abs_path = local_file.abs_path(cx); + self.file_read_times.remove(&abs_path); + } + } + } + fn track_buffer_internal( &mut self, buffer: Entity, @@ -506,24 +541,69 @@ impl ActionLog { /// Track a buffer as read by agent, so we can notify the model about user edits. pub fn buffer_read(&mut self, buffer: Entity, cx: &mut Context) { - if let Some(linked_action_log) = &mut self.linked_action_log { - linked_action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx)); + self.buffer_read_impl(buffer, true, cx); + } + + fn buffer_read_impl( + &mut self, + buffer: Entity, + record_file_read_time: bool, + cx: &mut Context, + ) { + if let Some(linked_action_log) = &self.linked_action_log { + // We don't want to share read times since the other agent hasn't read it necessarily + linked_action_log.update(cx, |log, cx| { + log.buffer_read_impl(buffer.clone(), false, cx); + }); + } + if record_file_read_time { + self.update_file_read_time(&buffer, cx); } self.track_buffer_internal(buffer, false, cx); } /// Mark a buffer as created by agent, so we can refresh it in the context pub fn buffer_created(&mut self, buffer: Entity, cx: &mut Context) { - if let Some(linked_action_log) = &mut self.linked_action_log { - linked_action_log.update(cx, |log, cx| log.buffer_created(buffer.clone(), cx)); + self.buffer_created_impl(buffer, true, cx); + } + + fn buffer_created_impl( + &mut self, + buffer: Entity, + record_file_read_time: bool, + cx: &mut Context, + ) { + if let Some(linked_action_log) = &self.linked_action_log { + // We don't want to share read times since the other agent hasn't read it necessarily + linked_action_log.update(cx, |log, cx| { + log.buffer_created_impl(buffer.clone(), false, cx); + }); + } + if record_file_read_time { + self.update_file_read_time(&buffer, cx); } self.track_buffer_internal(buffer, true, cx); } /// Mark a buffer as edited by agent, so we can refresh it in the context pub fn buffer_edited(&mut self, buffer: Entity, cx: &mut Context) { - if let Some(linked_action_log) = &mut self.linked_action_log { - linked_action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); + self.buffer_edited_impl(buffer, true, cx); + } + + fn buffer_edited_impl( + &mut self, + buffer: Entity, + record_file_read_time: bool, + cx: &mut Context, + ) { + if let Some(linked_action_log) = &self.linked_action_log { + // We don't want to share read times since the other agent hasn't read it necessarily + linked_action_log.update(cx, |log, cx| { + log.buffer_edited_impl(buffer.clone(), false, cx); + }); + } + if record_file_read_time { + self.update_file_read_time(&buffer, cx); } let new_version = buffer.read(cx).version(); let tracked_buffer = self.track_buffer_internal(buffer, false, cx); @@ -536,6 +616,8 @@ impl ActionLog { } pub fn will_delete_buffer(&mut self, buffer: Entity, cx: &mut Context) { + // Ok to propagate file read time removal to linked action log + self.remove_file_read_time(&buffer, cx); let has_linked_action_log = self.linked_action_log.is_some(); let tracked_buffer = self.track_buffer_internal(buffer.clone(), false, cx); match tracked_buffer.status { @@ -2976,6 +3058,196 @@ mod tests { ); } + #[gpui::test] + async fn test_file_read_time_recorded_on_buffer_read(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/dir"), json!({"file": "hello world"})) + .await; + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + + let file_path = project + .read_with(cx, |project, cx| project.find_project_path("dir/file", cx)) + .unwrap(); + let buffer = project + .update(cx, |project, cx| project.open_buffer(file_path, cx)) + .await + .unwrap(); + + let abs_path = PathBuf::from(path!("/dir/file")); + assert!( + action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_none()), + "file_read_time should be None before buffer_read" + ); + + cx.update(|cx| { + action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx)); + }); + + assert!( + action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_some()), + "file_read_time should be recorded after buffer_read" + ); + } + + #[gpui::test] + async fn test_file_read_time_recorded_on_buffer_edited(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/dir"), json!({"file": "hello world"})) + .await; + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + + let file_path = project + .read_with(cx, |project, cx| project.find_project_path("dir/file", cx)) + .unwrap(); + let buffer = project + .update(cx, |project, cx| project.open_buffer(file_path, cx)) + .await + .unwrap(); + + let abs_path = PathBuf::from(path!("/dir/file")); + assert!( + action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_none()), + "file_read_time should be None before buffer_edited" + ); + + cx.update(|cx| { + action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); + }); + + assert!( + action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_some()), + "file_read_time should be recorded after buffer_edited" + ); + } + + #[gpui::test] + async fn test_file_read_time_recorded_on_buffer_created(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/dir"), json!({"file": "existing content"})) + .await; + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + + let file_path = project + .read_with(cx, |project, cx| project.find_project_path("dir/file", cx)) + .unwrap(); + let buffer = project + .update(cx, |project, cx| project.open_buffer(file_path, cx)) + .await + .unwrap(); + + let abs_path = PathBuf::from(path!("/dir/file")); + assert!( + action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_none()), + "file_read_time should be None before buffer_created" + ); + + cx.update(|cx| { + action_log.update(cx, |log, cx| log.buffer_created(buffer.clone(), cx)); + }); + + assert!( + action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_some()), + "file_read_time should be recorded after buffer_created" + ); + } + + #[gpui::test] + async fn test_file_read_time_removed_on_delete(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/dir"), json!({"file": "hello world"})) + .await; + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + + let file_path = project + .read_with(cx, |project, cx| project.find_project_path("dir/file", cx)) + .unwrap(); + let buffer = project + .update(cx, |project, cx| project.open_buffer(file_path, cx)) + .await + .unwrap(); + + let abs_path = PathBuf::from(path!("/dir/file")); + + cx.update(|cx| { + action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx)); + }); + assert!( + action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_some()), + "file_read_time should exist after buffer_read" + ); + + cx.update(|cx| { + action_log.update(cx, |log, cx| log.will_delete_buffer(buffer.clone(), cx)); + }); + assert!( + action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_none()), + "file_read_time should be removed after will_delete_buffer" + ); + } + + #[gpui::test] + async fn test_file_read_time_not_forwarded_to_linked_action_log(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/dir"), json!({"file": "hello world"})) + .await; + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let parent_log = cx.new(|_| ActionLog::new(project.clone())); + let child_log = + cx.new(|_| ActionLog::new(project.clone()).with_linked_action_log(parent_log.clone())); + + let file_path = project + .read_with(cx, |project, cx| project.find_project_path("dir/file", cx)) + .unwrap(); + let buffer = project + .update(cx, |project, cx| project.open_buffer(file_path, cx)) + .await + .unwrap(); + + let abs_path = PathBuf::from(path!("/dir/file")); + + cx.update(|cx| { + child_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx)); + }); + assert!( + child_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_some()), + "child should record file_read_time on buffer_read" + ); + assert!( + parent_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_none()), + "parent should NOT get file_read_time from child's buffer_read" + ); + + cx.update(|cx| { + child_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); + }); + assert!( + parent_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_none()), + "parent should NOT get file_read_time from child's buffer_edited" + ); + + cx.update(|cx| { + child_log.update(cx, |log, cx| log.buffer_created(buffer.clone(), cx)); + }); + assert!( + parent_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_none()), + "parent should NOT get file_read_time from child's buffer_created" + ); + } + #[derive(Debug, PartialEq)] struct HunkStatus { range: Range, diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index 7cf9416840a6bd2870327c9c68135857c01f7c9b..5421538ca736028a4ea7290c09ef81036e055b81 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -352,6 +352,8 @@ impl NativeAgent { let parent_session_id = thread.parent_thread_id(); let title = thread.title(); let draft_prompt = thread.draft_prompt().map(Vec::from); + let scroll_position = thread.ui_scroll_position(); + let token_usage = thread.latest_token_usage(); let project = thread.project.clone(); let action_log = thread.action_log.clone(); let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone(); @@ -367,6 +369,8 @@ impl NativeAgent { cx, ); acp_thread.set_draft_prompt(draft_prompt); + acp_thread.set_ui_scroll_position(scroll_position); + acp_thread.update_token_usage(token_usage, cx); acp_thread }); @@ -1917,7 +1921,9 @@ mod internal_tests { use gpui::TestAppContext; use indoc::formatdoc; use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider}; - use language_model::{LanguageModelProviderId, LanguageModelProviderName}; + use language_model::{ + LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName, + }; use serde_json::json; use settings::SettingsStore; use util::{path, rel_path::rel_path}; @@ -2549,6 +2555,13 @@ mod internal_tests { cx.run_until_parked(); model.send_last_completion_stream_text_chunk("Lorem."); + model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate( + language_model::TokenUsage { + input_tokens: 150, + output_tokens: 75, + ..Default::default() + }, + )); model.end_last_completion_stream(); cx.run_until_parked(); summary_model @@ -2587,6 +2600,12 @@ mod internal_tests { acp_thread.update(cx, |thread, _cx| { thread.set_draft_prompt(Some(draft_blocks.clone())); }); + thread.update(cx, |thread, _cx| { + thread.set_ui_scroll_position(Some(gpui::ListOffset { + item_ix: 5, + offset_in_item: gpui::px(12.5), + })); + }); thread.update(cx, |_thread, cx| cx.notify()); cx.run_until_parked(); @@ -2632,6 +2651,24 @@ mod internal_tests { acp_thread.read_with(cx, |thread, _| { assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice())); }); + + // Ensure token usage survived the round-trip. + acp_thread.read_with(cx, |thread, _| { + let usage = thread + .token_usage() + .expect("token usage should be restored after reload"); + assert_eq!(usage.input_tokens, 150); + assert_eq!(usage.output_tokens, 75); + }); + + // Ensure scroll position survived the round-trip. + acp_thread.read_with(cx, |thread, _| { + let scroll = thread + .ui_scroll_position() + .expect("scroll position should be restored after reload"); + assert_eq!(scroll.item_ix, 5); + assert_eq!(scroll.offset_in_item, gpui::px(12.5)); + }); } fn thread_entries( diff --git a/crates/agent/src/db.rs b/crates/agent/src/db.rs index 3a7af37cac85065d8853fbb5332093ef3fd20592..10ecb643b9a17dd6b02b47a416c526a662d12632 100644 --- a/crates/agent/src/db.rs +++ b/crates/agent/src/db.rs @@ -66,6 +66,14 @@ pub struct DbThread { pub thinking_effort: Option, #[serde(default)] pub draft_prompt: Option>, + #[serde(default)] + pub ui_scroll_position: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] +pub struct SerializedScrollPosition { + pub item_ix: usize, + pub offset_in_item: f32, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -108,6 +116,7 @@ impl SharedThread { thinking_enabled: false, thinking_effort: None, draft_prompt: None, + ui_scroll_position: None, } } @@ -286,6 +295,7 @@ impl DbThread { thinking_enabled: false, thinking_effort: None, draft_prompt: None, + ui_scroll_position: None, }) } } @@ -637,6 +647,7 @@ mod tests { thinking_enabled: false, thinking_effort: None, draft_prompt: None, + ui_scroll_position: None, } } @@ -841,4 +852,53 @@ mod tests { assert_eq!(threads.len(), 1); assert!(threads[0].folder_paths.is_empty()); } + + #[test] + fn test_scroll_position_defaults_to_none() { + let json = r#"{ + "title": "Old Thread", + "messages": [], + "updated_at": "2024-01-01T00:00:00Z" + }"#; + + let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize"); + + assert!( + db_thread.ui_scroll_position.is_none(), + "Legacy threads without scroll_position field should default to None" + ); + } + + #[gpui::test] + async fn test_scroll_position_roundtrips_through_save_load(cx: &mut TestAppContext) { + let database = ThreadsDatabase::new(cx.executor()).unwrap(); + + let thread_id = session_id("thread-with-scroll"); + + let mut thread = make_thread( + "Thread With Scroll", + Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(), + ); + thread.ui_scroll_position = Some(SerializedScrollPosition { + item_ix: 42, + offset_in_item: 13.5, + }); + + database + .save_thread(thread_id.clone(), thread, PathList::default()) + .await + .unwrap(); + + let loaded = database + .load_thread(thread_id) + .await + .unwrap() + .expect("thread should exist"); + + let scroll = loaded + .ui_scroll_position + .expect("scroll_position should be restored"); + assert_eq!(scroll.item_ix, 42); + assert!((scroll.offset_in_item - 13.5).abs() < f32::EPSILON); + } } diff --git a/crates/agent/src/tests/edit_file_thread_test.rs b/crates/agent/src/tests/edit_file_thread_test.rs index 069bf0349299e6f4952f673cbf7607e52d48d9c5..3beb5cb0d51abc55fbf3cf0849ced248a9d1fa5c 100644 --- a/crates/agent/src/tests/edit_file_thread_test.rs +++ b/crates/agent/src/tests/edit_file_thread_test.rs @@ -50,9 +50,9 @@ async fn test_edit_file_tool_in_thread_context(cx: &mut TestAppContext) { // Add just the tools we need for this test let language_registry = project.read(cx).languages().clone(); thread.add_tool(crate::ReadFileTool::new( - cx.weak_entity(), project.clone(), thread.action_log().clone(), + true, )); thread.add_tool(crate::EditFileTool::new( project.clone(), diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index 8d75aae7e2948ef9c0934a72da112b926f633941..23ebe41d3c42654cb8fcdc0266009416686858aa 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -2631,6 +2631,84 @@ async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) { assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]); } +#[gpui::test] +async fn test_retry_cancelled_promptly_on_new_send(cx: &mut TestAppContext) { + // Regression test: when a completion fails with a retryable error (e.g. upstream 500), + // the retry loop waits on a timer. If the user switches models and sends a new message + // during that delay, the old turn should exit immediately instead of retrying with the + // stale model. + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let model_a = model.as_fake(); + + // Start a turn with model_a. + let events_1 = thread + .update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Hello"], cx) + }) + .unwrap(); + cx.run_until_parked(); + assert_eq!(model_a.completion_count(), 1); + + // Model returns a retryable upstream 500. The turn enters the retry delay. + model_a.send_last_completion_stream_error( + LanguageModelCompletionError::UpstreamProviderError { + message: "Internal server error".to_string(), + status: http_client::StatusCode::INTERNAL_SERVER_ERROR, + retry_after: None, + }, + ); + model_a.end_last_completion_stream(); + cx.run_until_parked(); + + // The old completion was consumed; model_a has no pending requests yet because the + // retry timer hasn't fired. + assert_eq!(model_a.completion_count(), 0); + + // Switch to model_b and send a new message. This cancels the old turn. + let model_b = Arc::new(FakeLanguageModel::with_id_and_thinking( + "fake", "model-b", "Model B", false, + )); + thread.update(cx, |thread, cx| { + thread.set_model(model_b.clone(), cx); + }); + let events_2 = thread + .update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Continue"], cx) + }) + .unwrap(); + cx.run_until_parked(); + + // model_b should have received its completion request. + assert_eq!(model_b.as_fake().completion_count(), 1); + + // Advance the clock well past the retry delay (BASE_RETRY_DELAY = 5s). + cx.executor().advance_clock(Duration::from_secs(10)); + cx.run_until_parked(); + + // model_a must NOT have received another completion request — the cancelled turn + // should have exited during the retry delay rather than retrying with the old model. + assert_eq!( + model_a.completion_count(), + 0, + "old model should not receive a retry request after cancellation" + ); + + // Complete model_b's turn. + model_b + .as_fake() + .send_last_completion_stream_text_chunk("Done!"); + model_b + .as_fake() + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)); + model_b.as_fake().end_last_completion_stream(); + + let events_1 = events_1.collect::>().await; + assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]); + + let events_2 = events_2.collect::>().await; + assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]); +} + #[gpui::test] async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) { let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index c5ca1118ace28b66d555d67aa40c718da292f644..148702e1bafeae05ac67c6127d8259581aff93dd 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -893,14 +893,13 @@ pub struct Thread { pub(crate) prompt_capabilities_rx: watch::Receiver, pub(crate) project: Entity, pub(crate) action_log: Entity, - /// Tracks the last time files were read by the agent, to detect external modifications - pub(crate) file_read_times: HashMap, /// True if this thread was imported from a shared thread and can be synced. imported: bool, /// If this is a subagent thread, contains context about the parent subagent_context: Option, /// The user's unsent prompt text, persisted so it can be restored when reloading the thread. draft_prompt: Option>, + ui_scroll_position: Option, /// Weak references to running subagent threads for cancellation propagation running_subagents: Vec>, } @@ -1013,10 +1012,10 @@ impl Thread { prompt_capabilities_rx, project, action_log, - file_read_times: HashMap::default(), imported: false, subagent_context: None, draft_prompt: None, + ui_scroll_position: None, running_subagents: Vec::new(), } } @@ -1229,10 +1228,13 @@ impl Thread { updated_at: db_thread.updated_at, prompt_capabilities_tx, prompt_capabilities_rx, - file_read_times: HashMap::default(), imported: db_thread.imported, subagent_context: db_thread.subagent_context, draft_prompt: db_thread.draft_prompt, + ui_scroll_position: db_thread.ui_scroll_position.map(|sp| gpui::ListOffset { + item_ix: sp.item_ix, + offset_in_item: gpui::px(sp.offset_in_item), + }), running_subagents: Vec::new(), } } @@ -1258,6 +1260,12 @@ impl Thread { thinking_enabled: self.thinking_enabled, thinking_effort: self.thinking_effort.clone(), draft_prompt: self.draft_prompt.clone(), + ui_scroll_position: self.ui_scroll_position.map(|lo| { + crate::db::SerializedScrollPosition { + item_ix: lo.item_ix, + offset_in_item: lo.offset_in_item.as_f32(), + } + }), }; cx.background_spawn(async move { @@ -1307,6 +1315,14 @@ impl Thread { self.draft_prompt = prompt; } + pub fn ui_scroll_position(&self) -> Option { + self.ui_scroll_position + } + + pub fn set_ui_scroll_position(&mut self, position: Option) { + self.ui_scroll_position = position; + } + pub fn model(&self) -> Option<&Arc> { self.model.as_ref() } @@ -1416,6 +1432,9 @@ impl Thread { environment: Rc, cx: &mut Context, ) { + // Only update the agent location for the root thread, not for subagents. + let update_agent_location = self.parent_thread_id().is_none(); + let language_registry = self.project.read(cx).languages().clone(); self.add_tool(CopyPathTool::new(self.project.clone())); self.add_tool(CreateDirectoryTool::new(self.project.clone())); @@ -1433,6 +1452,7 @@ impl Thread { self.add_tool(StreamingEditFileTool::new( self.project.clone(), cx.weak_entity(), + self.action_log.clone(), language_registry, )); self.add_tool(FetchTool::new(self.project.read(cx).client().http_client())); @@ -1443,9 +1463,9 @@ impl Thread { self.add_tool(NowTool); self.add_tool(OpenTool::new(self.project.clone())); self.add_tool(ReadFileTool::new( - cx.weak_entity(), self.project.clone(), self.action_log.clone(), + update_agent_location, )); self.add_tool(SaveFileTool::new(self.project.clone())); self.add_tool(RestoreFileFromDiskTool::new(self.project.clone())); @@ -1940,7 +1960,15 @@ impl Thread { })??; let timer = cx.background_executor().timer(retry.duration); event_stream.send_retry(retry); - timer.await; + futures::select! { + _ = timer.fuse() => {} + _ = cancellation_rx.changed().fuse() => { + if *cancellation_rx.borrow() { + log::debug!("Turn cancelled during retry delay, exiting"); + return Ok(()); + } + } + } this.update(cx, |this, _cx| { if let Some(Message::Agent(message)) = this.messages.last() { if message.tool_results.is_empty() { @@ -2308,20 +2336,18 @@ impl Thread { ) { // Ensure the last message ends in the current tool use let last_message = self.pending_message(); - let push_new_tool_use = last_message.content.last_mut().is_none_or(|content| { + + let has_tool_use = last_message.content.iter_mut().rev().any(|content| { if let AgentMessageContent::ToolUse(last_tool_use) = content { if last_tool_use.id == tool_use.id { *last_tool_use = tool_use.clone(); - false - } else { - true + return true; } - } else { - true } + false }); - if push_new_tool_use { + if !has_tool_use { event_stream.send_tool_call( &tool_use.id, &tool_use.name, @@ -2609,7 +2635,8 @@ impl Thread { } } - let use_streaming_edit_tool = cx.has_flag::(); + let use_streaming_edit_tool = + cx.has_flag::() && model.supports_streaming_tools(); let mut tools = self .tools diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index f944377e489a88ac0fa6dbb802edf9702e86f5f2..e26820ddacc3132d42946de3b27d25f4424fae02 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -146,6 +146,7 @@ mod tests { thinking_enabled: false, thinking_effort: None, draft_prompt: None, + ui_scroll_position: None, } } diff --git a/crates/agent/src/tools/edit_file_tool.rs b/crates/agent/src/tools/edit_file_tool.rs index d8c380eba326d089b848563cca04557e903ba0f4..29b08ac09db4417123403fd3915b8575791b2a4e 100644 --- a/crates/agent/src/tools/edit_file_tool.rs +++ b/crates/agent/src/tools/edit_file_tool.rs @@ -305,13 +305,13 @@ impl AgentTool for EditFileTool { // Check if the file has been modified since the agent last read it if let Some(abs_path) = abs_path.as_ref() { - let (last_read_mtime, current_mtime, is_dirty, has_save_tool, has_restore_tool) = self.thread.update(cx, |thread, cx| { - let last_read = thread.file_read_times.get(abs_path).copied(); + let last_read_mtime = action_log.read_with(cx, |log, _| log.file_read_time(abs_path)); + let (current_mtime, is_dirty, has_save_tool, has_restore_tool) = self.thread.read_with(cx, |thread, cx| { let current = buffer.read(cx).file().and_then(|file| file.disk_state().mtime()); let dirty = buffer.read(cx).is_dirty(); let has_save = thread.has_tool(SaveFileTool::NAME); let has_restore = thread.has_tool(RestoreFileFromDiskTool::NAME); - (last_read, current, dirty, has_save, has_restore) + (current, dirty, has_save, has_restore) })?; // Check for unsaved changes first - these indicate modifications we don't know about @@ -470,17 +470,6 @@ impl AgentTool for EditFileTool { log.buffer_edited(buffer.clone(), cx); }); - // Update the recorded read time after a successful edit so consecutive edits work - if let Some(abs_path) = abs_path.as_ref() { - if let Some(new_mtime) = buffer.read_with(cx, |buffer, _| { - buffer.file().and_then(|file| file.disk_state().mtime()) - }) { - self.thread.update(cx, |thread, _| { - thread.file_read_times.insert(abs_path.to_path_buf(), new_mtime); - })?; - } - } - let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); let (new_text, unified_diff) = cx .background_spawn({ @@ -2212,14 +2201,18 @@ mod tests { let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); // Initially, file_read_times should be empty - let is_empty = thread.read_with(cx, |thread, _| thread.file_read_times.is_empty()); + let is_empty = action_log.read_with(cx, |action_log, _| { + action_log + .file_read_time(path!("/root/test.txt").as_ref()) + .is_none() + }); assert!(is_empty, "file_read_times should start empty"); // Create read tool let read_tool = Arc::new(crate::ReadFileTool::new( - thread.downgrade(), project.clone(), - action_log, + action_log.clone(), + true, )); // Read the file to record the read time @@ -2238,12 +2231,9 @@ mod tests { .unwrap(); // Verify that file_read_times now contains an entry for the file - let has_entry = thread.read_with(cx, |thread, _| { - thread.file_read_times.len() == 1 - && thread - .file_read_times - .keys() - .any(|path| path.ends_with("test.txt")) + let has_entry = action_log.read_with(cx, |log, _| { + log.file_read_time(path!("/root/test.txt").as_ref()) + .is_some() }); assert!( has_entry, @@ -2265,11 +2255,14 @@ mod tests { .await .unwrap(); - // Should still have exactly one entry - let has_one_entry = thread.read_with(cx, |thread, _| thread.file_read_times.len() == 1); + // Should still have an entry after re-reading + let has_entry = action_log.read_with(cx, |log, _| { + log.file_read_time(path!("/root/test.txt").as_ref()) + .is_some() + }); assert!( - has_one_entry, - "file_read_times should still have one entry after re-reading" + has_entry, + "file_read_times should still have an entry after re-reading" ); } @@ -2309,11 +2302,7 @@ mod tests { let languages = project.read_with(cx, |project, _| project.languages().clone()); let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); - let read_tool = Arc::new(crate::ReadFileTool::new( - thread.downgrade(), - project.clone(), - action_log, - )); + let read_tool = Arc::new(crate::ReadFileTool::new(project.clone(), action_log, true)); let edit_tool = Arc::new(EditFileTool::new( project.clone(), thread.downgrade(), @@ -2423,11 +2412,7 @@ mod tests { let languages = project.read_with(cx, |project, _| project.languages().clone()); let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); - let read_tool = Arc::new(crate::ReadFileTool::new( - thread.downgrade(), - project.clone(), - action_log, - )); + let read_tool = Arc::new(crate::ReadFileTool::new(project.clone(), action_log, true)); let edit_tool = Arc::new(EditFileTool::new( project.clone(), thread.downgrade(), @@ -2534,11 +2519,7 @@ mod tests { let languages = project.read_with(cx, |project, _| project.languages().clone()); let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); - let read_tool = Arc::new(crate::ReadFileTool::new( - thread.downgrade(), - project.clone(), - action_log, - )); + let read_tool = Arc::new(crate::ReadFileTool::new(project.clone(), action_log, true)); let edit_tool = Arc::new(EditFileTool::new( project.clone(), thread.downgrade(), diff --git a/crates/agent/src/tools/read_file_tool.rs b/crates/agent/src/tools/read_file_tool.rs index 8cfc16ddf6174a190ffe7cc11921dc204b05b79d..f7a75bc63a1c461b65c3a2e6f74f2c70e0ca15f6 100644 --- a/crates/agent/src/tools/read_file_tool.rs +++ b/crates/agent/src/tools/read_file_tool.rs @@ -2,7 +2,7 @@ use action_log::ActionLog; use agent_client_protocol::{self as acp, ToolCallUpdateFields}; use anyhow::{Context as _, Result, anyhow}; use futures::FutureExt as _; -use gpui::{App, Entity, SharedString, Task, WeakEntity}; +use gpui::{App, Entity, SharedString, Task}; use indoc::formatdoc; use language::Point; use language_model::{LanguageModelImage, LanguageModelToolResultContent}; @@ -21,7 +21,7 @@ use super::tool_permissions::{ ResolvedProjectPath, authorize_symlink_access, canonicalize_worktree_roots, resolve_project_path, }; -use crate::{AgentTool, Thread, ToolCallEventStream, ToolInput, outline}; +use crate::{AgentTool, ToolCallEventStream, ToolInput, outline}; /// Reads the content of the given file in the project. /// @@ -56,21 +56,21 @@ pub struct ReadFileToolInput { } pub struct ReadFileTool { - thread: WeakEntity, project: Entity, action_log: Entity, + update_agent_location: bool, } impl ReadFileTool { pub fn new( - thread: WeakEntity, project: Entity, action_log: Entity, + update_agent_location: bool, ) -> Self { Self { - thread, project, action_log, + update_agent_location, } } } @@ -119,7 +119,6 @@ impl AgentTool for ReadFileTool { cx: &mut App, ) -> Task> { let project = self.project.clone(); - let thread = self.thread.clone(); let action_log = self.action_log.clone(); cx.spawn(async move |cx| { let input = input @@ -257,20 +256,6 @@ impl AgentTool for ReadFileTool { return Err(tool_content_err(format!("{file_path} not found"))); } - // Record the file read time and mtime - if let Some(mtime) = buffer.read_with(cx, |buffer, _| { - buffer.file().and_then(|file| file.disk_state().mtime()) - }) { - thread - .update(cx, |thread, _| { - thread.file_read_times.insert(abs_path.to_path_buf(), mtime); - }) - .ok(); - } - - - let update_agent_location = self.thread.read_with(cx, |thread, _cx| !thread.is_subagent()).unwrap_or_default(); - let mut anchor = None; // Check if specific line ranges are provided @@ -330,7 +315,7 @@ impl AgentTool for ReadFileTool { }; project.update(cx, |project, cx| { - if update_agent_location { + if self.update_agent_location { project.set_agent_location( Some(AgentLocation { buffer: buffer.downgrade(), @@ -362,13 +347,10 @@ impl AgentTool for ReadFileTool { #[cfg(test)] mod test { use super::*; - use crate::{ContextServerRegistry, Templates, Thread}; use agent_client_protocol as acp; use fs::Fs as _; use gpui::{AppContext, TestAppContext, UpdateGlobal as _}; - use language_model::fake_provider::FakeLanguageModel; use project::{FakeFs, Project}; - use prompt_store::ProjectContext; use serde_json::json; use settings::SettingsStore; use std::path::PathBuf; @@ -383,20 +365,7 @@ mod test { fs.insert_tree(path!("/root"), json!({})).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log)); + let tool = Arc::new(ReadFileTool::new(project, action_log, true)); let (event_stream, _) = ToolCallEventStream::test(); let result = cx @@ -429,20 +398,7 @@ mod test { .await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log)); + let tool = Arc::new(ReadFileTool::new(project, action_log, true)); let result = cx .update(|cx| { let input = ReadFileToolInput { @@ -476,20 +432,7 @@ mod test { let language_registry = project.read_with(cx, |project, _| project.languages().clone()); language_registry.add(language::rust_lang()); let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log)); + let tool = Arc::new(ReadFileTool::new(project, action_log, true)); let result = cx .update(|cx| { let input = ReadFileToolInput { @@ -569,20 +512,7 @@ mod test { let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log)); + let tool = Arc::new(ReadFileTool::new(project, action_log, true)); let result = cx .update(|cx| { let input = ReadFileToolInput { @@ -614,20 +544,7 @@ mod test { .await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log)); + let tool = Arc::new(ReadFileTool::new(project, action_log, true)); // start_line of 0 should be treated as 1 let result = cx @@ -757,20 +674,7 @@ mod test { let project = Project::test(fs.clone(), [path!("/project_root").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log)); + let tool = Arc::new(ReadFileTool::new(project, action_log, true)); // Reading a file outside the project worktree should fail let result = cx @@ -965,20 +869,7 @@ mod test { let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log)); + let tool = Arc::new(ReadFileTool::new(project, action_log, true)); let (event_stream, mut event_rx) = ToolCallEventStream::test(); let read_task = cx.update(|cx| { @@ -1084,24 +975,7 @@ mod test { .await; let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let tool = Arc::new(ReadFileTool::new( - thread.downgrade(), - project.clone(), - action_log.clone(), - )); + let tool = Arc::new(ReadFileTool::new(project.clone(), action_log.clone(), true)); // Test reading allowed files in worktree1 let result = cx @@ -1288,24 +1162,7 @@ mod test { cx.executor().run_until_parked(); let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let tool = Arc::new(ReadFileTool::new( - thread.downgrade(), - project.clone(), - action_log, - )); + let tool = Arc::new(ReadFileTool::new(project.clone(), action_log, true)); let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { @@ -1364,24 +1221,7 @@ mod test { cx.executor().run_until_parked(); let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let tool = Arc::new(ReadFileTool::new( - thread.downgrade(), - project.clone(), - action_log, - )); + let tool = Arc::new(ReadFileTool::new(project.clone(), action_log, true)); let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { @@ -1444,24 +1284,7 @@ mod test { cx.executor().run_until_parked(); let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let tool = Arc::new(ReadFileTool::new( - thread.downgrade(), - project.clone(), - action_log, - )); + let tool = Arc::new(ReadFileTool::new(project.clone(), action_log, true)); let (event_stream, mut event_rx) = ToolCallEventStream::test(); let result = cx diff --git a/crates/agent/src/tools/spawn_agent_tool.rs b/crates/agent/src/tools/spawn_agent_tool.rs index b75c41775258db49577024dca3eb1770937e52e8..162de68b86115056e9579d22a8623d675245cc91 100644 --- a/crates/agent/src/tools/spawn_agent_tool.rs +++ b/crates/agent/src/tools/spawn_agent_tool.rs @@ -161,29 +161,42 @@ impl AgentTool for SpawnAgentTool { Ok((subagent, session_info)) })?; - match subagent.send(input.message, cx).await { - Ok(output) => { - session_info.message_end_index = - cx.update(|cx| Some(subagent.num_entries(cx).saturating_sub(1))); - event_stream.update_fields_with_meta( - acp::ToolCallUpdateFields::new().content(vec![output.clone().into()]), - Some(acp::Meta::from_iter([( - SUBAGENT_SESSION_INFO_META_KEY.into(), - serde_json::json!(&session_info), - )])), - ); + let send_result = subagent.send(input.message, cx).await; + + session_info.message_end_index = + cx.update(|cx| Some(subagent.num_entries(cx).saturating_sub(1))); + + let meta = Some(acp::Meta::from_iter([( + SUBAGENT_SESSION_INFO_META_KEY.into(), + serde_json::json!(&session_info), + )])); + + let (output, result) = match send_result { + Ok(output) => ( + output.clone(), Ok(SpawnAgentToolOutput::Success { session_id: session_info.session_id.clone(), session_info, output, - }) + }), + ), + Err(e) => { + let error = e.to_string(); + ( + error.clone(), + Err(SpawnAgentToolOutput::Error { + session_id: Some(session_info.session_id.clone()), + error, + session_info: Some(session_info), + }), + ) } - Err(e) => Err(SpawnAgentToolOutput::Error { - session_id: Some(session_info.session_id.clone()), - error: e.to_string(), - session_info: Some(session_info), - }), - } + }; + event_stream.update_fields_with_meta( + acp::ToolCallUpdateFields::new().content(vec![output.into()]), + meta, + ); + result }) } diff --git a/crates/agent/src/tools/streaming_edit_file_tool.rs b/crates/agent/src/tools/streaming_edit_file_tool.rs index 6b1c70931ce00842a7cf427c492b2512bc7a3750..74e91ee1d2607ad1f68a5d327cd0519699cce88b 100644 --- a/crates/agent/src/tools/streaming_edit_file_tool.rs +++ b/crates/agent/src/tools/streaming_edit_file_tool.rs @@ -73,7 +73,7 @@ pub struct StreamingEditFileToolInput { /// /// `frontend/db.js` /// - pub path: String, + pub path: PathBuf, /// The mode of operation on the file. Possible values: /// - 'write': Replace the entire contents of the file. If the file doesn't exist, it will be created. Requires 'content' field. @@ -93,7 +93,7 @@ pub struct StreamingEditFileToolInput { pub edits: Option>, } -#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "snake_case")] pub enum StreamingEditFileMode { /// Overwrite the file with new content (replacing any existing content). @@ -187,20 +187,23 @@ impl From for LanguageModelToolResultContent { } pub struct StreamingEditFileTool { + project: Entity, thread: WeakEntity, + action_log: Entity, language_registry: Arc, - project: Entity, } impl StreamingEditFileTool { pub fn new( project: Entity, thread: WeakEntity, + action_log: Entity, language_registry: Arc, ) -> Self { Self { project, thread, + action_log, language_registry, } } @@ -264,11 +267,11 @@ impl AgentTool for StreamingEditFileTool { .read(cx) .short_full_path_for_project_path(&project_path, cx) }) - .unwrap_or(input.path) + .unwrap_or(input.path.to_string_lossy().into_owned()) .into(), Err(raw_input) => { - if let Some(input) = - serde_json::from_value::(raw_input).ok() + if let Ok(input) = + serde_json::from_value::(raw_input) { let path = input.path.unwrap_or_default(); let path = path.trim(); @@ -311,24 +314,37 @@ impl AgentTool for StreamingEditFileTool { partial = input.recv_partial().fuse() => { let Some(partial_value) = partial else { break }; if let Ok(parsed) = serde_json::from_value::(partial_value) { - if state.is_none() && let Some(path_str) = &parsed.path - && let Some(display_description) = &parsed.display_description - && let Some(mode) = parsed.mode.clone() { - state = Some( - EditSession::new( - path_str, - display_description, - mode, - &self, - &event_stream, - cx, - ) - .await?, - ); + if state.is_none() + && let StreamingEditFileToolPartialInput { + path: Some(path), + display_description: Some(display_description), + mode: Some(mode), + .. + } = &parsed + { + match EditSession::new( + &PathBuf::from(path), + display_description, + *mode, + &self, + &event_stream, + cx, + ) + .await + { + Ok(session) => state = Some(session), + Err(e) => { + log::error!("Failed to create edit session: {}", e); + return Err(e); + } + } } if let Some(state) = &mut state { - state.process(parsed, &self, &event_stream, cx)?; + if let Err(e) = state.process(parsed, &self, &event_stream, cx) { + log::error!("Failed to process edit: {}", e); + return Err(e); + } } } } @@ -341,22 +357,39 @@ impl AgentTool for StreamingEditFileTool { input .recv() .await - .map_err(|e| StreamingEditFileToolOutput::error(format!("Failed to receive tool input: {e}")))?; + .map_err(|e| { + let err = StreamingEditFileToolOutput::error(format!("Failed to receive tool input: {e}")); + log::error!("Failed to receive tool input: {e}"); + err + })?; let mut state = if let Some(state) = state { state } else { - EditSession::new( + match EditSession::new( &full_input.path, &full_input.display_description, - full_input.mode.clone(), + full_input.mode, &self, &event_stream, cx, ) - .await? + .await + { + Ok(session) => session, + Err(e) => { + log::error!("Failed to create edit session: {}", e); + return Err(e); + } + } }; - state.finalize(full_input, &self, &event_stream, cx).await + match state.finalize(full_input, &self, &event_stream, cx).await { + Ok(output) => Ok(output), + Err(e) => { + log::error!("Failed to finalize edit: {}", e); + Err(e) + } + } }) } @@ -409,7 +442,7 @@ enum EditPipeline { original_snapshot: text::BufferSnapshot, }, Edit { - edits: Vec, + current_edit: Option, }, } @@ -424,73 +457,51 @@ enum EditPipelineEntry { reindenter: Reindenter, original_snapshot: text::BufferSnapshot, }, - Done, } impl EditPipeline { - fn new(mode: StreamingEditFileMode, snapshot: text::BufferSnapshot) -> Self { + fn new(mode: StreamingEditFileMode, original_snapshot: text::BufferSnapshot) -> Self { match mode { StreamingEditFileMode::Write => Self::Write { - content_written: false, - streaming_diff: StreamingDiff::new(snapshot.text()), + streaming_diff: StreamingDiff::new(original_snapshot.text()), line_diff: LineDiff::default(), - original_snapshot: snapshot, + content_written: false, + original_snapshot, }, - StreamingEditFileMode::Edit => Self::Edit { edits: Vec::new() }, - } - } - - fn edits(&mut self) -> &mut [EditPipelineEntry] { - match self { - EditPipeline::Write { .. } => &mut [], - EditPipeline::Edit { edits } => edits, + StreamingEditFileMode::Edit => Self::Edit { current_edit: None }, } } - fn ensure_resolving_old_text( - &mut self, - edit_index: usize, - buffer: &Entity, - cx: &mut AsyncApp, - ) { - match self { - EditPipeline::Write { .. } => {} - EditPipeline::Edit { edits } => { - while edits.len() <= edit_index { - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.text_snapshot()); - edits.push(EditPipelineEntry::ResolvingOldText { - matcher: StreamingFuzzyMatcher::new(snapshot), - }); - } - } + fn ensure_resolving_old_text(&mut self, buffer: &Entity, cx: &mut AsyncApp) { + if let Self::Edit { current_edit } = self + && current_edit.is_none() + { + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.text_snapshot()); + *current_edit = Some(EditPipelineEntry::ResolvingOldText { + matcher: StreamingFuzzyMatcher::new(snapshot), + }); } } } -/// Compute the `LineIndent` of the first line in a set of query lines. -fn query_first_line_indent(query_lines: &[String]) -> text::LineIndent { - let first_line = query_lines.first().map(|s| s.as_str()).unwrap_or(""); - text::LineIndent::from_iter(first_line.chars()) -} - impl EditSession { async fn new( - path_str: &str, + path: &PathBuf, display_description: &str, mode: StreamingEditFileMode, tool: &StreamingEditFileTool, event_stream: &ToolCallEventStream, cx: &mut AsyncApp, ) -> Result { - let path = PathBuf::from(path_str); let project_path = cx - .update(|cx| resolve_path(mode.clone(), &path, &tool.project, cx)) + .update(|cx| resolve_path(mode, &path, &tool.project, cx)) .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?; let Some(abs_path) = cx.update(|cx| tool.project.read(cx).absolute_path(&project_path, cx)) else { return Err(StreamingEditFileToolOutput::error(format!( - "Worktree at '{path_str}' does not exist" + "Worktree at '{}' does not exist", + path.to_string_lossy() ))); }; @@ -520,13 +531,8 @@ impl EditSession { } }) as Box); - tool.thread - .update(cx, |thread, cx| { - thread - .action_log() - .update(cx, |log, cx| log.buffer_read(buffer.clone(), cx)) - }) - .ok(); + tool.action_log + .update(cx, |log, cx| log.buffer_read(buffer.clone(), cx)); let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); let old_text = cx @@ -555,69 +561,31 @@ impl EditSession { event_stream: &ToolCallEventStream, cx: &mut AsyncApp, ) -> Result { - let Self { - buffer, - old_text, - diff, - abs_path, - parser, - pipeline, - .. - } = self; - - let action_log = tool - .thread - .read_with(cx, |thread, _cx| thread.action_log().clone()) - .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?; + let old_text = self.old_text.clone(); match input.mode { StreamingEditFileMode::Write => { - action_log.update(cx, |log, cx| { - log.buffer_created(buffer.clone(), cx); - }); let content = input.content.ok_or_else(|| { StreamingEditFileToolOutput::error("'content' field is required for write mode") })?; - let events = parser.finalize_content(&content); - Self::process_events( - &events, - buffer, - diff, - pipeline, - abs_path, - tool, - event_stream, - cx, - )?; + let events = self.parser.finalize_content(&content); + self.process_events(&events, tool, event_stream, cx)?; + + tool.action_log.update(cx, |log, cx| { + log.buffer_created(self.buffer.clone(), cx); + }); } StreamingEditFileMode::Edit => { let edits = input.edits.ok_or_else(|| { StreamingEditFileToolOutput::error("'edits' field is required for edit mode") })?; - - let final_edits = edits - .into_iter() - .map(|e| Edit { - old_text: e.old_text, - new_text: e.new_text, - }) - .collect::>(); - let events = parser.finalize_edits(&final_edits); - Self::process_events( - &events, - buffer, - diff, - pipeline, - abs_path, - tool, - event_stream, - cx, - )?; + let events = self.parser.finalize_edits(&edits); + self.process_events(&events, tool, event_stream, cx)?; } } - let format_on_save_enabled = buffer.read_with(cx, |buffer, cx| { + let format_on_save_enabled = self.buffer.read_with(cx, |buffer, cx| { let settings = language_settings::language_settings( buffer.language().map(|l| l.name()), buffer.file(), @@ -627,13 +595,13 @@ impl EditSession { }); if format_on_save_enabled { - action_log.update(cx, |log, cx| { - log.buffer_edited(buffer.clone(), cx); + tool.action_log.update(cx, |log, cx| { + log.buffer_edited(self.buffer.clone(), cx); }); let format_task = tool.project.update(cx, |project, cx| { project.format( - HashSet::from_iter([buffer.clone()]), + HashSet::from_iter([self.buffer.clone()]), LspFormatTarget::Buffers, false, FormatTrigger::Save, @@ -648,9 +616,9 @@ impl EditSession { }; } - let save_task = tool - .project - .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)); + let save_task = tool.project.update(cx, |project, cx| { + project.save_buffer(self.buffer.clone(), cx) + }); futures::select! { result = save_task.fuse() => { result.map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?; }, _ = event_stream.cancelled_by_user().fuse() => { @@ -658,23 +626,11 @@ impl EditSession { } }; - action_log.update(cx, |log, cx| { - log.buffer_edited(buffer.clone(), cx); + tool.action_log.update(cx, |log, cx| { + log.buffer_edited(self.buffer.clone(), cx); }); - if let Some(new_mtime) = buffer.read_with(cx, |buffer, _| { - buffer.file().and_then(|file| file.disk_state().mtime()) - }) { - tool.thread - .update(cx, |thread, _| { - thread - .file_read_times - .insert(abs_path.to_path_buf(), new_mtime); - }) - .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?; - } - - let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let new_snapshot = self.buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); let (new_text, unified_diff) = cx .background_spawn({ let new_snapshot = new_snapshot.clone(); @@ -688,7 +644,7 @@ impl EditSession { .await; let output = StreamingEditFileToolOutput::Success { - input_path: PathBuf::from(input.path), + input_path: input.path, new_text, old_text: old_text.clone(), diff: unified_diff, @@ -707,31 +663,13 @@ impl EditSession { StreamingEditFileMode::Write => { if let Some(content) = &partial.content { let events = self.parser.push_content(content); - Self::process_events( - &events, - &self.buffer, - &self.diff, - &mut self.pipeline, - &self.abs_path, - tool, - event_stream, - cx, - )?; + self.process_events(&events, tool, event_stream, cx)?; } } StreamingEditFileMode::Edit => { if let Some(edits) = partial.edits { let events = self.parser.push_edits(&edits); - Self::process_events( - &events, - &self.buffer, - &self.diff, - &mut self.pipeline, - &self.abs_path, - tool, - event_stream, - cx, - )?; + self.process_events(&events, tool, event_stream, cx)?; } } } @@ -739,52 +677,43 @@ impl EditSession { } fn process_events( + &mut self, events: &[ToolEditEvent], - buffer: &Entity, - diff: &Entity, - pipeline: &mut EditPipeline, - abs_path: &PathBuf, tool: &StreamingEditFileTool, event_stream: &ToolCallEventStream, cx: &mut AsyncApp, ) -> Result<(), StreamingEditFileToolOutput> { - let action_log = tool - .thread - .read_with(cx, |thread, _cx| thread.action_log().clone()) - .ok(); - for event in events { match event { ToolEditEvent::ContentChunk { chunk } => { let EditPipeline::Write { - original_snapshot, - content_written, streaming_diff, line_diff, - } = pipeline + content_written, + original_snapshot, + } = &mut self.pipeline else { continue; }; - let (buffer_id, insert_at) = buffer.read_with(cx, |buffer, _cx| { - let insert_at = if !*content_written && buffer.len() > 0 { - 0..buffer.len() - } else { - let len = buffer.len(); - len..len - }; - (buffer.remote_id(), insert_at) - }); + let (buffer_id, buffer_len) = self + .buffer + .read_with(cx, |buffer, _cx| (buffer.remote_id(), buffer.len())); + let edit_range = if *content_written { + buffer_len..buffer_len + } else { + 0..buffer_len + }; - let char_ops = streaming_diff.push_new(chunk); agent_edit_buffer( - buffer, - [(insert_at, chunk.as_str())], - action_log.as_ref(), + &self.buffer, + [(edit_range, chunk.as_str())], + &tool.action_log, cx, ); + let char_ops = streaming_diff.push_new(chunk); line_diff.push_char_operations(&char_ops, original_snapshot.as_rope()); - diff.update(cx, |diff, cx| { + self.diff.update(cx, |diff, cx| { diff.update_pending( line_diff.line_operations(), original_snapshot.clone(), @@ -794,7 +723,7 @@ impl EditSession { cx.update(|cx| { tool.set_agent_location( - buffer.downgrade(), + self.buffer.downgrade(), text::Anchor::max_for_buffer(buffer_id), cx, ); @@ -803,27 +732,27 @@ impl EditSession { } ToolEditEvent::OldTextChunk { - edit_index, - chunk, - done: false, + chunk, done: false, .. } => { - pipeline.ensure_resolving_old_text(*edit_index, buffer, cx); + self.pipeline.ensure_resolving_old_text(&self.buffer, cx); + let EditPipeline::Edit { current_edit } = &mut self.pipeline else { + continue; + }; - if let EditPipelineEntry::ResolvingOldText { matcher } = - &mut pipeline.edits()[*edit_index] + if let Some(EditPipelineEntry::ResolvingOldText { matcher }) = current_edit + && !chunk.is_empty() { - if !chunk.is_empty() { - if let Some(match_range) = matcher.push(chunk, None) { - let anchor_range = buffer.read_with(cx, |buffer, _cx| { - buffer.anchor_range_between(match_range.clone()) - }); - diff.update(cx, |diff, cx| diff.reveal_range(anchor_range, cx)); - - cx.update(|cx| { - let position = buffer.read(cx).anchor_before(match_range.end); - tool.set_agent_location(buffer.downgrade(), position, cx); - }); - } + if let Some(match_range) = matcher.push(chunk, None) { + let anchor_range = self.buffer.read_with(cx, |buffer, _cx| { + buffer.anchor_range_between(match_range.clone()) + }); + self.diff + .update(cx, |diff, cx| diff.reveal_range(anchor_range, cx)); + + cx.update(|cx| { + let position = self.buffer.read(cx).anchor_before(match_range.end); + tool.set_agent_location(self.buffer.downgrade(), position, cx); + }); } } } @@ -833,100 +762,81 @@ impl EditSession { chunk, done: true, } => { - pipeline.ensure_resolving_old_text(*edit_index, buffer, cx); + self.pipeline.ensure_resolving_old_text(&self.buffer, cx); + let EditPipeline::Edit { current_edit } = &mut self.pipeline else { + continue; + }; - let EditPipelineEntry::ResolvingOldText { matcher } = - &mut pipeline.edits()[*edit_index] - else { + let Some(EditPipelineEntry::ResolvingOldText { matcher }) = current_edit else { continue; }; if !chunk.is_empty() { matcher.push(chunk, None); } - let matches = matcher.finish(); - - if matches.is_empty() { - return Err(StreamingEditFileToolOutput::error(format!( - "Could not find matching text for edit at index {}. \ - The old_text did not match any content in the file. \ - Please read the file again to get the current content.", - edit_index, - ))); - } - if matches.len() > 1 { - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); - let lines = matches - .iter() - .map(|r| (snapshot.offset_to_point(r.start).row + 1).to_string()) - .collect::>() - .join(", "); - return Err(StreamingEditFileToolOutput::error(format!( - "Edit {} matched multiple locations in the file at lines: {}. \ - Please provide more context in old_text to uniquely \ - identify the location.", - edit_index, lines - ))); - } - - let range = matches.into_iter().next().expect("checked len above"); + let range = extract_match(matcher.finish(), &self.buffer, edit_index, cx)?; - let anchor_range = buffer + let anchor_range = self + .buffer .read_with(cx, |buffer, _cx| buffer.anchor_range_between(range.clone())); - diff.update(cx, |diff, cx| diff.reveal_range(anchor_range, cx)); + self.diff + .update(cx, |diff, cx| diff.reveal_range(anchor_range, cx)); - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let snapshot = self.buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); let line = snapshot.offset_to_point(range.start).row; event_stream.update_fields( - ToolCallUpdateFields::new() - .locations(vec![ToolCallLocation::new(abs_path).line(Some(line))]), + ToolCallUpdateFields::new().locations(vec![ + ToolCallLocation::new(&self.abs_path).line(Some(line)), + ]), ); - let EditPipelineEntry::ResolvingOldText { matcher } = - &pipeline.edits()[*edit_index] - else { - continue; - }; - let buffer_indent = - snapshot.line_indent_for_row(snapshot.offset_to_point(range.start).row); - let query_indent = query_first_line_indent(matcher.query_lines()); + let buffer_indent = snapshot.line_indent_for_row(line); + let query_indent = text::LineIndent::from_iter( + matcher + .query_lines() + .first() + .map(|s| s.as_str()) + .unwrap_or("") + .chars(), + ); let indent_delta = compute_indent_delta(buffer_indent, query_indent); let old_text_in_buffer = snapshot.text_for_range(range.clone()).collect::(); - let text_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.text_snapshot()); - pipeline.edits()[*edit_index] = EditPipelineEntry::StreamingNewText { + let text_snapshot = self + .buffer + .read_with(cx, |buffer, _cx| buffer.text_snapshot()); + *current_edit = Some(EditPipelineEntry::StreamingNewText { streaming_diff: StreamingDiff::new(old_text_in_buffer), line_diff: LineDiff::default(), edit_cursor: range.start, reindenter: Reindenter::new(indent_delta), original_snapshot: text_snapshot, - }; + }); cx.update(|cx| { - let position = buffer.read(cx).anchor_before(range.end); - tool.set_agent_location(buffer.downgrade(), position, cx); + let position = self.buffer.read(cx).anchor_before(range.end); + tool.set_agent_location(self.buffer.downgrade(), position, cx); }); } ToolEditEvent::NewTextChunk { - edit_index, - chunk, - done: false, + chunk, done: false, .. } => { - if *edit_index >= pipeline.edits().len() { + let EditPipeline::Edit { current_edit } = &mut self.pipeline else { continue; - } - let EditPipelineEntry::StreamingNewText { + }; + + let Some(EditPipelineEntry::StreamingNewText { streaming_diff, line_diff, edit_cursor, reindenter, original_snapshot, .. - } = &mut pipeline.edits()[*edit_index] + }) = current_edit else { continue; }; @@ -937,16 +847,16 @@ impl EditSession { } let char_ops = streaming_diff.push_new(&reindented); - Self::apply_char_operations( + apply_char_operations( &char_ops, - buffer, + &self.buffer, original_snapshot, edit_cursor, - action_log.as_ref(), + &tool.action_log, cx, ); line_diff.push_char_operations(&char_ops, original_snapshot.as_rope()); - diff.update(cx, |diff, cx| { + self.diff.update(cx, |diff, cx| { diff.update_pending( line_diff.line_operations(), original_snapshot.clone(), @@ -956,29 +866,23 @@ impl EditSession { let position = original_snapshot.anchor_before(*edit_cursor); cx.update(|cx| { - tool.set_agent_location(buffer.downgrade(), position, cx); + tool.set_agent_location(self.buffer.downgrade(), position, cx); }); } ToolEditEvent::NewTextChunk { - edit_index, - chunk, - done: true, + chunk, done: true, .. } => { - if *edit_index >= pipeline.edits().len() { + let EditPipeline::Edit { current_edit } = &mut self.pipeline else { continue; - } - - let EditPipelineEntry::StreamingNewText { + }; + let Some(EditPipelineEntry::StreamingNewText { mut streaming_diff, mut line_diff, mut edit_cursor, mut reindenter, original_snapshot, - } = std::mem::replace( - &mut pipeline.edits()[*edit_index], - EditPipelineEntry::Done, - ) + }) = current_edit.take() else { continue; }; @@ -989,16 +893,16 @@ impl EditSession { if !final_text.is_empty() { let char_ops = streaming_diff.push_new(&final_text); - Self::apply_char_operations( + apply_char_operations( &char_ops, - buffer, + &self.buffer, &original_snapshot, &mut edit_cursor, - action_log.as_ref(), + &tool.action_log, cx, ); line_diff.push_char_operations(&char_ops, original_snapshot.as_rope()); - diff.update(cx, |diff, cx| { + self.diff.update(cx, |diff, cx| { diff.update_pending( line_diff.line_operations(), original_snapshot.clone(), @@ -1008,17 +912,17 @@ impl EditSession { } let remaining_ops = streaming_diff.finish(); - Self::apply_char_operations( + apply_char_operations( &remaining_ops, - buffer, + &self.buffer, &original_snapshot, &mut edit_cursor, - action_log.as_ref(), + &tool.action_log, cx, ); line_diff.push_char_operations(&remaining_ops, original_snapshot.as_rope()); line_diff.finish(original_snapshot.as_rope()); - diff.update(cx, |diff, cx| { + self.diff.update(cx, |diff, cx| { diff.update_pending( line_diff.line_operations(), original_snapshot.clone(), @@ -1028,42 +932,73 @@ impl EditSession { let position = original_snapshot.anchor_before(edit_cursor); cx.update(|cx| { - tool.set_agent_location(buffer.downgrade(), position, cx); + tool.set_agent_location(self.buffer.downgrade(), position, cx); }); } } } Ok(()) } +} - fn apply_char_operations( - ops: &[CharOperation], - buffer: &Entity, - snapshot: &text::BufferSnapshot, - edit_cursor: &mut usize, - action_log: Option<&Entity>, - cx: &mut AsyncApp, - ) { - for op in ops { - match op { - CharOperation::Insert { text } => { - let anchor = snapshot.anchor_after(*edit_cursor); - agent_edit_buffer(&buffer, [(anchor..anchor, text.as_str())], action_log, cx); - } - CharOperation::Delete { bytes } => { - let delete_end = *edit_cursor + bytes; - let anchor_range = snapshot.anchor_range_around(*edit_cursor..delete_end); - agent_edit_buffer(&buffer, [(anchor_range, "")], action_log, cx); - *edit_cursor = delete_end; - } - CharOperation::Keep { bytes } => { - *edit_cursor += bytes; - } +fn apply_char_operations( + ops: &[CharOperation], + buffer: &Entity, + snapshot: &text::BufferSnapshot, + edit_cursor: &mut usize, + action_log: &Entity, + cx: &mut AsyncApp, +) { + for op in ops { + match op { + CharOperation::Insert { text } => { + let anchor = snapshot.anchor_after(*edit_cursor); + agent_edit_buffer(&buffer, [(anchor..anchor, text.as_str())], action_log, cx); + } + CharOperation::Delete { bytes } => { + let delete_end = *edit_cursor + bytes; + let anchor_range = snapshot.anchor_range_around(*edit_cursor..delete_end); + agent_edit_buffer(&buffer, [(anchor_range, "")], action_log, cx); + *edit_cursor = delete_end; + } + CharOperation::Keep { bytes } => { + *edit_cursor += bytes; } } } } +fn extract_match( + matches: Vec>, + buffer: &Entity, + edit_index: &usize, + cx: &mut AsyncApp, +) -> Result, StreamingEditFileToolOutput> { + match matches.len() { + 0 => Err(StreamingEditFileToolOutput::error(format!( + "Could not find matching text for edit at index {}. \ + The old_text did not match any content in the file. \ + Please read the file again to get the current content.", + edit_index, + ))), + 1 => Ok(matches.into_iter().next().unwrap()), + _ => { + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let lines = matches + .iter() + .map(|r| (snapshot.offset_to_point(r.start).row + 1).to_string()) + .collect::>() + .join(", "); + Err(StreamingEditFileToolOutput::error(format!( + "Edit {} matched multiple locations in the file at lines: {}. \ + Please provide more context in old_text to uniquely \ + identify the location.", + edit_index, lines + ))) + } + } +} + /// Edits a buffer and reports the edit to the action log in the same effect /// cycle. This ensures the action log's subscription handler sees the version /// already updated by `buffer_edited`, so it does not misattribute the agent's @@ -1071,7 +1006,7 @@ impl EditSession { fn agent_edit_buffer( buffer: &Entity, edits: I, - action_log: Option<&Entity>, + action_log: &Entity, cx: &mut AsyncApp, ) where I: IntoIterator, T)>, @@ -1082,9 +1017,7 @@ fn agent_edit_buffer( buffer.update(cx, |buffer, cx| { buffer.edit(edits, None, cx); }); - if let Some(action_log) = action_log { - action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); - } + action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); }); } @@ -1094,8 +1027,10 @@ fn ensure_buffer_saved( tool: &StreamingEditFileTool, cx: &mut AsyncApp, ) -> Result<(), StreamingEditFileToolOutput> { - let check_result = tool.thread.update(cx, |thread, cx| { - let last_read = thread.file_read_times.get(abs_path).copied(); + let last_read_mtime = tool + .action_log + .read_with(cx, |log, _| log.file_read_time(abs_path)); + let check_result = tool.thread.read_with(cx, |thread, cx| { let current = buffer .read(cx) .file() @@ -1103,12 +1038,10 @@ fn ensure_buffer_saved( let dirty = buffer.read(cx).is_dirty(); let has_save = thread.has_tool(SaveFileTool::NAME); let has_restore = thread.has_tool(RestoreFileFromDiskTool::NAME); - (last_read, current, dirty, has_save, has_restore) + (current, dirty, has_save, has_restore) }); - let Ok((last_read_mtime, current_mtime, is_dirty, has_save_tool, has_restore_tool)) = - check_result - else { + let Ok((current_mtime, is_dirty, has_save_tool, has_restore_tool)) = check_result else { return Ok(()); }; @@ -1225,42 +1158,17 @@ mod tests { #[gpui::test] async fn test_streaming_edit_create_file(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree("/root", json!({"dir": {}})).await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"dir": {}})).await; let result = cx .update(|cx| { - let input = StreamingEditFileToolInput { - display_description: "Create new file".into(), - path: "root/dir/new_file.txt".into(), - mode: StreamingEditFileMode::Write, - content: Some("Hello, World!".into()), - edits: None, - }; - Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )) - .run( - ToolInput::resolved(input), + tool.clone().run( + ToolInput::resolved(StreamingEditFileToolInput { + display_description: "Create new file".into(), + path: "root/dir/new_file.txt".into(), + mode: StreamingEditFileMode::Write, + content: Some("Hello, World!".into()), + edits: None, + }), ToolCallEventStream::test().0, cx, ) @@ -1276,43 +1184,18 @@ mod tests { #[gpui::test] async fn test_streaming_edit_overwrite_file(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree("/root", json!({"file.txt": "old content"})) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, _project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "old content"})).await; let result = cx .update(|cx| { - let input = StreamingEditFileToolInput { - display_description: "Overwrite file".into(), - path: "root/file.txt".into(), - mode: StreamingEditFileMode::Write, - content: Some("new content".into()), - edits: None, - }; - Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )) - .run( - ToolInput::resolved(input), + tool.clone().run( + ToolInput::resolved(StreamingEditFileToolInput { + display_description: "Overwrite file".into(), + path: "root/file.txt".into(), + mode: StreamingEditFileMode::Write, + content: Some("new content".into()), + edits: None, + }), ToolCallEventStream::test().0, cx, ) @@ -1331,51 +1214,21 @@ mod tests { #[gpui::test] async fn test_streaming_edit_granular_edits(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "line 1\nline 2\nline 3\n" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, _project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await; let result = cx .update(|cx| { - let input = StreamingEditFileToolInput { - display_description: "Edit lines".into(), - path: "root/file.txt".into(), - mode: StreamingEditFileMode::Edit, - content: None, - edits: Some(vec![Edit { - old_text: "line 2".into(), - new_text: "modified line 2".into(), - }]), - }; - Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )) - .run( - ToolInput::resolved(input), + tool.clone().run( + ToolInput::resolved(StreamingEditFileToolInput { + display_description: "Edit lines".into(), + path: "root/file.txt".into(), + mode: StreamingEditFileMode::Edit, + content: None, + edits: Some(vec![Edit { + old_text: "line 2".into(), + new_text: "modified line 2".into(), + }]), + }), ToolCallEventStream::test().0, cx, ) @@ -1390,57 +1243,30 @@ mod tests { #[gpui::test] async fn test_streaming_edit_multiple_edits(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n" - }), + let (tool, _project, _action_log, _fs, _thread) = setup_test( + cx, + json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}), ) .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let result = cx .update(|cx| { - let input = StreamingEditFileToolInput { - display_description: "Edit multiple lines".into(), - path: "root/file.txt".into(), - mode: StreamingEditFileMode::Edit, - content: None, - edits: Some(vec![ - Edit { - old_text: "line 5".into(), - new_text: "modified line 5".into(), - }, - Edit { - old_text: "line 1".into(), - new_text: "modified line 1".into(), - }, - ]), - }; - Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )) - .run( - ToolInput::resolved(input), + tool.clone().run( + ToolInput::resolved(StreamingEditFileToolInput { + display_description: "Edit multiple lines".into(), + path: "root/file.txt".into(), + mode: StreamingEditFileMode::Edit, + content: None, + edits: Some(vec![ + Edit { + old_text: "line 5".into(), + new_text: "modified line 5".into(), + }, + Edit { + old_text: "line 1".into(), + new_text: "modified line 1".into(), + }, + ]), + }), ToolCallEventStream::test().0, cx, ) @@ -1458,57 +1284,30 @@ mod tests { #[gpui::test] async fn test_streaming_edit_adjacent_edits(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n" - }), + let (tool, _project, _action_log, _fs, _thread) = setup_test( + cx, + json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}), ) .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let result = cx .update(|cx| { - let input = StreamingEditFileToolInput { - display_description: "Edit adjacent lines".into(), - path: "root/file.txt".into(), - mode: StreamingEditFileMode::Edit, - content: None, - edits: Some(vec![ - Edit { - old_text: "line 2".into(), - new_text: "modified line 2".into(), - }, - Edit { - old_text: "line 3".into(), - new_text: "modified line 3".into(), - }, - ]), - }; - Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )) - .run( - ToolInput::resolved(input), + tool.clone().run( + ToolInput::resolved(StreamingEditFileToolInput { + display_description: "Edit adjacent lines".into(), + path: "root/file.txt".into(), + mode: StreamingEditFileMode::Edit, + content: None, + edits: Some(vec![ + Edit { + old_text: "line 2".into(), + new_text: "modified line 2".into(), + }, + Edit { + old_text: "line 3".into(), + new_text: "modified line 3".into(), + }, + ]), + }), ToolCallEventStream::test().0, cx, ) @@ -1526,57 +1325,30 @@ mod tests { #[gpui::test] async fn test_streaming_edit_ascending_order_edits(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n" - }), + let (tool, _project, _action_log, _fs, _thread) = setup_test( + cx, + json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}), ) .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let result = cx .update(|cx| { - let input = StreamingEditFileToolInput { - display_description: "Edit multiple lines in ascending order".into(), - path: "root/file.txt".into(), - mode: StreamingEditFileMode::Edit, - content: None, - edits: Some(vec![ - Edit { - old_text: "line 1".into(), - new_text: "modified line 1".into(), - }, - Edit { - old_text: "line 5".into(), - new_text: "modified line 5".into(), - }, - ]), - }; - Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )) - .run( - ToolInput::resolved(input), + tool.clone().run( + ToolInput::resolved(StreamingEditFileToolInput { + display_description: "Edit multiple lines in ascending order".into(), + path: "root/file.txt".into(), + mode: StreamingEditFileMode::Edit, + content: None, + edits: Some(vec![ + Edit { + old_text: "line 1".into(), + new_text: "modified line 1".into(), + }, + Edit { + old_text: "line 5".into(), + new_text: "modified line 5".into(), + }, + ]), + }), ToolCallEventStream::test().0, cx, ) @@ -1594,45 +1366,20 @@ mod tests { #[gpui::test] async fn test_streaming_edit_nonexistent_file(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree("/root", json!({})).await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({})).await; let result = cx .update(|cx| { - let input = StreamingEditFileToolInput { - display_description: "Some edit".into(), - path: "root/nonexistent_file.txt".into(), - mode: StreamingEditFileMode::Edit, - content: None, - edits: Some(vec![Edit { - old_text: "foo".into(), - new_text: "bar".into(), - }]), - }; - Arc::new(StreamingEditFileTool::new( - project, - thread.downgrade(), - language_registry, - )) - .run( - ToolInput::resolved(input), + tool.clone().run( + ToolInput::resolved(StreamingEditFileToolInput { + display_description: "Some edit".into(), + path: "root/nonexistent_file.txt".into(), + mode: StreamingEditFileMode::Edit, + content: None, + edits: Some(vec![Edit { + old_text: "foo".into(), + new_text: "bar".into(), + }]), + }), ToolCallEventStream::test().0, cx, ) @@ -1647,46 +1394,21 @@ mod tests { #[gpui::test] async fn test_streaming_edit_failed_match(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree("/root", json!({"file.txt": "hello world"})) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, _project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "hello world"})).await; let result = cx .update(|cx| { - let input = StreamingEditFileToolInput { - display_description: "Edit file".into(), - path: "root/file.txt".into(), - mode: StreamingEditFileMode::Edit, - content: None, - edits: Some(vec![Edit { - old_text: "nonexistent text that is not in the file".into(), - new_text: "replacement".into(), - }]), - }; - Arc::new(StreamingEditFileTool::new( - project, - thread.downgrade(), - language_registry, - )) - .run( - ToolInput::resolved(input), + tool.clone().run( + ToolInput::resolved(StreamingEditFileToolInput { + display_description: "Edit file".into(), + path: "root/file.txt".into(), + mode: StreamingEditFileMode::Edit, + content: None, + edits: Some(vec![Edit { + old_text: "nonexistent text that is not in the file".into(), + new_text: "replacement".into(), + }]), + }), ToolCallEventStream::test().0, cx, ) @@ -1704,42 +1426,11 @@ mod tests { #[gpui::test] async fn test_streaming_early_buffer_open(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "line 1\nline 2\nline 3\n" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, _project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await; let (sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Send partials simulating LLM streaming: description first, then path, then mode sender.send_partial(json!({"display_description": "Edit lines"})); @@ -1776,42 +1467,11 @@ mod tests { #[gpui::test] async fn test_streaming_path_completeness_heuristic(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "hello world" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, _project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "hello world"})).await; let (sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Send partial with path but NO mode — path should NOT be treated as complete sender.send_partial(json!({ @@ -1845,43 +1505,12 @@ mod tests { #[gpui::test] async fn test_streaming_cancellation_during_partials(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "hello world" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, _project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "hello world"})).await; let (sender, input) = ToolInput::::test(); let (event_stream, _receiver, mut cancellation_tx) = ToolCallEventStream::test_with_cancellation(); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Send a partial sender.send_partial(json!({"display_description": "Edit"})); @@ -1907,42 +1536,14 @@ mod tests { #[gpui::test] async fn test_streaming_edit_with_multiple_partials(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n" - }), + let (tool, _project, _action_log, _fs, _thread) = setup_test( + cx, + json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}), ) .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let (sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Simulate fine-grained streaming of the JSON sender.send_partial(json!({"display_description": "Edit multiple"})); @@ -2003,36 +1604,10 @@ mod tests { #[gpui::test] async fn test_streaming_create_file_with_partials(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree("/root", json!({"dir": {}})).await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"dir": {}})).await; let (sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Stream partials for create mode sender.send_partial(json!({"display_description": "Create new file"})); @@ -2070,42 +1645,11 @@ mod tests { #[gpui::test] async fn test_streaming_no_partials_direct_final(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "line 1\nline 2\nline 3\n" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, _project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await; let (sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Send final immediately with no partials (simulates non-streaming path) sender.send_final(json!({ @@ -2124,42 +1668,14 @@ mod tests { #[gpui::test] async fn test_streaming_incremental_edit_application(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n" - }), + let (tool, project, _action_log, _fs, _thread) = setup_test( + cx, + json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}), ) .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let (sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Stream description, path, mode sender.send_partial(json!({"display_description": "Edit multiple lines"})); @@ -2253,42 +1769,11 @@ mod tests { #[gpui::test] async fn test_streaming_incremental_three_edits(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "aaa\nbbb\nccc\nddd\neee\n" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "aaa\nbbb\nccc\nddd\neee\n"})).await; let (sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Setup: description + path + mode sender.send_partial(json!({ @@ -2373,43 +1858,12 @@ mod tests { } #[gpui::test] - async fn test_streaming_edit_failure_mid_stream(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "line 1\nline 2\nline 3\n" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + async fn test_streaming_edit_failure_mid_stream(cx: &mut TestAppContext) { + let (tool, project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await; let (sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Setup sender.send_partial(json!({ @@ -2486,42 +1940,11 @@ mod tests { #[gpui::test] async fn test_streaming_single_edit_no_incremental(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "hello world\n" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "hello world\n"})).await; let (sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Setup + single edit that stays in-progress (no second edit to prove completion) sender.send_partial(json!({ @@ -2565,44 +1988,12 @@ mod tests { #[gpui::test] async fn test_streaming_input_partials_then_final(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "line 1\nline 2\nline 3\n" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, _project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await; let (sender, input): (ToolInputSender, ToolInput) = ToolInput::test(); - let (event_stream, _event_rx) = ToolCallEventStream::test(); - let task = cx.update(|cx| { - Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )) - .run(input, event_stream, cx) - }); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Send progressively more complete partial snapshots, as the LLM would sender.send_partial(json!({ @@ -2642,44 +2033,12 @@ mod tests { #[gpui::test] async fn test_streaming_input_sender_dropped_before_final(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "hello world\n" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, _project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "hello world\n"})).await; let (sender, input): (ToolInputSender, ToolInput) = ToolInput::test(); - let (event_stream, _event_rx) = ToolCallEventStream::test(); - let task = cx.update(|cx| { - Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )) - .run(input, event_stream, cx) - }); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Send a partial then drop the sender without sending final sender.send_partial(json!({ @@ -2698,41 +2057,14 @@ mod tests { #[gpui::test] async fn test_streaming_input_recv_drains_partials(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree("/root", json!({"dir": {}})).await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"dir": {}})).await; // Create a channel and send multiple partials before a final, then use // ToolInput::resolved-style immediate delivery to confirm recv() works // when partials are already buffered. let (sender, input): (ToolInputSender, ToolInput) = ToolInput::test(); - let (event_stream, _event_rx) = ToolCallEventStream::test(); - let task = cx.update(|cx| { - Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )) - .run(input, event_stream, cx) - }); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Buffer several partials before sending the final sender.send_partial(json!({"display_description": "Create"})); @@ -2831,7 +2163,7 @@ mod tests { .await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - cx.update(|cx| resolve_path(mode.clone(), &PathBuf::from(path), &project, cx)) + cx.update(|cx| resolve_path(*mode, &PathBuf::from(path), &project, cx)) } #[track_caller] @@ -2846,8 +2178,8 @@ mod tests { let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/root", json!({"src": {}})).await; - - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let (tool, project, action_log, fs, thread) = + setup_test_with_fs(cx, fs, &[path!("/root").as_ref()]).await; let rust_language = Arc::new(language::Language::new( language::LanguageConfig { @@ -2896,9 +2228,10 @@ mod tests { project.register_buffer_with_language_servers(&buffer, cx) }); - const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n"; - const FORMATTED_CONTENT: &str = - "This file was formatted by the fake formatter in the test.\n"; + const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\ +"; + const FORMATTED_CONTENT: &str = "This file was formatted by the fake formatter in the test.\ +"; // Get the fake language server and set up formatting handler let fake_language_server = fake_language_servers.next().await.unwrap(); @@ -2911,20 +2244,6 @@ mod tests { } }); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model.clone()), - cx, - ) - }); - // Test with format_on_save enabled cx.update(|cx| { SettingsStore::update_global(cx, |store, cx| { @@ -2940,13 +2259,7 @@ mod tests { let (sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry.clone(), - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); sender.send_partial(json!({ "display_description": "Create main function", @@ -2997,13 +2310,14 @@ mod tests { let (sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - let tool = Arc::new(StreamingEditFileTool::new( + let tool2 = Arc::new(StreamingEditFileTool::new( project.clone(), thread.downgrade(), + action_log.clone(), language_registry, )); - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool2.run(input, event_stream, cx)); sender.send_partial(json!({ "display_description": "Update main function", @@ -3038,7 +2352,6 @@ mod tests { let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/root", json!({"src": {}})).await; - fs.save( path!("/root/src/main.rs").as_ref(), &"initial content".into(), @@ -3046,22 +2359,9 @@ mod tests { ) .await .unwrap(); - - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model.clone()), - cx, - ) - }); + let (tool, project, action_log, fs, thread) = + setup_test_with_fs(cx, fs, &[path!("/root").as_ref()]).await; + let language_registry = project.read_with(cx, |p, _cx| p.languages().clone()); // Test with remove_trailing_whitespace_on_save enabled cx.update(|cx| { @@ -3081,20 +2381,14 @@ mod tests { let result = cx .update(|cx| { - let input = StreamingEditFileToolInput { - display_description: "Create main function".into(), - path: "root/src/main.rs".into(), - mode: StreamingEditFileMode::Write, - content: Some(CONTENT_WITH_TRAILING_WHITESPACE.into()), - edits: None, - }; - Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry.clone(), - )) - .run( - ToolInput::resolved(input), + tool.clone().run( + ToolInput::resolved(StreamingEditFileToolInput { + display_description: "Create main function".into(), + path: "root/src/main.rs".into(), + mode: StreamingEditFileMode::Write, + content: Some(CONTENT_WITH_TRAILING_WHITESPACE.into()), + edits: None, + }), ToolCallEventStream::test().0, cx, ) @@ -3126,22 +2420,23 @@ mod tests { }); }); + let tool2 = Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + action_log.clone(), + language_registry, + )); + let result = cx .update(|cx| { - let input = StreamingEditFileToolInput { - display_description: "Update main function".into(), - path: "root/src/main.rs".into(), - mode: StreamingEditFileMode::Write, - content: Some(CONTENT_WITH_TRAILING_WHITESPACE.into()), - edits: None, - }; - Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )) - .run( - ToolInput::resolved(input), + tool2.run( + ToolInput::resolved(StreamingEditFileToolInput { + display_description: "Update main function".into(), + path: "root/src/main.rs".into(), + mode: StreamingEditFileMode::Write, + content: Some(CONTENT_WITH_TRAILING_WHITESPACE.into()), + edits: None, + }), ToolCallEventStream::test().0, cx, ) @@ -3161,29 +2456,7 @@ mod tests { #[gpui::test] async fn test_streaming_authorize(cx: &mut TestAppContext) { - init_test(cx); - let fs = project::FakeFs::new(cx.executor()); - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model.clone()), - cx, - ) - }); - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - fs.insert_tree("/root", json!({})).await; + let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({})).await; // Test 1: Path with .zed component should require confirmation let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); @@ -3304,27 +2577,8 @@ mod tests { fs.insert_tree("/outside", json!({})).await; fs.insert_symlink("/root/link", PathBuf::from("/outside")) .await; - - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let tool = Arc::new(StreamingEditFileTool::new( - project, - thread.downgrade(), - language_registry, - )); + let (tool, _project, _action_log, _fs, _thread) = + setup_test_with_fs(cx, fs, &[path!("/root").as_ref()]).await; cx.update(|cx| { let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); @@ -3378,38 +2632,17 @@ mod tests { path!("/outside"), json!({ "config.txt": "old content" - }), - ) - .await; - fs.create_symlink( - path!("/root/link_to_external").as_ref(), - PathBuf::from("/outside"), - ) - .await - .unwrap(); - - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - cx.executor().run_until_parked(); - - let language_registry = project.read_with(cx, |project, _| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); + }), + ) + .await; + fs.create_symlink( + path!("/root/link_to_external").as_ref(), + PathBuf::from("/outside"), + ) + .await + .unwrap(); + let (tool, _project, _action_log, _fs, _thread) = + setup_test_with_fs(cx, fs, &[path!("/root").as_ref()]).await; let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); let _authorize_task = cx.update(|cx| { @@ -3454,29 +2687,8 @@ mod tests { ) .await .unwrap(); - - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - cx.executor().run_until_parked(); - - let language_registry = project.read_with(cx, |project, _| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); + let (tool, _project, _action_log, _fs, _thread) = + setup_test_with_fs(cx, fs, &[path!("/root").as_ref()]).await; let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); let authorize_task = cx.update(|cx| { @@ -3531,29 +2743,8 @@ mod tests { ) .await .unwrap(); - - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - cx.executor().run_until_parked(); - - let language_registry = project.read_with(cx, |project, _| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); + let (tool, _project, _action_log, _fs, _thread) = + setup_test_with_fs(cx, fs, &[path!("/root").as_ref()]).await; let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); let result = cx @@ -3582,26 +2773,8 @@ mod tests { init_test(cx); let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/project", json!({})).await; - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model.clone()), - cx, - ) - }); - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); + let (tool, _project, _action_log, _fs, _thread) = + setup_test_with_fs(cx, fs, &[path!("/project").as_ref()]).await; let test_cases = vec![ ( @@ -3644,7 +2817,6 @@ mod tests { async fn test_streaming_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) { init_test(cx); let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( "/workspace/frontend", json!({ @@ -3672,36 +2844,16 @@ mod tests { }), ) .await; - - let project = Project::test( - fs.clone(), - [ + let (tool, _project, _action_log, _fs, _thread) = setup_test_with_fs( + cx, + fs, + &[ path!("/workspace/frontend").as_ref(), path!("/workspace/backend").as_ref(), path!("/workspace/shared").as_ref(), ], - cx, ) .await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry.clone(), - Templates::new(), - Some(model.clone()), - cx, - ) - }); - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); let test_cases = vec![ ("frontend/src/main.js", false, "File in first worktree"), @@ -3756,26 +2908,8 @@ mod tests { }), ) .await; - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry.clone(), - Templates::new(), - Some(model.clone()), - cx, - ) - }); - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); + let (tool, _project, _action_log, _fs, _thread) = + setup_test_with_fs(cx, fs, &[path!("/project").as_ref()]).await; let test_cases = vec![ ("", false, "Empty path is treated as project root"), @@ -3831,26 +2965,8 @@ mod tests { }), ) .await; - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry.clone(), - Templates::new(), - Some(model.clone()), - cx, - ) - }); - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); + let (tool, _project, _action_log, _fs, _thread) = + setup_test_with_fs(cx, fs, &[path!("/project").as_ref()]).await; let modes = vec![StreamingEditFileMode::Edit, StreamingEditFileMode::Write]; @@ -3901,26 +3017,9 @@ mod tests { async fn test_streaming_initial_title_with_partial_input(cx: &mut TestAppContext) { init_test(cx); let fs = project::FakeFs::new(cx.executor()); - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model.clone()), - cx, - ) - }); - let tool = Arc::new(StreamingEditFileTool::new( - project, - thread.downgrade(), - language_registry, - )); + fs.insert_tree("/project", json!({})).await; + let (tool, _project, _action_log, _fs, _thread) = + setup_test_with_fs(cx, fs, &[path!("/project").as_ref()]).await; cx.update(|cx| { assert_eq!( @@ -3975,33 +3074,15 @@ mod tests { init_test(cx); let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/", json!({"main.rs": ""})).await; - - let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await; - let languages = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry.clone(), - Templates::new(), - Some(model.clone()), - cx, - ) - }); + let (tool, project, action_log, _fs, thread) = + setup_test_with_fs(cx, fs, &[path!("/").as_ref()]).await; + let language_registry = project.read_with(cx, |p, _cx| p.languages().clone()); // Ensure the diff is finalized after the edit completes. { - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - languages.clone(), - )); let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); let edit = cx.update(|cx| { - tool.run( + tool.clone().run( ToolInput::resolved(StreamingEditFileToolInput { display_description: "Edit file".into(), path: path!("/main.rs").into(), @@ -4026,7 +3107,8 @@ mod tests { let tool = Arc::new(StreamingEditFileTool::new( project.clone(), thread.downgrade(), - languages.clone(), + action_log, + language_registry, )); let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); let edit = cx.update(|cx| { @@ -4053,42 +3135,12 @@ mod tests { #[gpui::test] async fn test_streaming_consecutive_edits_work(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "test.txt": "original content" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model.clone()), - cx, - ) - }); - let languages = project.read_with(cx, |project, _| project.languages().clone()); - let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); - + let (tool, project, action_log, _fs, _thread) = + setup_test(cx, json!({"test.txt": "original content"})).await; let read_tool = Arc::new(crate::ReadFileTool::new( - thread.downgrade(), - project.clone(), - action_log, - )); - let edit_tool = Arc::new(StreamingEditFileTool::new( project.clone(), - thread.downgrade(), - languages, + action_log.clone(), + true, )); // Read the file first @@ -4109,7 +3161,7 @@ mod tests { // First edit should work let edit_result = cx .update(|cx| { - edit_tool.clone().run( + tool.clone().run( ToolInput::resolved(StreamingEditFileToolInput { display_description: "First edit".into(), path: "root/test.txt".into(), @@ -4134,7 +3186,7 @@ mod tests { // Second edit should also work because the edit updated the recorded read time let edit_result = cx .update(|cx| { - edit_tool.clone().run( + tool.clone().run( ToolInput::resolved(StreamingEditFileToolInput { display_description: "Second edit".into(), path: "root/test.txt".into(), @@ -4159,42 +3211,12 @@ mod tests { #[gpui::test] async fn test_streaming_external_modification_detected(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "test.txt": "original content" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model.clone()), - cx, - ) - }); - let languages = project.read_with(cx, |project, _| project.languages().clone()); - let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); - + let (tool, project, action_log, fs, _thread) = + setup_test(cx, json!({"test.txt": "original content"})).await; let read_tool = Arc::new(crate::ReadFileTool::new( - thread.downgrade(), - project.clone(), - action_log, - )); - let edit_tool = Arc::new(StreamingEditFileTool::new( project.clone(), - thread.downgrade(), - languages, + action_log.clone(), + true, )); // Read the file first @@ -4243,7 +3265,7 @@ mod tests { // Try to edit - should fail because file was modified externally let result = cx .update(|cx| { - edit_tool.clone().run( + tool.clone().run( ToolInput::resolved(StreamingEditFileToolInput { display_description: "Edit after external change".into(), path: "root/test.txt".into(), @@ -4262,52 +3284,22 @@ mod tests { let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else { panic!("expected error"); - }; - assert!( - error.contains("has been modified since you last read it"), - "Error should mention file modification, got: {}", - error - ); - } - - #[gpui::test] - async fn test_streaming_dirty_buffer_detected(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "test.txt": "original content" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model.clone()), - cx, - ) - }); - let languages = project.read_with(cx, |project, _| project.languages().clone()); - let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); + }; + assert!( + error.contains("has been modified since you last read it"), + "Error should mention file modification, got: {}", + error + ); + } + #[gpui::test] + async fn test_streaming_dirty_buffer_detected(cx: &mut TestAppContext) { + let (tool, project, action_log, _fs, _thread) = + setup_test(cx, json!({"test.txt": "original content"})).await; let read_tool = Arc::new(crate::ReadFileTool::new( - thread.downgrade(), - project.clone(), - action_log, - )); - let edit_tool = Arc::new(StreamingEditFileTool::new( project.clone(), - thread.downgrade(), - languages, + action_log.clone(), + true, )); // Read the file first @@ -4347,7 +3339,7 @@ mod tests { // Try to edit - should fail because buffer has unsaved changes let result = cx .update(|cx| { - edit_tool.clone().run( + tool.clone().run( ToolInput::resolved(StreamingEditFileToolInput { display_description: "Edit with dirty buffer".into(), path: "root/test.txt".into(), @@ -4386,46 +3378,15 @@ mod tests { #[gpui::test] async fn test_streaming_overlapping_edits_resolved_sequentially(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); // Edit 1's replacement introduces text that contains edit 2's // old_text as a substring. Because edits resolve sequentially // against the current buffer, edit 2 finds a unique match in // the modified buffer and succeeds. - fs.insert_tree( - "/root", - json!({ - "file.txt": "aaa\nbbb\nccc\nddd\neee\n" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, _project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "aaa\nbbb\nccc\nddd\neee\n"})).await; let (sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Setup: resolve the buffer sender.send_partial(json!({ @@ -4473,36 +3434,10 @@ mod tests { #[gpui::test] async fn test_streaming_create_content_streamed(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree("/root", json!({"dir": {}})).await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, project, _action_log, _fs, _thread) = setup_test(cx, json!({"dir": {}})).await; let (sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Transition to BufferResolved sender.send_partial(json!({ @@ -4570,42 +3505,14 @@ mod tests { #[gpui::test] async fn test_streaming_overwrite_diff_revealed_during_streaming(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "old line 1\nold line 2\nold line 3\n" - }), + let (tool, _project, _action_log, _fs, _thread) = setup_test( + cx, + json!({"file.txt": "old line 1\nold line 2\nold line 3\n"}), ) .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let (sender, input) = ToolInput::::test(); let (event_stream, mut receiver) = ToolCallEventStream::test(); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Transition to BufferResolved sender.send_partial(json!({ @@ -4663,42 +3570,14 @@ mod tests { #[gpui::test] async fn test_streaming_overwrite_content_streamed(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "old line 1\nold line 2\nold line 3\n" - }), + let (tool, project, _action_log, _fs, _thread) = setup_test( + cx, + json!({"file.txt": "old line 1\nold line 2\nold line 3\n"}), ) .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let (sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Transition to BufferResolved sender.send_partial(json!({ @@ -4762,42 +3641,11 @@ mod tests { #[gpui::test] async fn test_streaming_edit_json_fixer_escape_corruption(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "hello\nworld\nfoo\n" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, _project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "hello\nworld\nfoo\n"})).await; let (sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); sender.send_partial(json!({ "display_description": "Edit", @@ -4847,47 +3695,17 @@ mod tests { // reports changed buffers so that the Accept All / Reject All review UI appears. #[gpui::test] async fn test_streaming_edit_file_tool_registers_changed_buffers(cx: &mut TestAppContext) { - init_test(cx); + let (tool, _project, action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await; cx.update(|cx| { let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); settings.tool_permissions.default = settings::ToolPermissionMode::Allow; agent_settings::AgentSettings::override_global(settings, cx); }); - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/root"), - json!({ - "file.txt": "line 1\nline 2\nline 3\n" - }), - ) - .await; - - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - None, - cx, - ) - }); - let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); let (event_stream, _rx) = ToolCallEventStream::test(); - let task = cx.update(|cx| { - tool.run( + tool.clone().run( ToolInput::resolved(StreamingEditFileToolInput { display_description: "Edit lines".to_string(), path: "root/file.txt".into(), @@ -4911,7 +3729,7 @@ mod tests { let changed = action_log.read_with(cx, |log, cx| log.changed_buffers(cx)); assert!( !changed.is_empty(), - "action_log.changed_buffers() should be non-empty after streaming edit, \ + "action_log.changed_buffers() should be non-empty after streaming edit, but no changed buffers were found \u{2014} Accept All / Reject All will not appear" ); } @@ -4921,47 +3739,17 @@ mod tests { async fn test_streaming_edit_file_tool_write_mode_registers_changed_buffers( cx: &mut TestAppContext, ) { - init_test(cx); + let (tool, _project, action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "original content"})).await; cx.update(|cx| { let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); settings.tool_permissions.default = settings::ToolPermissionMode::Allow; agent_settings::AgentSettings::override_global(settings, cx); }); - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/root"), - json!({ - "file.txt": "original content" - }), - ) - .await; - - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - None, - cx, - ) - }); - let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); let (event_stream, _rx) = ToolCallEventStream::test(); - let task = cx.update(|cx| { - tool.run( + tool.clone().run( ToolInput::resolved(StreamingEditFileToolInput { display_description: "Overwrite file".to_string(), path: "root/file.txt".into(), @@ -4987,6 +3775,58 @@ mod tests { ); } + async fn setup_test_with_fs( + cx: &mut TestAppContext, + fs: Arc, + worktree_paths: &[&std::path::Path], + ) -> ( + Arc, + Entity, + Entity, + Arc, + Entity, + ) { + let project = Project::test(fs.clone(), worktree_paths.iter().copied(), cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry, + Templates::new(), + Some(model), + cx, + ) + }); + let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); + let tool = Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + action_log.clone(), + language_registry, + )); + (tool, project, action_log, fs, thread) + } + + async fn setup_test( + cx: &mut TestAppContext, + initial_tree: serde_json::Value, + ) -> ( + Arc, + Entity, + Entity, + Arc, + Entity, + ) { + init_test(cx); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/root", initial_tree).await; + setup_test_with_fs(cx, fs, &[path!("/root").as_ref()]).await + } + fn init_test(cx: &mut TestAppContext) { cx.update(|cx| { let settings_store = SettingsStore::test(cx); diff --git a/crates/agent_ui/Cargo.toml b/crates/agent_ui/Cargo.toml index 2a31781054fd29b30a3c8119e87491edbfb1e658..3e46e14b53c46a2aec3ac9552246a10ffc2aeee9 100644 --- a/crates/agent_ui/Cargo.toml +++ b/crates/agent_ui/Cargo.toml @@ -58,6 +58,7 @@ feature_flags.workspace = true file_icons.workspace = true fs.workspace = true futures.workspace = true +git.workspace = true fuzzy.workspace = true gpui.workspace = true gpui_tokio.workspace = true diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 7097e5be156eb33382a1a0f47c1b4256c84ce9b1..0f1cd3ebf0fdf1df939ccc6f2b0d1a40545bf082 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -1,6 +1,6 @@ use std::{ ops::Range, - path::Path, + path::{Path, PathBuf}, rc::Rc, sync::{ Arc, @@ -22,15 +22,18 @@ use project::{ use serde::{Deserialize, Serialize}; use settings::{LanguageModelProviderSetting, LanguageModelSelection}; +use feature_flags::{AgentGitWorktreesFeatureFlag, AgentV2FeatureFlag, FeatureFlagAppExt as _}; use zed_actions::agent::{OpenClaudeAgentOnboardingModal, ReauthenticateAgent, ReviewBranchDiff}; +use crate::ManageProfiles; use crate::ui::{AcpOnboardingModal, ClaudeCodeOnboardingModal}; use crate::{ AddContextServer, AgentDiffPane, ConnectionView, CopyThreadToClipboard, Follow, InlineAssistant, LoadThreadFromClipboard, NewTextThread, NewThread, OpenActiveThreadAsMarkdown, - OpenAgentDiff, OpenHistory, ResetTrialEndUpsell, ResetTrialUpsell, ToggleNavigationMenu, - ToggleNewThreadMenu, ToggleOptionsMenu, + OpenAgentDiff, OpenHistory, ResetTrialEndUpsell, ResetTrialUpsell, StartThreadIn, + ToggleNavigationMenu, ToggleNewThreadMenu, ToggleOptionsMenu, agent_configuration::{AgentConfiguration, AssistantConfigurationEvent}, + connection_view::{AcpThreadViewEvent, ThreadView}, slash_command::SlashCommandCompletionProvider, text_thread_editor::{AgentPanelDelegate, TextThreadEditor, make_lsp_adapter_delegate}, ui::EndTrialUpsell, @@ -42,7 +45,6 @@ use crate::{ ExpandMessageEditor, ThreadHistory, ThreadHistoryEvent, text_thread_history::{TextThreadHistory, TextThreadHistoryEvent}, }; -use crate::{ManageProfiles, connection_view::ThreadView}; use agent_settings::AgentSettings; use ai_onboarding::AgentPanelOnboarding; use anyhow::{Result, anyhow}; @@ -54,6 +56,7 @@ use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer}; use extension::ExtensionEvents; use extension_host::ExtensionStore; use fs::Fs; +use git::repository::validate_worktree_directory; use gpui::{ Action, Animation, AnimationExt, AnyElement, App, AsyncWindowContext, ClipboardItem, Corner, DismissEvent, Entity, EventEmitter, ExternalPaths, FocusHandle, Focusable, KeyContext, Pixels, @@ -61,6 +64,7 @@ use gpui::{ }; use language::LanguageRegistry; use language_model::{ConfigurationError, LanguageModelRegistry}; +use project::project_settings::ProjectSettings; use project::{Project, ProjectPath, Worktree}; use prompt_store::{PromptBuilder, PromptStore, UserPromptId}; use rules_library::{RulesLibrary, open_rules_library}; @@ -68,8 +72,8 @@ use search::{BufferSearchBar, buffer_search}; use settings::{Settings, update_settings_file}; use theme::ThemeSettings; use ui::{ - Callout, ContextMenu, ContextMenuEntry, KeyBinding, PopoverMenu, PopoverMenuHandle, Tab, - Tooltip, prelude::*, utils::WithRemSize, + Button, Callout, ContextMenu, ContextMenuEntry, DocumentationSide, KeyBinding, PopoverMenu, + PopoverMenuHandle, SpinnerLabel, Tab, Tooltip, prelude::*, utils::WithRemSize, }; use util::ResultExt as _; use workspace::{ @@ -123,6 +127,8 @@ struct SerializedAgentPanel { selected_agent: Option, #[serde(default)] last_active_thread: Option, + #[serde(default)] + start_thread_in: Option, } #[derive(Serialize, Deserialize, Debug, Clone)] @@ -324,6 +330,13 @@ pub fn init(cx: &mut App) { cx, ); }); + }) + .register_action(|workspace, action: &StartThreadIn, _window, cx| { + if let Some(panel) = workspace.panel::(cx) { + panel.update(cx, |panel, cx| { + panel.set_start_thread_in(action, cx); + }); + } }); }, ) @@ -371,6 +384,10 @@ pub enum AgentType { } impl AgentType { + pub fn is_native(&self) -> bool { + matches!(self, Self::NativeAgent) + } + fn label(&self) -> SharedString { match self { Self::NativeAgent | Self::TextThread => "Zed Agent".into(), @@ -395,6 +412,29 @@ impl From for AgentType { } } +impl StartThreadIn { + fn label(&self) -> SharedString { + match self { + Self::LocalProject => "Local Project".into(), + Self::NewWorktree => "New Worktree".into(), + } + } + + fn icon(&self) -> IconName { + match self { + Self::LocalProject => IconName::Screen, + Self::NewWorktree => IconName::GitBranchPlus, + } + } +} + +#[derive(Clone, Debug)] +#[allow(dead_code)] +pub enum WorktreeCreationStatus { + Creating, + Error(SharedString), +} + impl ActiveView { pub fn which_font_size_used(&self) -> WhichFontSize { match self { @@ -515,6 +555,7 @@ pub struct AgentPanel { previous_view: Option, _active_view_observation: Option, new_thread_menu_handle: PopoverMenuHandle, + start_thread_in_menu_handle: PopoverMenuHandle, agent_panel_menu_handle: PopoverMenuHandle, agent_navigation_menu_handle: PopoverMenuHandle, agent_navigation_menu: Option>, @@ -525,6 +566,10 @@ pub struct AgentPanel { pending_serialization: Option>>, onboarding: Entity, selected_agent: AgentType, + start_thread_in: StartThreadIn, + worktree_creation_status: Option, + _thread_view_subscription: Option, + _worktree_creation_task: Option>, show_trust_workspace_message: bool, last_configuration_error_telemetry: Option, on_boarding_upsell_dismissed: AtomicBool, @@ -538,6 +583,7 @@ impl AgentPanel { let width = self.width; let selected_agent = self.selected_agent.clone(); + let start_thread_in = Some(self.start_thread_in); let last_active_thread = self.active_agent_thread(cx).map(|thread| { let thread = thread.read(cx); @@ -561,6 +607,7 @@ impl AgentPanel { width, selected_agent: Some(selected_agent), last_active_thread, + start_thread_in, }, ) .await?; @@ -605,6 +652,37 @@ impl AgentPanel { })? .await?; + let last_active_thread = if let Some(thread_info) = serialized_panel + .as_ref() + .and_then(|p| p.last_active_thread.clone()) + { + if thread_info.agent_type.is_native() { + let session_id = acp::SessionId::new(thread_info.session_id.clone()); + let load_result = cx.update(|_window, cx| { + let thread_store = ThreadStore::global(cx); + thread_store.update(cx, |store, cx| store.load_thread(session_id, cx)) + }); + let thread_exists = if let Ok(task) = load_result { + task.await.ok().flatten().is_some() + } else { + false + }; + if thread_exists { + Some(thread_info) + } else { + log::warn!( + "last active thread {} not found in database, skipping restoration", + thread_info.session_id + ); + None + } + } else { + Some(thread_info) + } + } else { + None + }; + let panel = workspace.update_in(cx, |workspace, window, cx| { let panel = cx.new(|cx| Self::new(workspace, text_thread_store, prompt_store, window, cx)); @@ -615,44 +693,45 @@ impl AgentPanel { if let Some(selected_agent) = serialized_panel.selected_agent.clone() { panel.selected_agent = selected_agent; } + if let Some(start_thread_in) = serialized_panel.start_thread_in { + let is_worktree_flag_enabled = + cx.has_flag::(); + let is_valid = match &start_thread_in { + StartThreadIn::LocalProject => true, + StartThreadIn::NewWorktree => { + let project = panel.project.read(cx); + is_worktree_flag_enabled && !project.is_via_collab() + } + }; + if is_valid { + panel.start_thread_in = start_thread_in; + } else { + log::info!( + "deserialized start_thread_in {:?} is no longer valid, falling back to LocalProject", + start_thread_in, + ); + } + } cx.notify(); }); } - panel - })?; - - if let Some(thread_info) = serialized_panel.and_then(|p| p.last_active_thread) { - let session_id = acp::SessionId::new(thread_info.session_id.clone()); - let load_task = panel.update(cx, |panel, cx| { - let thread_store = panel.thread_store.clone(); - thread_store.update(cx, |store, cx| store.load_thread(session_id, cx)) - }); - let thread_exists = load_task - .await - .map(|thread: Option| thread.is_some()) - .unwrap_or(false); - - if thread_exists { - panel.update_in(cx, |panel, window, cx| { - panel.selected_agent = thread_info.agent_type.clone(); - let session_info = AgentSessionInfo { - session_id: acp::SessionId::new(thread_info.session_id), - cwd: thread_info.cwd, - title: thread_info.title.map(SharedString::from), - updated_at: None, - meta: None, - }; + if let Some(thread_info) = last_active_thread { + let agent_type = thread_info.agent_type.clone(); + let session_info = AgentSessionInfo { + session_id: acp::SessionId::new(thread_info.session_id), + cwd: thread_info.cwd, + title: thread_info.title.map(SharedString::from), + updated_at: None, + meta: None, + }; + panel.update(cx, |panel, cx| { + panel.selected_agent = agent_type; panel.load_agent_thread(session_info, window, cx); - })?; - } else { - log::error!( - "could not restore last active thread: \ - no thread found in database with ID {:?}", - thread_info.session_id - ); + }); } - } + panel + })?; Ok(panel) }) @@ -800,6 +879,7 @@ impl AgentPanel { previous_view: None, _active_view_observation: None, new_thread_menu_handle: PopoverMenuHandle::default(), + start_thread_in_menu_handle: PopoverMenuHandle::default(), agent_panel_menu_handle: PopoverMenuHandle::default(), agent_navigation_menu_handle: PopoverMenuHandle::default(), agent_navigation_menu: None, @@ -813,6 +893,10 @@ impl AgentPanel { text_thread_history, thread_store, selected_agent: AgentType::default(), + start_thread_in: StartThreadIn::default(), + worktree_creation_status: None, + _thread_view_subscription: None, + _worktree_creation_task: None, show_trust_workspace_message: false, last_configuration_error_telemetry: None, on_boarding_upsell_dismissed: AtomicBool::new(OnboardingUpsell::dismissed()), @@ -1044,7 +1128,7 @@ impl AgentPanel { let server = ext_agent.server(fs, thread_store); this.update_in(cx, |agent_panel, window, cx| { - agent_panel._external_thread( + agent_panel.create_external_thread( server, resume_thread, initial_content, @@ -1618,15 +1702,28 @@ impl AgentPanel { self.active_view = new_view; } + // Subscribe to the active ThreadView's events (e.g. FirstSendRequested) + // so the panel can intercept the first send for worktree creation. + // Re-subscribe whenever the ConnectionView changes, since the inner + // ThreadView may have been replaced (e.g. navigating between threads). self._active_view_observation = match &self.active_view { ActiveView::AgentThread { server_view } => { - Some(cx.observe(server_view, |this, _, cx| { - cx.emit(AgentPanelEvent::ActiveViewChanged); - this.serialize(cx); - cx.notify(); - })) + self._thread_view_subscription = + Self::subscribe_to_active_thread_view(server_view, window, cx); + Some( + cx.observe_in(server_view, window, |this, server_view, window, cx| { + this._thread_view_subscription = + Self::subscribe_to_active_thread_view(&server_view, window, cx); + cx.emit(AgentPanelEvent::ActiveViewChanged); + this.serialize(cx); + cx.notify(); + }), + ) + } + _ => { + self._thread_view_subscription = None; + None } - _ => None, }; let is_in_agent_history = matches!( @@ -1740,6 +1837,56 @@ impl AgentPanel { self.selected_agent.clone() } + fn subscribe_to_active_thread_view( + server_view: &Entity, + window: &mut Window, + cx: &mut Context, + ) -> Option { + server_view.read(cx).active_thread().cloned().map(|tv| { + cx.subscribe_in( + &tv, + window, + |this, view, event: &AcpThreadViewEvent, window, cx| match event { + AcpThreadViewEvent::FirstSendRequested { content } => { + this.handle_first_send_requested(view.clone(), content.clone(), window, cx); + } + }, + ) + }) + } + + pub fn start_thread_in(&self) -> &StartThreadIn { + &self.start_thread_in + } + + fn set_start_thread_in(&mut self, action: &StartThreadIn, cx: &mut Context) { + if matches!(action, StartThreadIn::NewWorktree) + && !cx.has_flag::() + { + return; + } + + let new_target = match *action { + StartThreadIn::LocalProject => StartThreadIn::LocalProject, + StartThreadIn::NewWorktree => { + if !self.project_has_git_repository(cx) { + log::error!( + "set_start_thread_in: cannot use NewWorktree without a git repository" + ); + return; + } + if self.project.read(cx).is_via_collab() { + log::error!("set_start_thread_in: cannot use NewWorktree in a collab project"); + return; + } + StartThreadIn::NewWorktree + } + }; + self.start_thread_in = new_target; + self.serialize(cx); + cx.notify(); + } + fn selected_external_agent(&self) -> Option { match &self.selected_agent { AgentType::NativeAgent => Some(ExternalAgent::NativeAgent), @@ -1830,7 +1977,7 @@ impl AgentPanel { self.external_thread(Some(agent), Some(thread), None, window, cx); } - fn _external_thread( + pub(crate) fn create_external_thread( &mut self, server: Rc, resume_thread: Option, @@ -1869,135 +2016,641 @@ impl AgentPanel { self.set_active_view(ActiveView::AgentThread { server_view }, true, window, cx); } -} -impl Focusable for AgentPanel { - fn focus_handle(&self, cx: &App) -> FocusHandle { - match &self.active_view { - ActiveView::Uninitialized => self.focus_handle.clone(), - ActiveView::AgentThread { server_view, .. } => server_view.focus_handle(cx), - ActiveView::History { kind } => match kind { - HistoryKind::AgentThreads => self.acp_history.focus_handle(cx), - HistoryKind::TextThreads => self.text_thread_history.focus_handle(cx), - }, - ActiveView::TextThread { - text_thread_editor, .. - } => text_thread_editor.focus_handle(cx), - ActiveView::Configuration => { - if let Some(configuration) = self.configuration.as_ref() { - configuration.focus_handle(cx) - } else { - self.focus_handle.clone() - } - } - } + fn active_thread_has_messages(&self, cx: &App) -> bool { + self.active_agent_thread(cx) + .is_some_and(|thread| !thread.read(cx).entries().is_empty()) } -} -fn agent_panel_dock_position(cx: &App) -> DockPosition { - AgentSettings::get_global(cx).dock.into() -} + fn handle_first_send_requested( + &mut self, + thread_view: Entity, + content: Vec, + window: &mut Window, + cx: &mut Context, + ) { + if self.start_thread_in == StartThreadIn::NewWorktree { + self.handle_worktree_creation_requested(content, window, cx); + } else { + cx.defer_in(window, move |_this, window, cx| { + thread_view.update(cx, |thread_view, cx| { + let editor = thread_view.message_editor.clone(); + thread_view.send_impl(editor, window, cx); + }); + }); + } + } -pub enum AgentPanelEvent { - ActiveViewChanged, -} + /// Partitions the project's visible worktrees into git-backed repositories + /// and plain (non-git) paths. Git repos will have worktrees created for + /// them; non-git paths are carried over to the new workspace as-is. + /// + /// When multiple worktrees map to the same repository, the most specific + /// match wins (deepest work directory path), with a deterministic + /// tie-break on entity id. Each repository appears at most once. + fn classify_worktrees( + &self, + cx: &App, + ) -> (Vec>, Vec) { + let project = &self.project; + let repositories = project.read(cx).repositories(cx).clone(); + let mut git_repos: Vec> = Vec::new(); + let mut non_git_paths: Vec = Vec::new(); + let mut seen_repo_ids = std::collections::HashSet::new(); + + for worktree in project.read(cx).visible_worktrees(cx) { + let wt_path = worktree.read(cx).abs_path(); + + let matching_repo = repositories + .iter() + .filter_map(|(id, repo)| { + let work_dir = repo.read(cx).work_directory_abs_path.clone(); + if wt_path.starts_with(work_dir.as_ref()) + || work_dir.starts_with(wt_path.as_ref()) + { + Some((*id, repo.clone(), work_dir.as_ref().components().count())) + } else { + None + } + }) + .max_by( + |(left_id, _left_repo, left_depth), (right_id, _right_repo, right_depth)| { + left_depth + .cmp(right_depth) + .then_with(|| left_id.cmp(right_id)) + }, + ); -impl EventEmitter for AgentPanel {} -impl EventEmitter for AgentPanel {} + if let Some((id, repo, _)) = matching_repo { + if seen_repo_ids.insert(id) { + git_repos.push(repo); + } + } else { + non_git_paths.push(wt_path.to_path_buf()); + } + } -impl Panel for AgentPanel { - fn persistent_name() -> &'static str { - "AgentPanel" + (git_repos, non_git_paths) } - fn panel_key() -> &'static str { - AGENT_PANEL_KEY - } + /// Kicks off an async git-worktree creation for each repository. Returns: + /// + /// - `creation_infos`: a vec of `(repo, new_path, receiver)` tuples—the + /// receiver resolves once the git worktree command finishes. + /// - `path_remapping`: `(old_work_dir, new_worktree_path)` pairs used + /// later to remap open editor tabs into the new workspace. + fn start_worktree_creations( + git_repos: &[Entity], + branch_name: &str, + worktree_directory_setting: &str, + cx: &mut Context, + ) -> Result<( + Vec<( + Entity, + PathBuf, + futures::channel::oneshot::Receiver>, + )>, + Vec<(PathBuf, PathBuf)>, + )> { + let mut creation_infos = Vec::new(); + let mut path_remapping = Vec::new(); + + for repo in git_repos { + let (work_dir, new_path, receiver) = repo.update(cx, |repo, _cx| { + let original_repo = repo.original_repo_abs_path.clone(); + let directory = + validate_worktree_directory(&original_repo, worktree_directory_setting)?; + let new_path = directory.join(branch_name); + let receiver = repo.create_worktree(branch_name.to_string(), directory, None); + let work_dir = repo.work_directory_abs_path.clone(); + anyhow::Ok((work_dir, new_path, receiver)) + })?; + path_remapping.push((work_dir.to_path_buf(), new_path.clone())); + creation_infos.push((repo.clone(), new_path, receiver)); + } - fn position(&self, _window: &Window, cx: &App) -> DockPosition { - agent_panel_dock_position(cx) + Ok((creation_infos, path_remapping)) } - fn position_is_valid(&self, position: DockPosition) -> bool { - position != DockPosition::Bottom - } + /// Waits for every in-flight worktree creation to complete. If any + /// creation fails, all successfully-created worktrees are rolled back + /// (removed) so the project isn't left in a half-migrated state. + async fn await_and_rollback_on_failure( + creation_infos: Vec<( + Entity, + PathBuf, + futures::channel::oneshot::Receiver>, + )>, + cx: &mut AsyncWindowContext, + ) -> Result> { + let mut created_paths: Vec = Vec::new(); + let mut repos_and_paths: Vec<(Entity, PathBuf)> = + Vec::new(); + let mut first_error: Option = None; + + for (repo, new_path, receiver) in creation_infos { + match receiver.await { + Ok(Ok(())) => { + created_paths.push(new_path.clone()); + repos_and_paths.push((repo, new_path)); + } + Ok(Err(err)) => { + if first_error.is_none() { + first_error = Some(err); + } + } + Err(_canceled) => { + if first_error.is_none() { + first_error = Some(anyhow!("Worktree creation was canceled")); + } + } + } + } - fn set_position(&mut self, position: DockPosition, _: &mut Window, cx: &mut Context) { - settings::update_settings_file(self.fs.clone(), cx, move |settings, _| { - settings - .agent - .get_or_insert_default() - .set_dock(position.into()); - }); - } + let Some(err) = first_error else { + return Ok(created_paths); + }; - fn size(&self, window: &Window, cx: &App) -> Pixels { - let settings = AgentSettings::get_global(cx); - match self.position(window, cx) { - DockPosition::Left | DockPosition::Right => { - self.width.unwrap_or(settings.default_width) + // Rollback all successfully created worktrees + let mut rollback_receivers = Vec::new(); + for (rollback_repo, rollback_path) in &repos_and_paths { + if let Ok(receiver) = cx.update(|_, cx| { + rollback_repo.update(cx, |repo, _cx| { + repo.remove_worktree(rollback_path.clone(), true) + }) + }) { + rollback_receivers.push((rollback_path.clone(), receiver)); } - DockPosition::Bottom => self.height.unwrap_or(settings.default_height), } - } - - fn set_size(&mut self, size: Option, window: &mut Window, cx: &mut Context) { - match self.position(window, cx) { - DockPosition::Left | DockPosition::Right => self.width = size, - DockPosition::Bottom => self.height = size, + let mut rollback_failures: Vec = Vec::new(); + for (path, receiver) in rollback_receivers { + match receiver.await { + Ok(Ok(())) => {} + Ok(Err(rollback_err)) => { + log::error!( + "failed to rollback worktree at {}: {rollback_err}", + path.display() + ); + rollback_failures.push(format!("{}: {rollback_err}", path.display())); + } + Err(rollback_err) => { + log::error!( + "failed to rollback worktree at {}: {rollback_err}", + path.display() + ); + rollback_failures.push(format!("{}: {rollback_err}", path.display())); + } + } } - self.serialize(cx); - cx.notify(); + let mut error_message = format!("Failed to create worktree: {err}"); + if !rollback_failures.is_empty() { + error_message.push_str("\n\nFailed to clean up: "); + error_message.push_str(&rollback_failures.join(", ")); + } + Err(anyhow!(error_message)) } - fn set_active(&mut self, active: bool, window: &mut Window, cx: &mut Context) { - if active && matches!(self.active_view, ActiveView::Uninitialized) { + fn set_worktree_creation_error( + &mut self, + message: SharedString, + window: &mut Window, + cx: &mut Context, + ) { + self.worktree_creation_status = Some(WorktreeCreationStatus::Error(message)); + if matches!(self.active_view, ActiveView::Uninitialized) { let selected_agent = self.selected_agent.clone(); self.new_agent_thread(selected_agent, window, cx); } + cx.notify(); } - fn remote_id() -> Option { - Some(proto::PanelId::AssistantPanel) - } - - fn icon(&self, _window: &Window, cx: &App) -> Option { - (self.enabled(cx) && AgentSettings::get_global(cx).button).then_some(IconName::ZedAssistant) - } + fn handle_worktree_creation_requested( + &mut self, + content: Vec, + window: &mut Window, + cx: &mut Context, + ) { + if matches!( + self.worktree_creation_status, + Some(WorktreeCreationStatus::Creating) + ) { + return; + } - fn icon_tooltip(&self, _window: &Window, _cx: &App) -> Option<&'static str> { - Some("Agent Panel") - } + self.worktree_creation_status = Some(WorktreeCreationStatus::Creating); + cx.notify(); - fn toggle_action(&self) -> Box { - Box::new(ToggleFocus) - } + let (git_repos, non_git_paths) = self.classify_worktrees(cx); - fn activation_priority(&self) -> u32 { - 3 - } + if git_repos.is_empty() { + self.set_worktree_creation_error( + "No git repositories found in the project".into(), + window, + cx, + ); + return; + } - fn enabled(&self, cx: &App) -> bool { - AgentSettings::get_global(cx).enabled(cx) - } + // Kick off branch listing as early as possible so it can run + // concurrently with the remaining synchronous setup work. + let branch_receivers: Vec<_> = git_repos + .iter() + .map(|repo| repo.update(cx, |repo, _cx| repo.branches())) + .collect(); + + let worktree_directory_setting = ProjectSettings::get_global(cx) + .git + .worktree_directory + .clone(); + + let (dock_structure, open_file_paths) = self + .workspace + .upgrade() + .map(|workspace| { + let dock_structure = workspace.read(cx).capture_dock_state(window, cx); + let open_file_paths = workspace.read(cx).open_item_abs_paths(cx); + (dock_structure, open_file_paths) + }) + .unwrap_or_default(); - fn is_zoomed(&self, _window: &Window, _cx: &App) -> bool { - self.zoomed - } + let workspace = self.workspace.clone(); + let window_handle = window + .window_handle() + .downcast::(); + + let task = cx.spawn_in(window, async move |this, cx| { + // Await the branch listings we kicked off earlier. + let mut existing_branches = Vec::new(); + for result in futures::future::join_all(branch_receivers).await { + match result { + Ok(Ok(branches)) => { + for branch in branches { + existing_branches.push(branch.name().to_string()); + } + } + Ok(Err(err)) => { + Err::<(), _>(err).log_err(); + } + Err(_) => {} + } + } - fn set_zoomed(&mut self, zoomed: bool, _window: &mut Window, cx: &mut Context) { - self.zoomed = zoomed; - cx.notify(); - } -} + let existing_branch_refs: Vec<&str> = + existing_branches.iter().map(|s| s.as_str()).collect(); + let mut rng = rand::rng(); + let branch_name = + match crate::branch_names::generate_branch_name(&existing_branch_refs, &mut rng) { + Some(name) => name, + None => { + this.update_in(cx, |this, window, cx| { + this.set_worktree_creation_error( + "Failed to generate a branch name: all typewriter names are taken" + .into(), + window, + cx, + ); + })?; + return anyhow::Ok(()); + } + }; -impl AgentPanel { - fn render_title_view(&self, _window: &mut Window, cx: &Context) -> AnyElement { - const LOADING_SUMMARY_PLACEHOLDER: &str = "Loading Summary…"; + let (creation_infos, path_remapping) = match this.update_in(cx, |_this, _window, cx| { + Self::start_worktree_creations( + &git_repos, + &branch_name, + &worktree_directory_setting, + cx, + ) + }) { + Ok(Ok(result)) => result, + Ok(Err(err)) | Err(err) => { + this.update_in(cx, |this, window, cx| { + this.set_worktree_creation_error( + format!("Failed to validate worktree directory: {err}").into(), + window, + cx, + ); + }) + .log_err(); + return anyhow::Ok(()); + } + }; - let content = match &self.active_view { - ActiveView::AgentThread { server_view } => { - let is_generating_title = server_view + let created_paths = match Self::await_and_rollback_on_failure(creation_infos, cx).await + { + Ok(paths) => paths, + Err(err) => { + this.update_in(cx, |this, window, cx| { + this.set_worktree_creation_error(format!("{err}").into(), window, cx); + })?; + return anyhow::Ok(()); + } + }; + + let mut all_paths = created_paths; + let has_non_git = !non_git_paths.is_empty(); + all_paths.extend(non_git_paths.iter().cloned()); + + let app_state = match workspace.upgrade() { + Some(workspace) => cx.update(|_, cx| workspace.read(cx).app_state().clone())?, + None => { + this.update_in(cx, |this, window, cx| { + this.set_worktree_creation_error( + "Workspace no longer available".into(), + window, + cx, + ); + })?; + return anyhow::Ok(()); + } + }; + + let this_for_error = this.clone(); + if let Err(err) = Self::setup_new_workspace( + this, + all_paths, + app_state, + window_handle, + dock_structure, + open_file_paths, + path_remapping, + non_git_paths, + has_non_git, + content, + cx, + ) + .await + { + this_for_error + .update_in(cx, |this, window, cx| { + this.set_worktree_creation_error( + format!("Failed to set up workspace: {err}").into(), + window, + cx, + ); + }) + .log_err(); + } + anyhow::Ok(()) + }); + + self._worktree_creation_task = Some(cx.foreground_executor().spawn(async move { + task.await.log_err(); + })); + } + + async fn setup_new_workspace( + this: WeakEntity, + all_paths: Vec, + app_state: Arc, + window_handle: Option>, + dock_structure: workspace::DockStructure, + open_file_paths: Vec, + path_remapping: Vec<(PathBuf, PathBuf)>, + non_git_paths: Vec, + has_non_git: bool, + content: Vec, + cx: &mut AsyncWindowContext, + ) -> Result<()> { + let init: Option< + Box) + Send>, + > = Some(Box::new(move |workspace, window, cx| { + workspace.set_dock_structure(dock_structure, window, cx); + })); + + let (new_window_handle, _) = cx + .update(|_window, cx| { + Workspace::new_local(all_paths, app_state, window_handle, None, init, false, cx) + })? + .await?; + + let new_workspace = new_window_handle.update(cx, |multi_workspace, _window, _cx| { + let workspaces = multi_workspace.workspaces(); + workspaces.last().cloned() + })?; + + let Some(new_workspace) = new_workspace else { + anyhow::bail!("New workspace was not added to MultiWorkspace"); + }; + + let panels_task = new_window_handle.update(cx, |_, _, cx| { + new_workspace.update(cx, |workspace, _cx| workspace.take_panels_task()) + })?; + if let Some(task) = panels_task { + task.await.log_err(); + } + + let initial_content = AgentInitialContent::ContentBlock { + blocks: content, + auto_submit: true, + }; + + new_window_handle.update(cx, |_multi_workspace, window, cx| { + new_workspace.update(cx, |workspace, cx| { + if has_non_git { + let toast_id = workspace::notifications::NotificationId::unique::(); + workspace.show_toast( + workspace::Toast::new( + toast_id, + "Some project folders are not git repositories. \ + They were included as-is without creating a worktree.", + ), + cx, + ); + } + + let remapped_paths: Vec = open_file_paths + .iter() + .filter_map(|original_path| { + let best_match = path_remapping + .iter() + .filter_map(|(old_root, new_root)| { + original_path.strip_prefix(old_root).ok().map(|relative| { + (old_root.components().count(), new_root.join(relative)) + }) + }) + .max_by_key(|(depth, _)| *depth); + + if let Some((_, remapped_path)) = best_match { + return Some(remapped_path); + } + + for non_git in &non_git_paths { + if original_path.starts_with(non_git) { + return Some(original_path.clone()); + } + } + None + }) + .collect(); + + if !remapped_paths.is_empty() { + workspace + .open_paths( + remapped_paths, + workspace::OpenOptions::default(), + None, + window, + cx, + ) + .detach(); + } + + workspace.focus_panel::(window, cx); + if let Some(panel) = workspace.panel::(cx) { + panel.update(cx, |panel, cx| { + panel.external_thread(None, None, Some(initial_content), window, cx); + }); + } + }); + })?; + + new_window_handle.update(cx, |multi_workspace, _window, cx| { + multi_workspace.activate(new_workspace.clone(), cx); + })?; + + this.update_in(cx, |this, _window, cx| { + this.worktree_creation_status = None; + cx.notify(); + })?; + + anyhow::Ok(()) + } +} + +impl Focusable for AgentPanel { + fn focus_handle(&self, cx: &App) -> FocusHandle { + match &self.active_view { + ActiveView::Uninitialized => self.focus_handle.clone(), + ActiveView::AgentThread { server_view, .. } => server_view.focus_handle(cx), + ActiveView::History { kind } => match kind { + HistoryKind::AgentThreads => self.acp_history.focus_handle(cx), + HistoryKind::TextThreads => self.text_thread_history.focus_handle(cx), + }, + ActiveView::TextThread { + text_thread_editor, .. + } => text_thread_editor.focus_handle(cx), + ActiveView::Configuration => { + if let Some(configuration) = self.configuration.as_ref() { + configuration.focus_handle(cx) + } else { + self.focus_handle.clone() + } + } + } + } +} + +fn agent_panel_dock_position(cx: &App) -> DockPosition { + AgentSettings::get_global(cx).dock.into() +} + +pub enum AgentPanelEvent { + ActiveViewChanged, +} + +impl EventEmitter for AgentPanel {} +impl EventEmitter for AgentPanel {} + +impl Panel for AgentPanel { + fn persistent_name() -> &'static str { + "AgentPanel" + } + + fn panel_key() -> &'static str { + AGENT_PANEL_KEY + } + + fn position(&self, _window: &Window, cx: &App) -> DockPosition { + agent_panel_dock_position(cx) + } + + fn position_is_valid(&self, position: DockPosition) -> bool { + position != DockPosition::Bottom + } + + fn set_position(&mut self, position: DockPosition, _: &mut Window, cx: &mut Context) { + settings::update_settings_file(self.fs.clone(), cx, move |settings, _| { + settings + .agent + .get_or_insert_default() + .set_dock(position.into()); + }); + } + + fn size(&self, window: &Window, cx: &App) -> Pixels { + let settings = AgentSettings::get_global(cx); + match self.position(window, cx) { + DockPosition::Left | DockPosition::Right => { + self.width.unwrap_or(settings.default_width) + } + DockPosition::Bottom => self.height.unwrap_or(settings.default_height), + } + } + + fn set_size(&mut self, size: Option, window: &mut Window, cx: &mut Context) { + match self.position(window, cx) { + DockPosition::Left | DockPosition::Right => self.width = size, + DockPosition::Bottom => self.height = size, + } + self.serialize(cx); + cx.notify(); + } + + fn set_active(&mut self, active: bool, window: &mut Window, cx: &mut Context) { + if active + && matches!(self.active_view, ActiveView::Uninitialized) + && !matches!( + self.worktree_creation_status, + Some(WorktreeCreationStatus::Creating) + ) + { + let selected_agent = self.selected_agent.clone(); + self.new_agent_thread(selected_agent, window, cx); + } + } + + fn remote_id() -> Option { + Some(proto::PanelId::AssistantPanel) + } + + fn icon(&self, _window: &Window, cx: &App) -> Option { + (self.enabled(cx) && AgentSettings::get_global(cx).button).then_some(IconName::ZedAssistant) + } + + fn icon_tooltip(&self, _window: &Window, _cx: &App) -> Option<&'static str> { + Some("Agent Panel") + } + + fn toggle_action(&self) -> Box { + Box::new(ToggleFocus) + } + + fn activation_priority(&self) -> u32 { + 3 + } + + fn enabled(&self, cx: &App) -> bool { + AgentSettings::get_global(cx).enabled(cx) + } + + fn is_zoomed(&self, _window: &Window, _cx: &App) -> bool { + self.zoomed + } + + fn set_zoomed(&mut self, zoomed: bool, _window: &mut Window, cx: &mut Context) { + self.zoomed = zoomed; + cx.notify(); + } +} + +impl AgentPanel { + fn render_title_view(&self, _window: &mut Window, cx: &Context) -> AnyElement { + const LOADING_SUMMARY_PLACEHOLDER: &str = "Loading Summary…"; + + let content = match &self.active_view { + ActiveView::AgentThread { server_view } => { + let is_generating_title = server_view .read(cx) .as_native_thread(cx) .map_or(false, |t| t.read(cx).is_generating_title()); @@ -2331,6 +2984,99 @@ impl AgentPanel { }) } + fn project_has_git_repository(&self, cx: &App) -> bool { + !self.project.read(cx).repositories(cx).is_empty() + } + + fn render_start_thread_in_selector(&self, cx: &mut Context) -> impl IntoElement { + let has_git_repo = self.project_has_git_repository(cx); + let is_via_collab = self.project.read(cx).is_via_collab(); + + let is_creating = matches!( + self.worktree_creation_status, + Some(WorktreeCreationStatus::Creating) + ); + + let current_target = self.start_thread_in; + let trigger_label = self.start_thread_in.label(); + + let icon = if self.start_thread_in_menu_handle.is_deployed() { + IconName::ChevronUp + } else { + IconName::ChevronDown + }; + + let trigger_button = Button::new("thread-target-trigger", trigger_label) + .label_size(LabelSize::Small) + .color(Color::Muted) + .icon(icon) + .icon_size(IconSize::XSmall) + .icon_position(IconPosition::End) + .icon_color(Color::Muted) + .disabled(is_creating); + + let dock_position = AgentSettings::get_global(cx).dock; + let documentation_side = match dock_position { + settings::DockPosition::Left => DocumentationSide::Right, + settings::DockPosition::Bottom | settings::DockPosition::Right => { + DocumentationSide::Left + } + }; + + PopoverMenu::new("thread-target-selector") + .trigger(trigger_button) + .anchor(gpui::Corner::BottomRight) + .with_handle(self.start_thread_in_menu_handle.clone()) + .menu(move |window, cx| { + let current_target = current_target; + Some(ContextMenu::build(window, cx, move |menu, _window, _cx| { + let is_local_selected = current_target == StartThreadIn::LocalProject; + let is_new_worktree_selected = current_target == StartThreadIn::NewWorktree; + + let new_worktree_disabled = !has_git_repo || is_via_collab; + + menu.header("Start Thread In…") + .item( + ContextMenuEntry::new("Local Project") + .icon(StartThreadIn::LocalProject.icon()) + .icon_color(Color::Muted) + .toggleable(IconPosition::End, is_local_selected) + .handler(|window, cx| { + window + .dispatch_action(Box::new(StartThreadIn::LocalProject), cx); + }), + ) + .item({ + let entry = ContextMenuEntry::new("New Worktree") + .icon(StartThreadIn::NewWorktree.icon()) + .icon_color(Color::Muted) + .toggleable(IconPosition::End, is_new_worktree_selected) + .disabled(new_worktree_disabled) + .handler(|window, cx| { + window + .dispatch_action(Box::new(StartThreadIn::NewWorktree), cx); + }); + + if new_worktree_disabled { + entry.documentation_aside(documentation_side, move |_| { + let reason = if !has_git_repo { + "No git repository found in this project." + } else { + "Not available for remote/collab projects yet." + }; + Label::new(reason) + .color(Color::Muted) + .size(LabelSize::Small) + .into_any_element() + }) + } else { + entry + } + }) + })) + }) + } + fn render_toolbar(&self, window: &mut Window, cx: &mut Context) -> impl IntoElement { let agent_server_store = self.project.read(cx).agent_server_store().clone(); let focus_handle = self.focus_handle(cx); @@ -2718,6 +3464,7 @@ impl AgentPanel { }; let show_history_menu = self.history_kind_for_selected_agent(cx).is_some(); + let has_v2_flag = cx.has_flag::(); h_flex() .id("agent-panel-toolbar") @@ -2748,6 +3495,12 @@ impl AgentPanel { .gap(DynamicSpacing::Base02.rems(cx)) .pl(DynamicSpacing::Base04.rems(cx)) .pr(DynamicSpacing::Base06.rems(cx)) + .when( + has_v2_flag + && cx.has_flag::() + && !self.active_thread_has_messages(cx), + |this| this.child(self.render_start_thread_in_selector(cx)), + ) .child(new_thread_menu) .when(show_history_menu, |this| { this.child(self.render_recent_entries_menu( @@ -2760,6 +3513,51 @@ impl AgentPanel { ) } + fn render_worktree_creation_status(&self, cx: &mut Context) -> Option { + let status = self.worktree_creation_status.as_ref()?; + match status { + WorktreeCreationStatus::Creating => Some( + h_flex() + .w_full() + .px(DynamicSpacing::Base06.rems(cx)) + .py(DynamicSpacing::Base02.rems(cx)) + .gap_2() + .bg(cx.theme().colors().surface_background) + .border_b_1() + .border_color(cx.theme().colors().border) + .child(SpinnerLabel::new().size(LabelSize::Small)) + .child( + Label::new("Creating worktree…") + .color(Color::Muted) + .size(LabelSize::Small), + ) + .into_any_element(), + ), + WorktreeCreationStatus::Error(message) => Some( + h_flex() + .w_full() + .px(DynamicSpacing::Base06.rems(cx)) + .py(DynamicSpacing::Base02.rems(cx)) + .gap_2() + .bg(cx.theme().colors().surface_background) + .border_b_1() + .border_color(cx.theme().colors().border) + .child( + Icon::new(IconName::Warning) + .size(IconSize::Small) + .color(Color::Warning), + ) + .child( + Label::new(message.clone()) + .color(Color::Warning) + .size(LabelSize::Small) + .truncate(), + ) + .into_any_element(), + ), + } + } + fn should_render_trial_end_upsell(&self, cx: &mut Context) -> bool { if TrialEndUpsell::dismissed() { return false; @@ -3191,6 +3989,7 @@ impl Render for AgentPanel { } })) .child(self.render_toolbar(window, cx)) + .children(self.render_worktree_creation_status(cx)) .children(self.render_workspace_trust_message(cx)) .children(self.render_onboarding(window, cx)) .map(|parent| { @@ -3456,7 +4255,7 @@ impl AgentPanel { name: server.name(), }; - self._external_thread( + self.create_external_thread( server, None, None, workspace, project, ext_agent, window, cx, ); } @@ -3468,6 +4267,61 @@ impl AgentPanel { pub fn active_thread_view_for_tests(&self) -> Option<&Entity> { self.active_thread_view() } + + /// Sets the start_thread_in value directly, bypassing validation. + /// + /// This is a test-only helper for visual tests that need to show specific + /// start_thread_in states without requiring a real git repository. + pub fn set_start_thread_in_for_tests(&mut self, target: StartThreadIn, cx: &mut Context) { + self.start_thread_in = target; + cx.notify(); + } + + /// Returns the current worktree creation status. + /// + /// This is a test-only helper for visual tests. + pub fn worktree_creation_status_for_tests(&self) -> Option<&WorktreeCreationStatus> { + self.worktree_creation_status.as_ref() + } + + /// Sets the worktree creation status directly. + /// + /// This is a test-only helper for visual tests that need to show the + /// "Creating worktree…" spinner or error banners. + pub fn set_worktree_creation_status_for_tests( + &mut self, + status: Option, + cx: &mut Context, + ) { + self.worktree_creation_status = status; + cx.notify(); + } + + /// Opens the history view. + /// + /// This is a test-only helper that exposes the private `open_history()` + /// method for visual tests. + pub fn open_history_for_tests(&mut self, window: &mut Window, cx: &mut Context) { + self.open_history(window, cx); + } + + /// Opens the start_thread_in selector popover menu. + /// + /// This is a test-only helper for visual tests. + pub fn open_start_thread_in_menu_for_tests( + &mut self, + window: &mut Window, + cx: &mut Context, + ) { + self.start_thread_in_menu_handle.show(window, cx); + } + + /// Dismisses the start_thread_in dropdown menu. + /// + /// This is a test-only helper for visual tests. + pub fn close_start_thread_in_menu_for_tests(&mut self, cx: &mut Context) { + self.start_thread_in_menu_handle.hide(cx); + } } #[cfg(test)] @@ -3479,6 +4333,7 @@ mod tests { use fs::FakeFs; use gpui::{TestAppContext, VisualTestContext}; use project::Project; + use serde_json::json; use workspace::MultiWorkspace; #[gpui::test] @@ -3581,9 +4436,7 @@ mod tests { .expect("panel B load should succeed"); cx.run_until_parked(); - // Workspace A should restore width and agent type, but the thread - // should NOT be restored because the stub agent never persisted it - // to the database (the load-side validation skips missing threads). + // Workspace A should restore its thread, width, and agent type loaded_a.read_with(cx, |panel, _cx| { assert_eq!( panel.width, @@ -3594,6 +4447,10 @@ mod tests { panel.selected_agent, agent_type_a, "workspace A agent type should be restored" ); + assert!( + panel.active_thread_view().is_some(), + "workspace A should have its active thread restored" + ); }); // Workspace B should restore its own width and agent type, with no thread @@ -3663,4 +4520,383 @@ mod tests { cx.run_until_parked(); } + + #[gpui::test] + async fn test_thread_target_local_project(cx: &mut TestAppContext) { + init_test(cx); + cx.update(|cx| { + cx.update_flags(true, vec!["agent-v2".to_string()]); + agent::ThreadStore::init_global(cx); + language_model::LanguageModelRegistry::test(cx); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/project", + json!({ + ".git": {}, + "src": { + "main.rs": "fn main() {}" + } + }), + ) + .await; + fs.set_branch_name(Path::new("/project/.git"), Some("main")); + + let project = Project::test(fs.clone(), [Path::new("/project")], cx).await; + + let multi_workspace = + cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx)); + + let workspace = multi_workspace + .read_with(cx, |multi_workspace, _cx| { + multi_workspace.workspace().clone() + }) + .unwrap(); + + workspace.update(cx, |workspace, _cx| { + workspace.set_random_database_id(); + }); + + let cx = &mut VisualTestContext::from_window(multi_workspace.into(), cx); + + // Wait for the project to discover the git repository. + cx.run_until_parked(); + + let panel = workspace.update_in(cx, |workspace, window, cx| { + let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx)); + let panel = + cx.new(|cx| AgentPanel::new(workspace, text_thread_store, None, window, cx)); + workspace.add_panel(panel.clone(), window, cx); + panel + }); + + cx.run_until_parked(); + + // Default thread target should be LocalProject. + panel.read_with(cx, |panel, _cx| { + assert_eq!( + *panel.start_thread_in(), + StartThreadIn::LocalProject, + "default thread target should be LocalProject" + ); + }); + + // Start a new thread with the default LocalProject target. + // Use StubAgentServer so the thread connects immediately in tests. + panel.update_in(cx, |panel, window, cx| { + panel.open_external_thread_with_server( + Rc::new(StubAgentServer::default_response()), + window, + cx, + ); + }); + + cx.run_until_parked(); + + // MultiWorkspace should still have exactly one workspace (no worktree created). + multi_workspace + .read_with(cx, |multi_workspace, _cx| { + assert_eq!( + multi_workspace.workspaces().len(), + 1, + "LocalProject should not create a new workspace" + ); + }) + .unwrap(); + + // The thread should be active in the panel. + panel.read_with(cx, |panel, cx| { + assert!( + panel.active_agent_thread(cx).is_some(), + "a thread should be running in the current workspace" + ); + }); + + // The thread target should still be LocalProject (unchanged). + panel.read_with(cx, |panel, _cx| { + assert_eq!( + *panel.start_thread_in(), + StartThreadIn::LocalProject, + "thread target should remain LocalProject" + ); + }); + + // No worktree creation status should be set. + panel.read_with(cx, |panel, _cx| { + assert!( + panel.worktree_creation_status.is_none(), + "no worktree creation should have occurred" + ); + }); + } + + #[gpui::test] + async fn test_thread_target_serialization_round_trip(cx: &mut TestAppContext) { + init_test(cx); + cx.update(|cx| { + cx.update_flags( + true, + vec!["agent-v2".to_string(), "agent-git-worktrees".to_string()], + ); + agent::ThreadStore::init_global(cx); + language_model::LanguageModelRegistry::test(cx); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/project", + json!({ + ".git": {}, + "src": { + "main.rs": "fn main() {}" + } + }), + ) + .await; + fs.set_branch_name(Path::new("/project/.git"), Some("main")); + + let project = Project::test(fs.clone(), [Path::new("/project")], cx).await; + + let multi_workspace = + cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx)); + + let workspace = multi_workspace + .read_with(cx, |multi_workspace, _cx| { + multi_workspace.workspace().clone() + }) + .unwrap(); + + workspace.update(cx, |workspace, _cx| { + workspace.set_random_database_id(); + }); + + let cx = &mut VisualTestContext::from_window(multi_workspace.into(), cx); + + // Wait for the project to discover the git repository. + cx.run_until_parked(); + + let panel = workspace.update_in(cx, |workspace, window, cx| { + let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx)); + let panel = + cx.new(|cx| AgentPanel::new(workspace, text_thread_store, None, window, cx)); + workspace.add_panel(panel.clone(), window, cx); + panel + }); + + cx.run_until_parked(); + + // Default should be LocalProject. + panel.read_with(cx, |panel, _cx| { + assert_eq!(*panel.start_thread_in(), StartThreadIn::LocalProject); + }); + + // Change thread target to NewWorktree. + panel.update(cx, |panel, cx| { + panel.set_start_thread_in(&StartThreadIn::NewWorktree, cx); + }); + + panel.read_with(cx, |panel, _cx| { + assert_eq!( + *panel.start_thread_in(), + StartThreadIn::NewWorktree, + "thread target should be NewWorktree after set_thread_target" + ); + }); + + // Let serialization complete. + cx.run_until_parked(); + + // Load a fresh panel from the serialized data. + let prompt_builder = Arc::new(prompt_store::PromptBuilder::new(None).unwrap()); + let async_cx = cx.update(|window, cx| window.to_async(cx)); + let loaded_panel = + AgentPanel::load(workspace.downgrade(), prompt_builder.clone(), async_cx) + .await + .expect("panel load should succeed"); + cx.run_until_parked(); + + loaded_panel.read_with(cx, |panel, _cx| { + assert_eq!( + *panel.start_thread_in(), + StartThreadIn::NewWorktree, + "thread target should survive serialization round-trip" + ); + }); + } + + #[gpui::test] + async fn test_thread_target_deserialization_falls_back_when_worktree_flag_disabled( + cx: &mut TestAppContext, + ) { + init_test(cx); + cx.update(|cx| { + cx.update_flags( + true, + vec!["agent-v2".to_string(), "agent-git-worktrees".to_string()], + ); + agent::ThreadStore::init_global(cx); + language_model::LanguageModelRegistry::test(cx); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/project", + json!({ + ".git": {}, + "src": { + "main.rs": "fn main() {}" + } + }), + ) + .await; + fs.set_branch_name(Path::new("/project/.git"), Some("main")); + + let project = Project::test(fs.clone(), [Path::new("/project")], cx).await; + + let multi_workspace = + cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx)); + + let workspace = multi_workspace + .read_with(cx, |multi_workspace, _cx| { + multi_workspace.workspace().clone() + }) + .unwrap(); + + workspace.update(cx, |workspace, _cx| { + workspace.set_random_database_id(); + }); + + let cx = &mut VisualTestContext::from_window(multi_workspace.into(), cx); + + // Wait for the project to discover the git repository. + cx.run_until_parked(); + + let panel = workspace.update_in(cx, |workspace, window, cx| { + let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx)); + let panel = + cx.new(|cx| AgentPanel::new(workspace, text_thread_store, None, window, cx)); + workspace.add_panel(panel.clone(), window, cx); + panel + }); + + cx.run_until_parked(); + + panel.update(cx, |panel, cx| { + panel.set_start_thread_in(&StartThreadIn::NewWorktree, cx); + }); + + panel.read_with(cx, |panel, _cx| { + assert_eq!( + *panel.start_thread_in(), + StartThreadIn::NewWorktree, + "thread target should be NewWorktree before reload" + ); + }); + + // Let serialization complete. + cx.run_until_parked(); + + // Disable worktree flag and reload panel from serialized data. + cx.update(|_, cx| { + cx.update_flags(true, vec!["agent-v2".to_string()]); + }); + + let prompt_builder = Arc::new(prompt_store::PromptBuilder::new(None).unwrap()); + let async_cx = cx.update(|window, cx| window.to_async(cx)); + let loaded_panel = + AgentPanel::load(workspace.downgrade(), prompt_builder.clone(), async_cx) + .await + .expect("panel load should succeed"); + cx.run_until_parked(); + + loaded_panel.read_with(cx, |panel, _cx| { + assert_eq!( + *panel.start_thread_in(), + StartThreadIn::LocalProject, + "thread target should fall back to LocalProject when worktree flag is disabled" + ); + }); + } + + #[gpui::test] + async fn test_set_active_blocked_during_worktree_creation(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + cx.update(|cx| { + cx.update_flags(true, vec!["agent-v2".to_string()]); + agent::ThreadStore::init_global(cx); + language_model::LanguageModelRegistry::test(cx); + ::set_global(fs.clone(), cx); + }); + + fs.insert_tree( + "/project", + json!({ + ".git": {}, + "src": { + "main.rs": "fn main() {}" + } + }), + ) + .await; + + let project = Project::test(fs.clone(), [Path::new("/project")], cx).await; + + let multi_workspace = + cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx)); + + let workspace = multi_workspace + .read_with(cx, |multi_workspace, _cx| { + multi_workspace.workspace().clone() + }) + .unwrap(); + + let cx = &mut VisualTestContext::from_window(multi_workspace.into(), cx); + + let panel = workspace.update_in(cx, |workspace, window, cx| { + let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx)); + let panel = + cx.new(|cx| AgentPanel::new(workspace, text_thread_store, None, window, cx)); + workspace.add_panel(panel.clone(), window, cx); + panel + }); + + cx.run_until_parked(); + + // Simulate worktree creation in progress and reset to Uninitialized + panel.update_in(cx, |panel, window, cx| { + panel.worktree_creation_status = Some(WorktreeCreationStatus::Creating); + panel.active_view = ActiveView::Uninitialized; + Panel::set_active(panel, true, window, cx); + assert!( + matches!(panel.active_view, ActiveView::Uninitialized), + "set_active should not create a thread while worktree is being created" + ); + }); + + // Clear the creation status and use open_external_thread_with_server + // (which bypasses new_agent_thread) to verify the panel can transition + // out of Uninitialized. We can't call set_active directly because + // new_agent_thread requires full agent server infrastructure. + panel.update_in(cx, |panel, window, cx| { + panel.worktree_creation_status = None; + panel.active_view = ActiveView::Uninitialized; + panel.open_external_thread_with_server( + Rc::new(StubAgentServer::default_response()), + window, + cx, + ); + }); + + cx.run_until_parked(); + + panel.read_with(cx, |panel, _cx| { + assert!( + !matches!(panel.active_view, ActiveView::Uninitialized), + "panel should transition out of Uninitialized once worktree creation is cleared" + ); + }); + } } diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index ad778ca496f7815d0155f98187c8fad3e81365eb..5ae2d677ba6dd4622127b39938f2bf005e7fcab9 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -3,6 +3,7 @@ mod agent_diff; mod agent_model_selector; mod agent_panel; mod agent_registry_ui; +mod branch_names; mod buffer_codegen; mod completion_provider; mod config_options; @@ -55,7 +56,9 @@ use std::any::TypeId; use workspace::Workspace; use crate::agent_configuration::{ConfigureContextServerModal, ManageProfilesModal}; -pub use crate::agent_panel::{AgentPanel, AgentPanelEvent, ConcreteAssistantPanelDelegate}; +pub use crate::agent_panel::{ + AgentPanel, AgentPanelEvent, ConcreteAssistantPanelDelegate, WorktreeCreationStatus, +}; use crate::agent_registry_ui::AgentRegistryPage; pub use crate::inline_assistant::InlineAssistant; pub use agent_diff::{AgentDiffPane, AgentDiffToolbar}; @@ -222,6 +225,18 @@ impl ExternalAgent { } } +/// Sets where new threads will run. +#[derive( + Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Action, +)] +#[action(namespace = agent)] +#[serde(rename_all = "snake_case", tag = "kind")] +pub enum StartThreadIn { + #[default] + LocalProject, + NewWorktree, +} + /// Content to initialize new external agent with. pub enum AgentInitialContent { ThreadSummary(acp_thread::AgentSessionInfo), diff --git a/crates/agent_ui/src/branch_names.rs b/crates/agent_ui/src/branch_names.rs new file mode 100644 index 0000000000000000000000000000000000000000..74e3dbc76b729309403606dfbecc8ea87f271913 --- /dev/null +++ b/crates/agent_ui/src/branch_names.rs @@ -0,0 +1,847 @@ +use collections::HashSet; +use rand::Rng; + +/// Names of historical typewriter brands, for use in auto-generated branch names. +/// (Hyphens and parens have been dropped so that the branch names are one-word.) +/// +/// Thanks to https://typewriterdatabase.com/alph.0.brands for the names! +const TYPEWRITER_NAMES: &[&str] = &[ + "abeille", + "acme", + "addo", + "adler", + "adlerette", + "adlerita", + "admiral", + "agamli", + "agar", + "agidel", + "agil", + "aguia", + "aguila", + "ahram", + "aigle", + "ajax", + "aktiv", + "ala", + "alba", + "albus", + "alexander", + "alexis", + "alfa", + "allen", + "alonso", + "alpina", + "amata", + "amaya", + "amka", + "anavi", + "anderson", + "andina", + "antares", + "apex", + "apsco", + "aquila", + "archo", + "ardita", + "argyle", + "aristocrat", + "aristokrat", + "arlington", + "armstrong", + "arpha", + "artus", + "astoria", + "atlantia", + "atlantic", + "atlas", + "augusta", + "aurora", + "austro", + "automatic", + "avanti", + "avona", + "azzurra", + "bajnok", + "baldwin", + "balkan", + "baltica", + "baltimore", + "barlock", + "barr", + "barrat", + "bartholomew", + "bashkiriya", + "bavaria", + "beaucourt", + "beko", + "belka", + "bennett", + "bennington", + "berni", + "bianca", + "bijou", + "bing", + "bisei", + "biser", + "bluebird", + "bolida", + "borgo", + "boston", + "boyce", + "bradford", + "brandenburg", + "brigitte", + "briton", + "brooks", + "brosette", + "buddy", + "burns", + "burroughs", + "byron", + "calanda", + "caligraph", + "cappel", + "cardinal", + "carissima", + "carlem", + "carlton", + "carmen", + "cawena", + "cella", + "celtic", + "century", + "champignon", + "cherryland", + "chevron", + "chicago", + "cicero", + "cifra", + "citizen", + "claudia", + "cleveland", + "clover", + "coffman", + "cole", + "columbia", + "commercial", + "companion", + "concentra", + "concord", + "concordia", + "conover", + "constanta", + "consul", + "conta", + "contenta", + "contimat", + "contina", + "continento", + "cornelia", + "coronado", + "cosmopolita", + "courier", + "craftamatic", + "crandall", + "crown", + "culema", + "dactyle", + "dankers", + "dart", + "daugherty", + "davis", + "dayton", + "dea", + "delmar", + "densmore", + "depantio", + "diadema", + "dial", + "diamant", + "diana", + "dictatype", + "diplomat", + "diskret", + "dolfus", + "dollar", + "domus", + "drake", + "draper", + "duplex", + "durabel", + "dynacord", + "eagle", + "eclipse", + "edelmann", + "edelweiss", + "edison", + "edita", + "edland", + "efka", + "eldorado", + "electa", + "electromatic", + "elektro", + "elgin", + "elliot", + "emerson", + "emka", + "emona", + "empire", + "engadine", + "engler", + "erfurt", + "erika", + "esko", + "essex", + "eureka", + "europa", + "everest", + "everlux", + "excelsior", + "express", + "fabers", + "facit", + "fairbanks", + "faktotum", + "famos", + "federal", + "felio", + "fidat", + "filius", + "fips", + "fish", + "fitch", + "fleet", + "florida", + "flott", + "flyer", + "flying", + "fontana", + "ford", + "forto", + "fortuna", + "fox", + "framo", + "franconia", + "franklin", + "friden", + "frolio", + "furstenberg", + "galesburg", + "galiette", + "gallia", + "garbell", + "gardner", + "geka", + "generation", + "genia", + "geniatus", + "gerda", + "gisela", + "glashutte", + "gloria", + "godrej", + "gossen", + "gourland", + "grandjean", + "granta", + "granville", + "graphic", + "gritzner", + "groma", + "guhl", + "guidonia", + "gundka", + "hacabo", + "haddad", + "halberg", + "halda", + "hall", + "hammond", + "hammonia", + "hanford", + "hansa", + "harmony", + "harris", + "hartford", + "hassia", + "hatch", + "heady", + "hebronia", + "hebros", + "hega", + "helios", + "helma", + "herald", + "hercules", + "hermes", + "herold", + "heros", + "hesperia", + "hogar", + "hooven", + "hopkins", + "horton", + "hugin", + "hungaria", + "hurtu", + "iberia", + "idea", + "ideal", + "imperia", + "impo", + "industria", + "industrio", + "ingersoll", + "international", + "invicta", + "irene", + "iris", + "iskra", + "ivitsa", + "ivriah", + "jackson", + "janalif", + "janos", + "jolux", + "juki", + "junior", + "juventa", + "juwel", + "kamkap", + "kamo", + "kanzler", + "kappel", + "karli", + "karstadt", + "keaton", + "kenbar", + "keystone", + "kim", + "klein", + "kneist", + "knoch", + "koh", + "kolibri", + "kolumbus", + "komet", + "kondor", + "koniger", + "konryu", + "kontor", + "kosmopolit", + "krypton", + "lambert", + "lasalle", + "lectra", + "leframa", + "lemair", + "lemco", + "liberty", + "libia", + "liga", + "lignose", + "lilliput", + "lindeteves", + "linowriter", + "listvitsa", + "ludolf", + "lutece", + "luxa", + "lyubava", + "mafra", + "magnavox", + "maher", + "majestic", + "majitouch", + "manhattan", + "mapuua", + "marathon", + "marburger", + "maritsa", + "maruzen", + "maskelyne", + "masspro", + "matous", + "mccall", + "mccool", + "mcloughlin", + "mead", + "mechno", + "mehano", + "meiselbach", + "melbi", + "melior", + "melotyp", + "mentor", + "mepas", + "mercedesia", + "mercurius", + "mercury", + "merkur", + "merritt", + "merz", + "messa", + "meteco", + "meteor", + "micron", + "mignon", + "mikro", + "minerva", + "mirian", + "mirina", + "mitex", + "molle", + "monac", + "monarch", + "mondiale", + "monica", + "monofix", + "monopol", + "monpti", + "monta", + "montana", + "montgomery", + "moon", + "morgan", + "morris", + "morse", + "moya", + "moyer", + "munson", + "musicwriter", + "nadex", + "nakajima", + "neckermann", + "neubert", + "neya", + "ninety", + "nisa", + "noiseless", + "noor", + "nora", + "nord", + "norden", + "norica", + "norma", + "norman", + "north", + "nototyp", + "nova", + "novalevi", + "odell", + "odhner", + "odo", + "odoma", + "ohio", + "ohtani", + "oliva", + "oliver", + "olivetti", + "olympia", + "omega", + "optima", + "orbis", + "orel", + "orga", + "oriette", + "orion", + "orn", + "orplid", + "pacior", + "pagina", + "parisienne", + "passat", + "pearl", + "peerless", + "perfect", + "perfecta", + "perkeo", + "perkins", + "perlita", + "pettypet", + "phoenix", + "piccola", + "picht", + "pinnock", + "pionier", + "plurotyp", + "plutarch", + "pneumatic", + "pocket", + "polyglott", + "polygraph", + "pontiac", + "portable", + "portex", + "pozzi", + "premier", + "presto", + "primavera", + "progress", + "protos", + "pterotype", + "pullman", + "pulsatta", + "quick", + "racer", + "radio", + "rally", + "rand", + "readers", + "reed", + "referent", + "reff", + "regent", + "regia", + "regina", + "rekord", + "reliable", + "reliance", + "remagg", + "rembrandt", + "remer", + "remington", + "remsho", + "remstar", + "remtor", + "reporters", + "resko", + "rex", + "rexpel", + "rheinita", + "rheinmetall", + "rival", + "roberts", + "robotron", + "rocher", + "rochester", + "roebuck", + "rofa", + "roland", + "rooy", + "rover", + "roxy", + "roy", + "royal", + "rundstatler", + "sabaudia", + "sabb", + "saleem", + "salter", + "sampo", + "sarafan", + "saturn", + "saxonia", + "schade", + "schapiro", + "schreibi", + "scripta", + "sears", + "secor", + "selectric", + "selekta", + "senator", + "sense", + "senta", + "serd", + "shilling", + "shimade", + "shimer", + "sholes", + "shuang", + "siegfried", + "siemag", + "silma", + "silver", + "simplex", + "simtype", + "singer", + "smith", + "soemtron", + "sonja", + "speedwriter", + "sphinx", + "starlet", + "stearns", + "steel", + "stella", + "steno", + "sterling", + "stoewer", + "stolzenberg", + "stott", + "strangfeld", + "sture", + "stylotyp", + "sun", + "superba", + "superia", + "supermetall", + "surety", + "swintec", + "swissa", + "talbos", + "talleres", + "tatrapoint", + "taurus", + "taylorix", + "tell", + "tempotype", + "tippco", + "titania", + "tops", + "towa", + "toyo", + "tradition", + "transatlantic", + "traveller", + "trebla", + "triumph", + "turia", + "typatune", + "typen", + "typorium", + "ugro", + "ultima", + "unda", + "underwood", + "unica", + "unitype", + "ursula", + "utax", + "varityper", + "vasanta", + "vendex", + "venus", + "victor", + "victoria", + "video", + "viking", + "vira", + "virotyp", + "visigraph", + "vittoria", + "volcan", + "vornado", + "voss", + "vultur", + "waltons", + "wanamaker", + "wanderer", + "ward", + "warner", + "waterloo", + "waverley", + "wayne", + "webster", + "wedgefield", + "welco", + "wellington", + "wellon", + "weltblick", + "westphalia", + "wiedmer", + "williams", + "wilson", + "winkel", + "winsor", + "wizard", + "woodstock", + "woodwards", + "yatran", + "yost", + "zenit", + "zentronik", + "zeta", + "zeya", +]; + +/// Picks a typewriter name that isn't already taken by an existing branch. +/// +/// Each entry in `existing_branches` is expected to be a full branch name +/// like `"olivetti-a3f9b2c1"`. The prefix before the last `'-'` is treated +/// as the taken typewriter name. Branches without a `'-'` are ignored. +/// +/// Returns `None` when every name in the pool is already taken. +pub fn pick_typewriter_name( + existing_branches: &[&str], + rng: &mut impl Rng, +) -> Option<&'static str> { + let disallowed: HashSet<&str> = existing_branches + .iter() + .filter_map(|branch| branch.rsplit_once('-').map(|(prefix, _)| prefix)) + .collect(); + + let available: Vec<&'static str> = TYPEWRITER_NAMES + .iter() + .copied() + .filter(|name| !disallowed.contains(name)) + .collect(); + + if available.is_empty() { + return None; + } + + let index = rng.random_range(0..available.len()); + Some(available[index]) +} + +/// Generates a branch name like `"olivetti-a3f9b2c1"` by picking a typewriter +/// name that isn't already taken and appending an 8-character alphanumeric hash. +/// +/// Returns `None` when every typewriter name in the pool is already taken. +pub fn generate_branch_name(existing_branches: &[&str], rng: &mut impl Rng) -> Option { + let typewriter_name = pick_typewriter_name(existing_branches, rng)?; + let hash: String = (0..8) + .map(|_| { + let idx: u8 = rng.random_range(0..36); + if idx < 10 { + (b'0' + idx) as char + } else { + (b'a' + idx - 10) as char + } + }) + .collect(); + Some(format!("{typewriter_name}-{hash}")) +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::rngs::StdRng; + + #[gpui::test(iterations = 10)] + fn test_pick_typewriter_name_with_no_disallowed(mut rng: StdRng) { + let name = pick_typewriter_name(&[], &mut rng); + assert!(name.is_some()); + assert!(TYPEWRITER_NAMES.contains(&name.unwrap())); + } + + #[gpui::test(iterations = 10)] + fn test_pick_typewriter_name_excludes_taken_names(mut rng: StdRng) { + let branch_names = &["olivetti-abc12345", "selectric-def67890"]; + let name = pick_typewriter_name(branch_names, &mut rng).unwrap(); + assert_ne!(name, "olivetti"); + assert_ne!(name, "selectric"); + } + + #[gpui::test] + fn test_pick_typewriter_name_all_taken(mut rng: StdRng) { + let branch_names: Vec = TYPEWRITER_NAMES + .iter() + .map(|name| format!("{name}-00000000")) + .collect(); + let branch_name_refs: Vec<&str> = branch_names.iter().map(|s| s.as_str()).collect(); + let name = pick_typewriter_name(&branch_name_refs, &mut rng); + assert!(name.is_none()); + } + + #[gpui::test(iterations = 10)] + fn test_pick_typewriter_name_ignores_branches_without_hyphen(mut rng: StdRng) { + let branch_names = &["main", "develop", "feature"]; + let name = pick_typewriter_name(branch_names, &mut rng); + assert!(name.is_some()); + assert!(TYPEWRITER_NAMES.contains(&name.unwrap())); + } + + #[gpui::test(iterations = 10)] + fn test_generate_branch_name_format(mut rng: StdRng) { + let branch_name = generate_branch_name(&[], &mut rng).unwrap(); + let (prefix, suffix) = branch_name.rsplit_once('-').unwrap(); + assert!(TYPEWRITER_NAMES.contains(&prefix)); + assert_eq!(suffix.len(), 8); + assert!(suffix.chars().all(|c| c.is_ascii_alphanumeric())); + } + + #[gpui::test] + fn test_generate_branch_name_returns_none_when_exhausted(mut rng: StdRng) { + let branch_names: Vec = TYPEWRITER_NAMES + .iter() + .map(|name| format!("{name}-00000000")) + .collect(); + let branch_name_refs: Vec<&str> = branch_names.iter().map(|s| s.as_str()).collect(); + let result = generate_branch_name(&branch_name_refs, &mut rng); + assert!(result.is_none()); + } + + #[gpui::test(iterations = 100)] + fn test_generate_branch_name_never_reuses_taken_prefix(mut rng: StdRng) { + let existing = &["olivetti-123abc", "selectric-def456"]; + let branch_name = generate_branch_name(existing, &mut rng).unwrap(); + let (prefix, _) = branch_name.rsplit_once('-').unwrap(); + assert_ne!(prefix, "olivetti"); + assert_ne!(prefix, "selectric"); + } + + #[gpui::test(iterations = 100)] + fn test_generate_branch_name_avoids_multiple_taken_prefixes(mut rng: StdRng) { + let existing = &[ + "olivetti-aaa11111", + "selectric-bbb22222", + "corona-ccc33333", + "remington-ddd44444", + "underwood-eee55555", + ]; + let taken_prefixes: HashSet<&str> = existing + .iter() + .filter_map(|b| b.rsplit_once('-').map(|(prefix, _)| prefix)) + .collect(); + let branch_name = generate_branch_name(existing, &mut rng).unwrap(); + let (prefix, _) = branch_name.rsplit_once('-').unwrap(); + assert!( + !taken_prefixes.contains(prefix), + "generated prefix {prefix:?} collides with an existing branch" + ); + } + + #[gpui::test(iterations = 100)] + fn test_generate_branch_name_with_varied_hash_suffixes(mut rng: StdRng) { + let existing = &[ + "olivetti-aaaaaaaa", + "olivetti-bbbbbbbb", + "olivetti-cccccccc", + ]; + let branch_name = generate_branch_name(existing, &mut rng).unwrap(); + let (prefix, _) = branch_name.rsplit_once('-').unwrap(); + assert_ne!( + prefix, "olivetti", + "should avoid olivetti regardless of how many variants exist" + ); + } + + #[test] + fn test_typewriter_names_are_valid() { + let mut seen = HashSet::default(); + for &name in TYPEWRITER_NAMES { + assert!( + seen.insert(name), + "duplicate entry in TYPEWRITER_NAMES: {name:?}" + ); + } + + for window in TYPEWRITER_NAMES.windows(2) { + assert!( + window[0] <= window[1], + "TYPEWRITER_NAMES is not sorted: {0:?} should come after {1:?}", + window[1], + window[0], + ); + } + + for &name in TYPEWRITER_NAMES { + assert!( + !name.contains('-'), + "TYPEWRITER_NAMES entry contains a hyphen: {name:?}" + ); + } + + for &name in TYPEWRITER_NAMES { + assert!( + name.chars().all(|c| c.is_lowercase() || !c.is_alphabetic()), + "TYPEWRITER_NAMES entry is not lowercase: {name:?}" + ); + } + } +} diff --git a/crates/agent_ui/src/connection_view.rs b/crates/agent_ui/src/connection_view.rs index 93bf7c98098530b23522c60f987f9e341ebc69ca..07e34ccd56f0bd867135fe62894a5a3ff388c85e 100644 --- a/crates/agent_ui/src/connection_view.rs +++ b/crates/agent_ui/src/connection_view.rs @@ -26,10 +26,10 @@ use fs::Fs; use futures::FutureExt as _; use gpui::{ Action, Animation, AnimationExt, AnyView, App, ClickEvent, ClipboardItem, CursorStyle, - ElementId, Empty, Entity, FocusHandle, Focusable, Hsla, ListOffset, ListState, ObjectFit, - PlatformDisplay, ScrollHandle, SharedString, Subscription, Task, TextStyle, WeakEntity, Window, - WindowHandle, div, ease_in_out, img, linear_color_stop, linear_gradient, list, point, - pulsating_between, + ElementId, Empty, Entity, EventEmitter, FocusHandle, Focusable, Hsla, ListOffset, ListState, + ObjectFit, PlatformDisplay, ScrollHandle, SharedString, Subscription, Task, TextStyle, + WeakEntity, Window, WindowHandle, div, ease_in_out, img, linear_color_stop, linear_gradient, + list, point, pulsating_between, }; use language::Buffer; use language_model::LanguageModelRegistry; @@ -295,6 +295,12 @@ impl Conversation { } } +pub enum AcpServerViewEvent { + ActiveThreadChanged, +} + +impl EventEmitter for ConnectionView {} + pub struct ConnectionView { agent: Rc, agent_server_store: Entity, @@ -386,6 +392,7 @@ impl ConnectionView { if let Some(view) = self.active_thread() { view.focus_handle(cx).focus(window, cx); } + cx.emit(AcpServerViewEvent::ActiveThreadChanged); cx.notify(); } } @@ -524,6 +531,7 @@ impl ConnectionView { } self.server_state = state; + cx.emit(AcpServerViewEvent::ActiveThreadChanged); cx.notify(); } @@ -728,6 +736,14 @@ impl ConnectionView { } let id = current.read(cx).thread.read(cx).session_id().clone(); + let session_list = if connection.supports_session_history() { + connection.session_list(cx) + } else { + None + }; + this.history.update(cx, |history, cx| { + history.set_session_list(session_list, cx); + }); this.set_server_state( ServerState::Connected(ConnectedServerState { connection, @@ -829,18 +845,14 @@ impl ConnectionView { ); }); + if let Some(scroll_position) = thread.read(cx).ui_scroll_position() { + list_state.scroll_to(scroll_position); + } + AgentDiff::set_active_thread(&self.workspace, thread.clone(), window, cx); let connection = thread.read(cx).connection().clone(); let session_id = thread.read(cx).session_id().clone(); - let session_list = if connection.supports_session_history() { - connection.session_list(cx) - } else { - None - }; - self.history.update(cx, |history, cx| { - history.set_session_list(session_list, cx); - }); // Check for config options first // Config options take precedence over legacy mode/model selectors @@ -2835,6 +2847,33 @@ pub(crate) mod tests { }); } + #[gpui::test] + async fn test_new_thread_creation_triggers_session_list_refresh(cx: &mut TestAppContext) { + init_test(cx); + + let session = AgentSessionInfo::new(SessionId::new("history-session")); + let (thread_view, history, cx) = setup_thread_view_with_history( + StubAgentServer::new(SessionHistoryConnection::new(vec![session.clone()])), + cx, + ) + .await; + + history.read_with(cx, |history, _cx| { + assert!( + history.has_session_list(), + "session list should be attached after thread creation" + ); + }); + + active_thread(&thread_view, cx).read_with(cx, |view, _cx| { + assert_eq!(view.recent_history_entries.len(), 1); + assert_eq!( + view.recent_history_entries[0].session_id, + session.session_id + ); + }); + } + #[gpui::test] async fn test_resume_without_history_adds_notice(cx: &mut TestAppContext) { init_test(cx); @@ -3482,6 +3521,18 @@ pub(crate) mod tests { agent: impl AgentServer + 'static, cx: &mut TestAppContext, ) -> (Entity, &mut VisualTestContext) { + let (thread_view, _history, cx) = setup_thread_view_with_history(agent, cx).await; + (thread_view, cx) + } + + async fn setup_thread_view_with_history( + agent: impl AgentServer + 'static, + cx: &mut TestAppContext, + ) -> ( + Entity, + Entity, + &mut VisualTestContext, + ) { let fs = FakeFs::new(cx.executor()); let project = Project::test(fs, [], cx).await; let (multi_workspace, cx) = @@ -3501,14 +3552,14 @@ pub(crate) mod tests { project, Some(thread_store), None, - history, + history.clone(), window, cx, ) }) }); cx.run_until_parked(); - (thread_view, cx) + (thread_view, history, cx) } fn add_to_workspace(thread_view: Entity, cx: &mut VisualTestContext) { @@ -3648,6 +3699,102 @@ pub(crate) mod tests { ) -> Task> { Task::ready(Ok(AgentSessionListResponse::new(self.sessions.clone()))) } + + fn into_any(self: Rc) -> Rc { + self + } + } + + #[derive(Clone)] + struct SessionHistoryConnection { + sessions: Vec, + } + + impl SessionHistoryConnection { + fn new(sessions: Vec) -> Self { + Self { sessions } + } + } + + fn build_test_thread( + connection: Rc, + project: Entity, + name: &'static str, + session_id: SessionId, + cx: &mut App, + ) -> Entity { + let action_log = cx.new(|_| ActionLog::new(project.clone())); + cx.new(|cx| { + AcpThread::new( + None, + name, + connection, + project, + action_log, + session_id, + watch::Receiver::constant( + acp::PromptCapabilities::new() + .image(true) + .audio(true) + .embedded_context(true), + ), + cx, + ) + }) + } + + impl AgentConnection for SessionHistoryConnection { + fn telemetry_id(&self) -> SharedString { + "history-connection".into() + } + + fn new_session( + self: Rc, + project: Entity, + _cwd: &Path, + cx: &mut App, + ) -> Task>> { + let thread = build_test_thread( + self, + project, + "SessionHistoryConnection", + SessionId::new("history-session"), + cx, + ); + Task::ready(Ok(thread)) + } + + fn supports_load_session(&self) -> bool { + true + } + + fn session_list(&self, _cx: &mut App) -> Option> { + Some(Rc::new(StubSessionList::new(self.sessions.clone()))) + } + + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] + } + + fn authenticate( + &self, + _method_id: acp::AuthMethodId, + _cx: &mut App, + ) -> Task> { + Task::ready(Ok(())) + } + + fn prompt( + &self, + _id: Option, + _params: acp::PromptRequest, + _cx: &mut App, + ) -> Task> { + Task::ready(Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))) + } + + fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {} + fn into_any(self: Rc) -> Rc { self } @@ -3667,24 +3814,13 @@ pub(crate) mod tests { _cwd: &Path, cx: &mut gpui::App, ) -> Task>> { - let action_log = cx.new(|_| ActionLog::new(project.clone())); - let thread = cx.new(|cx| { - AcpThread::new( - None, - "ResumeOnlyAgentConnection", - self.clone(), - project, - action_log, - SessionId::new("new-session"), - watch::Receiver::constant( - acp::PromptCapabilities::new() - .image(true) - .audio(true) - .embedded_context(true), - ), - cx, - ) - }); + let thread = build_test_thread( + self, + project, + "ResumeOnlyAgentConnection", + SessionId::new("new-session"), + cx, + ); Task::ready(Ok(thread)) } @@ -3699,24 +3835,13 @@ pub(crate) mod tests { _cwd: &Path, cx: &mut App, ) -> Task>> { - let action_log = cx.new(|_| ActionLog::new(project.clone())); - let thread = cx.new(|cx| { - AcpThread::new( - None, - "ResumeOnlyAgentConnection", - self.clone(), - project, - action_log, - session.session_id, - watch::Receiver::constant( - acp::PromptCapabilities::new() - .image(true) - .audio(true) - .embedded_context(true), - ), - cx, - ) - }); + let thread = build_test_thread( + self, + project, + "ResumeOnlyAgentConnection", + session.session_id, + cx, + ); Task::ready(Ok(thread)) } diff --git a/crates/agent_ui/src/connection_view/thread_view.rs b/crates/agent_ui/src/connection_view/thread_view.rs index 2544305bc8f8666b897d11285ffa7711f3af8794..8a1a7d2ea5b0f01ba559e83051861b9d6324985f 100644 --- a/crates/agent_ui/src/connection_view/thread_view.rs +++ b/crates/agent_ui/src/connection_view/thread_view.rs @@ -1,6 +1,8 @@ use acp_thread::ContentBlock; use cloud_api_types::{SubmitAgentThreadFeedbackBody, SubmitAgentThreadFeedbackCommentsBody}; use editor::actions::OpenExcerpts; + +use crate::StartThreadIn; use gpui::{Corner, List}; use language_model::{LanguageModelEffortLevel, Speed}; use settings::update_settings_file; @@ -191,6 +193,12 @@ impl DiffStats { } } +pub enum AcpThreadViewEvent { + FirstSendRequested { content: Vec }, +} + +impl EventEmitter for ThreadView {} + pub struct ThreadView { pub id: acp::SessionId, pub parent_id: Option, @@ -240,7 +248,8 @@ pub struct ThreadView { pub resumed_without_history: bool, pub resume_thread_metadata: Option, pub _cancel_task: Option>, - _draft_save_task: Option>, + _save_task: Option>, + _draft_resolve_task: Option>, pub skip_queue_processing_count: usize, pub user_interrupted_generation: bool, pub can_fast_track_queue: bool, @@ -388,7 +397,7 @@ impl ThreadView { } else { Some(editor.update(cx, |editor, cx| editor.draft_contents(cx))) }; - this._draft_save_task = Some(cx.spawn(async move |this, cx| { + this._draft_resolve_task = Some(cx.spawn(async move |this, cx| { let draft = if let Some(task) = draft_contents_task { let blocks = task.await.ok().filter(|b| !b.is_empty()); blocks @@ -399,15 +408,7 @@ impl ThreadView { this.thread.update(cx, |thread, _cx| { thread.set_draft_prompt(draft); }); - }) - .ok(); - cx.background_executor() - .timer(SERIALIZATION_THROTTLE_TIME) - .await; - this.update(cx, |this, cx| { - if let Some(thread) = this.as_native_thread(cx) { - thread.update(cx, |_thread, cx| cx.notify()); - } + this.schedule_save(cx); }) .ok(); })); @@ -463,7 +464,8 @@ impl ThreadView { is_loading_contents: false, new_server_version_available: None, _cancel_task: None, - _draft_save_task: None, + _save_task: None, + _draft_resolve_task: None, skip_queue_processing_count: 0, user_interrupted_generation: false, can_fast_track_queue: false, @@ -479,12 +481,50 @@ impl ThreadView { _history_subscription: history_subscription, show_codex_windows_warning, }; + let list_state_for_scroll = this.list_state.clone(); + let thread_view = cx.entity().downgrade(); + this.list_state + .set_scroll_handler(move |_event, _window, cx| { + let list_state = list_state_for_scroll.clone(); + let thread_view = thread_view.clone(); + // N.B. We must defer because the scroll handler is called while the + // ListState's RefCell is mutably borrowed. Reading logical_scroll_top() + // directly would panic from a double borrow. + cx.defer(move |cx| { + let scroll_top = list_state.logical_scroll_top(); + let _ = thread_view.update(cx, |this, cx| { + if let Some(thread) = this.as_native_thread(cx) { + thread.update(cx, |thread, _cx| { + thread.set_ui_scroll_position(Some(scroll_top)); + }); + } + this.schedule_save(cx); + }); + }); + }); + if should_auto_submit { this.send(window, cx); } this } + /// Schedule a throttled save of the thread state (draft prompt, scroll position, etc.). + /// Multiple calls within `SERIALIZATION_THROTTLE_TIME` are coalesced into a single save. + fn schedule_save(&mut self, cx: &mut Context) { + self._save_task = Some(cx.spawn(async move |this, cx| { + cx.background_executor() + .timer(SERIALIZATION_THROTTLE_TIME) + .await; + this.update(cx, |this, cx| { + if let Some(thread) = this.as_native_thread(cx) { + thread.update(cx, |_thread, cx| cx.notify()); + } + }) + .ok(); + })); + } + pub fn handle_message_editor_event( &mut self, _editor: &Entity, @@ -518,6 +558,24 @@ impl ThreadView { .thread(acp_thread.session_id(), cx) } + /// Resolves the message editor's contents into content blocks. For profiles + /// that do not enable any tools, directory mentions are expanded to inline + /// file contents since the agent can't read files on its own. + fn resolve_message_contents( + &self, + message_editor: &Entity, + cx: &mut App, + ) -> Task, Vec>)>> { + let expand = self.as_native_thread(cx).is_some_and(|thread| { + let thread = thread.read(cx); + AgentSettings::get_global(cx) + .profiles + .get(thread.profile()) + .is_some_and(|profile| profile.tools.is_empty()) + }); + message_editor.update(cx, |message_editor, cx| message_editor.contents(expand, cx)) + } + pub fn current_model_id(&self, cx: &App) -> Option { let selector = self.model_selector.as_ref()?; let model = selector.read(cx).active_model(cx)?; @@ -731,6 +789,46 @@ impl ThreadView { } let message_editor = self.message_editor.clone(); + + // Intercept the first send so the agent panel can capture the full + // content blocks — needed for "Start thread in New Worktree", + // which must create a workspace before sending the message there. + let intercept_first_send = self.thread.read(cx).entries().is_empty() + && !message_editor.read(cx).is_empty(cx) + && self + .workspace + .upgrade() + .and_then(|workspace| workspace.read(cx).panel::(cx)) + .is_some_and(|panel| { + panel.read(cx).start_thread_in() == &StartThreadIn::NewWorktree + }); + + if intercept_first_send { + let content_task = self.resolve_message_contents(&message_editor, cx); + + cx.spawn(async move |this, cx| match content_task.await { + Ok((content, _tracked_buffers)) => { + if content.is_empty() { + return; + } + + this.update(cx, |_, cx| { + cx.emit(AcpThreadViewEvent::FirstSendRequested { content }); + }) + .ok(); + } + Err(error) => { + this.update(cx, |this, cx| { + this.handle_thread_error(error, cx); + }) + .ok(); + } + }) + .detach(); + + return; + } + let is_editor_empty = message_editor.read(cx).is_empty(cx); let is_generating = thread.read(cx).status() != ThreadStatus::Idle; @@ -794,18 +892,7 @@ impl ThreadView { window: &mut Window, cx: &mut Context, ) { - let full_mention_content = self.as_native_thread(cx).is_some_and(|thread| { - // Include full contents when using minimal profile - let thread = thread.read(cx); - AgentSettings::get_global(cx) - .profiles - .get(thread.profile()) - .is_some_and(|profile| profile.tools.is_empty()) - }); - - let contents = message_editor.update(cx, |message_editor, cx| { - message_editor.contents(full_mention_content, cx) - }); + let contents = self.resolve_message_contents(&message_editor, cx); self.thread_error.take(); self.thread_feedback.clear(); @@ -1140,21 +1227,11 @@ impl ThreadView { let is_idle = self.thread.read(cx).status() == acp_thread::ThreadStatus::Idle; if is_idle { - self.send_impl(message_editor.clone(), window, cx); + self.send_impl(message_editor, window, cx); return; } - let full_mention_content = self.as_native_thread(cx).is_some_and(|thread| { - let thread = thread.read(cx); - AgentSettings::get_global(cx) - .profiles - .get(thread.profile()) - .is_some_and(|profile| profile.tools.is_empty()) - }); - - let contents = message_editor.update(cx, |message_editor, cx| { - message_editor.contents(full_mention_content, cx) - }); + let contents = self.resolve_message_contents(&message_editor, cx); cx.spawn_in(window, async move |this, cx| { let (content, tracked_buffers) = contents.await?; @@ -6691,6 +6768,31 @@ impl ThreadView { .read(cx) .pending_tool_call(thread.read(cx).session_id(), cx); + let session_id = thread.read(cx).session_id().clone(); + + let fullscreen_toggle = h_flex() + .id(entry_ix) + .py_1() + .w_full() + .justify_center() + .border_t_1() + .when(is_failed, |this| this.border_dashed()) + .border_color(self.tool_card_border_color(cx)) + .hover(|s| s.bg(cx.theme().colors().element_hover)) + .child( + Icon::new(IconName::Maximize) + .color(Color::Muted) + .size(IconSize::Small), + ) + .tooltip(Tooltip::text("Make Subagent Full Screen")) + .on_click(cx.listener(move |this, _event, window, cx| { + this.server_view + .update(cx, |this, cx| { + this.navigate_to_session(session_id.clone(), window, cx); + }) + .ok(); + })); + if is_running && let Some((_, subagent_tool_call_id, _)) = pending_tool_call { if let Some((entry_ix, tool_call)) = thread.read(cx).tool_call(&subagent_tool_call_id) @@ -6705,11 +6807,11 @@ impl ThreadView { window, cx, )) + .child(fullscreen_toggle) } else { this } } else { - let session_id = thread.read(cx).session_id().clone(); this.when(is_expanded, |this| { this.child(self.render_subagent_expanded_content( thread_view, @@ -6726,34 +6828,7 @@ impl ThreadView { .title(message), ) }) - .child( - h_flex() - .id(entry_ix) - .py_1() - .w_full() - .justify_center() - .border_t_1() - .when(is_failed, |this| this.border_dashed()) - .border_color(self.tool_card_border_color(cx)) - .hover(|s| s.bg(cx.theme().colors().element_hover)) - .child( - Icon::new(IconName::Maximize) - .color(Color::Muted) - .size(IconSize::Small), - ) - .tooltip(Tooltip::text("Make Subagent Full Screen")) - .on_click(cx.listener(move |this, _event, window, cx| { - this.server_view - .update(cx, |this, cx| { - this.navigate_to_session( - session_id.clone(), - window, - cx, - ); - }) - .ok(); - })), - ) + .child(fullscreen_toggle) }) } }) diff --git a/crates/agent_ui/src/mention_set.rs b/crates/agent_ui/src/mention_set.rs index 58e7e4cdfc196862bb3b8936f8582ba1ad54bda5..792bfc11a63471e02b22835823fa8c59cdfc9bcf 100644 --- a/crates/agent_ui/src/mention_set.rs +++ b/crates/agent_ui/src/mention_set.rs @@ -234,6 +234,8 @@ impl MentionSet { mention_uri.name().into(), IconName::Image.path().into(), mention_uri.tooltip_text(), + Some(mention_uri.clone()), + Some(workspace.downgrade()), Some(image), editor.clone(), window, @@ -247,6 +249,8 @@ impl MentionSet { crease_text, mention_uri.icon_path(cx), mention_uri.tooltip_text(), + Some(mention_uri.clone()), + Some(workspace.downgrade()), None, editor.clone(), window, @@ -699,6 +703,8 @@ pub(crate) async fn insert_images_as_context( MentionUri::PastedImage.name().into(), IconName::Image.path().into(), None, + None, + None, Some(Task::ready(Ok(image.clone())).shared()), editor.clone(), window, @@ -810,6 +816,8 @@ pub(crate) fn insert_crease_for_mention( crease_label: SharedString, crease_icon: SharedString, crease_tooltip: Option, + mention_uri: Option, + workspace: Option>, image: Option, String>>>>, editor: Entity, window: &mut Window, @@ -830,6 +838,8 @@ pub(crate) fn insert_crease_for_mention( crease_label.clone(), crease_icon.clone(), crease_tooltip, + mention_uri.clone(), + workspace.clone(), start..end, rx, image, @@ -1029,6 +1039,8 @@ fn render_mention_fold_button( label: SharedString, icon: SharedString, tooltip: Option, + mention_uri: Option, + workspace: Option>, range: Range, mut loading_finished: postage::barrier::Receiver, image_task: Option, String>>>>, @@ -1049,6 +1061,8 @@ fn render_mention_fold_button( label, icon, tooltip, + mention_uri: mention_uri.clone(), + workspace: workspace.clone(), range, editor, loading: Some(loading), @@ -1063,6 +1077,8 @@ struct LoadingContext { label: SharedString, icon: SharedString, tooltip: Option, + mention_uri: Option, + workspace: Option>, range: Range, editor: WeakEntity, loading: Option>, @@ -1079,6 +1095,8 @@ impl Render for LoadingContext { let id = ElementId::from(("loading_context", self.id)); MentionCrease::new(id, self.icon.clone(), self.label.clone()) + .mention_uri(self.mention_uri.clone()) + .workspace(self.workspace.clone()) .is_toggled(is_in_text_selection) .is_loading(self.loading.is_some()) .when_some(self.tooltip.clone(), |this, tooltip_text| { diff --git a/crates/agent_ui/src/message_editor.rs b/crates/agent_ui/src/message_editor.rs index 50b297847b43e4d147978fbcf14dce492fc572d0..c75d0479b7bf16229cc487544d2c87403b3da430 100644 --- a/crates/agent_ui/src/message_editor.rs +++ b/crates/agent_ui/src/message_editor.rs @@ -154,6 +154,7 @@ impl MessageEditor { Box::new(editor::actions::Copy), ) .action("Paste", Box::new(editor::actions::Paste)) + .action("Paste as Plain Text", Box::new(PasteRaw)) })) }); @@ -722,6 +723,8 @@ impl MessageEditor { crease_text.into(), mention_uri.icon_path(cx), mention_uri.tooltip_text(), + Some(mention_uri.clone()), + Some(self.workspace.clone()), None, self.editor.clone(), window, @@ -833,6 +836,8 @@ impl MessageEditor { mention_uri.name().into(), mention_uri.icon_path(cx), mention_uri.tooltip_text(), + Some(mention_uri.clone()), + Some(self.workspace.clone()), None, self.editor.clone(), window, @@ -1014,6 +1019,8 @@ impl MessageEditor { mention_uri.name().into(), mention_uri.icon_path(cx), mention_uri.tooltip_text(), + Some(mention_uri.clone()), + Some(self.workspace.clone()), None, self.editor.clone(), window, @@ -1370,6 +1377,8 @@ impl MessageEditor { mention_uri.name().into(), mention_uri.icon_path(cx), mention_uri.tooltip_text(), + Some(mention_uri.clone()), + Some(self.workspace.clone()), None, self.editor.clone(), window, diff --git a/crates/agent_ui/src/ui/mention_crease.rs b/crates/agent_ui/src/ui/mention_crease.rs index 2d464039dc552203ad76979239673ec27d5568c7..0a61b8e4ef2ec69714f158a72f83cc0528cc8a8f 100644 --- a/crates/agent_ui/src/ui/mention_crease.rs +++ b/crates/agent_ui/src/ui/mention_crease.rs @@ -1,15 +1,25 @@ -use std::time::Duration; +use std::{ops::RangeInclusive, path::PathBuf, time::Duration}; -use gpui::{Animation, AnimationExt, AnyView, IntoElement, Window, pulsating_between}; +use acp_thread::MentionUri; +use agent_client_protocol as acp; +use editor::{Editor, SelectionEffects, scroll::Autoscroll}; +use gpui::{ + Animation, AnimationExt, AnyView, Context, IntoElement, WeakEntity, Window, pulsating_between, +}; +use prompt_store::PromptId; +use rope::Point; use settings::Settings; use theme::ThemeSettings; use ui::{ButtonLike, TintColor, Tooltip, prelude::*}; +use workspace::{OpenOptions, Workspace}; #[derive(IntoElement)] pub struct MentionCrease { id: ElementId, icon: SharedString, label: SharedString, + mention_uri: Option, + workspace: Option>, is_toggled: bool, is_loading: bool, tooltip: Option, @@ -26,6 +36,8 @@ impl MentionCrease { id: id.into(), icon: icon.into(), label: label.into(), + mention_uri: None, + workspace: None, is_toggled: false, is_loading: false, tooltip: None, @@ -33,6 +45,16 @@ impl MentionCrease { } } + pub fn mention_uri(mut self, mention_uri: Option) -> Self { + self.mention_uri = mention_uri; + self + } + + pub fn workspace(mut self, workspace: Option>) -> Self { + self.workspace = workspace; + self + } + pub fn is_toggled(mut self, is_toggled: bool) -> Self { self.is_toggled = is_toggled; self @@ -76,6 +98,14 @@ impl RenderOnce for MentionCrease { .height(button_height) .selected_style(ButtonStyle::Tinted(TintColor::Accent)) .toggle_state(self.is_toggled) + .when_some( + self.mention_uri.clone().zip(self.workspace.clone()), + |this, (mention_uri, workspace)| { + this.on_click(move |_event, window, cx| { + open_mention_uri(mention_uri.clone(), &workspace, window, cx); + }) + }, + ) .child( h_flex() .pb_px() @@ -114,3 +144,168 @@ impl RenderOnce for MentionCrease { }) } } + +fn open_mention_uri( + mention_uri: MentionUri, + workspace: &WeakEntity, + window: &mut Window, + cx: &mut App, +) { + let Some(workspace) = workspace.upgrade() else { + return; + }; + + workspace.update(cx, |workspace, cx| match mention_uri { + MentionUri::File { abs_path } => { + open_file(workspace, abs_path, None, window, cx); + } + MentionUri::Symbol { + abs_path, + line_range, + .. + } + | MentionUri::Selection { + abs_path: Some(abs_path), + line_range, + } => { + open_file(workspace, abs_path, Some(line_range), window, cx); + } + MentionUri::Directory { abs_path } => { + reveal_in_project_panel(workspace, abs_path, cx); + } + MentionUri::Thread { id, name } => { + open_thread(workspace, id, name, window, cx); + } + MentionUri::TextThread { .. } => {} + MentionUri::Rule { id, .. } => { + open_rule(workspace, id, window, cx); + } + MentionUri::Fetch { url } => { + cx.open_url(url.as_str()); + } + MentionUri::PastedImage + | MentionUri::Selection { abs_path: None, .. } + | MentionUri::Diagnostics { .. } + | MentionUri::TerminalSelection { .. } + | MentionUri::GitDiff { .. } => {} + }); +} + +fn open_file( + workspace: &mut Workspace, + abs_path: PathBuf, + line_range: Option>, + window: &mut Window, + cx: &mut Context, +) { + let project = workspace.project(); + + if let Some(project_path) = + project.update(cx, |project, cx| project.find_project_path(&abs_path, cx)) + { + let item = workspace.open_path(project_path, None, true, window, cx); + if let Some(line_range) = line_range { + window + .spawn(cx, async move |cx| { + let Some(editor) = item.await?.downcast::() else { + return Ok(()); + }; + editor + .update_in(cx, |editor, window, cx| { + let range = Point::new(*line_range.start(), 0) + ..Point::new(*line_range.start(), 0); + editor.change_selections( + SelectionEffects::scroll(Autoscroll::center()), + window, + cx, + |selections| selections.select_ranges(vec![range]), + ); + }) + .ok(); + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + } else { + item.detach_and_log_err(cx); + } + } else if abs_path.exists() { + workspace + .open_abs_path( + abs_path, + OpenOptions { + focus: Some(true), + ..Default::default() + }, + window, + cx, + ) + .detach_and_log_err(cx); + } +} + +fn reveal_in_project_panel( + workspace: &mut Workspace, + abs_path: PathBuf, + cx: &mut Context, +) { + let project = workspace.project(); + let Some(entry_id) = project.update(cx, |project, cx| { + let path = project.find_project_path(&abs_path, cx)?; + project.entry_for_path(&path, cx).map(|entry| entry.id) + }) else { + return; + }; + + project.update(cx, |_, cx| { + cx.emit(project::Event::RevealInProjectPanel(entry_id)); + }); +} + +fn open_thread( + workspace: &mut Workspace, + id: acp::SessionId, + name: String, + window: &mut Window, + cx: &mut Context, +) { + use crate::AgentPanel; + use acp_thread::AgentSessionInfo; + + let Some(panel) = workspace.panel::(cx) else { + return; + }; + + panel.update(cx, |panel, cx| { + panel.load_agent_thread( + AgentSessionInfo { + session_id: id, + cwd: None, + title: Some(name.into()), + updated_at: None, + meta: None, + }, + window, + cx, + ) + }); +} + +fn open_rule( + _workspace: &mut Workspace, + id: PromptId, + window: &mut Window, + cx: &mut Context, +) { + use zed_actions::assistant::OpenRulesLibrary; + + let PromptId::User { uuid } = id else { + return; + }; + + window.dispatch_action( + Box::new(OpenRulesLibrary { + prompt_to_select: Some(uuid.0), + }), + cx, + ); +} diff --git a/crates/auto_update_helper/Cargo.toml b/crates/auto_update_helper/Cargo.toml index 73c38d80dd12e9c42daa42b7e6f2c9d6975cf47b..aa5bf6ac40b0e1ab20cbde510be5d7f389c7ade8 100644 --- a/crates/auto_update_helper/Cargo.toml +++ b/crates/auto_update_helper/Cargo.toml @@ -19,6 +19,7 @@ log.workspace = true simplelog.workspace = true [target.'cfg(target_os = "windows")'.dependencies] +scopeguard = "1.2" windows.workspace = true [target.'cfg(target_os = "windows")'.dev-dependencies] diff --git a/crates/auto_update_helper/src/updater.rs b/crates/auto_update_helper/src/updater.rs index 076e11fb4eef1e5c53e2bdc290be7117330c3e61..7821c908c40873637c4ac3993c320416e2a4b978 100644 --- a/crates/auto_update_helper/src/updater.rs +++ b/crates/auto_update_helper/src/updater.rs @@ -1,13 +1,22 @@ use std::{ + ffi::OsStr, + os::windows::ffi::OsStrExt, path::Path, sync::LazyLock, time::{Duration, Instant}, }; use anyhow::{Context as _, Result}; -use windows::Win32::{ - Foundation::{HWND, LPARAM, WPARAM}, - UI::WindowsAndMessaging::PostMessageW, +use windows::{ + Win32::{ + Foundation::{HWND, LPARAM, WPARAM}, + System::RestartManager::{ + CCH_RM_SESSION_KEY, RmEndSession, RmGetList, RmRegisterResources, RmShutdown, + RmStartSession, + }, + UI::WindowsAndMessaging::PostMessageW, + }, + core::{PCWSTR, PWSTR}, }; use crate::windows_impl::WM_JOB_UPDATED; @@ -262,9 +271,106 @@ pub(crate) static JOBS: LazyLock<[Job; 9]> = LazyLock::new(|| { ] }); +/// Attempts to use Windows Restart Manager to release file handles held by other processes +/// (e.g., Explorer.exe) on the files we need to move during the update. +/// +/// This is a best-effort operation - if it fails, we'll still try the update and rely on +/// the retry logic. +fn release_file_handles(app_dir: &Path) -> Result<()> { + // Files that commonly get locked by Explorer or other processes + let files_to_release = [ + app_dir.join("Zed.exe"), + app_dir.join("bin\\Zed.exe"), + app_dir.join("bin\\zed"), + app_dir.join("conpty.dll"), + ]; + + log::info!("Attempting to release file handles using Restart Manager..."); + + let mut session: u32 = 0; + let mut session_key = [0u16; CCH_RM_SESSION_KEY as usize + 1]; + + // Start a Restart Manager session + let err = unsafe { + RmStartSession( + &mut session, + Some(0), + PWSTR::from_raw(session_key.as_mut_ptr()), + ) + }; + if err.is_err() { + anyhow::bail!("RmStartSession failed: {err:?}"); + } + + // Ensure we end the session when done + let _session_guard = scopeguard::guard(session, |s| { + let _ = unsafe { RmEndSession(s) }; + }); + + // Convert paths to wide strings for Windows API + let wide_paths: Vec> = files_to_release + .iter() + .filter(|p| p.exists()) + .map(|p| { + OsStr::new(p) + .encode_wide() + .chain(std::iter::once(0)) + .collect() + }) + .collect(); + + if wide_paths.is_empty() { + log::info!("No files to release handles for"); + return Ok(()); + } + + let pcwstr_paths: Vec = wide_paths + .iter() + .map(|p| PCWSTR::from_raw(p.as_ptr())) + .collect(); + + // Register the files we want to modify + let err = unsafe { RmRegisterResources(session, Some(&pcwstr_paths), None, None) }; + if err.is_err() { + anyhow::bail!("RmRegisterResources failed: {err:?}"); + } + + // Check if any processes are using these files + let mut needed: u32 = 0; + let mut count: u32 = 0; + let mut reboot_reasons: u32 = 0; + let _ = unsafe { RmGetList(session, &mut needed, &mut count, None, &mut reboot_reasons) }; + + if needed == 0 { + log::info!("No processes are holding handles to the files"); + return Ok(()); + } + + log::info!( + "{} process(es) are holding handles to the files, requesting release...", + needed + ); + + // Request processes to release their handles + // RmShutdown with flags=0 asks applications to release handles gracefully + // For Explorer, this typically releases icon cache handles without closing Explorer + let err = unsafe { RmShutdown(session, 0, None) }; + if err.is_err() { + anyhow::bail!("RmShutdown failed: {:?}", err); + } + + log::info!("Successfully requested handle release"); + Ok(()) +} + pub(crate) fn perform_update(app_dir: &Path, hwnd: Option, launch: bool) -> Result<()> { let hwnd = hwnd.map(|ptr| HWND(ptr as _)); + // Try to release file handles before starting the update + if let Err(e) = release_file_handles(app_dir) { + log::warn!("Restart Manager failed (will continue anyway): {}", e); + } + let mut last_successful_job = None; 'outer: for (i, job) in JOBS.iter().enumerate() { let start = Instant::now(); @@ -279,19 +385,22 @@ pub(crate) fn perform_update(app_dir: &Path, hwnd: Option, launch: bool) unsafe { PostMessageW(hwnd, WM_JOB_UPDATED, WPARAM(0), LPARAM(0))? }; break; } - Err(err) => { - // Check if it's a "not found" error - let io_err = err.downcast_ref::().unwrap(); - if io_err.kind() == std::io::ErrorKind::NotFound { - log::warn!("File or folder not found."); - last_successful_job = Some(i); - unsafe { PostMessageW(hwnd, WM_JOB_UPDATED, WPARAM(0), LPARAM(0))? }; - break; + Err(err) => match err.downcast_ref::() { + Some(io_err) => match io_err.kind() { + std::io::ErrorKind::NotFound => { + log::error!("Operation failed with file not found, aborting: {}", err); + break 'outer; + } + _ => { + log::error!("Operation failed (retrying): {}", err); + std::thread::sleep(Duration::from_millis(50)); + } + }, + None => { + log::error!("Operation failed with unexpected error, aborting: {}", err); + break 'outer; } - - log::error!("Operation failed: {} ({:?})", err, io_err.kind()); - std::thread::sleep(Duration::from_millis(50)); - } + }, } } } diff --git a/crates/cloud_api_client/src/cloud_api_client.rs b/crates/cloud_api_client/src/cloud_api_client.rs index f485e2d20c619715ea342fccd2a5cec0ecaa6f4e..13d67838b216f4990f15ec22c1701aa7aef9dbf2 100644 --- a/crates/cloud_api_client/src/cloud_api_client.rs +++ b/crates/cloud_api_client/src/cloud_api_client.rs @@ -9,7 +9,9 @@ use futures::AsyncReadExt as _; use gpui::{App, Task}; use gpui_tokio::Tokio; use http_client::http::request; -use http_client::{AsyncBody, HttpClientWithUrl, HttpRequestExt, Method, Request, StatusCode}; +use http_client::{ + AsyncBody, HttpClientWithUrl, HttpRequestExt, Json, Method, Request, StatusCode, +}; use parking_lot::RwLock; use thiserror::Error; use yawc::WebSocket; @@ -141,6 +143,7 @@ impl CloudApiClient { pub async fn create_llm_token( &self, system_id: Option, + organization_id: Option, ) -> Result { let request_builder = Request::builder() .method(Method::POST) @@ -153,7 +156,10 @@ impl CloudApiClient { builder.header(ZED_SYSTEM_ID_HEADER_NAME, system_id) }); - let request = self.build_request(request_builder, AsyncBody::default())?; + let request = self.build_request( + request_builder, + Json(CreateLlmTokenBody { organization_id }), + )?; let mut response = self.http_client.send(request).await?; diff --git a/crates/cloud_api_types/src/cloud_api_types.rs b/crates/cloud_api_types/src/cloud_api_types.rs index 2d457fc6630d5b32f049e67a6a460047e925973a..42d3442bfc016f5cb1a39ba421ccdfe386bcbc65 100644 --- a/crates/cloud_api_types/src/cloud_api_types.rs +++ b/crates/cloud_api_types/src/cloud_api_types.rs @@ -52,6 +52,12 @@ pub struct AcceptTermsOfServiceResponse { #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] pub struct LlmToken(pub String); +#[derive(Debug, Default, PartialEq, Clone, Serialize, Deserialize)] +pub struct CreateLlmTokenBody { + #[serde(default)] + pub organization_id: Option, +} + #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] pub struct CreateLlmTokenResponse { pub token: LlmToken, diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 71e39fb595656e0dcdc53d97705b87a216ceb0f3..3e4b5c2ce211f68ef7e12895b542db5e6e3ea47c 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -122,6 +122,8 @@ CREATE TABLE "project_repository_statuses" ( "status_kind" INT4 NOT NULL, "first_status" INT4 NULL, "second_status" INT4 NULL, + "lines_added" INT4 NULL, + "lines_deleted" INT4 NULL, "scan_id" INT8 NOT NULL, "is_deleted" BOOL NOT NULL, PRIMARY KEY (project_id, repository_id, repo_path) diff --git a/crates/collab/migrations/20251208000000_test_schema.sql b/crates/collab/migrations/20251208000000_test_schema.sql index 493be3823e25a433d4a6a27a21c508f218dc68d1..0f4e4f2d2e3925ea1e4d2b964c5e4f159f393b4f 100644 --- a/crates/collab/migrations/20251208000000_test_schema.sql +++ b/crates/collab/migrations/20251208000000_test_schema.sql @@ -315,6 +315,8 @@ CREATE TABLE public.project_repository_statuses ( status_kind integer NOT NULL, first_status integer, second_status integer, + lines_added integer, + lines_deleted integer, scan_id bigint NOT NULL, is_deleted boolean NOT NULL ); diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 57fb0df86495dc2013e7cd780c2e62e57298bd11..d8803c253f5feef8ef5e040f3ea112abcc688f52 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -732,6 +732,8 @@ fn db_status_to_proto( status: Some(proto::GitFileStatus { variant: Some(variant), }), + diff_stat_added: entry.lines_added.map(|v| v as u32), + diff_stat_deleted: entry.lines_deleted.map(|v| v as u32), }) } diff --git a/crates/collab/src/db/queries/projects.rs b/crates/collab/src/db/queries/projects.rs index ed6325c62173358c8deac2dcd6289ce0b8ae5e71..24cf639a715aa9b88da80375b389debaea0c4295 100644 --- a/crates/collab/src/db/queries/projects.rs +++ b/crates/collab/src/db/queries/projects.rs @@ -334,147 +334,6 @@ impl Database { .await?; } - // Backward-compatibility for old Zed clients. - // - // Remove this block when Zed 1.80 stable has been out for a week. - { - if !update.updated_repositories.is_empty() { - project_repository::Entity::insert_many( - update.updated_repositories.iter().map(|repository| { - project_repository::ActiveModel { - project_id: ActiveValue::set(project_id), - legacy_worktree_id: ActiveValue::set(Some(worktree_id)), - id: ActiveValue::set(repository.repository_id as i64), - scan_id: ActiveValue::set(update.scan_id as i64), - is_deleted: ActiveValue::set(false), - branch_summary: ActiveValue::Set( - repository - .branch_summary - .as_ref() - .map(|summary| serde_json::to_string(summary).unwrap()), - ), - current_merge_conflicts: ActiveValue::Set(Some( - serde_json::to_string(&repository.current_merge_conflicts) - .unwrap(), - )), - // Old clients do not use abs path, entry ids, head_commit_details, or merge_message. - abs_path: ActiveValue::set(String::new()), - entry_ids: ActiveValue::set("[]".into()), - head_commit_details: ActiveValue::set(None), - merge_message: ActiveValue::set(None), - remote_upstream_url: ActiveValue::set(None), - remote_origin_url: ActiveValue::set(None), - } - }), - ) - .on_conflict( - OnConflict::columns([ - project_repository::Column::ProjectId, - project_repository::Column::Id, - ]) - .update_columns([ - project_repository::Column::ScanId, - project_repository::Column::BranchSummary, - project_repository::Column::CurrentMergeConflicts, - ]) - .to_owned(), - ) - .exec(&*tx) - .await?; - - let has_any_statuses = update - .updated_repositories - .iter() - .any(|repository| !repository.updated_statuses.is_empty()); - - if has_any_statuses { - project_repository_statuses::Entity::insert_many( - update.updated_repositories.iter().flat_map( - |repository: &proto::RepositoryEntry| { - repository.updated_statuses.iter().map(|status_entry| { - let (repo_path, status_kind, first_status, second_status) = - proto_status_to_db(status_entry.clone()); - project_repository_statuses::ActiveModel { - project_id: ActiveValue::set(project_id), - repository_id: ActiveValue::set( - repository.repository_id as i64, - ), - scan_id: ActiveValue::set(update.scan_id as i64), - is_deleted: ActiveValue::set(false), - repo_path: ActiveValue::set(repo_path), - status: ActiveValue::set(0), - status_kind: ActiveValue::set(status_kind), - first_status: ActiveValue::set(first_status), - second_status: ActiveValue::set(second_status), - } - }) - }, - ), - ) - .on_conflict( - OnConflict::columns([ - project_repository_statuses::Column::ProjectId, - project_repository_statuses::Column::RepositoryId, - project_repository_statuses::Column::RepoPath, - ]) - .update_columns([ - project_repository_statuses::Column::ScanId, - project_repository_statuses::Column::StatusKind, - project_repository_statuses::Column::FirstStatus, - project_repository_statuses::Column::SecondStatus, - ]) - .to_owned(), - ) - .exec(&*tx) - .await?; - } - - for repo in &update.updated_repositories { - if !repo.removed_statuses.is_empty() { - project_repository_statuses::Entity::update_many() - .filter( - project_repository_statuses::Column::ProjectId - .eq(project_id) - .and( - project_repository_statuses::Column::RepositoryId - .eq(repo.repository_id), - ) - .and( - project_repository_statuses::Column::RepoPath - .is_in(repo.removed_statuses.iter()), - ), - ) - .set(project_repository_statuses::ActiveModel { - is_deleted: ActiveValue::Set(true), - scan_id: ActiveValue::Set(update.scan_id as i64), - ..Default::default() - }) - .exec(&*tx) - .await?; - } - } - } - - if !update.removed_repositories.is_empty() { - project_repository::Entity::update_many() - .filter( - project_repository::Column::ProjectId - .eq(project_id) - .and(project_repository::Column::LegacyWorktreeId.eq(worktree_id)) - .and(project_repository::Column::Id.is_in( - update.removed_repositories.iter().map(|id| *id as i64), - )), - ) - .set(project_repository::ActiveModel { - is_deleted: ActiveValue::Set(true), - scan_id: ActiveValue::Set(update.scan_id as i64), - ..Default::default() - }) - .exec(&*tx) - .await?; - } - } - let connection_ids = self.project_guest_connection_ids(project_id, &tx).await?; Ok(connection_ids) }) @@ -552,6 +411,12 @@ impl Database { status_kind: ActiveValue::set(status_kind), first_status: ActiveValue::set(first_status), second_status: ActiveValue::set(second_status), + lines_added: ActiveValue::set( + status_entry.diff_stat_added.map(|v| v as i32), + ), + lines_deleted: ActiveValue::set( + status_entry.diff_stat_deleted.map(|v| v as i32), + ), } }), ) @@ -566,6 +431,8 @@ impl Database { project_repository_statuses::Column::StatusKind, project_repository_statuses::Column::FirstStatus, project_repository_statuses::Column::SecondStatus, + project_repository_statuses::Column::LinesAdded, + project_repository_statuses::Column::LinesDeleted, ]) .to_owned(), ) @@ -1002,7 +869,7 @@ impl Database { repositories.push(proto::UpdateRepository { project_id: db_repository_entry.project_id.0 as u64, id: db_repository_entry.id as u64, - abs_path: db_repository_entry.abs_path, + abs_path: db_repository_entry.abs_path.clone(), entry_ids, updated_statuses, removed_statuses: Vec::new(), @@ -1015,6 +882,7 @@ impl Database { stash_entries: Vec::new(), remote_upstream_url: db_repository_entry.remote_upstream_url.clone(), remote_origin_url: db_repository_entry.remote_origin_url.clone(), + original_repo_abs_path: Some(db_repository_entry.abs_path), }); } } diff --git a/crates/collab/src/db/queries/rooms.rs b/crates/collab/src/db/queries/rooms.rs index d8fca0306f5b2ae5668a735db578061275192b58..b4cbd83167b227542d8de1022b7e2cf49f5a7645 100644 --- a/crates/collab/src/db/queries/rooms.rs +++ b/crates/collab/src/db/queries/rooms.rs @@ -738,7 +738,7 @@ impl Database { while let Some(db_status) = db_statuses.next().await { let db_status: project_repository_statuses::Model = db_status?; if db_status.is_deleted { - removed_statuses.push(db_status.repo_path); + removed_statuses.push(db_status.repo_path.clone()); } else { updated_statuses.push(db_status_to_proto(db_status)?); } @@ -791,13 +791,14 @@ impl Database { head_commit_details, project_id: project_id.to_proto(), id: db_repository.id as u64, - abs_path: db_repository.abs_path, + abs_path: db_repository.abs_path.clone(), scan_id: db_repository.scan_id as u64, is_last_update: true, merge_message: db_repository.merge_message, stash_entries: Vec::new(), remote_upstream_url: db_repository.remote_upstream_url.clone(), remote_origin_url: db_repository.remote_origin_url.clone(), + original_repo_abs_path: Some(db_repository.abs_path), }); } } diff --git a/crates/collab/src/db/tables/project_repository_statuses.rs b/crates/collab/src/db/tables/project_repository_statuses.rs index 7bb903d45085467a3285a58f8afdd7a29339731a..8160d8a03c2a3b4dd0db7675489eeafcef020a9a 100644 --- a/crates/collab/src/db/tables/project_repository_statuses.rs +++ b/crates/collab/src/db/tables/project_repository_statuses.rs @@ -17,6 +17,8 @@ pub struct Model { pub first_status: Option, /// For unmerged entries, this is the `second_head` status. For tracked entries, this is the `worktree_status`. pub second_status: Option, + pub lines_added: Option, + pub lines_deleted: Option, pub scan_id: i64, pub is_deleted: bool, } diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 087dbe2a0ba23851689e75401c62b64775cf2282..b521f6b083ae311d98ec46c900ce821fd8042e4a 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -437,6 +437,8 @@ impl Server { .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) + .add_request_handler(forward_read_only_project_request::) + .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) .add_message_handler(broadcast_project_message_from_host::) .add_message_handler(update_context) diff --git a/crates/collab/tests/integration/git_tests.rs b/crates/collab/tests/integration/git_tests.rs index f3abb5bc3f3e1a12e7ecb56c985f2cff46582cee..dccc99a07769e66a3eb318a8201d8e14a29ef4f2 100644 --- a/crates/collab/tests/integration/git_tests.rs +++ b/crates/collab/tests/integration/git_tests.rs @@ -1,17 +1,40 @@ -use std::path::Path; +use std::path::{Path, PathBuf}; use call::ActiveCall; -use git::status::{FileStatus, StatusCode, TrackedStatus}; -use git_ui::project_diff::ProjectDiff; -use gpui::{AppContext as _, TestAppContext, VisualTestContext}; +use collections::HashMap; +use git::{ + repository::RepoPath, + status::{DiffStat, FileStatus, StatusCode, TrackedStatus}, +}; +use git_ui::{git_panel::GitPanel, project_diff::ProjectDiff}; +use gpui::{AppContext as _, BackgroundExecutor, TestAppContext, VisualTestContext}; use project::ProjectPath; use serde_json::json; + use util::{path, rel_path::rel_path}; use workspace::{MultiWorkspace, Workspace}; -// use crate::TestServer; +fn collect_diff_stats( + panel: &gpui::Entity, + cx: &C, +) -> HashMap { + panel.read_with(cx, |panel, cx| { + let Some(repo) = panel.active_repository() else { + return HashMap::default(); + }; + let snapshot = repo.read(cx).snapshot(); + let mut stats = HashMap::default(); + for entry in snapshot.statuses_by_path.iter() { + if let Some(diff_stat) = entry.diff_stat { + stats.insert(entry.repo_path.clone(), diff_stat); + } + } + stats + }) +} + #[gpui::test] async fn test_project_diff(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { let mut server = TestServer::start(cx_a.background_executor.clone()).await; @@ -141,3 +164,337 @@ async fn test_project_diff(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) ); }); } + +#[gpui::test] +async fn test_remote_git_worktrees( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + + client_a + .fs() + .insert_tree( + path!("/project"), + json!({ ".git": {}, "file.txt": "content" }), + ) + .await; + + let (project_a, _) = client_a.build_local_project(path!("/project"), cx_a).await; + + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + let project_b = client_b.join_remote_project(project_id, cx_b).await; + + executor.run_until_parked(); + + let repo_b = cx_b.update(|cx| project_b.read(cx).active_repository(cx).unwrap()); + + // Initially only the main worktree (the repo itself) should be present + let worktrees = cx_b + .update(|cx| repo_b.update(cx, |repository, _| repository.worktrees())) + .await + .unwrap() + .unwrap(); + assert_eq!(worktrees.len(), 1); + assert_eq!(worktrees[0].path, PathBuf::from(path!("/project"))); + + // Client B creates a git worktree via the remote project + let worktree_directory = PathBuf::from(path!("/project")); + cx_b.update(|cx| { + repo_b.update(cx, |repository, _| { + repository.create_worktree( + "feature-branch".to_string(), + worktree_directory.clone(), + Some("abc123".to_string()), + ) + }) + }) + .await + .unwrap() + .unwrap(); + + executor.run_until_parked(); + + // Client B lists worktrees — should see main + the one just created + let worktrees = cx_b + .update(|cx| repo_b.update(cx, |repository, _| repository.worktrees())) + .await + .unwrap() + .unwrap(); + assert_eq!(worktrees.len(), 2); + assert_eq!(worktrees[0].path, PathBuf::from(path!("/project"))); + assert_eq!(worktrees[1].path, worktree_directory.join("feature-branch")); + assert_eq!(worktrees[1].ref_name.as_ref(), "refs/heads/feature-branch"); + assert_eq!(worktrees[1].sha.as_ref(), "abc123"); + + // Verify from the host side that the worktree was actually created + let host_worktrees = { + let repo_a = cx_a.update(|cx| { + project_a + .read(cx) + .repositories(cx) + .values() + .next() + .unwrap() + .clone() + }); + cx_a.update(|cx| repo_a.update(cx, |repository, _| repository.worktrees())) + .await + .unwrap() + .unwrap() + }; + assert_eq!(host_worktrees.len(), 2); + assert_eq!(host_worktrees[0].path, PathBuf::from(path!("/project"))); + assert_eq!( + host_worktrees[1].path, + worktree_directory.join("feature-branch") + ); + + // Client B creates a second git worktree without an explicit commit + cx_b.update(|cx| { + repo_b.update(cx, |repository, _| { + repository.create_worktree( + "bugfix-branch".to_string(), + worktree_directory.clone(), + None, + ) + }) + }) + .await + .unwrap() + .unwrap(); + + executor.run_until_parked(); + + // Client B lists worktrees — should now have main + two created + let worktrees = cx_b + .update(|cx| repo_b.update(cx, |repository, _| repository.worktrees())) + .await + .unwrap() + .unwrap(); + assert_eq!(worktrees.len(), 3); + + let feature_worktree = worktrees + .iter() + .find(|worktree| worktree.ref_name.as_ref() == "refs/heads/feature-branch") + .expect("should find feature-branch worktree"); + assert_eq!( + feature_worktree.path, + worktree_directory.join("feature-branch") + ); + + let bugfix_worktree = worktrees + .iter() + .find(|worktree| worktree.ref_name.as_ref() == "refs/heads/bugfix-branch") + .expect("should find bugfix-branch worktree"); + assert_eq!( + bugfix_worktree.path, + worktree_directory.join("bugfix-branch") + ); + assert_eq!(bugfix_worktree.sha.as_ref(), "fake-sha"); +} + +#[gpui::test] +async fn test_diff_stat_sync_between_host_and_downstream_client( + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, + cx_c: &mut TestAppContext, +) { + let mut server = TestServer::start(cx_a.background_executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + let client_c = server.create_client(cx_c, "user_c").await; + + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b), (&client_c, cx_c)]) + .await; + + let fs = client_a.fs(); + fs.insert_tree( + path!("/code"), + json!({ + "project1": { + ".git": {}, + "src": { + "lib.rs": "line1\nline2\nline3\n", + "new_file.rs": "added1\nadded2\n", + }, + "README.md": "# project 1", + } + }), + ) + .await; + + let dot_git = Path::new(path!("/code/project1/.git")); + fs.set_head_for_repo( + dot_git, + &[ + ("src/lib.rs", "line1\nold_line2\n".into()), + ("src/deleted.rs", "was_here\n".into()), + ], + "deadbeef", + ); + fs.set_index_for_repo( + dot_git, + &[ + ("src/lib.rs", "line1\nold_line2\nline3\nline4\n".into()), + ("src/staged_only.rs", "x\ny\n".into()), + ("src/new_file.rs", "added1\nadded2\n".into()), + ("README.md", "# project 1".into()), + ], + ); + + let (project_a, worktree_id) = client_a + .build_local_project(path!("/code/project1"), cx_a) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + let project_b = client_b.join_remote_project(project_id, cx_b).await; + let _project_c = client_c.join_remote_project(project_id, cx_c).await; + cx_a.run_until_parked(); + + let (workspace_a, cx_a) = client_a.build_workspace(&project_a, cx_a); + let (workspace_b, cx_b) = client_b.build_workspace(&project_b, cx_b); + + let panel_a = workspace_a.update_in(cx_a, GitPanel::new_test); + workspace_a.update_in(cx_a, |workspace, window, cx| { + workspace.add_panel(panel_a.clone(), window, cx); + }); + + let panel_b = workspace_b.update_in(cx_b, GitPanel::new_test); + workspace_b.update_in(cx_b, |workspace, window, cx| { + workspace.add_panel(panel_b.clone(), window, cx); + }); + + cx_a.run_until_parked(); + + let stats_a = collect_diff_stats(&panel_a, cx_a); + let stats_b = collect_diff_stats(&panel_b, cx_b); + + let mut expected: HashMap = HashMap::default(); + expected.insert( + RepoPath::new("src/lib.rs").unwrap(), + DiffStat { + added: 3, + deleted: 2, + }, + ); + expected.insert( + RepoPath::new("src/deleted.rs").unwrap(), + DiffStat { + added: 0, + deleted: 1, + }, + ); + expected.insert( + RepoPath::new("src/new_file.rs").unwrap(), + DiffStat { + added: 2, + deleted: 0, + }, + ); + expected.insert( + RepoPath::new("README.md").unwrap(), + DiffStat { + added: 1, + deleted: 0, + }, + ); + assert_eq!(stats_a, expected, "host diff stats should match expected"); + assert_eq!(stats_a, stats_b, "host and remote should agree"); + + let buffer_a = project_a + .update(cx_a, |p, cx| { + p.open_buffer((worktree_id, rel_path("src/lib.rs")), cx) + }) + .await + .unwrap(); + + let _buffer_b = project_b + .update(cx_b, |p, cx| { + p.open_buffer((worktree_id, rel_path("src/lib.rs")), cx) + }) + .await + .unwrap(); + cx_a.run_until_parked(); + + buffer_a.update(cx_a, |buf, cx| { + buf.edit([(buf.len()..buf.len(), "line4\n")], None, cx); + }); + project_a + .update(cx_a, |project, cx| { + project.save_buffer(buffer_a.clone(), cx) + }) + .await + .unwrap(); + cx_a.run_until_parked(); + + let stats_a = collect_diff_stats(&panel_a, cx_a); + let stats_b = collect_diff_stats(&panel_b, cx_b); + + let mut expected_after_edit = expected.clone(); + expected_after_edit.insert( + RepoPath::new("src/lib.rs").unwrap(), + DiffStat { + added: 4, + deleted: 2, + }, + ); + assert_eq!( + stats_a, expected_after_edit, + "host diff stats should reflect the edit" + ); + assert_eq!( + stats_b, expected_after_edit, + "remote diff stats should reflect the host's edit" + ); + + let active_call_b = cx_b.read(ActiveCall::global); + active_call_b + .update(cx_b, |call, cx| call.hang_up(cx)) + .await + .unwrap(); + cx_a.run_until_parked(); + + let user_id_b = client_b.current_user_id(cx_b).to_proto(); + active_call_a + .update(cx_a, |call, cx| call.invite(user_id_b, None, cx)) + .await + .unwrap(); + cx_b.run_until_parked(); + let active_call_b = cx_b.read(ActiveCall::global); + active_call_b + .update(cx_b, |call, cx| call.accept_incoming(cx)) + .await + .unwrap(); + cx_a.run_until_parked(); + + let project_b = client_b.join_remote_project(project_id, cx_b).await; + cx_a.run_until_parked(); + + let (workspace_b, cx_b) = client_b.build_workspace(&project_b, cx_b); + let panel_b = workspace_b.update_in(cx_b, GitPanel::new_test); + workspace_b.update_in(cx_b, |workspace, window, cx| { + workspace.add_panel(panel_b.clone(), window, cx); + }); + cx_b.run_until_parked(); + + let stats_b = collect_diff_stats(&panel_b, cx_b); + assert_eq!( + stats_b, expected_after_edit, + "remote diff stats should be restored from the database after rejoining the call" + ); +} diff --git a/crates/collab/tests/integration/integration_tests.rs b/crates/collab/tests/integration/integration_tests.rs index c26f20c1e294326f275dbfda1d2d41603719cd3e..3bad9c82c26392a935f67efc578b5d293b2cab3d 100644 --- a/crates/collab/tests/integration/integration_tests.rs +++ b/crates/collab/tests/integration/integration_tests.rs @@ -7205,3 +7205,89 @@ async fn test_remote_git_branches( assert_eq!(host_branch.name(), "totally-new-branch"); } + +#[gpui::test] +async fn test_guest_can_rejoin_shared_project_after_leaving_call( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, + cx_c: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + let client_c = server.create_client(cx_c, "user_c").await; + + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b), (&client_c, cx_c)]) + .await; + + client_a + .fs() + .insert_tree( + path!("/project"), + json!({ + "file.txt": "hello\n", + }), + ) + .await; + + let (project_a, _worktree_id) = client_a.build_local_project(path!("/project"), cx_a).await; + let active_call_a = cx_a.read(ActiveCall::global); + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + + let _project_b = client_b.join_remote_project(project_id, cx_b).await; + executor.run_until_parked(); + + // third client joins call to prevent room from being torn down + let _project_c = client_c.join_remote_project(project_id, cx_c).await; + executor.run_until_parked(); + + let active_call_b = cx_b.read(ActiveCall::global); + active_call_b + .update(cx_b, |call, cx| call.hang_up(cx)) + .await + .unwrap(); + executor.run_until_parked(); + + let user_id_b = client_b.current_user_id(cx_b).to_proto(); + let active_call_a = cx_a.read(ActiveCall::global); + active_call_a + .update(cx_a, |call, cx| call.invite(user_id_b, None, cx)) + .await + .unwrap(); + executor.run_until_parked(); + let active_call_b = cx_b.read(ActiveCall::global); + active_call_b + .update(cx_b, |call, cx| call.accept_incoming(cx)) + .await + .unwrap(); + executor.run_until_parked(); + + let _project_b2 = client_b.join_remote_project(project_id, cx_b).await; + executor.run_until_parked(); + + project_a.read_with(cx_a, |project, _| { + let guest_count = project + .collaborators() + .values() + .filter(|c| !c.is_host) + .count(); + + assert_eq!( + guest_count, 2, + "host should have exactly one guest collaborator after rejoin" + ); + }); + + _project_b.read_with(cx_b, |project, _| { + assert_eq!( + project.client_subscriptions().len(), + 0, + "We should clear all host subscriptions after leaving the project" + ); + }) +} diff --git a/crates/collab/tests/integration/remote_editing_collaboration_tests.rs b/crates/collab/tests/integration/remote_editing_collaboration_tests.rs index 4556c740ec74f6fb1bc8a2c760812376dae6b4a8..6825c468e783ee8d3a2a6107a031accfc108abd0 100644 --- a/crates/collab/tests/integration/remote_editing_collaboration_tests.rs +++ b/crates/collab/tests/integration/remote_editing_collaboration_tests.rs @@ -33,7 +33,7 @@ use settings::{ SettingsStore, }; use std::{ - path::Path, + path::{Path, PathBuf}, sync::{ Arc, atomic::{AtomicUsize, Ordering}, @@ -396,6 +396,130 @@ async fn test_ssh_collaboration_git_branches( }); } +#[gpui::test] +async fn test_ssh_collaboration_git_worktrees( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, + server_cx: &mut TestAppContext, +) { + cx_a.set_name("a"); + cx_b.set_name("b"); + server_cx.set_name("server"); + + cx_a.update(|cx| { + release_channel::init(semver::Version::new(0, 0, 0), cx); + }); + server_cx.update(|cx| { + release_channel::init(semver::Version::new(0, 0, 0), cx); + }); + + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + + let (opts, server_ssh, _) = RemoteClient::fake_server(cx_a, server_cx); + let remote_fs = FakeFs::new(server_cx.executor()); + remote_fs + .insert_tree("/project", json!({ ".git": {}, "file.txt": "content" })) + .await; + + server_cx.update(HeadlessProject::init); + let languages = Arc::new(LanguageRegistry::new(server_cx.executor())); + let headless_project = server_cx.new(|cx| { + HeadlessProject::new( + HeadlessAppState { + session: server_ssh, + fs: remote_fs.clone(), + http_client: Arc::new(BlockedHttpClient), + node_runtime: NodeRuntime::unavailable(), + languages, + extension_host_proxy: Arc::new(ExtensionHostProxy::new()), + startup_time: std::time::Instant::now(), + }, + false, + cx, + ) + }); + + let client_ssh = RemoteClient::connect_mock(opts, cx_a).await; + let (project_a, _) = client_a + .build_ssh_project("/project", client_ssh, false, cx_a) + .await; + + let active_call_a = cx_a.read(ActiveCall::global); + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + let project_b = client_b.join_remote_project(project_id, cx_b).await; + + executor.run_until_parked(); + + let repo_b = cx_b.update(|cx| project_b.read(cx).active_repository(cx).unwrap()); + + let worktrees = cx_b + .update(|cx| repo_b.update(cx, |repo, _| repo.worktrees())) + .await + .unwrap() + .unwrap(); + assert_eq!(worktrees.len(), 1); + + let worktree_directory = PathBuf::from("/project"); + cx_b.update(|cx| { + repo_b.update(cx, |repo, _| { + repo.create_worktree( + "feature-branch".to_string(), + worktree_directory.clone(), + Some("abc123".to_string()), + ) + }) + }) + .await + .unwrap() + .unwrap(); + + executor.run_until_parked(); + + let worktrees = cx_b + .update(|cx| repo_b.update(cx, |repo, _| repo.worktrees())) + .await + .unwrap() + .unwrap(); + assert_eq!(worktrees.len(), 2); + assert_eq!(worktrees[1].path, worktree_directory.join("feature-branch")); + assert_eq!(worktrees[1].ref_name.as_ref(), "refs/heads/feature-branch"); + assert_eq!(worktrees[1].sha.as_ref(), "abc123"); + + let server_worktrees = { + let server_repo = server_cx.update(|cx| { + headless_project.update(cx, |headless_project, cx| { + headless_project + .git_store + .read(cx) + .repositories() + .values() + .next() + .unwrap() + .clone() + }) + }); + server_cx + .update(|cx| server_repo.update(cx, |repo, _| repo.worktrees())) + .await + .unwrap() + .unwrap() + }; + assert_eq!(server_worktrees.len(), 2); + assert_eq!( + server_worktrees[1].path, + worktree_directory.join("feature-branch") + ); +} + #[gpui::test] async fn test_ssh_collaboration_formatting_with_prettier( executor: BackgroundExecutor, diff --git a/crates/crashes/Cargo.toml b/crates/crashes/Cargo.toml index 5e451853a925d86ffcc1491a5c95af1f94e6ed05..2c13dc83c5a88c3504da6f8be48c1d75c8e43652 100644 --- a/crates/crashes/Cargo.toml +++ b/crates/crashes/Cargo.toml @@ -6,13 +6,12 @@ edition.workspace = true license = "GPL-3.0-or-later" [dependencies] -bincode.workspace = true cfg-if.workspace = true crash-handler.workspace = true futures.workspace = true log.workspace = true minidumper.workspace = true - +parking_lot.workspace = true paths.workspace = true release_channel.workspace = true smol.workspace = true diff --git a/crates/crashes/src/crashes.rs b/crates/crashes/src/crashes.rs index a1a43dbb88198b7afd4b89141f7578c0a5bc25ce..0c848d759cd444f3eb6e2a9838d3005254a25b19 100644 --- a/crates/crashes/src/crashes.rs +++ b/crates/crashes/src/crashes.rs @@ -2,12 +2,14 @@ use crash_handler::{CrashEventResult, CrashHandler}; use futures::future::BoxFuture; use log::info; use minidumper::{Client, LoopAction, MinidumpBinary}; +use parking_lot::Mutex; use release_channel::{RELEASE_CHANNEL, ReleaseChannel}; use serde::{Deserialize, Serialize}; use std::mem; #[cfg(not(target_os = "windows"))] use smol::process::Command; +use system_specs::GpuSpecs; #[cfg(target_os = "macos")] use std::sync::atomic::AtomicU32; @@ -27,12 +29,14 @@ use std::{ }; // set once the crash handler has initialized and the client has connected to it -pub static CRASH_HANDLER: OnceLock> = OnceLock::new(); +static CRASH_HANDLER: OnceLock> = OnceLock::new(); // set when the first minidump request is made to avoid generating duplicate crash reports pub static REQUESTED_MINIDUMP: AtomicBool = AtomicBool::new(false); const CRASH_HANDLER_PING_TIMEOUT: Duration = Duration::from_secs(60); const CRASH_HANDLER_CONNECT_TIMEOUT: Duration = Duration::from_secs(10); +static PENDING_CRASH_SERVER_MESSAGES: Mutex> = Mutex::new(Vec::new()); + #[cfg(target_os = "macos")] static PANIC_THREAD_ID: AtomicU32 = AtomicU32::new(0); @@ -118,6 +122,7 @@ async fn connect_and_keepalive(crash_init: InitCrashHandler, handler: CrashHandl spawn_crash_handler_windows(&exe, &socket_name); info!("spawning crash handler process"); + send_crash_server_message(CrashServerMessage::Init(crash_init)); let mut elapsed = Duration::ZERO; let retry_frequency = Duration::from_millis(100); @@ -134,10 +139,6 @@ async fn connect_and_keepalive(crash_init: InitCrashHandler, handler: CrashHandl smol::Timer::after(retry_frequency).await; } let client = maybe_client.unwrap(); - client - .send_message(1, serde_json::to_vec(&crash_init).unwrap()) - .unwrap(); - let client = Arc::new(client); #[cfg(target_os = "linux")] @@ -146,6 +147,10 @@ async fn connect_and_keepalive(crash_init: InitCrashHandler, handler: CrashHandl // Publishing the client to the OnceLock makes it visible to the signal // handler callback installed earlier. CRASH_HANDLER.set(client.clone()).ok(); + let messages: Vec<_> = mem::take(PENDING_CRASH_SERVER_MESSAGES.lock().as_mut()); + for message in messages.into_iter() { + send_crash_server_message(message); + } // mem::forget so that the drop is not called mem::forget(handler); info!("crash handler registered"); @@ -177,9 +182,10 @@ unsafe fn suspend_all_other_threads() { } pub struct CrashServer { - initialization_params: OnceLock, - panic_info: OnceLock, - active_gpu: OnceLock, + initialization_params: Mutex>, + panic_info: Mutex>, + active_gpu: Mutex>, + user_info: Mutex>, has_connection: Arc, } @@ -190,6 +196,7 @@ pub struct CrashInfo { pub minidump_error: Option, pub gpus: Vec, pub active_gpu: Option, + pub user_info: Option, } #[derive(Debug, Deserialize, Serialize, Clone)] @@ -207,15 +214,55 @@ pub struct CrashPanic { pub span: String, } +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct UserInfo { + pub metrics_id: Option, + pub is_staff: Option, +} + +fn send_crash_server_message(message: CrashServerMessage) { + let Some(crash_server) = CRASH_HANDLER.get() else { + PENDING_CRASH_SERVER_MESSAGES.lock().push(message); + return; + }; + let data = match serde_json::to_vec(&message) { + Ok(data) => data, + Err(err) => { + log::warn!("Failed to serialize crash server message: {:?}", err); + return; + } + }; + + if let Err(err) = crash_server.send_message(0, data) { + log::warn!("Failed to send data to crash server {:?}", err); + } +} + +pub fn set_gpu_info(specs: GpuSpecs) { + send_crash_server_message(CrashServerMessage::GPUInfo(specs)); +} + +pub fn set_user_info(info: UserInfo) { + send_crash_server_message(CrashServerMessage::UserInfo(info)); +} + +#[derive(Serialize, Deserialize, Debug)] +enum CrashServerMessage { + Init(InitCrashHandler), + Panic(CrashPanic), + GPUInfo(GpuSpecs), + UserInfo(UserInfo), +} + impl minidumper::ServerHandler for CrashServer { fn create_minidump_file(&self) -> Result<(File, PathBuf), io::Error> { - let err_message = "Missing initialization data"; let dump_path = paths::logs_dir() .join( &self .initialization_params - .get() - .expect(err_message) + .lock() + .as_ref() + .expect("Missing initialization data") .session_id, ) .with_extension("dmp"); @@ -255,13 +302,14 @@ impl minidumper::ServerHandler for CrashServer { let crash_info = CrashInfo { init: self .initialization_params - .get() - .expect("not initialized") - .clone(), - panic: self.panic_info.get().cloned(), + .lock() + .clone() + .expect("not initialized"), + panic: self.panic_info.lock().clone(), minidump_error, - active_gpu: self.active_gpu.get().cloned(), + active_gpu: self.active_gpu.lock().clone(), gpus, + user_info: self.user_info.lock().clone(), }; let crash_data_path = paths::logs_dir() @@ -273,30 +321,21 @@ impl minidumper::ServerHandler for CrashServer { LoopAction::Exit } - fn on_message(&self, kind: u32, buffer: Vec) { - match kind { - 1 => { - let init_data = - serde_json::from_slice::(&buffer).expect("invalid init data"); - self.initialization_params - .set(init_data) - .expect("already initialized"); + fn on_message(&self, _: u32, buffer: Vec) { + let message: CrashServerMessage = + serde_json::from_slice(&buffer).expect("invalid init data"); + match message { + CrashServerMessage::Init(init_data) => { + self.initialization_params.lock().replace(init_data); } - 2 => { - let panic_data = - serde_json::from_slice::(&buffer).expect("invalid panic data"); - self.panic_info.set(panic_data).expect("already panicked"); + CrashServerMessage::Panic(crash_panic) => { + self.panic_info.lock().replace(crash_panic); } - 3 => { - let gpu_specs: system_specs::GpuSpecs = - bincode::deserialize(&buffer).expect("gpu specs"); - // we ignore the case where it was already set because this message is sent - // on each new window. in theory all zed windows should be using the same - // GPU so this is fine. - self.active_gpu.set(gpu_specs).ok(); + CrashServerMessage::GPUInfo(gpu_specs) => { + self.active_gpu.lock().replace(gpu_specs); } - _ => { - panic!("invalid message kind"); + CrashServerMessage::UserInfo(user_info) => { + self.user_info.lock().replace(user_info); } } } @@ -326,37 +365,33 @@ pub fn panic_hook(info: &PanicHookInfo) { // if it's still not there just write panic info and no minidump let retry_frequency = Duration::from_millis(100); for _ in 0..5 { - if let Some(client) = CRASH_HANDLER.get() { - let location = info - .location() - .map_or_else(|| "".to_owned(), |location| location.to_string()); - log::error!("thread '{thread_name}' panicked at {location}:\n{message}..."); - client - .send_message( - 2, - serde_json::to_vec(&CrashPanic { message, span }).unwrap(), - ) - .ok(); - log::error!("triggering a crash to generate a minidump..."); - - #[cfg(target_os = "macos")] - PANIC_THREAD_ID.store( - unsafe { mach2::mach_init::mach_thread_self() }, - Ordering::SeqCst, - ); - - cfg_if::cfg_if! { - if #[cfg(target_os = "windows")] { - // https://learn.microsoft.com/en-us/windows/win32/debug/system-error-codes--0-499- - CrashHandler.simulate_exception(Some(234)); // (MORE_DATA_AVAILABLE) - break; - } else { - std::process::abort(); - } - } + if CRASH_HANDLER.get().is_some() { + break; } thread::sleep(retry_frequency); } + let location = info + .location() + .map_or_else(|| "".to_owned(), |location| location.to_string()); + log::error!("thread '{thread_name}' panicked at {location}:\n{message}..."); + + send_crash_server_message(CrashServerMessage::Panic(CrashPanic { message, span })); + log::error!("triggering a crash to generate a minidump..."); + + #[cfg(target_os = "macos")] + PANIC_THREAD_ID.store( + unsafe { mach2::mach_init::mach_thread_self() }, + Ordering::SeqCst, + ); + + cfg_if::cfg_if! { + if #[cfg(target_os = "windows")] { + // https://learn.microsoft.com/en-us/windows/win32/debug/system-error-codes--0-499- + CrashHandler.simulate_exception(Some(234)); // (MORE_DATA_AVAILABLE) + } else { + std::process::abort(); + } + } } #[cfg(target_os = "windows")] @@ -436,10 +471,11 @@ pub fn crash_server(socket: &Path) { server .run( Box::new(CrashServer { - initialization_params: OnceLock::new(), - panic_info: OnceLock::new(), + initialization_params: Mutex::default(), + panic_info: Mutex::default(), + user_info: Mutex::default(), has_connection, - active_gpu: OnceLock::new(), + active_gpu: Mutex::default(), }), &shutdown, Some(CRASH_HANDLER_PING_TIMEOUT), diff --git a/crates/csv_preview/Cargo.toml b/crates/csv_preview/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..7e9ce2c4d515cfce9586a0686475a8dfed0ddc95 --- /dev/null +++ b/crates/csv_preview/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "csv_preview" +version = "0.1.0" +publish.workspace = true +edition.workspace = true + +[lib] +path = "src/csv_preview.rs" + +[dependencies] +anyhow.workspace = true +feature_flags.workspace = true +gpui.workspace = true +editor.workspace = true +ui.workspace = true +workspace.workspace = true +log.workspace = true +text.workspace = true + +[lints] +workspace = true diff --git a/crates/csv_preview/LICENSE-GPL b/crates/csv_preview/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/csv_preview/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/csv_preview/src/csv_preview.rs b/crates/csv_preview/src/csv_preview.rs new file mode 100644 index 0000000000000000000000000000000000000000..f056f5a12225b000527b9087760e3d683bda1b5b --- /dev/null +++ b/crates/csv_preview/src/csv_preview.rs @@ -0,0 +1,302 @@ +use editor::{Editor, EditorEvent}; +use feature_flags::{FeatureFlag, FeatureFlagAppExt as _}; +use gpui::{ + AppContext, Entity, EventEmitter, FocusHandle, Focusable, ListAlignment, Task, actions, +}; +use std::{ + collections::HashMap, + time::{Duration, Instant}, +}; + +use crate::table_data_engine::TableDataEngine; +use ui::{SharedString, TableColumnWidths, TableInteractionState, prelude::*}; +use workspace::{Item, SplitDirection, Workspace}; + +use crate::{parser::EditorState, settings::CsvPreviewSettings, types::TableLikeContent}; + +mod parser; +mod renderer; +mod settings; +mod table_data_engine; +mod types; + +actions!(csv, [OpenPreview, OpenPreviewToTheSide]); + +pub struct TabularDataPreviewFeatureFlag; + +impl FeatureFlag for TabularDataPreviewFeatureFlag { + const NAME: &'static str = "tabular-data-preview"; +} + +pub struct CsvPreviewView { + pub(crate) engine: TableDataEngine, + + pub(crate) focus_handle: FocusHandle, + active_editor_state: EditorState, + pub(crate) table_interaction_state: Entity, + pub(crate) column_widths: ColumnWidths, + pub(crate) parsing_task: Option>>, + pub(crate) settings: CsvPreviewSettings, + /// Performance metrics for debugging and monitoring CSV operations. + pub(crate) performance_metrics: PerformanceMetrics, + pub(crate) list_state: gpui::ListState, + /// Time when the last parsing operation ended, used for smart debouncing + pub(crate) last_parse_end_time: Option, +} + +pub fn init(cx: &mut App) { + cx.observe_new(|workspace: &mut Workspace, _, _| { + CsvPreviewView::register(workspace); + }) + .detach() +} + +impl CsvPreviewView { + pub fn register(workspace: &mut Workspace) { + workspace.register_action_renderer(|div, _, _, cx| { + div.when(cx.has_flag::(), |div| { + div.on_action(cx.listener(|workspace, _: &OpenPreview, window, cx| { + if let Some(editor) = workspace + .active_item(cx) + .and_then(|item| item.act_as::(cx)) + .filter(|editor| Self::is_csv_file(editor, cx)) + { + let csv_preview = Self::new(&editor, cx); + workspace.active_pane().update(cx, |pane, cx| { + let existing = pane + .items_of_type::() + .find(|view| view.read(cx).active_editor_state.editor == editor); + if let Some(idx) = existing.and_then(|e| pane.index_for_item(&e)) { + pane.activate_item(idx, true, true, window, cx); + } else { + pane.add_item(Box::new(csv_preview), true, true, None, window, cx); + } + }); + cx.notify(); + } + })) + .on_action(cx.listener( + |workspace, _: &OpenPreviewToTheSide, window, cx| { + if let Some(editor) = workspace + .active_item(cx) + .and_then(|item| item.act_as::(cx)) + .filter(|editor| Self::is_csv_file(editor, cx)) + { + let csv_preview = Self::new(&editor, cx); + let pane = workspace + .find_pane_in_direction(SplitDirection::Right, cx) + .unwrap_or_else(|| { + workspace.split_pane( + workspace.active_pane().clone(), + SplitDirection::Right, + window, + cx, + ) + }); + pane.update(cx, |pane, cx| { + let existing = + pane.items_of_type::().find(|view| { + view.read(cx).active_editor_state.editor == editor + }); + if let Some(idx) = existing.and_then(|e| pane.index_for_item(&e)) { + pane.activate_item(idx, true, true, window, cx); + } else { + pane.add_item( + Box::new(csv_preview), + false, + false, + None, + window, + cx, + ); + } + }); + cx.notify(); + } + }, + )) + }) + }); + } + + fn new(editor: &Entity, cx: &mut Context) -> Entity { + let contents = TableLikeContent::default(); + let table_interaction_state = cx.new(|cx| { + TableInteractionState::new(cx) + .with_custom_scrollbar(ui::Scrollbars::for_settings::()) + }); + + cx.new(|cx| { + let subscription = cx.subscribe( + editor, + |this: &mut CsvPreviewView, _editor, event: &EditorEvent, cx| { + match event { + EditorEvent::Edited { .. } + | EditorEvent::DirtyChanged + | EditorEvent::ExcerptsEdited { .. } => { + this.parse_csv_from_active_editor(true, cx); + } + _ => {} + }; + }, + ); + + let mut view = CsvPreviewView { + focus_handle: cx.focus_handle(), + active_editor_state: EditorState { + editor: editor.clone(), + _subscription: subscription, + }, + table_interaction_state, + column_widths: ColumnWidths::new(cx, 1), + parsing_task: None, + performance_metrics: PerformanceMetrics::default(), + list_state: gpui::ListState::new(contents.rows.len(), ListAlignment::Top, px(1.)), + settings: CsvPreviewSettings::default(), + last_parse_end_time: None, + engine: TableDataEngine::default(), + }; + + view.parse_csv_from_active_editor(false, cx); + view + }) + } + + pub(crate) fn editor_state(&self) -> &EditorState { + &self.active_editor_state + } + pub(crate) fn apply_sort(&mut self) { + self.performance_metrics.record("Sort", || { + self.engine.apply_sort(); + }); + } + + /// Update ordered indices when ordering or content changes + pub(crate) fn apply_filter_sort(&mut self) { + self.performance_metrics.record("Filter&sort", || { + self.engine.calculate_d2d_mapping(); + }); + + // Update list state with filtered row count + let visible_rows = self.engine.d2d_mapping().visible_row_count(); + self.list_state = gpui::ListState::new(visible_rows, ListAlignment::Top, px(1.)); + } + + pub fn resolve_active_item_as_csv_editor( + workspace: &Workspace, + cx: &mut Context, + ) -> Option> { + let editor = workspace + .active_item(cx) + .and_then(|item| item.act_as::(cx))?; + Self::is_csv_file(&editor, cx).then_some(editor) + } + + fn is_csv_file(editor: &Entity, cx: &App) -> bool { + editor + .read(cx) + .buffer() + .read(cx) + .as_singleton() + .and_then(|buffer| { + buffer + .read(cx) + .file() + .and_then(|file| file.path().extension()) + .map(|ext| ext.eq_ignore_ascii_case("csv")) + }) + .unwrap_or(false) + } +} + +impl Focusable for CsvPreviewView { + fn focus_handle(&self, _cx: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl EventEmitter<()> for CsvPreviewView {} + +impl Item for CsvPreviewView { + type Event = (); + + fn tab_icon(&self, _window: &Window, _cx: &App) -> Option { + Some(Icon::new(IconName::FileDoc)) + } + + fn tab_content_text(&self, _detail: usize, cx: &App) -> SharedString { + self.editor_state() + .editor + .read(cx) + .buffer() + .read(cx) + .as_singleton() + .and_then(|b| { + let file = b.read(cx).file()?; + let local_file = file.as_local()?; + local_file + .abs_path(cx) + .file_name() + .map(|name| format!("Preview {}", name.to_string_lossy()).into()) + }) + .unwrap_or_else(|| SharedString::from("CSV Preview")) + } +} + +#[derive(Debug, Default)] +pub struct PerformanceMetrics { + /// Map of timing metrics with their duration and measurement time. + pub timings: HashMap<&'static str, (Duration, Instant)>, + /// List of display indices that were rendered in the current frame. + pub rendered_indices: Vec, +} +impl PerformanceMetrics { + pub fn record(&mut self, name: &'static str, mut f: F) -> R + where + F: FnMut() -> R, + { + let start_time = Instant::now(); + let ret = f(); + let duration = start_time.elapsed(); + self.timings.insert(name, (duration, Instant::now())); + ret + } + + /// Displays all metrics sorted A-Z in format: `{name}: {took}ms {ago}s ago` + pub fn display(&self) -> String { + let mut metrics = self.timings.iter().collect::>(); + metrics.sort_by_key(|&(name, _)| *name); + metrics + .iter() + .map(|(name, (duration, time))| { + let took = duration.as_secs_f32() * 1000.; + let ago = time.elapsed().as_secs(); + format!("{name}: {took:.2}ms {ago}s ago") + }) + .collect::>() + .join("\n") + } + + /// Get timing for a specific metric + pub fn get_timing(&self, name: &str) -> Option { + self.timings.get(name).map(|(duration, _)| *duration) + } +} + +/// Holds state of column widths for a table component in CSV preview. +pub(crate) struct ColumnWidths { + pub widths: Entity, +} + +impl ColumnWidths { + pub(crate) fn new(cx: &mut Context, cols: usize) -> Self { + Self { + widths: cx.new(|cx| TableColumnWidths::new(cols, cx)), + } + } + /// Replace the current `TableColumnWidths` entity with a new one for the given column count. + pub(crate) fn replace(&self, cx: &mut Context, cols: usize) { + self.widths + .update(cx, |entity, cx| *entity = TableColumnWidths::new(cols, cx)); + } +} diff --git a/crates/csv_preview/src/parser.rs b/crates/csv_preview/src/parser.rs new file mode 100644 index 0000000000000000000000000000000000000000..b087404e0ebbd13cdaf20cab692f5470ea6ce292 --- /dev/null +++ b/crates/csv_preview/src/parser.rs @@ -0,0 +1,513 @@ +use crate::{ + CsvPreviewView, + types::TableLikeContent, + types::{LineNumber, TableCell}, +}; +use editor::Editor; +use gpui::{AppContext, Context, Entity, Subscription, Task}; +use std::time::{Duration, Instant}; +use text::BufferSnapshot; +use ui::{SharedString, table_row::TableRow}; + +pub(crate) const REPARSE_DEBOUNCE: Duration = Duration::from_millis(200); + +pub(crate) struct EditorState { + pub editor: Entity, + pub _subscription: Subscription, +} + +impl CsvPreviewView { + pub(crate) fn parse_csv_from_active_editor( + &mut self, + wait_for_debounce: bool, + cx: &mut Context, + ) { + let editor = self.active_editor_state.editor.clone(); + self.parsing_task = Some(self.parse_csv_in_background(wait_for_debounce, editor, cx)); + } + + fn parse_csv_in_background( + &mut self, + wait_for_debounce: bool, + editor: Entity, + cx: &mut Context, + ) -> Task> { + cx.spawn(async move |view, cx| { + if wait_for_debounce { + // Smart debouncing: check if cooldown period has already passed + let now = Instant::now(); + let should_wait = view.update(cx, |view, _| { + if let Some(last_end) = view.last_parse_end_time { + let cooldown_until = last_end + REPARSE_DEBOUNCE; + if now < cooldown_until { + Some(cooldown_until - now) + } else { + None // Cooldown already passed, parse immediately + } + } else { + None // First parse, no debounce + } + })?; + + if let Some(wait_duration) = should_wait { + cx.background_executor().timer(wait_duration).await; + } + } + + let buffer_snapshot = view.update(cx, |_, cx| { + editor + .read(cx) + .buffer() + .read(cx) + .as_singleton() + .map(|b| b.read(cx).text_snapshot()) + })?; + + let Some(buffer_snapshot) = buffer_snapshot else { + return Ok(()); + }; + + let instant = Instant::now(); + let parsed_csv = cx + .background_spawn(async move { from_buffer(&buffer_snapshot) }) + .await; + let parse_duration = instant.elapsed(); + let parse_end_time: Instant = Instant::now(); + log::debug!("Parsed CSV in {}ms", parse_duration.as_millis()); + view.update(cx, move |view, cx| { + view.performance_metrics + .timings + .insert("Parsing", (parse_duration, Instant::now())); + + log::debug!("Parsed {} rows", parsed_csv.rows.len()); + // Update table width so it can be rendered properly + let cols = parsed_csv.headers.cols(); + view.column_widths.replace(cx, cols + 1); // Add 1 for the line number column + + view.engine.contents = parsed_csv; + view.last_parse_end_time = Some(parse_end_time); + + view.apply_filter_sort(); + cx.notify(); + }) + }) + } +} + +pub fn from_buffer(buffer_snapshot: &BufferSnapshot) -> TableLikeContent { + let text = buffer_snapshot.text(); + + if text.trim().is_empty() { + return TableLikeContent::default(); + } + + let (parsed_cells_with_positions, line_numbers) = parse_csv_with_positions(&text); + if parsed_cells_with_positions.is_empty() { + return TableLikeContent::default(); + } + let raw_headers = parsed_cells_with_positions[0].clone(); + + // Calculating the longest row, as CSV might have less headers than max row width + let Some(max_number_of_cols) = parsed_cells_with_positions.iter().map(|r| r.len()).max() else { + return TableLikeContent::default(); + }; + + // Convert to TableCell objects with buffer positions + let headers = create_table_row(&buffer_snapshot, max_number_of_cols, raw_headers); + + let rows = parsed_cells_with_positions + .into_iter() + .skip(1) + .map(|row| create_table_row(&buffer_snapshot, max_number_of_cols, row)) + .collect(); + + let row_line_numbers = line_numbers.into_iter().skip(1).collect(); + + TableLikeContent { + headers, + rows, + line_numbers: row_line_numbers, + number_of_cols: max_number_of_cols, + } +} + +/// Parse CSV and track byte positions for each cell +fn parse_csv_with_positions( + text: &str, +) -> ( + Vec)>>, + Vec, +) { + let mut rows = Vec::new(); + let mut line_numbers = Vec::new(); + let mut current_row: Vec<(SharedString, std::ops::Range)> = Vec::new(); + let mut current_field = String::new(); + let mut field_start_offset = 0; + let mut current_offset = 0; + let mut in_quotes = false; + let mut current_line = 1; // 1-based line numbering + let mut row_start_line = 1; + let mut chars = text.chars().peekable(); + + while let Some(ch) = chars.next() { + let char_byte_len = ch.len_utf8(); + + match ch { + '"' => { + if in_quotes { + if chars.peek() == Some(&'"') { + // Escaped quote + chars.next(); + current_field.push('"'); + current_offset += 1; // Skip the second quote + } else { + // End of quoted field + in_quotes = false; + } + } else { + // Start of quoted field + in_quotes = true; + if current_field.is_empty() { + // Include the opening quote in the range + field_start_offset = current_offset; + } + } + } + ',' if !in_quotes => { + // Field separator + let field_end_offset = current_offset; + if current_field.is_empty() && !in_quotes { + field_start_offset = current_offset; + } + current_row.push(( + current_field.clone().into(), + field_start_offset..field_end_offset, + )); + current_field.clear(); + field_start_offset = current_offset + char_byte_len; + } + '\n' => { + current_line += 1; + if !in_quotes { + // Row separator (only when not inside quotes) + let field_end_offset = current_offset; + if current_field.is_empty() && current_row.is_empty() { + field_start_offset = 0; + } + current_row.push(( + current_field.clone().into(), + field_start_offset..field_end_offset, + )); + current_field.clear(); + + // Only add non-empty rows + if !current_row.is_empty() + && !current_row.iter().all(|(field, _)| field.trim().is_empty()) + { + rows.push(current_row); + // Add line number info for this row + let line_info = if row_start_line == current_line - 1 { + LineNumber::Line(row_start_line) + } else { + LineNumber::LineRange(row_start_line, current_line - 1) + }; + line_numbers.push(line_info); + } + current_row = Vec::new(); + row_start_line = current_line; + field_start_offset = current_offset + char_byte_len; + } else { + // Newline inside quotes - preserve it + current_field.push(ch); + } + } + '\r' => { + if chars.peek() == Some(&'\n') { + // Handle Windows line endings (\r\n): account for \r byte, let \n be handled next + current_offset += char_byte_len; + continue; + } else { + // Standalone \r + current_line += 1; + if !in_quotes { + // Row separator (only when not inside quotes) + let field_end_offset = current_offset; + current_row.push(( + current_field.clone().into(), + field_start_offset..field_end_offset, + )); + current_field.clear(); + + // Only add non-empty rows + if !current_row.is_empty() + && !current_row.iter().all(|(field, _)| field.trim().is_empty()) + { + rows.push(current_row); + // Add line number info for this row + let line_info = if row_start_line == current_line - 1 { + LineNumber::Line(row_start_line) + } else { + LineNumber::LineRange(row_start_line, current_line - 1) + }; + line_numbers.push(line_info); + } + current_row = Vec::new(); + row_start_line = current_line; + field_start_offset = current_offset + char_byte_len; + } else { + // \r inside quotes - preserve it + current_field.push(ch); + } + } + } + _ => { + if current_field.is_empty() && !in_quotes { + field_start_offset = current_offset; + } + current_field.push(ch); + } + } + + current_offset += char_byte_len; + } + + // Add the last field and row if not empty + if !current_field.is_empty() || !current_row.is_empty() { + let field_end_offset = current_offset; + current_row.push(( + current_field.clone().into(), + field_start_offset..field_end_offset, + )); + } + if !current_row.is_empty() && !current_row.iter().all(|(field, _)| field.trim().is_empty()) { + rows.push(current_row); + // Add line number info for the last row + let line_info = if row_start_line == current_line { + LineNumber::Line(row_start_line) + } else { + LineNumber::LineRange(row_start_line, current_line) + }; + line_numbers.push(line_info); + } + + (rows, line_numbers) +} + +fn create_table_row( + buffer_snapshot: &BufferSnapshot, + max_number_of_cols: usize, + row: Vec<(SharedString, std::ops::Range)>, +) -> TableRow { + let mut raw_row = row + .into_iter() + .map(|(content, range)| { + TableCell::from_buffer_position(content, range.start, range.end, &buffer_snapshot) + }) + .collect::>(); + + let append_elements = max_number_of_cols - raw_row.len(); + if append_elements > 0 { + for _ in 0..append_elements { + raw_row.push(TableCell::Virtual); + } + } + + TableRow::from_vec(raw_row, max_number_of_cols) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_csv_parsing_basic() { + let csv_data = "Name,Age,City\nJohn,30,New York\nJane,25,Los Angeles"; + let parsed = TableLikeContent::from_str(csv_data.to_string()); + + assert_eq!(parsed.headers.cols(), 3); + assert_eq!(parsed.headers[0].display_value().unwrap().as_ref(), "Name"); + assert_eq!(parsed.headers[1].display_value().unwrap().as_ref(), "Age"); + assert_eq!(parsed.headers[2].display_value().unwrap().as_ref(), "City"); + + assert_eq!(parsed.rows.len(), 2); + assert_eq!(parsed.rows[0][0].display_value().unwrap().as_ref(), "John"); + assert_eq!(parsed.rows[0][1].display_value().unwrap().as_ref(), "30"); + assert_eq!( + parsed.rows[0][2].display_value().unwrap().as_ref(), + "New York" + ); + } + + #[test] + fn test_csv_parsing_with_quotes() { + let csv_data = r#"Name,Description +"John Doe","A person with ""special"" characters" +Jane,"Simple name""#; + let parsed = TableLikeContent::from_str(csv_data.to_string()); + + assert_eq!(parsed.headers.cols(), 2); + assert_eq!(parsed.rows.len(), 2); + assert_eq!( + parsed.rows[0][1].display_value().unwrap().as_ref(), + r#"A person with "special" characters"# + ); + } + + #[test] + fn test_csv_parsing_with_newlines_in_quotes() { + let csv_data = "Name,Description,Status\n\"John\nDoe\",\"A person with\nmultiple lines\",Active\n\"Jane Smith\",\"Simple\",\"Also\nActive\""; + let parsed = TableLikeContent::from_str(csv_data.to_string()); + + assert_eq!(parsed.headers.cols(), 3); + assert_eq!(parsed.headers[0].display_value().unwrap().as_ref(), "Name"); + assert_eq!( + parsed.headers[1].display_value().unwrap().as_ref(), + "Description" + ); + assert_eq!( + parsed.headers[2].display_value().unwrap().as_ref(), + "Status" + ); + + assert_eq!(parsed.rows.len(), 2); + assert_eq!( + parsed.rows[0][0].display_value().unwrap().as_ref(), + "John\nDoe" + ); + assert_eq!( + parsed.rows[0][1].display_value().unwrap().as_ref(), + "A person with\nmultiple lines" + ); + assert_eq!( + parsed.rows[0][2].display_value().unwrap().as_ref(), + "Active" + ); + + assert_eq!( + parsed.rows[1][0].display_value().unwrap().as_ref(), + "Jane Smith" + ); + assert_eq!( + parsed.rows[1][1].display_value().unwrap().as_ref(), + "Simple" + ); + assert_eq!( + parsed.rows[1][2].display_value().unwrap().as_ref(), + "Also\nActive" + ); + + // Check line numbers + assert_eq!(parsed.line_numbers.len(), 2); + match &parsed.line_numbers[0] { + LineNumber::LineRange(start, end) => { + assert_eq!(start, &2); + assert_eq!(end, &4); + } + _ => panic!("Expected LineRange for multiline row"), + } + match &parsed.line_numbers[1] { + LineNumber::LineRange(start, end) => { + assert_eq!(start, &5); + assert_eq!(end, &6); + } + _ => panic!("Expected LineRange for second multiline row"), + } + } + + #[test] + fn test_empty_csv() { + let parsed = TableLikeContent::from_str("".to_string()); + assert_eq!(parsed.headers.cols(), 0); + assert!(parsed.rows.is_empty()); + } + + #[test] + fn test_csv_parsing_quote_offset_handling() { + let csv_data = r#"first,"se,cond",third"#; + let (parsed_cells, _) = parse_csv_with_positions(csv_data); + + assert_eq!(parsed_cells.len(), 1); // One row + assert_eq!(parsed_cells[0].len(), 3); // Three cells + + // first: 0..5 (no quotes) + let (content1, range1) = &parsed_cells[0][0]; + assert_eq!(content1.as_ref(), "first"); + assert_eq!(*range1, 0..5); + + // "se,cond": 6..15 (includes quotes in range, content without quotes) + let (content2, range2) = &parsed_cells[0][1]; + assert_eq!(content2.as_ref(), "se,cond"); + assert_eq!(*range2, 6..15); + + // third: 16..21 (no quotes) + let (content3, range3) = &parsed_cells[0][2]; + assert_eq!(content3.as_ref(), "third"); + assert_eq!(*range3, 16..21); + } + + #[test] + fn test_csv_parsing_complex_quotes() { + let csv_data = r#"id,"name with spaces","description, with commas",status +1,"John Doe","A person with ""quotes"" and, commas",active +2,"Jane Smith","Simple description",inactive"#; + let (parsed_cells, _) = parse_csv_with_positions(csv_data); + + assert_eq!(parsed_cells.len(), 3); // header + 2 rows + + // Check header row + let header_row = &parsed_cells[0]; + assert_eq!(header_row.len(), 4); + + // id: 0..2 + assert_eq!(header_row[0].0.as_ref(), "id"); + assert_eq!(header_row[0].1, 0..2); + + // "name with spaces": 3..21 (includes quotes) + assert_eq!(header_row[1].0.as_ref(), "name with spaces"); + assert_eq!(header_row[1].1, 3..21); + + // "description, with commas": 22..48 (includes quotes) + assert_eq!(header_row[2].0.as_ref(), "description, with commas"); + assert_eq!(header_row[2].1, 22..48); + + // status: 49..55 + assert_eq!(header_row[3].0.as_ref(), "status"); + assert_eq!(header_row[3].1, 49..55); + + // Check first data row + let first_row = &parsed_cells[1]; + assert_eq!(first_row.len(), 4); + + // 1: 56..57 + assert_eq!(first_row[0].0.as_ref(), "1"); + assert_eq!(first_row[0].1, 56..57); + + // "John Doe": 58..68 (includes quotes) + assert_eq!(first_row[1].0.as_ref(), "John Doe"); + assert_eq!(first_row[1].1, 58..68); + + // Content should be stripped of quotes but include escaped quotes + assert_eq!( + first_row[2].0.as_ref(), + r#"A person with "quotes" and, commas"# + ); + // The range should include the outer quotes: 69..107 + assert_eq!(first_row[2].1, 69..107); + + // active: 108..114 + assert_eq!(first_row[3].0.as_ref(), "active"); + assert_eq!(first_row[3].1, 108..114); + } +} + +impl TableLikeContent { + #[cfg(test)] + pub fn from_str(text: String) -> Self { + use text::{Buffer, BufferId, ReplicaId}; + + let buffer_id = BufferId::new(1).unwrap(); + let buffer = Buffer::new(ReplicaId::LOCAL, buffer_id, text); + let snapshot = buffer.snapshot(); + from_buffer(snapshot) + } +} diff --git a/crates/csv_preview/src/renderer.rs b/crates/csv_preview/src/renderer.rs new file mode 100644 index 0000000000000000000000000000000000000000..42ae05936c7ebd3fb9c619793376998b6d33e2c1 --- /dev/null +++ b/crates/csv_preview/src/renderer.rs @@ -0,0 +1,5 @@ +mod preview_view; +mod render_table; +mod row_identifiers; +mod table_cell; +mod table_header; diff --git a/crates/csv_preview/src/renderer/preview_view.rs b/crates/csv_preview/src/renderer/preview_view.rs new file mode 100644 index 0000000000000000000000000000000000000000..55e62d03806b578f59c2542cf997f90ec22a1f8f --- /dev/null +++ b/crates/csv_preview/src/renderer/preview_view.rs @@ -0,0 +1,50 @@ +use std::time::Instant; + +use ui::{div, prelude::*}; + +use crate::{CsvPreviewView, settings::FontType}; + +impl Render for CsvPreviewView { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + let theme = cx.theme(); + + self.performance_metrics.rendered_indices.clear(); + let render_prep_start = Instant::now(); + let table_with_settings = v_flex() + .size_full() + .p_4() + .bg(theme.colors().editor_background) + .track_focus(&self.focus_handle) + .child({ + if self.engine.contents.number_of_cols == 0 { + div() + .flex() + .items_center() + .justify_center() + .h_32() + .text_ui(cx) + .map(|div| match self.settings.font_type { + FontType::Ui => div.font_ui(cx), + FontType::Monospace => div.font_buffer(cx), + }) + .text_color(cx.theme().colors().text_muted) + .child("No CSV content to display") + .into_any_element() + } else { + self.create_table(&self.column_widths.widths, cx) + } + }); + + let render_prep_duration = render_prep_start.elapsed(); + self.performance_metrics.timings.insert( + "render_prep", + (render_prep_duration, std::time::Instant::now()), + ); + + div() + .relative() + .w_full() + .h_full() + .child(table_with_settings) + } +} diff --git a/crates/csv_preview/src/renderer/render_table.rs b/crates/csv_preview/src/renderer/render_table.rs new file mode 100644 index 0000000000000000000000000000000000000000..0cc3bc3c46fb24570b3c99c9121dff3860c6b820 --- /dev/null +++ b/crates/csv_preview/src/renderer/render_table.rs @@ -0,0 +1,193 @@ +use crate::types::TableCell; +use gpui::{AnyElement, Entity}; +use std::ops::Range; +use ui::Table; +use ui::TableColumnWidths; +use ui::TableResizeBehavior; +use ui::UncheckedTableRow; +use ui::{DefiniteLength, div, prelude::*}; + +use crate::{ + CsvPreviewView, + settings::RowRenderMechanism, + types::{AnyColumn, DisplayCellId, DisplayRow}, +}; + +impl CsvPreviewView { + /// Creates a new table. + /// Column number is derived from the `TableColumnWidths` entity. + pub(crate) fn create_table( + &self, + current_widths: &Entity, + cx: &mut Context, + ) -> AnyElement { + let cols = current_widths.read(cx).cols(); + let remaining_col_number = cols - 1; + let fraction = if remaining_col_number > 0 { + 1. / remaining_col_number as f32 + } else { + 1. // only column with line numbers is present. Put 100%, but it will be overwritten anyways :D + }; + let mut widths = vec![DefiniteLength::Fraction(fraction); cols]; + let line_number_width = self.calculate_row_identifier_column_width(); + widths[0] = DefiniteLength::Absolute(AbsoluteLength::Pixels(line_number_width.into())); + + let mut resize_behaviors = vec![TableResizeBehavior::Resizable; cols]; + resize_behaviors[0] = TableResizeBehavior::None; + + self.create_table_inner( + self.engine.contents.rows.len(), + widths, + resize_behaviors, + current_widths, + cx, + ) + } + + fn create_table_inner( + &self, + row_count: usize, + widths: UncheckedTableRow, + resize_behaviors: UncheckedTableRow, + current_widths: &Entity, + cx: &mut Context, + ) -> AnyElement { + let cols = widths.len(); + // Create headers array with interactive elements + let mut headers = Vec::with_capacity(cols); + + headers.push(self.create_row_identifier_header(cx)); + + // Add the actual CSV headers with sort buttons + for i in 0..(cols - 1) { + let header_text = self + .engine + .contents + .headers + .get(AnyColumn(i)) + .and_then(|h| h.display_value().cloned()) + .unwrap_or_else(|| format!("Col {}", i + 1).into()); + + headers.push(self.create_header_element_with_sort_button( + header_text, + cx, + AnyColumn::from(i), + )); + } + + Table::new(cols) + .interactable(&self.table_interaction_state) + .striped() + .column_widths(widths) + .resizable_columns(resize_behaviors, current_widths, cx) + .header(headers) + .disable_base_style() + .map(|table| { + let row_identifier_text_color = cx.theme().colors().editor_line_number; + match self.settings.rendering_with { + RowRenderMechanism::VariableList => { + table.variable_row_height_list(row_count, self.list_state.clone(), { + cx.processor(move |this, display_row: usize, _window, cx| { + this.performance_metrics.rendered_indices.push(display_row); + + let display_row = DisplayRow(display_row); + Self::render_single_table_row( + this, + cols, + display_row, + row_identifier_text_color, + cx, + ) + .unwrap_or_else(|| panic!("Expected to render a table row")) + }) + }) + } + RowRenderMechanism::UniformList => { + table.uniform_list("csv-table", row_count, { + cx.processor(move |this, range: Range, _window, cx| { + // Record all display indices in the range for performance metrics + this.performance_metrics + .rendered_indices + .extend(range.clone()); + + range + .filter_map(|display_index| { + Self::render_single_table_row( + this, + cols, + DisplayRow(display_index), + row_identifier_text_color, + cx, + ) + }) + .collect() + }) + }) + } + } + }) + .into_any_element() + } + + /// Render a single table row + /// + /// Used both by UniformList and VariableRowHeightList + fn render_single_table_row( + this: &CsvPreviewView, + cols: usize, + display_row: DisplayRow, + row_identifier_text_color: gpui::Hsla, + cx: &Context, + ) -> Option> { + // Get the actual row index from our sorted indices + let data_row = this.engine.d2d_mapping().get_data_row(display_row)?; + let row = this.engine.contents.get_row(data_row)?; + + let mut elements = Vec::with_capacity(cols); + elements.push(this.create_row_identifier_cell(display_row, data_row, cx)?); + + // Remaining columns: actual CSV data + for col in (0..this.engine.contents.number_of_cols).map(AnyColumn) { + let table_cell = row.expect_get(col); + + // TODO: Introduce `` cell type + let cell_content = table_cell.display_value().cloned().unwrap_or_default(); + + let display_cell_id = DisplayCellId::new(display_row, col); + + let cell = div().size_full().whitespace_nowrap().text_ellipsis().child( + CsvPreviewView::create_selectable_cell( + display_cell_id, + cell_content, + this.settings.vertical_alignment, + this.settings.font_type, + cx, + ), + ); + + elements.push( + div() + .size_full() + .when(this.settings.show_debug_info, |parent| { + parent.child(div().text_color(row_identifier_text_color).child( + match table_cell { + TableCell::Real { position: pos, .. } => { + let slv = pos.start.timestamp().value; + let so = pos.start.offset; + let elv = pos.end.timestamp().value; + let eo = pos.end.offset; + format!("Pos {so}(L{slv})-{eo}(L{elv})") + } + TableCell::Virtual => "Virtual cell".into(), + }, + )) + }) + .text_ui(cx) + .child(cell) + .into_any_element(), + ); + } + + Some(elements) + } +} diff --git a/crates/csv_preview/src/renderer/row_identifiers.rs b/crates/csv_preview/src/renderer/row_identifiers.rs new file mode 100644 index 0000000000000000000000000000000000000000..a122aa9bf3d803b9deb9c6211e117ba4aa593d93 --- /dev/null +++ b/crates/csv_preview/src/renderer/row_identifiers.rs @@ -0,0 +1,189 @@ +use ui::{ + ActiveTheme as _, AnyElement, Button, ButtonCommon as _, ButtonSize, ButtonStyle, + Clickable as _, Context, ElementId, FluentBuilder as _, IntoElement as _, ParentElement as _, + SharedString, Styled as _, StyledTypography as _, Tooltip, div, +}; + +use crate::{ + CsvPreviewView, + settings::{FontType, RowIdentifiers}, + types::{DataRow, DisplayRow, LineNumber}, +}; + +pub enum RowIdentDisplayMode { + /// E.g + /// ```text + /// 1 + /// ... + /// 5 + /// ``` + Vertical, + /// E.g. + /// ```text + /// 1-5 + /// ``` + Horizontal, +} + +impl LineNumber { + pub fn display_string(&self, mode: RowIdentDisplayMode) -> String { + match *self { + LineNumber::Line(line) => line.to_string(), + LineNumber::LineRange(start, end) => match mode { + RowIdentDisplayMode::Vertical => { + if start + 1 == end { + format!("{start}\n{end}") + } else { + format!("{start}\n...\n{end}") + } + } + RowIdentDisplayMode::Horizontal => { + format!("{start}-{end}") + } + }, + } + } +} + +impl CsvPreviewView { + /// Calculate the optimal width for the row identifier column (line numbers or row numbers). + /// + /// This ensures the column is wide enough to display the largest identifier comfortably, + /// but not wastefully wide for small files. + pub(crate) fn calculate_row_identifier_column_width(&self) -> f32 { + match self.settings.numbering_type { + RowIdentifiers::SrcLines => self.calculate_line_number_width(), + RowIdentifiers::RowNum => self.calculate_row_number_width(), + } + } + + /// Calculate width needed for line numbers (can be multi-line) + fn calculate_line_number_width(&self) -> f32 { + // Find the maximum line number that could be displayed + let max_line_number = self + .engine + .contents + .line_numbers + .iter() + .map(|ln| match ln { + LineNumber::Line(n) => *n, + LineNumber::LineRange(_, end) => *end, + }) + .max() + .unwrap_or_default(); + + let digit_count = if max_line_number == 0 { + 1 + } else { + (max_line_number as f32).log10().floor() as usize + 1 + }; + + // if !self.settings.multiline_cells_enabled { + // // Uses horizontal line numbers layout like `123-456`. Needs twice the size + // digit_count *= 2; + // } + + let char_width_px = 9.0; // TODO: get real width of the characters + let base_width = (digit_count as f32) * char_width_px; + let padding = 20.0; + let min_width = 60.0; + (base_width + padding).max(min_width) + } + + /// Calculate width needed for sequential row numbers + fn calculate_row_number_width(&self) -> f32 { + let max_row_number = self.engine.contents.rows.len(); + + let digit_count = if max_row_number == 0 { + 1 + } else { + (max_row_number as f32).log10().floor() as usize + 1 + }; + + let char_width_px = 9.0; // TODO: get real width of the characters + let base_width = (digit_count as f32) * char_width_px; + let padding = 20.0; + let min_width = 60.0; + (base_width + padding).max(min_width) + } + + pub(crate) fn create_row_identifier_header( + &self, + cx: &mut Context<'_, CsvPreviewView>, + ) -> AnyElement { + // First column: row identifier (clickable to toggle between Lines and Rows) + let row_identifier_text = match self.settings.numbering_type { + RowIdentifiers::SrcLines => "Lines", + RowIdentifiers::RowNum => "Rows", + }; + + let view = cx.entity(); + let value = div() + .map(|div| match self.settings.font_type { + FontType::Ui => div.font_ui(cx), + FontType::Monospace => div.font_buffer(cx), + }) + .child( + Button::new( + ElementId::Name("row-identifier-toggle".into()), + row_identifier_text, + ) + .style(ButtonStyle::Subtle) + .size(ButtonSize::Compact) + .tooltip(Tooltip::text( + "Toggle between: file line numbers or sequential row numbers", + )) + .on_click(move |_event, _window, cx| { + view.update(cx, |this, cx| { + this.settings.numbering_type = match this.settings.numbering_type { + RowIdentifiers::SrcLines => RowIdentifiers::RowNum, + RowIdentifiers::RowNum => RowIdentifiers::SrcLines, + }; + cx.notify(); + }); + }), + ) + .into_any_element(); + value + } + + pub(crate) fn create_row_identifier_cell( + &self, + display_row: DisplayRow, + data_row: DataRow, + cx: &Context<'_, CsvPreviewView>, + ) -> Option { + let row_identifier: SharedString = match self.settings.numbering_type { + RowIdentifiers::SrcLines => self + .engine + .contents + .line_numbers + .get(*data_row)? + .display_string(if self.settings.multiline_cells_enabled { + RowIdentDisplayMode::Vertical + } else { + RowIdentDisplayMode::Horizontal + }) + .into(), + RowIdentifiers::RowNum => (*display_row + 1).to_string().into(), + }; + + let value = div() + .flex() + .px_1() + .border_b_1() + .border_color(cx.theme().colors().border_variant) + .h_full() + .text_ui(cx) + // Row identifiers are always centered + .items_center() + .justify_end() + .map(|div| match self.settings.font_type { + FontType::Ui => div.font_ui(cx), + FontType::Monospace => div.font_buffer(cx), + }) + .child(row_identifier) + .into_any_element(); + Some(value) + } +} diff --git a/crates/csv_preview/src/renderer/table_cell.rs b/crates/csv_preview/src/renderer/table_cell.rs new file mode 100644 index 0000000000000000000000000000000000000000..32900ab77708936e218e9af10a4de5fba796e6a7 --- /dev/null +++ b/crates/csv_preview/src/renderer/table_cell.rs @@ -0,0 +1,72 @@ +//! Table Cell Rendering + +use gpui::{AnyElement, ElementId}; +use ui::{SharedString, Tooltip, div, prelude::*}; + +use crate::{ + CsvPreviewView, + settings::{FontType, VerticalAlignment}, + types::DisplayCellId, +}; + +impl CsvPreviewView { + /// Create selectable table cell with mouse event handlers. + pub fn create_selectable_cell( + display_cell_id: DisplayCellId, + cell_content: SharedString, + vertical_alignment: VerticalAlignment, + font_type: FontType, + cx: &Context, + ) -> AnyElement { + create_table_cell( + display_cell_id, + cell_content, + vertical_alignment, + font_type, + cx, + ) + // Mouse events handlers will be here + .into_any_element() + } +} + +/// Create styled table cell div element. +fn create_table_cell( + display_cell_id: DisplayCellId, + cell_content: SharedString, + vertical_alignment: VerticalAlignment, + font_type: FontType, + cx: &Context<'_, CsvPreviewView>, +) -> gpui::Stateful
{ + div() + .id(ElementId::NamedInteger( + format!( + "csv-display-cell-{}-{}", + *display_cell_id.row, *display_cell_id.col + ) + .into(), + 0, + )) + .cursor_pointer() + .flex() + .h_full() + .px_1() + .bg(cx.theme().colors().editor_background) + .border_b_1() + .border_r_1() + .border_color(cx.theme().colors().border_variant) + .map(|div| match vertical_alignment { + VerticalAlignment::Top => div.items_start(), + VerticalAlignment::Center => div.items_center(), + }) + .map(|div| match vertical_alignment { + VerticalAlignment::Top => div.content_start(), + VerticalAlignment::Center => div.content_center(), + }) + .map(|div| match font_type { + FontType::Ui => div.font_ui(cx), + FontType::Monospace => div.font_buffer(cx), + }) + .tooltip(Tooltip::text(cell_content.clone())) + .child(div().child(cell_content)) +} diff --git a/crates/csv_preview/src/renderer/table_header.rs b/crates/csv_preview/src/renderer/table_header.rs new file mode 100644 index 0000000000000000000000000000000000000000..52a16be9fc81ef1c3f001513b652a33c3b06dc82 --- /dev/null +++ b/crates/csv_preview/src/renderer/table_header.rs @@ -0,0 +1,94 @@ +use gpui::ElementId; +use ui::{Tooltip, prelude::*}; + +use crate::{ + CsvPreviewView, + settings::FontType, + table_data_engine::sorting_by_column::{AppliedSorting, SortDirection}, + types::AnyColumn, +}; + +impl CsvPreviewView { + /// Create header for data, which is orderable with text on the left and sort button on the right + pub(crate) fn create_header_element_with_sort_button( + &self, + header_text: SharedString, + cx: &mut Context<'_, CsvPreviewView>, + col_idx: AnyColumn, + ) -> AnyElement { + // CSV data columns: text + filter/sort buttons + h_flex() + .justify_between() + .items_center() + .w_full() + .map(|div| match self.settings.font_type { + FontType::Ui => div.font_ui(cx), + FontType::Monospace => div.font_buffer(cx), + }) + .child(div().child(header_text)) + .child(h_flex().gap_1().child(self.create_sort_button(cx, col_idx))) + .into_any_element() + } + + fn create_sort_button( + &self, + cx: &mut Context<'_, CsvPreviewView>, + col_idx: AnyColumn, + ) -> Button { + let sort_btn = Button::new( + ElementId::NamedInteger("sort-button".into(), col_idx.get() as u64), + match self.engine.applied_sorting { + Some(ordering) if ordering.col_idx == col_idx => match ordering.direction { + SortDirection::Asc => "↓", + SortDirection::Desc => "↑", + }, + _ => "↕", // Unsorted/available for sorting + }, + ) + .size(ButtonSize::Compact) + .style( + if self + .engine + .applied_sorting + .is_some_and(|o| o.col_idx == col_idx) + { + ButtonStyle::Filled + } else { + ButtonStyle::Subtle + }, + ) + .tooltip(Tooltip::text(match self.engine.applied_sorting { + Some(ordering) if ordering.col_idx == col_idx => match ordering.direction { + SortDirection::Asc => "Sorted A-Z. Click to sort Z-A", + SortDirection::Desc => "Sorted Z-A. Click to disable sorting", + }, + _ => "Not sorted. Click to sort A-Z", + })) + .on_click(cx.listener(move |this, _event, _window, cx| { + let new_sorting = match this.engine.applied_sorting { + Some(ordering) if ordering.col_idx == col_idx => { + // Same column clicked - cycle through states + match ordering.direction { + SortDirection::Asc => Some(AppliedSorting { + col_idx, + direction: SortDirection::Desc, + }), + SortDirection::Desc => None, // Clear sorting + } + } + _ => { + // Different column or no sorting - start with ascending + Some(AppliedSorting { + col_idx, + direction: SortDirection::Asc, + }) + } + }; + + this.engine.applied_sorting = new_sorting; + this.apply_sort(); + cx.notify(); + })); + sort_btn + } +} diff --git a/crates/csv_preview/src/settings.rs b/crates/csv_preview/src/settings.rs new file mode 100644 index 0000000000000000000000000000000000000000..e627b3cc994a84f54268a05ba17534789f631fe0 --- /dev/null +++ b/crates/csv_preview/src/settings.rs @@ -0,0 +1,46 @@ +#[derive(Default, Clone, Copy)] +pub enum RowRenderMechanism { + /// Default behaviour + #[default] + VariableList, + /// More performance oriented, but all rows are same height + #[allow(dead_code)] // Will be used when settings ui is added + UniformList, +} + +#[derive(Default, Clone, Copy)] +pub enum VerticalAlignment { + /// Align text to the top of cells + #[default] + Top, + /// Center text vertically in cells + Center, +} + +#[derive(Default, Clone, Copy)] +pub enum FontType { + /// Use the default UI font + #[default] + Ui, + /// Use monospace font (same as buffer/editor font) + Monospace, +} + +#[derive(Default, Clone, Copy)] +pub enum RowIdentifiers { + /// Show original line numbers from CSV file + #[default] + SrcLines, + /// Show sequential row numbers starting from 1 + RowNum, +} + +#[derive(Clone, Default)] +pub(crate) struct CsvPreviewSettings { + pub(crate) rendering_with: RowRenderMechanism, + pub(crate) vertical_alignment: VerticalAlignment, + pub(crate) font_type: FontType, + pub(crate) numbering_type: RowIdentifiers, + pub(crate) show_debug_info: bool, + pub(crate) multiline_cells_enabled: bool, +} diff --git a/crates/csv_preview/src/table_data_engine.rs b/crates/csv_preview/src/table_data_engine.rs new file mode 100644 index 0000000000000000000000000000000000000000..382b41a28507213dcc5993adb49a1fddc5e7b64c --- /dev/null +++ b/crates/csv_preview/src/table_data_engine.rs @@ -0,0 +1,90 @@ +//! This module defines core operations and config of tabular data view (CSV table) +//! It operates in 2 coordinate systems: +//! - `DataCellId` - indices of src data cells +//! - `DisplayCellId` - indices of data after applied transformations like sorting/filtering, which is used to render cell on the screen +//! +//! It's designed to contain core logic of operations without relying on `CsvPreviewView`, context or window handles. + +use std::{collections::HashMap, sync::Arc}; + +use ui::table_row::TableRow; + +use crate::{ + table_data_engine::sorting_by_column::{AppliedSorting, sort_data_rows}, + types::{DataRow, DisplayRow, TableCell, TableLikeContent}, +}; + +pub mod sorting_by_column; + +#[derive(Default)] +pub(crate) struct TableDataEngine { + pub applied_sorting: Option, + d2d_mapping: DisplayToDataMapping, + pub contents: TableLikeContent, +} + +impl TableDataEngine { + pub(crate) fn d2d_mapping(&self) -> &DisplayToDataMapping { + &self.d2d_mapping + } + + pub(crate) fn apply_sort(&mut self) { + self.d2d_mapping + .apply_sorting(self.applied_sorting, &self.contents.rows); + self.d2d_mapping.merge_mappings(); + } + + /// Applies sorting and filtering to the data and produces display to data mapping + pub(crate) fn calculate_d2d_mapping(&mut self) { + self.d2d_mapping + .apply_sorting(self.applied_sorting, &self.contents.rows); + self.d2d_mapping.merge_mappings(); + } +} + +/// Relation of Display (rendered) rows to Data (src) rows with applied transformations +/// Transformations applied: +/// - sorting by column +#[derive(Debug, Default)] +pub struct DisplayToDataMapping { + /// All rows sorted, regardless of applied filtering. Applied every time sorting changes + pub sorted_rows: Vec, + /// Filtered and sorted rows. Computed cheaply from `sorted_mapping` and `filtered_out_rows` + pub mapping: Arc>, +} + +impl DisplayToDataMapping { + /// Get the data row for a given display row + pub fn get_data_row(&self, display_row: DisplayRow) -> Option { + self.mapping.get(&display_row).copied() + } + + /// Get the number of filtered rows + pub fn visible_row_count(&self) -> usize { + self.mapping.len() + } + + /// Computes sorting + fn apply_sorting(&mut self, sorting: Option, rows: &[TableRow]) { + let data_rows: Vec = (0..rows.len()).map(DataRow).collect(); + + let sorted_rows = if let Some(sorting) = sorting { + sort_data_rows(&rows, data_rows, sorting) + } else { + data_rows + }; + + self.sorted_rows = sorted_rows; + } + + /// Take pre-computed sorting and filtering results, and apply them to the mapping + fn merge_mappings(&mut self) { + self.mapping = Arc::new( + self.sorted_rows + .iter() + .enumerate() + .map(|(display, data)| (DisplayRow(display), *data)) + .collect(), + ); + } +} diff --git a/crates/csv_preview/src/table_data_engine/sorting_by_column.rs b/crates/csv_preview/src/table_data_engine/sorting_by_column.rs new file mode 100644 index 0000000000000000000000000000000000000000..52d61351a3d4a8fad0cec60d8c6c594fec05c545 --- /dev/null +++ b/crates/csv_preview/src/table_data_engine/sorting_by_column.rs @@ -0,0 +1,49 @@ +use ui::table_row::TableRow; + +use crate::types::{AnyColumn, DataRow, TableCell}; + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum SortDirection { + Asc, + Desc, +} + +/// Config or currently active sorting +#[derive(Debug, Clone, Copy)] +pub struct AppliedSorting { + /// 0-based column index + pub col_idx: AnyColumn, + /// Direction of sorting (asc/desc) + pub direction: SortDirection, +} + +pub fn sort_data_rows( + content_rows: &[TableRow], + mut data_row_ids: Vec, + sorting: AppliedSorting, +) -> Vec { + data_row_ids.sort_by(|&a, &b| { + let row_a = &content_rows[*a]; + let row_b = &content_rows[*b]; + + // TODO: Decide how to handle nulls (on top or on bottom) + let val_a = row_a + .get(sorting.col_idx) + .and_then(|tc| tc.display_value()) + .map(|tc| tc.as_str()) + .unwrap_or(""); + let val_b = row_b + .get(sorting.col_idx) + .and_then(|tc| tc.display_value()) + .map(|tc| tc.as_str()) + .unwrap_or(""); + + let cmp = val_a.cmp(val_b); + match sorting.direction { + SortDirection::Asc => cmp, + SortDirection::Desc => cmp.reverse(), + } + }); + + data_row_ids +} diff --git a/crates/csv_preview/src/types.rs b/crates/csv_preview/src/types.rs new file mode 100644 index 0000000000000000000000000000000000000000..87fc513f53e61db996d39dcb05409c765fd0c6dc --- /dev/null +++ b/crates/csv_preview/src/types.rs @@ -0,0 +1,17 @@ +use std::fmt::Debug; + +pub use coordinates::*; +mod coordinates; +pub use table_cell::*; +mod table_cell; +pub use table_like_content::*; +mod table_like_content; + +/// Line number information for CSV rows +#[derive(Debug, Clone, Copy)] +pub enum LineNumber { + /// Single line row + Line(usize), + /// Multi-line row spanning from start to end line. Incluisive + LineRange(usize, usize), +} diff --git a/crates/csv_preview/src/types/coordinates.rs b/crates/csv_preview/src/types/coordinates.rs new file mode 100644 index 0000000000000000000000000000000000000000..d800bef6ce0dd54d5ae65301163f79013e447ce3 --- /dev/null +++ b/crates/csv_preview/src/types/coordinates.rs @@ -0,0 +1,127 @@ +//! Type definitions for CSV table coordinates and cell identifiers. +//! +//! Provides newtypes for self-documenting coordinate systems: +//! - Display coordinates: Visual positions in rendered table +//! - Data coordinates: Original CSV data positions + +use std::ops::Deref; + +///// Rows ///// +/// Visual row position in rendered table. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct DisplayRow(pub usize); + +impl DisplayRow { + /// Create a new display row + pub fn new(row: usize) -> Self { + Self(row) + } + + /// Get the inner row value + pub fn get(self) -> usize { + self.0 + } +} + +impl Deref for DisplayRow { + type Target = usize; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +/// Original CSV row position. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct DataRow(pub usize); + +impl DataRow { + /// Create a new data row + pub fn new(row: usize) -> Self { + Self(row) + } +} + +impl Deref for DataRow { + type Target = usize; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl From for DisplayRow { + fn from(row: usize) -> Self { + DisplayRow::new(row) + } +} + +impl From for DataRow { + fn from(row: usize) -> Self { + DataRow::new(row) + } +} + +///// Columns ///// +/// Data column position in CSV table. 0-based +/// +/// Currently represents both display and data coordinate systems since +/// column reordering is not yet implemented. When column reordering is added, +/// this will need to be split into `DisplayColumn` and `DataColumn` types. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct AnyColumn(pub usize); + +impl AnyColumn { + /// Create a new column ID + pub fn new(col: usize) -> Self { + Self(col) + } + + /// Get the inner column value + pub fn get(self) -> usize { + self.0 + } +} + +impl Deref for AnyColumn { + type Target = usize; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl From for AnyColumn { + fn from(col: usize) -> Self { + AnyColumn::new(col) + } +} + +impl From for usize { + fn from(value: AnyColumn) -> Self { + *value + } +} + +///// Cells ///// +/// Visual cell position in rendered table. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct DisplayCellId { + pub row: DisplayRow, + pub col: AnyColumn, +} + +impl DisplayCellId { + /// Create a new display cell ID + pub fn new(row: impl Into, col: impl Into) -> Self { + Self { + row: row.into(), + col: col.into(), + } + } + + /// Returns (row, column) + pub fn to_raw(&self) -> (usize, usize) { + (self.row.0, self.col.0) + } +} diff --git a/crates/csv_preview/src/types/table_cell.rs b/crates/csv_preview/src/types/table_cell.rs new file mode 100644 index 0000000000000000000000000000000000000000..b6f9adb3fe82b0d468d1ffc8404e707a762e94ea --- /dev/null +++ b/crates/csv_preview/src/types/table_cell.rs @@ -0,0 +1,54 @@ +use text::Anchor; +use ui::SharedString; + +/// Position of a cell within the source CSV buffer +#[derive(Clone, Debug)] +pub struct CellContentSpan { + /// Start anchor of the cell content in the source buffer + pub start: Anchor, + /// End anchor of the cell content in the source buffer + pub end: Anchor, +} + +/// A table cell with its content and position in the source buffer +#[derive(Clone, Debug)] +pub enum TableCell { + /// Cell existing in the CSV + Real { + /// Position of this cell in the source buffer + position: CellContentSpan, + /// Cached display value (for performance) + cached_value: SharedString, + }, + /// Virtual cell, created to pad malformed row + Virtual, +} + +impl TableCell { + /// Create a TableCell with buffer position tracking + pub fn from_buffer_position( + content: SharedString, + start_offset: usize, + end_offset: usize, + buffer_snapshot: &text::BufferSnapshot, + ) -> Self { + let start_anchor = buffer_snapshot.anchor_before(start_offset); + let end_anchor = buffer_snapshot.anchor_after(end_offset); + + Self::Real { + position: CellContentSpan { + start: start_anchor, + end: end_anchor, + }, + cached_value: content, + } + } + + /// Get the display value for this cell + pub fn display_value(&self) -> Option<&SharedString> { + match self { + TableCell::Real { cached_value, .. } => Some(cached_value), + TableCell::Virtual => None, + } + } +} diff --git a/crates/csv_preview/src/types/table_like_content.rs b/crates/csv_preview/src/types/table_like_content.rs new file mode 100644 index 0000000000000000000000000000000000000000..7bf205af812c24d70f33157f8ab7acc454c3b0d5 --- /dev/null +++ b/crates/csv_preview/src/types/table_like_content.rs @@ -0,0 +1,32 @@ +use ui::table_row::TableRow; + +use crate::types::{DataRow, LineNumber, TableCell}; + +/// Generic container struct of table-like data (CSV, TSV, etc) +#[derive(Clone)] +pub struct TableLikeContent { + /// Number of data columns. + /// Defines table width used to validate `TableRow` on creation + pub number_of_cols: usize, + pub headers: TableRow, + pub rows: Vec>, + /// Follows the same indices as `rows` + pub line_numbers: Vec, +} + +impl Default for TableLikeContent { + fn default() -> Self { + Self { + number_of_cols: 0, + headers: TableRow::::from_vec(vec![], 0), + rows: vec![], + line_numbers: vec![], + } + } +} + +impl TableLikeContent { + pub(crate) fn get_row(&self, data_row: DataRow) -> Option<&TableRow> { + self.rows.get(*data_row) + } +} diff --git a/crates/debugger_ui/src/session/running/memory_view.rs b/crates/debugger_ui/src/session/running/memory_view.rs index f10e5179e37f87be0e27985b557fcb63cf089a42..69ea556018fdadeb1e270b1d7c2520d25752e670 100644 --- a/crates/debugger_ui/src/session/running/memory_view.rs +++ b/crates/debugger_ui/src/session/running/memory_view.rs @@ -133,7 +133,7 @@ impl ViewState { fn set_offset(&mut self, point: Point) { if point.y >= -Pixels::ZERO { self.schedule_scroll_up(); - } else if point.y <= -self.scroll_handle.max_offset().height { + } else if point.y <= -self.scroll_handle.max_offset().y { self.schedule_scroll_down(); } self.scroll_handle.set_offset(point); @@ -141,7 +141,7 @@ impl ViewState { } impl ScrollableHandle for ViewStateHandle { - fn max_offset(&self) -> gpui::Size { + fn max_offset(&self) -> gpui::Point { self.0.borrow().scroll_handle.max_offset() } diff --git a/crates/docs_preprocessor/src/main.rs b/crates/docs_preprocessor/src/main.rs index 6ef599542a5b2f511915d7435af192162a5dbd3b..43efbeea0b0310cf70cd9bdb560b1b0d2b0c14ef 100644 --- a/crates/docs_preprocessor/src/main.rs +++ b/crates/docs_preprocessor/src/main.rs @@ -578,6 +578,7 @@ fn handle_postprocessing() -> Result<()> { .expect("Default title not a string") .to_string(); let amplitude_key = std::env::var("DOCS_AMPLITUDE_API_KEY").unwrap_or_default(); + let consent_io_instance = std::env::var("DOCS_CONSENT_IO_INSTANCE").unwrap_or_default(); output.insert("html".to_string(), zed_html); mdbook::Renderer::render(&mdbook::renderer::HtmlHandlebars::new(), &ctx)?; @@ -647,6 +648,7 @@ fn handle_postprocessing() -> Result<()> { zlog::trace!(logger => "Updating {:?}", pretty_path(&file, &root_dir)); let contents = contents.replace("#description#", meta_description); let contents = contents.replace("#amplitude_key#", &litude_key); + let contents = contents.replace("#consent_io_instance#", &consent_io_instance); let contents = title_regex() .replace(&contents, |_: ®ex::Captures| { format!("{}", meta_title) diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index e6e3a9abdf83deb785cd56d358b065973682b8cc..5c7ce045121739f341b84dd87d827878550f4048 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -1,7 +1,7 @@ use anyhow::Result; use arrayvec::ArrayVec; use client::{Client, EditPredictionUsage, UserStore}; -use cloud_api_types::SubmitEditPredictionFeedbackBody; +use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody}; use cloud_llm_client::predict_edits_v3::{ PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse, }; @@ -69,6 +69,7 @@ pub mod sweep_ai; pub mod udiff; mod capture_example; +pub mod open_ai_compatible; mod zed_edit_prediction_delegate; pub mod zeta; @@ -107,13 +108,8 @@ const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled"; const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5); const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10); -pub struct Zeta2FeatureFlag; pub struct EditPredictionJumpsFeatureFlag; -impl FeatureFlag for Zeta2FeatureFlag { - const NAME: &'static str = "zeta2"; -} - impl FeatureFlag for EditPredictionJumpsFeatureFlag { const NAME: &'static str = "edit_prediction_jumps"; } @@ -129,6 +125,7 @@ impl Global for EditPredictionStoreGlobal {} #[derive(Clone)] pub struct Zeta2RawConfig { pub model_id: Option, + pub environment: Option, pub format: ZetaFormat, } @@ -147,7 +144,7 @@ pub struct EditPredictionStore { pub sweep_ai: SweepAi, pub mercury: Mercury, data_collection_choice: DataCollectionChoice, - reject_predictions_tx: mpsc::UnboundedSender, + reject_predictions_tx: mpsc::UnboundedSender, settled_predictions_tx: mpsc::UnboundedSender, shown_predictions: VecDeque, rated_predictions: HashSet, @@ -155,6 +152,11 @@ pub struct EditPredictionStore { settled_event_callback: Option>, } +pub(crate) struct EditPredictionRejectionPayload { + rejection: EditPredictionRejection, + organization_id: Option, +} + #[derive(Copy, Clone, PartialEq, Eq)] pub enum EditPredictionModel { Zeta, @@ -723,8 +725,13 @@ impl EditPredictionStore { |this, _listener, _event, cx| { let client = this.client.clone(); let llm_token = this.llm_token.clone(); + let organization_id = this + .user_store + .read(cx) + .current_organization() + .map(|organization| organization.id.clone()); cx.spawn(async move |_this, _cx| { - llm_token.refresh(&client).await?; + llm_token.refresh(&client, organization_id).await?; anyhow::Ok(()) }) .detach_and_log_err(cx); @@ -754,7 +761,12 @@ impl EditPredictionStore { let version_str = env::var("ZED_ZETA_FORMAT").ok()?; let format = ZetaFormat::parse(&version_str).ok()?; let model_id = env::var("ZED_ZETA_MODEL").ok(); - Some(Zeta2RawConfig { model_id, format }) + let environment = env::var("ZED_ZETA_ENVIRONMENT").ok(); + Some(Zeta2RawConfig { + model_id, + environment, + format, + }) } pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) { @@ -785,11 +797,17 @@ impl EditPredictionStore { let client = self.client.clone(); let llm_token = self.llm_token.clone(); let app_version = AppVersion::global(cx); + let organization_id = self + .user_store + .read(cx) + .current_organization() + .map(|organization| organization.id.clone()); + cx.spawn(async move |this, cx| { let experiments = cx .background_spawn(async move { let http_client = client.http_client(); - let token = llm_token.acquire(&client).await?; + let token = llm_token.acquire(&client, organization_id).await?; let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?; let request = http_client::Request::builder() .method(Method::GET) @@ -1428,7 +1446,7 @@ impl EditPredictionStore { } async fn handle_rejected_predictions( - rx: UnboundedReceiver, + rx: UnboundedReceiver, client: Arc, llm_token: LlmApiToken, app_version: Version, @@ -1437,7 +1455,11 @@ impl EditPredictionStore { let mut rx = std::pin::pin!(rx.peekable()); let mut batched = Vec::new(); - while let Some(rejection) = rx.next().await { + while let Some(EditPredictionRejectionPayload { + rejection, + organization_id, + }) = rx.next().await + { batched.push(rejection); if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 { @@ -1475,6 +1497,7 @@ impl EditPredictionStore { }, client.clone(), llm_token.clone(), + organization_id, app_version.clone(), true, ) @@ -1680,13 +1703,23 @@ impl EditPredictionStore { all_language_settings(None, cx).edit_predictions.provider, EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi ); + if is_cloud { + let organization_id = self + .user_store + .read(cx) + .current_organization() + .map(|organization| organization.id.clone()); + self.reject_predictions_tx - .unbounded_send(EditPredictionRejection { - request_id: prediction_id.to_string(), - reason, - was_shown, - model_version, + .unbounded_send(EditPredictionRejectionPayload { + rejection: EditPredictionRejection { + request_id: prediction_id.to_string(), + reason, + was_shown, + model_version, + }, + organization_id, }) .log_err(); } @@ -2108,7 +2141,7 @@ impl EditPredictionStore { active_buffer.clone(), position, trigger, - cx.has_flag::(), + cx.has_flag::(), cx, ) } @@ -2341,6 +2374,7 @@ impl EditPredictionStore { client: Arc, custom_url: Option>, llm_token: LlmApiToken, + organization_id: Option, app_version: Version, ) -> Result<(RawCompletionResponse, Option)> { let url = if let Some(custom_url) = custom_url { @@ -2360,6 +2394,7 @@ impl EditPredictionStore { }, client, llm_token, + organization_id, app_version, true, ) @@ -2370,6 +2405,7 @@ impl EditPredictionStore { input: ZetaPromptInput, client: Arc, llm_token: LlmApiToken, + organization_id: Option, app_version: Version, trigger: PredictEditsRequestTrigger, ) -> Result<(PredictEditsV3Response, Option)> { @@ -2392,6 +2428,7 @@ impl EditPredictionStore { }, client, llm_token, + organization_id, app_version, true, ) @@ -2445,6 +2482,7 @@ impl EditPredictionStore { build: impl Fn(http_client::http::request::Builder) -> Result>, client: Arc, llm_token: LlmApiToken, + organization_id: Option, app_version: Version, require_auth: bool, ) -> Result<(Res, Option)> @@ -2454,9 +2492,12 @@ impl EditPredictionStore { let http_client = client.http_client(); let mut token = if require_auth { - Some(llm_token.acquire(&client).await?) + Some(llm_token.acquire(&client, organization_id.clone()).await?) } else { - llm_token.acquire(&client).await.ok() + llm_token + .acquire(&client, organization_id.clone()) + .await + .ok() }; let mut did_retry = false; @@ -2498,7 +2539,7 @@ impl EditPredictionStore { return Ok((serde_json::from_slice(&body)?, usage)); } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() { did_retry = true; - token = Some(llm_token.refresh(&client).await?); + token = Some(llm_token.refresh(&client, organization_id.clone()).await?); } else { let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index cc3bb84808981fd1430f9e71aa796e590cc78169..b34ff6fce71fe7afcaff68121510f48f6f8f98c4 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -1848,6 +1848,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { experiment: None, in_open_source_repo: false, can_collect_data: false, + repo_url: None, }, buffer_snapshotted_at: Instant::now(), response_received_at: Instant::now(), diff --git a/crates/edit_prediction/src/fim.rs b/crates/edit_prediction/src/fim.rs index 66f2e58a3b01b4fbf49b11864db4daec6b4dc1c2..02053aae7154acdfa22a01a4f84d6b732a9ca696 100644 --- a/crates/edit_prediction/src/fim.rs +++ b/crates/edit_prediction/src/fim.rs @@ -1,6 +1,7 @@ use crate::{ - EditPredictionId, EditPredictionModelInput, cursor_excerpt, prediction::EditPredictionResult, - zeta, + EditPredictionId, EditPredictionModelInput, cursor_excerpt, + open_ai_compatible::{self, load_open_ai_compatible_api_key_if_needed}, + prediction::EditPredictionResult, }; use anyhow::{Context as _, Result, anyhow}; use gpui::{App, AppContext as _, Entity, Task}; @@ -58,6 +59,8 @@ pub fn request_prediction( return Task::ready(Err(anyhow!("Unsupported edit prediction provider for FIM"))); }; + let api_key = load_open_ai_compatible_api_key_if_needed(provider, cx); + let result = cx.background_spawn(async move { let (excerpt_range, _) = cursor_excerpt::editable_and_context_ranges_for_cursor_position( cursor_point, @@ -82,6 +85,7 @@ pub fn request_prediction( experiment: None, in_open_source_repo: false, can_collect_data: false, + repo_url: None, }; let prefix = inputs.cursor_excerpt[..inputs.cursor_offset_in_excerpt].to_string(); @@ -90,12 +94,14 @@ pub fn request_prediction( let stop_tokens = get_fim_stop_tokens(); let max_tokens = settings.max_output_tokens; - let (response_text, request_id) = zeta::send_custom_server_request( + + let (response_text, request_id) = open_ai_compatible::send_custom_server_request( provider, &settings, prompt, max_tokens, stop_tokens, + api_key, &http_client, ) .await?; diff --git a/crates/edit_prediction/src/mercury.rs b/crates/edit_prediction/src/mercury.rs index bf9b43d528db1717f54143e4805e41aefc81f64a..f61219e2f71d5efbb2fb67250b58b0a5a090e9a8 100644 --- a/crates/edit_prediction/src/mercury.rs +++ b/crates/edit_prediction/src/mercury.rs @@ -112,6 +112,7 @@ impl Mercury { }, in_open_source_repo: false, can_collect_data: false, + repo_url: None, }; let prompt = build_prompt(&inputs); diff --git a/crates/edit_prediction/src/open_ai_compatible.rs b/crates/edit_prediction/src/open_ai_compatible.rs new file mode 100644 index 0000000000000000000000000000000000000000..ca378ba1fd0bc9bdbb3e85c7610e1b94c1be388f --- /dev/null +++ b/crates/edit_prediction/src/open_ai_compatible.rs @@ -0,0 +1,133 @@ +use anyhow::{Context as _, Result}; +use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse}; +use futures::AsyncReadExt as _; +use gpui::{App, AppContext as _, Entity, Global, SharedString, Task, http_client}; +use language::language_settings::{OpenAiCompatibleEditPredictionSettings, all_language_settings}; +use language_model::{ApiKeyState, EnvVar, env_var}; +use std::sync::Arc; + +pub fn open_ai_compatible_api_url(cx: &App) -> SharedString { + all_language_settings(None, cx) + .edit_predictions + .open_ai_compatible_api + .as_ref() + .map(|settings| settings.api_url.clone()) + .unwrap_or_default() + .into() +} + +pub const OPEN_AI_COMPATIBLE_CREDENTIALS_USERNAME: &str = "openai-compatible-api-token"; +pub static OPEN_AI_COMPATIBLE_TOKEN_ENV_VAR: std::sync::LazyLock = + env_var!("ZED_OPEN_AI_COMPATIBLE_EDIT_PREDICTION_API_KEY"); + +struct GlobalOpenAiCompatibleApiKey(Entity); + +impl Global for GlobalOpenAiCompatibleApiKey {} + +pub fn open_ai_compatible_api_token(cx: &mut App) -> Entity { + if let Some(global) = cx.try_global::() { + return global.0.clone(); + } + + let entity = cx.new(|cx| { + ApiKeyState::new( + open_ai_compatible_api_url(cx), + OPEN_AI_COMPATIBLE_TOKEN_ENV_VAR.clone(), + ) + }); + cx.set_global(GlobalOpenAiCompatibleApiKey(entity.clone())); + entity +} + +pub fn load_open_ai_compatible_api_token( + cx: &mut App, +) -> Task> { + let api_url = open_ai_compatible_api_url(cx); + open_ai_compatible_api_token(cx).update(cx, |key_state, cx| { + key_state.load_if_needed(api_url, |s| s, cx) + }) +} + +pub fn load_open_ai_compatible_api_key_if_needed( + provider: settings::EditPredictionProvider, + cx: &mut App, +) -> Option> { + if provider != settings::EditPredictionProvider::OpenAiCompatibleApi { + return None; + } + _ = load_open_ai_compatible_api_token(cx); + let url = open_ai_compatible_api_url(cx); + return open_ai_compatible_api_token(cx).read(cx).key(&url); +} + +pub(crate) async fn send_custom_server_request( + provider: settings::EditPredictionProvider, + settings: &OpenAiCompatibleEditPredictionSettings, + prompt: String, + max_tokens: u32, + stop_tokens: Vec, + api_key: Option>, + http_client: &Arc, +) -> Result<(String, String)> { + match provider { + settings::EditPredictionProvider::Ollama => { + let response = crate::ollama::make_request( + settings.clone(), + prompt, + stop_tokens, + http_client.clone(), + ) + .await?; + Ok((response.response, response.created_at)) + } + _ => { + let request = RawCompletionRequest { + model: settings.model.clone(), + prompt, + max_tokens: Some(max_tokens), + temperature: None, + stop: stop_tokens + .into_iter() + .map(std::borrow::Cow::Owned) + .collect(), + environment: None, + }; + + let request_body = serde_json::to_string(&request)?; + let mut http_request_builder = http_client::Request::builder() + .method(http_client::Method::POST) + .uri(settings.api_url.as_ref()) + .header("Content-Type", "application/json"); + + if let Some(api_key) = api_key { + http_request_builder = + http_request_builder.header("Authorization", format!("Bearer {}", api_key)); + } + + let http_request = + http_request_builder.body(http_client::AsyncBody::from(request_body))?; + + let mut response = http_client.send(http_request).await?; + let status = response.status(); + + if !status.is_success() { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + anyhow::bail!("custom server error: {} - {}", status, body); + } + + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + let parsed: RawCompletionResponse = + serde_json::from_str(&body).context("Failed to parse completion response")?; + let text = parsed + .choices + .into_iter() + .next() + .map(|choice| choice.text) + .unwrap_or_default(); + Ok((text, parsed.id)) + } + } +} diff --git a/crates/edit_prediction/src/prediction.rs b/crates/edit_prediction/src/prediction.rs index 0dd33c03a95d77ec680d47d96daa8e6a44f51b62..263409043b397e2df1ac32514a0ce76656fbefe1 100644 --- a/crates/edit_prediction/src/prediction.rs +++ b/crates/edit_prediction/src/prediction.rs @@ -165,6 +165,7 @@ mod tests { experiment: None, in_open_source_repo: false, can_collect_data: false, + repo_url: None, }, buffer_snapshotted_at: Instant::now(), response_received_at: Instant::now(), diff --git a/crates/edit_prediction/src/sweep_ai.rs b/crates/edit_prediction/src/sweep_ai.rs index d88a159a47aa7633a5b064e72a75dd61604710e1..d8ce180801aa8902bfff79044cabaae7570ed05f 100644 --- a/crates/edit_prediction/src/sweep_ai.rs +++ b/crates/edit_prediction/src/sweep_ai.rs @@ -229,6 +229,7 @@ impl SweepAi { experiment: None, in_open_source_repo: false, can_collect_data: false, + repo_url: None, }; send_started_event( diff --git a/crates/edit_prediction/src/zeta.rs b/crates/edit_prediction/src/zeta.rs index f6a786572736908556535b9131c1cf7814a6126f..3397d31276efcc7e1d68336f87ccf3e035f51f3a 100644 --- a/crates/edit_prediction/src/zeta.rs +++ b/crates/edit_prediction/src/zeta.rs @@ -2,29 +2,30 @@ use crate::cursor_excerpt::compute_excerpt_ranges; use crate::prediction::EditPredictionResult; use crate::{ CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, - EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, ollama, + EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, }; -use anyhow::{Context as _, Result}; -use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse}; +use anyhow::Result; +use cloud_llm_client::predict_edits_v3::RawCompletionRequest; use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason}; use edit_prediction_types::PredictedCursorPosition; -use futures::AsyncReadExt as _; -use gpui::{App, AppContext as _, Task, http_client, prelude::*}; -use language::language_settings::{OpenAiCompatibleEditPredictionSettings, all_language_settings}; +use gpui::{App, AppContext as _, Task, prelude::*}; +use language::language_settings::all_language_settings; use language::{BufferSnapshot, ToOffset as _, ToPoint, text_diff}; use release_channel::AppVersion; use settings::EditPredictionPromptFormat; use text::{Anchor, Bias}; -use std::env; -use std::ops::Range; -use std::{path::Path, sync::Arc, time::Instant}; +use std::{env, ops::Range, path::Path, sync::Arc, time::Instant}; use zeta_prompt::{ CURSOR_MARKER, ZetaFormat, clean_zeta2_model_output, format_zeta_prompt, get_prefill, - prompt_input_contains_special_tokens, + output_with_context_for_format, prompt_input_contains_special_tokens, zeta1::{self, EDITABLE_REGION_END_MARKER}, }; +use crate::open_ai_compatible::{ + load_open_ai_compatible_api_key_if_needed, send_custom_server_request, +}; + pub fn request_prediction_with_zeta( store: &mut EditPredictionStore, EditPredictionModelInput { @@ -56,14 +57,32 @@ pub fn request_prediction_with_zeta( let buffer_snapshotted_at = Instant::now(); let raw_config = store.zeta2_raw_config().cloned(); let preferred_experiment = store.preferred_experiment().map(|s| s.to_owned()); + let open_ai_compatible_api_key = load_open_ai_compatible_api_key_if_needed(provider, cx); let excerpt_path: Arc = snapshot .file() .map(|file| -> Arc { file.full_path(cx).into() }) .unwrap_or_else(|| Arc::from(Path::new("untitled"))); + let repo_url = if can_collect_data { + let buffer_id = buffer.read(cx).remote_id(); + project + .read(cx) + .git_store() + .read(cx) + .repository_and_path_for_buffer_id(buffer_id, cx) + .and_then(|(repo, _)| repo.read(cx).default_remote_url()) + } else { + None + }; + let client = store.client.clone(); let llm_token = store.llm_token.clone(); + let organization_id = store + .user_store + .read(cx) + .current_organization() + .map(|organization| organization.id.clone()); let app_version = AppVersion::global(cx); let request_task = cx.background_spawn({ @@ -84,6 +103,7 @@ pub fn request_prediction_with_zeta( preferred_experiment, is_open_source, can_collect_data, + repo_url, ); if prompt_input_contains_special_tokens(&prompt_input, zeta_version) { @@ -131,6 +151,7 @@ pub fn request_prediction_with_zeta( prompt, max_tokens, stop_tokens, + open_ai_compatible_api_key.clone(), &http_client, ) .await?; @@ -157,6 +178,7 @@ pub fn request_prediction_with_zeta( prompt, max_tokens, vec![], + open_ai_compatible_api_key.clone(), &http_client, ) .await?; @@ -177,13 +199,17 @@ pub fn request_prediction_with_zeta( let prompt = format_zeta_prompt(&prompt_input, config.format); let prefill = get_prefill(&prompt_input, config.format); let prompt = format!("{prompt}{prefill}"); + let environment = config + .environment + .clone() + .or_else(|| Some(config.format.to_string().to_lowercase())); let request = RawCompletionRequest { model: config.model_id.clone().unwrap_or_default(), prompt, temperature: None, stop: vec![], max_tokens: Some(2048), - environment: Some(config.format.to_string().to_lowercase()), + environment, }; editable_range_in_excerpt = zeta_prompt::excerpt_range_for_format( @@ -197,6 +223,7 @@ pub fn request_prediction_with_zeta( client, None, llm_token, + organization_id, app_version, ) .await?; @@ -215,6 +242,7 @@ pub fn request_prediction_with_zeta( prompt_input.clone(), client, llm_token, + organization_id, app_version, trigger, ) @@ -240,6 +268,25 @@ pub fn request_prediction_with_zeta( return Ok((Some((request_id, None, model_version)), usage)); }; + let editable_range_in_buffer = editable_range_in_excerpt.start + + full_context_offset_range.start + ..editable_range_in_excerpt.end + full_context_offset_range.start; + + let mut old_text = snapshot + .text_for_range(editable_range_in_buffer.clone()) + .collect::(); + + // For the hashline format, the model may return <|set|>/<|insert|> + // edit commands instead of a full replacement. Apply them against + // the original editable region to produce the full replacement text. + // This must happen before cursor marker stripping because the cursor + // marker is embedded inside edit command content. + if let Some(rewritten_output) = + output_with_context_for_format(zeta_version, &old_text, &output_text)? + { + output_text = rewritten_output; + } + // Client-side cursor marker processing (applies to both raw and v3 responses) let cursor_offset_in_output = output_text.find(CURSOR_MARKER); if let Some(offset) = cursor_offset_in_output { @@ -259,14 +306,6 @@ pub fn request_prediction_with_zeta( .ok(); } - let editable_range_in_buffer = editable_range_in_excerpt.start - + full_context_offset_range.start - ..editable_range_in_excerpt.end + full_context_offset_range.start; - - let mut old_text = snapshot - .text_for_range(editable_range_in_buffer.clone()) - .collect::(); - if !output_text.is_empty() && !output_text.ends_with('\n') { output_text.push('\n'); } @@ -365,6 +404,7 @@ pub fn zeta2_prompt_input( preferred_experiment: Option, is_open_source: bool, can_collect_data: bool, + repo_url: Option, ) -> (Range, zeta_prompt::ZetaPromptInput) { let cursor_point = cursor_offset.to_point(snapshot); @@ -396,70 +436,11 @@ pub fn zeta2_prompt_input( experiment: preferred_experiment, in_open_source_repo: is_open_source, can_collect_data, + repo_url, }; (full_context_offset_range, prompt_input) } -pub(crate) async fn send_custom_server_request( - provider: settings::EditPredictionProvider, - settings: &OpenAiCompatibleEditPredictionSettings, - prompt: String, - max_tokens: u32, - stop_tokens: Vec, - http_client: &Arc, -) -> Result<(String, String)> { - match provider { - settings::EditPredictionProvider::Ollama => { - let response = - ollama::make_request(settings.clone(), prompt, stop_tokens, http_client.clone()) - .await?; - Ok((response.response, response.created_at)) - } - _ => { - let request = RawCompletionRequest { - model: settings.model.clone(), - prompt, - max_tokens: Some(max_tokens), - temperature: None, - stop: stop_tokens - .into_iter() - .map(std::borrow::Cow::Owned) - .collect(), - environment: None, - }; - - let request_body = serde_json::to_string(&request)?; - let http_request = http_client::Request::builder() - .method(http_client::Method::POST) - .uri(settings.api_url.as_ref()) - .header("Content-Type", "application/json") - .body(http_client::AsyncBody::from(request_body))?; - - let mut response = http_client.send(http_request).await?; - let status = response.status(); - - if !status.is_success() { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - anyhow::bail!("custom server error: {} - {}", status, body); - } - - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - let parsed: RawCompletionResponse = - serde_json::from_str(&body).context("Failed to parse completion response")?; - let text = parsed - .choices - .into_iter() - .next() - .map(|choice| choice.text) - .unwrap_or_default(); - Ok((text, parsed.id)) - } - } -} - pub(crate) fn edit_prediction_accepted( store: &EditPredictionStore, current_prediction: CurrentEditPrediction, @@ -475,6 +456,11 @@ pub(crate) fn edit_prediction_accepted( let require_auth = custom_accept_url.is_none(); let client = store.client.clone(); let llm_token = store.llm_token.clone(); + let organization_id = store + .user_store + .read(cx) + .current_organization() + .map(|organization| organization.id.clone()); let app_version = AppVersion::global(cx); cx.background_spawn(async move { @@ -499,6 +485,7 @@ pub(crate) fn edit_prediction_accepted( }, client, llm_token, + organization_id, app_version, require_auth, ) diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index ecacd963023d7d113ea5ad77b61fd1d88306fc95..f36eaf2799166d6fbd2b7b212003a1a0644b82c4 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -12,7 +12,8 @@ use similar::DiffableStr; use std::ops::Range; use std::sync::Arc; use zeta_prompt::{ - ZetaFormat, excerpt_range_for_format, format_zeta_prompt, resolve_cursor_region, + ZetaFormat, encode_patch_as_output_for_format, excerpt_range_for_format, format_zeta_prompt, + output_end_marker_for_format, resolve_cursor_region, }; pub async fn run_format_prompt( @@ -53,18 +54,22 @@ pub async fn run_format_prompt( let prompt = format_zeta_prompt(prompt_inputs, zeta_format); let prefill = zeta_prompt::get_prefill(prompt_inputs, zeta_format); - let (expected_patch, expected_cursor_offset) = example + let expected_output = example .spec .expected_patches_with_cursor_positions() .into_iter() .next() - .context("expected patches is empty")?; - let expected_output = zeta2_output_for_patch( - prompt_inputs, - &expected_patch, - expected_cursor_offset, - zeta_format, - )?; + .and_then(|(expected_patch, expected_cursor_offset)| { + zeta2_output_for_patch( + prompt_inputs, + &expected_patch, + expected_cursor_offset, + zeta_format, + ) + .ok() + }) + .unwrap_or_default(); + let rejected_output = example.spec.rejected_patch.as_ref().and_then(|patch| { zeta2_output_for_patch(prompt_inputs, patch, None, zeta_format).ok() }); @@ -97,6 +102,12 @@ pub fn zeta2_output_for_patch( old_editable_region.push('\n'); } + if let Some(encoded_output) = + encode_patch_as_output_for_format(version, &old_editable_region, patch, cursor_offset)? + { + return Ok(encoded_output); + } + let (mut result, first_hunk_offset) = udiff::apply_diff_to_string_with_hunk_offset(patch, &old_editable_region).with_context( || { @@ -116,16 +127,11 @@ pub fn zeta2_output_for_patch( result.insert_str(offset, zeta_prompt::CURSOR_MARKER); } - match version { - ZetaFormat::V0120GitMergeMarkers - | ZetaFormat::V0131GitMergeMarkersPrefix - | ZetaFormat::V0211SeedCoder => { - if !result.ends_with('\n') { - result.push('\n'); - } - result.push_str(zeta_prompt::v0120_git_merge_markers::END_MARKER); + if let Some(end_marker) = output_end_marker_for_format(version) { + if !result.ends_with('\n') { + result.push('\n'); } - _ => (), + result.push_str(end_marker); } Ok(result) diff --git a/crates/edit_prediction_cli/src/load_project.rs b/crates/edit_prediction_cli/src/load_project.rs index dcf417c2e8cc70dfcaffdf4b96dbe3b17daa61d4..df458770519be5accd72f33a56893bb13c9b88a9 100644 --- a/crates/edit_prediction_cli/src/load_project.rs +++ b/crates/edit_prediction_cli/src/load_project.rs @@ -105,6 +105,7 @@ pub async fn run_load_project( in_open_source_repo: false, can_collect_data: false, experiment: None, + repo_url: None, }, language_name, ) diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 207a69328fb07277c39463c0c6a460862c95fe42..8bb4b2a8e2f50d448fc314a70e2fc94cfa2c3d71 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -358,6 +358,7 @@ enum PredictionProvider { Mercury, Zeta1, Zeta2(ZetaFormat), + Baseten(ZetaFormat), Teacher(TeacherBackend), TeacherNonBatching(TeacherBackend), Repair, @@ -376,6 +377,7 @@ impl std::fmt::Display for PredictionProvider { PredictionProvider::Mercury => write!(f, "mercury"), PredictionProvider::Zeta1 => write!(f, "zeta1"), PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"), + PredictionProvider::Baseten(format) => write!(f, "baseten:{format}"), PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"), PredictionProvider::TeacherNonBatching(backend) => { write!(f, "teacher-non-batching:{backend}") @@ -415,6 +417,13 @@ impl std::str::FromStr for PredictionProvider { Ok(PredictionProvider::TeacherNonBatching(backend)) } "repair" => Ok(PredictionProvider::Repair), + "baseten" => { + let format = arg + .map(ZetaFormat::parse) + .transpose()? + .unwrap_or(ZetaFormat::default()); + Ok(PredictionProvider::Baseten(format)) + } _ => { anyhow::bail!( "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:, teacher, teacher:, teacher-non-batching, repair\n\ diff --git a/crates/edit_prediction_cli/src/metrics.rs b/crates/edit_prediction_cli/src/metrics.rs index fc870c36c9c62f4d74486ddd4b2d35176b00bb5c..1bfd8e542fa3d74b55f091d2ac13aa22883f6a2f 100644 --- a/crates/edit_prediction_cli/src/metrics.rs +++ b/crates/edit_prediction_cli/src/metrics.rs @@ -76,14 +76,21 @@ impl ClassificationMetrics { } enum ChrfWhitespace { + /// Preserve whitespace as-is #[allow(unused)] Unchanged, + + /// Ignore all whitespace differences + #[allow(unused)] Ignore, + + /// Collapse whitespace into single spaces + Collapse, } const CHR_F_CHAR_ORDER: usize = 6; const CHR_F_BETA: f64 = 2.0; -const CHR_F_WHITESPACE: ChrfWhitespace = ChrfWhitespace::Ignore; +const CHR_F_WHITESPACE: ChrfWhitespace = ChrfWhitespace::Collapse; /// Computes a delta-chrF score that compares two sets of edits. /// @@ -196,9 +203,34 @@ fn filter_whitespace_chars(text: &str) -> Vec { match CHR_F_WHITESPACE { ChrfWhitespace::Unchanged => text.chars().collect(), ChrfWhitespace::Ignore => text.chars().filter(|c| !c.is_whitespace()).collect(), + ChrfWhitespace::Collapse => collapse_whitespace(text.chars()), } } +/// Collapse whitespace into single spaces. +/// Newlines and spaces are collapsed separately. +fn collapse_whitespace(chars: impl Iterator) -> Vec { + let mut result = Vec::new(); + let mut last_whitespace = None; + for c in chars { + if c.is_whitespace() && c != '\n' { + if last_whitespace != Some(' ') { + result.push(' '); + last_whitespace = Some(' '); + } + } else if c == '\n' { + if last_whitespace != Some('\n') { + result.push(c); + last_whitespace = Some('\n'); + } + } else { + result.push(c); + last_whitespace = None; + } + } + result +} + /// Extract only the changed regions between two texts, with context for n-gram boundaries. /// /// Returns (original_affected_region, modified_affected_region) as Vec. @@ -269,15 +301,15 @@ fn count_ngrams_from_chars(chars: &[char], n: usize) -> Counts { #[allow(dead_code)] fn chr_f_ngram_counts(text: &str) -> Vec { - // Ignore whitespace. The original chrF implementation skips all - // whitespace. We should consider compressing multiple consecutive - // spaces into one -- this may reflect our task more closely. let text = match CHR_F_WHITESPACE { ChrfWhitespace::Unchanged => text.to_string(), ChrfWhitespace::Ignore => text .chars() .filter(|c| !c.is_whitespace()) .collect::(), + ChrfWhitespace::Collapse => collapse_whitespace(text.chars()) + .into_iter() + .collect::(), }; (1..=CHR_F_CHAR_ORDER) @@ -1175,4 +1207,14 @@ index abc123..def456 100644 assert!(counts.deleted_tokens >= 2); assert!(counts.inserted_tokens >= 2); } + + #[test] + fn test_whitespace_collapse() { + let text = "abc \n\n\n 123"; + let collapsed = collapse_whitespace(text.chars()); + assert_eq!( + collapsed, + vec!['a', 'b', 'c', ' ', '\n', ' ', '1', '2', '3'] + ); + } } diff --git a/crates/edit_prediction_cli/src/parse_output.rs b/crates/edit_prediction_cli/src/parse_output.rs index 4b8af44785c1781de772f569c012ee64eee48aad..2c066b8b32b3eaab54ad6e3b3bcb0796ff27f950 100644 --- a/crates/edit_prediction_cli/src/parse_output.rs +++ b/crates/edit_prediction_cli/src/parse_output.rs @@ -6,7 +6,11 @@ use crate::{ }; use anyhow::{Context as _, Result}; use edit_prediction::example_spec::encode_cursor_in_patch; -use zeta_prompt::{CURSOR_MARKER, ZetaFormat}; +use zeta_prompt::{ + CURSOR_MARKER, ZetaFormat, clean_extracted_region_for_format, + current_region_markers_for_format, output_end_marker_for_format, + output_with_context_for_format, +}; pub fn run_parse_output(example: &mut Example) -> Result<()> { example @@ -51,22 +55,7 @@ pub fn parse_prediction_output( } fn extract_zeta2_current_region(prompt: &str, format: ZetaFormat) -> Result { - let (current_marker, end_marker) = match format { - ZetaFormat::V0112MiddleAtEnd => ("<|fim_middle|>current\n", "<|fim_middle|>updated"), - ZetaFormat::V0113Ordered | ZetaFormat::V0114180EditableRegion => { - ("<|fim_middle|>current\n", "<|fim_suffix|>") - } - ZetaFormat::V0120GitMergeMarkers - | ZetaFormat::V0131GitMergeMarkersPrefix - | ZetaFormat::V0211Prefill => ( - zeta_prompt::v0120_git_merge_markers::START_MARKER, - zeta_prompt::v0120_git_merge_markers::SEPARATOR, - ), - ZetaFormat::V0211SeedCoder => ( - zeta_prompt::seed_coder::START_MARKER, - zeta_prompt::seed_coder::SEPARATOR, - ), - }; + let (current_marker, end_marker) = current_region_markers_for_format(format); let start = prompt.find(current_marker).with_context(|| { format!( @@ -82,8 +71,7 @@ fn extract_zeta2_current_region(prompt: &str, format: ZetaFormat) -> Result { - zeta_prompt::v0131_git_merge_markers_prefix::END_MARKER - } - ZetaFormat::V0120GitMergeMarkers => zeta_prompt::v0120_git_merge_markers::END_MARKER, - ZetaFormat::V0112MiddleAtEnd - | ZetaFormat::V0113Ordered - | ZetaFormat::V0114180EditableRegion => "", - ZetaFormat::V0211SeedCoder => zeta_prompt::seed_coder::END_MARKER, - }; - if !suffix.is_empty() { + if let Some(marker) = output_end_marker_for_format(format) { new_text = new_text - .strip_suffix(suffix) + .strip_suffix(marker) .unwrap_or(&new_text) .to_string(); } diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index 02ba24b8a4f2627b9542254e3d118981737f8318..94e28d00da2d61f63b59364304c3b9b4276e15f7 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -6,14 +6,18 @@ use crate::{ headless::EpAppState, load_project::run_load_project, openai_client::OpenAiClient, + parse_output::parse_prediction_output, paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR}, - progress::{ExampleProgress, InfoStyle, Step}, + progress::{ExampleProgress, InfoStyle, Step, StepProgress}, retrieve_context::run_context_retrieval, }; use anyhow::Context as _; +use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse}; use edit_prediction::{DebugEvent, EditPredictionStore, Zeta2RawConfig}; -use futures::{FutureExt as _, StreamExt as _, future::Shared}; +use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _, future::Shared}; use gpui::{AppContext as _, AsyncApp, Task}; +use http_client::{AsyncBody, HttpClient, Method}; +use reqwest_client::ReqwestClient; use std::{ fs, sync::{ @@ -79,6 +83,22 @@ pub async fn run_prediction( .await; } + if let PredictionProvider::Baseten(format) = provider { + run_format_prompt( + example, + &FormatPromptArgs { + provider: PredictionProvider::Zeta2(format), + }, + app_state.clone(), + example_progress, + cx, + ) + .await?; + + let step_progress = example_progress.start(Step::Predict); + return predict_baseten(example, format, &step_progress).await; + } + run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?; run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?; @@ -116,7 +136,8 @@ pub async fn run_prediction( PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury, PredictionProvider::Teacher(..) | PredictionProvider::TeacherNonBatching(..) - | PredictionProvider::Repair => { + | PredictionProvider::Repair + | PredictionProvider::Baseten(_) => { unreachable!() } }; @@ -127,7 +148,12 @@ pub async fn run_prediction( if let PredictionProvider::Zeta2(format) = provider { if format != ZetaFormat::default() { let model_id = std::env::var("ZED_ZETA_MODEL").ok(); - store.set_zeta2_raw_config(Zeta2RawConfig { model_id, format }); + let environment = std::env::var("ZED_ZETA_ENVIRONMENT").ok(); + store.set_zeta2_raw_config(Zeta2RawConfig { + model_id, + environment, + format, + }); } } }); @@ -364,7 +390,7 @@ async fn predict_anthropic( .await? else { // Request stashed for batched processing - return Ok(()); + continue; }; let actual_output = response @@ -438,7 +464,7 @@ async fn predict_openai( .await? else { // Request stashed for batched processing - return Ok(()); + continue; }; let actual_output = response @@ -480,6 +506,89 @@ async fn predict_openai( Ok(()) } +pub async fn predict_baseten( + example: &mut Example, + format: ZetaFormat, + step_progress: &StepProgress, +) -> anyhow::Result<()> { + let model_id = + std::env::var("ZED_ZETA_MODEL").context("ZED_ZETA_MODEL environment variable required")?; + + let api_key = + std::env::var("BASETEN_API_KEY").context("BASETEN_API_KEY environment variable not set")?; + + let prompt = example.prompt.as_ref().context("Prompt is required")?; + let prompt_text = prompt.input.clone(); + let prefill = prompt.prefill.clone().unwrap_or_default(); + + step_progress.set_substatus("running prediction via baseten"); + + let environment: String = <&'static str>::from(&format).to_lowercase(); + let url = format!( + "https://model-{model_id}.api.baseten.co/environments/{environment}/sync/v1/completions" + ); + + let request_body = RawCompletionRequest { + model: model_id, + prompt: prompt_text.clone(), + max_tokens: Some(2048), + temperature: Some(0.), + stop: vec![], + environment: None, + }; + + let body_bytes = + serde_json::to_vec(&request_body).context("Failed to serialize request body")?; + + let http_client: Arc = Arc::new(ReqwestClient::new()); + let request = http_client::Request::builder() + .method(Method::POST) + .uri(&url) + .header("Content-Type", "application/json") + .header("Authorization", format!("Api-Key {api_key}")) + .body(AsyncBody::from(body_bytes))?; + + let mut response = http_client.send(request).await?; + let status = response.status(); + + let mut body = String::new(); + response + .body_mut() + .read_to_string(&mut body) + .await + .context("Failed to read Baseten response body")?; + + if !status.is_success() { + anyhow::bail!("Baseten API returned {status}: {body}"); + } + + let completion: RawCompletionResponse = + serde_json::from_str(&body).context("Failed to parse Baseten response")?; + + let actual_output = completion + .choices + .into_iter() + .next() + .map(|choice| choice.text) + .unwrap_or_default(); + + let actual_output = format!("{prefill}{actual_output}"); + + let (actual_patch, actual_cursor) = + parse_prediction_output(example, &actual_output, PredictionProvider::Zeta2(format))?; + + let prediction = ExamplePrediction { + actual_patch: Some(actual_patch), + actual_output, + actual_cursor, + error: None, + provider: PredictionProvider::Baseten(format), + }; + + example.predictions.push(prediction); + Ok(()) +} + pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> { match provider { Some(PredictionProvider::Teacher(backend)) => match backend { diff --git a/crates/edit_prediction_cli/src/pull_examples.rs b/crates/edit_prediction_cli/src/pull_examples.rs index 2f371675b29015795beef550ce5e3956c63751f9..cccd351dcdeda0dbf059d851a44b02bc1e558654 100644 --- a/crates/edit_prediction_cli/src/pull_examples.rs +++ b/crates/edit_prediction_cli/src/pull_examples.rs @@ -34,7 +34,7 @@ pub struct MinCaptureVersion { pub patch: u32, } -const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120; +const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 240; const SETTLED_STATEMENT_TIMEOUT_SECONDS: u64 = 240; pub(crate) const POLL_INTERVAL: Duration = Duration::from_secs(2); pub(crate) const MAX_POLL_ATTEMPTS: usize = 120; @@ -715,7 +715,7 @@ pub async fn fetch_rated_examples_after( AND rated.event_properties:inputs IS NOT NULL AND rated.event_properties:inputs:cursor_excerpt IS NOT NULL AND rated.event_properties:output IS NOT NULL - AND rated.event_properties:can_collect_data = true + AND rated.event_properties:inputs:can_collect_data = true ORDER BY rated.time ASC LIMIT ? OFFSET ? @@ -823,11 +823,11 @@ fn rated_examples_from_response<'a>( let environment = get_string("environment"); let zed_version = get_string("zed_version"); - match (inputs, output.clone(), rating.clone(), device_id.clone(), time.clone()) { - (Some(inputs), Some(output), Some(rating), Some(device_id), Some(time)) => { + match (inputs, output.clone(), rating.clone(), time.clone()) { + (Some(inputs), Some(output), Some(rating), Some(time)) => { Some(build_rated_example( request_id, - device_id, + device_id.unwrap_or_default(), time, inputs, output, @@ -840,11 +840,10 @@ fn rated_examples_from_response<'a>( } _ => { log::warn!( - "skipping row {row_index}: missing fields - inputs={:?} output={:?} rating={:?} device_id={:?} time={:?}", + "skipping row {row_index}: missing fields - inputs={:?} output={:?} rating={:?} time={:?}", inputs_json.is_some(), output.is_some(), rating.is_some(), - device_id.is_some(), time.is_some(), ); None diff --git a/crates/edit_prediction_cli/src/reversal_tracking.rs b/crates/edit_prediction_cli/src/reversal_tracking.rs index 2d578c8666f217365ed2ed24ff766ed6f19566d7..cb955dbdf7dd2375395e8c0ecd52df849e33fb38 100644 --- a/crates/edit_prediction_cli/src/reversal_tracking.rs +++ b/crates/edit_prediction_cli/src/reversal_tracking.rs @@ -681,6 +681,7 @@ mod tests { experiment: None, in_open_source_repo: false, can_collect_data: false, + repo_url: None, } } diff --git a/crates/edit_prediction_ui/src/edit_prediction_button.rs b/crates/edit_prediction_ui/src/edit_prediction_button.rs index 6339c7d6cd9fa1cc40101cc1bf14650a6904b3c7..b00a229164d480d38312ca97cac31a23010f8b69 100644 --- a/crates/edit_prediction_ui/src/edit_prediction_button.rs +++ b/crates/edit_prediction_ui/src/edit_prediction_button.rs @@ -3,7 +3,7 @@ use client::{Client, UserStore, zed_urls}; use cloud_llm_client::UsageLimit; use codestral::{self, CodestralEditPredictionDelegate}; use copilot::Status; -use edit_prediction::{EditPredictionStore, Zeta2FeatureFlag}; +use edit_prediction::EditPredictionStore; use edit_prediction_types::EditPredictionDelegateHandle; use editor::{ Editor, MultiBufferOffset, SelectionEffects, actions::ShowEditPrediction, scroll::Autoscroll, @@ -22,9 +22,7 @@ use language::{ }; use project::{DisableAiSettings, Project}; use regex::Regex; -use settings::{ - EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, Settings, SettingsStore, update_settings_file, -}; +use settings::{Settings, SettingsStore, update_settings_file}; use std::{ rc::Rc, sync::{Arc, LazyLock}, @@ -539,9 +537,15 @@ impl EditPredictionButton { edit_prediction::ollama::ensure_authenticated(cx); let sweep_api_token_task = edit_prediction::sweep_ai::load_sweep_api_token(cx); let mercury_api_token_task = edit_prediction::mercury::load_mercury_api_token(cx); + let open_ai_compatible_api_token_task = + edit_prediction::open_ai_compatible::load_open_ai_compatible_api_token(cx); cx.spawn(async move |this, cx| { - _ = futures::join!(sweep_api_token_task, mercury_api_token_task); + _ = futures::join!( + sweep_api_token_task, + mercury_api_token_task, + open_ai_compatible_api_token_task + ); this.update(cx, |_, cx| { cx.notify(); }) @@ -770,13 +774,7 @@ impl EditPredictionButton { menu = menu.separator().header("Privacy"); - if matches!( - provider, - EditPredictionProvider::Zed - | EditPredictionProvider::Experimental( - EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, - ) - ) { + if matches!(provider, EditPredictionProvider::Zed) { if let Some(provider) = &self.edit_prediction_provider { let data_collection = provider.data_collection_state(cx); @@ -1399,12 +1397,6 @@ pub fn get_available_providers(cx: &mut App) -> Vec { providers.push(EditPredictionProvider::Zed); - if cx.has_flag::() { - providers.push(EditPredictionProvider::Experimental( - EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, - )); - } - if let Some(app_state) = workspace::AppState::global(cx).upgrade() && copilot::GlobalCopilotAuth::try_get_or_init(app_state, cx) .is_some_and(|copilot| copilot.0.read(cx).is_authenticated()) diff --git a/crates/editor/src/display_map.rs b/crates/editor/src/display_map.rs index 57b8eb8ef6c1b29cb99da3e2a4e731d0c828038e..00a48a9ab3d249850b9749d64267d8274e7eaa79 100644 --- a/crates/editor/src/display_map.rs +++ b/crates/editor/src/display_map.rs @@ -789,6 +789,9 @@ impl DisplayMap { .collect(), cx, ); + for buffer_id in &other.block_snapshot.buffers_with_disabled_headers { + self.disable_header_for_buffer(*buffer_id, cx); + } } /// Creates folds for the given creases. @@ -1003,10 +1006,6 @@ impl DisplayMap { &self.block_map.folded_buffers } - pub(super) fn clear_folded_buffer(&mut self, buffer_id: language::BufferId) { - self.block_map.folded_buffers.remove(&buffer_id); - } - #[instrument(skip_all)] pub fn insert_creases( &mut self, @@ -1920,6 +1919,9 @@ impl DisplaySnapshot { color } }), + underline: chunk_highlight + .underline + .filter(|_| editor_style.show_underlines), ..chunk_highlight } }); diff --git a/crates/editor/src/display_map/block_map.rs b/crates/editor/src/display_map/block_map.rs index db7eb53b39088c6026d3d36bef636f748c80d587..2673baae84ab74b2852004320cf1d94c5ed1ed42 100644 --- a/crates/editor/src/display_map/block_map.rs +++ b/crates/editor/src/display_map/block_map.rs @@ -78,6 +78,7 @@ pub struct BlockSnapshot { custom_blocks_by_id: TreeMap>, pub(super) buffer_header_height: u32, pub(super) excerpt_header_height: u32, + pub(super) buffers_with_disabled_headers: HashSet, } impl Deref for BlockSnapshot { @@ -657,6 +658,7 @@ impl BlockMap { custom_blocks_by_id: self.custom_blocks_by_id.clone(), buffer_header_height: self.buffer_header_height, excerpt_header_height: self.excerpt_header_height, + buffers_with_disabled_headers: self.buffers_with_disabled_headers.clone(), }, } } diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index 28d96e721257eaad898408cafba67f9f991e4909..3b18c9a447d8fb4569bbf331f1ba8e4602a555b9 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -1973,6 +1973,8 @@ impl Editor { .clone_state(&self.scroll_manager, &my_snapshot, &clone_snapshot, cx); clone.searchable = self.searchable; clone.read_only = self.read_only; + clone.buffers_with_disabled_indent_guides = + self.buffers_with_disabled_indent_guides.clone(); clone } @@ -5214,29 +5216,48 @@ impl Editor { extra_line_additional_indent, prevent_auto_indent, } => { + let auto_indent_mode = + buffer.language_settings_at(start, cx).auto_indent; + let preserve_indent = + auto_indent_mode != language::AutoIndentMode::None; + let apply_syntax_indent = + auto_indent_mode == language::AutoIndentMode::SyntaxAware; let capacity_for_delimiter = delimiter.as_deref().map(str::len).unwrap_or_default(); + let existing_indent_len = if preserve_indent { + existing_indent.len as usize + } else { + 0 + }; let extra_line_len = extra_line_additional_indent - .map(|i| 1 + existing_indent.len as usize + i.len as usize) + .map(|i| 1 + existing_indent_len + i.len as usize) .unwrap_or(0); let mut new_text = String::with_capacity( 1 + capacity_for_delimiter - + existing_indent.len as usize + + existing_indent_len + additional_indent.len as usize + extra_line_len, ); new_text.push('\n'); - new_text.extend(existing_indent.chars()); + if preserve_indent { + new_text.extend(existing_indent.chars()); + } new_text.extend(additional_indent.chars()); if let Some(delimiter) = &delimiter { new_text.push_str(delimiter); } if let Some(extra_indent) = extra_line_additional_indent { new_text.push('\n'); - new_text.extend(existing_indent.chars()); + if preserve_indent { + new_text.extend(existing_indent.chars()); + } new_text.extend(extra_indent.chars()); } - (start, new_text, *prevent_auto_indent) + ( + start, + new_text, + *prevent_auto_indent || !apply_syntax_indent, + ) } }; @@ -24145,9 +24166,13 @@ impl Editor { self.display_map.update(cx, |display_map, cx| { display_map.invalidate_semantic_highlights(*buffer_id); display_map.clear_lsp_folding_ranges(*buffer_id, cx); - display_map.clear_folded_buffer(*buffer_id); }); } + + self.display_map.update(cx, |display_map, cx| { + display_map.unfold_buffers(removed_buffer_ids.iter().copied(), cx); + }); + jsx_tag_auto_close::refresh_enabled_in_any_buffer(self, multibuffer, cx); cx.emit(EditorEvent::ExcerptsRemoved { ids: ids.clone(), diff --git a/crates/editor/src/editor_tests.rs b/crates/editor/src/editor_tests.rs index 38abff942acf8717000090a90654f1117ba5005d..199cb0d3785a048f6390070d67546394bd89ff68 100644 --- a/crates/editor/src/editor_tests.rs +++ b/crates/editor/src/editor_tests.rs @@ -10716,7 +10716,9 @@ async fn test_autoindent(cx: &mut TestAppContext) { #[gpui::test] async fn test_autoindent_disabled(cx: &mut TestAppContext) { - init_test(cx, |settings| settings.defaults.auto_indent = Some(false)); + init_test(cx, |settings| { + settings.defaults.auto_indent = Some(settings::AutoIndentMode::None) + }); let language = Arc::new( Language::new( @@ -10794,14 +10796,165 @@ async fn test_autoindent_disabled(cx: &mut TestAppContext) { }); } +#[gpui::test] +async fn test_autoindent_none_does_not_preserve_indentation_on_newline(cx: &mut TestAppContext) { + init_test(cx, |settings| { + settings.defaults.auto_indent = Some(settings::AutoIndentMode::None) + }); + + let mut cx = EditorTestContext::new(cx).await; + + cx.set_state(indoc! {" + hello + indented lineˇ + world + "}); + + cx.update_editor(|editor, window, cx| { + editor.newline(&Newline, window, cx); + }); + + cx.assert_editor_state(indoc! {" + hello + indented line + ˇ + world + "}); +} + +#[gpui::test] +async fn test_autoindent_preserve_indent_maintains_indentation_on_newline(cx: &mut TestAppContext) { + // When auto_indent is "preserve_indent", pressing Enter on an indented line + // should preserve the indentation but not adjust based on syntax. + init_test(cx, |settings| { + settings.defaults.auto_indent = Some(settings::AutoIndentMode::PreserveIndent) + }); + + let mut cx = EditorTestContext::new(cx).await; + + cx.set_state(indoc! {" + hello + indented lineˇ + world + "}); + + cx.update_editor(|editor, window, cx| { + editor.newline(&Newline, window, cx); + }); + + // The new line SHOULD have the same indentation as the previous line + cx.assert_editor_state(indoc! {" + hello + indented line + ˇ + world + "}); +} + +#[gpui::test] +async fn test_autoindent_preserve_indent_does_not_apply_syntax_indent(cx: &mut TestAppContext) { + init_test(cx, |settings| { + settings.defaults.auto_indent = Some(settings::AutoIndentMode::PreserveIndent) + }); + + let language = Arc::new( + Language::new( + LanguageConfig { + brackets: BracketPairConfig { + pairs: vec![BracketPair { + start: "{".to_string(), + end: "}".to_string(), + close: false, + surround: false, + newline: false, // Disable extra newline behavior to isolate syntax indent test + }], + ..Default::default() + }, + ..Default::default() + }, + Some(tree_sitter_rust::LANGUAGE.into()), + ) + .with_indents_query(r#"(_ "{" "}" @end) @indent"#) + .unwrap(), + ); + + let buffer = + cx.new(|cx| Buffer::local("fn foo() {\n}", cx).with_language(language.clone(), cx)); + let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); + let (editor, cx) = cx.add_window_view(|window, cx| build_editor(buffer, window, cx)); + editor + .condition::(cx, |editor, cx| !editor.buffer.read(cx).is_parsing(cx)) + .await; + + // Position cursor at end of line containing `{` + editor.update_in(cx, |editor, window, cx| { + editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { + s.select_ranges([MultiBufferOffset(10)..MultiBufferOffset(10)]) // After "fn foo() {" + }); + editor.newline(&Newline, window, cx); + + // With PreserveIndent, the new line should have 0 indentation (same as the fn line) + // NOT 4 spaces (which tree-sitter would add for being inside `{}`) + assert_eq!(editor.text(cx), "fn foo() {\n\n}"); + }); +} + +#[gpui::test] +async fn test_autoindent_syntax_aware_applies_syntax_indent(cx: &mut TestAppContext) { + // Companion test to show that SyntaxAware DOES apply tree-sitter indentation + init_test(cx, |settings| { + settings.defaults.auto_indent = Some(settings::AutoIndentMode::SyntaxAware) + }); + + let language = Arc::new( + Language::new( + LanguageConfig { + brackets: BracketPairConfig { + pairs: vec![BracketPair { + start: "{".to_string(), + end: "}".to_string(), + close: false, + surround: false, + newline: false, // Disable extra newline behavior to isolate syntax indent test + }], + ..Default::default() + }, + ..Default::default() + }, + Some(tree_sitter_rust::LANGUAGE.into()), + ) + .with_indents_query(r#"(_ "{" "}" @end) @indent"#) + .unwrap(), + ); + + let buffer = + cx.new(|cx| Buffer::local("fn foo() {\n}", cx).with_language(language.clone(), cx)); + let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); + let (editor, cx) = cx.add_window_view(|window, cx| build_editor(buffer, window, cx)); + editor + .condition::(cx, |editor, cx| !editor.buffer.read(cx).is_parsing(cx)) + .await; + + // Position cursor at end of line containing `{` + editor.update_in(cx, |editor, window, cx| { + editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { + s.select_ranges([MultiBufferOffset(10)..MultiBufferOffset(10)]) // After "fn foo() {" + }); + editor.newline(&Newline, window, cx); + + // With SyntaxAware, tree-sitter adds indentation for being inside `{}` + assert_eq!(editor.text(cx), "fn foo() {\n \n}"); + }); +} + #[gpui::test] async fn test_autoindent_disabled_with_nested_language(cx: &mut TestAppContext) { init_test(cx, |settings| { - settings.defaults.auto_indent = Some(true); + settings.defaults.auto_indent = Some(settings::AutoIndentMode::SyntaxAware); settings.languages.0.insert( "python".into(), LanguageSettingsContent { - auto_indent: Some(false), + auto_indent: Some(settings::AutoIndentMode::None), ..Default::default() }, ); diff --git a/crates/feature_flags/src/flags.rs b/crates/feature_flags/src/flags.rs index eab9f8c1036a83451fc3201f97cfb1cc8c885043..77a98aae05572ac72b239db8bb3d4496bd1c0f4d 100644 --- a/crates/feature_flags/src/flags.rs +++ b/crates/feature_flags/src/flags.rs @@ -37,6 +37,16 @@ impl FeatureFlag for AgentSharingFeatureFlag { const NAME: &'static str = "agent-sharing"; } +pub struct AgentGitWorktreesFeatureFlag; + +impl FeatureFlag for AgentGitWorktreesFeatureFlag { + const NAME: &'static str = "agent-git-worktrees"; + + fn enabled_for_staff() -> bool { + false + } +} + pub struct DiffReviewFeatureFlag; impl FeatureFlag for DiffReviewFeatureFlag { @@ -59,6 +69,6 @@ impl FeatureFlag for StreamingEditFileToolFeatureFlag { const NAME: &'static str = "streaming-edit-file-tool"; fn enabled_for_staff() -> bool { - false + true } } diff --git a/crates/fs/Cargo.toml b/crates/fs/Cargo.toml index 6355524e4f328df0ca7fcf24c1df0557676ba6a6..04cae2dd2ad18f85a7c2ed663c1c3482febb22d3 100644 --- a/crates/fs/Cargo.toml +++ b/crates/fs/Cargo.toml @@ -58,4 +58,4 @@ gpui = { workspace = true, features = ["test-support"] } git = { workspace = true, features = ["test-support"] } [features] -test-support = ["gpui/test-support", "git/test-support"] +test-support = ["gpui/test-support", "git/test-support", "util/test-support"] diff --git a/crates/fs/src/fake_git_repo.rs b/crates/fs/src/fake_git_repo.rs index 12cd67cdae1a250d07468047617c8cc7a52737fa..85489b6057cd8214ee512fb477428c93cdb32219 100644 --- a/crates/fs/src/fake_git_repo.rs +++ b/crates/fs/src/fake_git_repo.rs @@ -20,7 +20,7 @@ use ignore::gitignore::GitignoreBuilder; use parking_lot::Mutex; use rope::Rope; use smol::{channel::Sender, future::FutureExt as _}; -use std::{path::PathBuf, sync::Arc}; +use std::{path::PathBuf, sync::Arc, sync::atomic::AtomicBool}; use text::LineEnding; use util::{paths::PathStyle, rel_path::RelPath}; @@ -32,6 +32,7 @@ pub struct FakeGitRepository { pub(crate) dot_git_path: PathBuf, pub(crate) repository_dir_path: PathBuf, pub(crate) common_dir_path: PathBuf, + pub(crate) is_trusted: Arc, } #[derive(Debug, Clone)] @@ -406,7 +407,31 @@ impl GitRepository for FakeGitRepository { } fn worktrees(&self) -> BoxFuture<'_, Result>> { - self.with_state_async(false, |state| Ok(state.worktrees.clone())) + let dot_git_path = self.dot_git_path.clone(); + self.with_state_async(false, move |state| { + let work_dir = dot_git_path + .parent() + .map(PathBuf::from) + .unwrap_or(dot_git_path); + let head_sha = state + .refs + .get("HEAD") + .cloned() + .unwrap_or_else(|| "0000000".to_string()); + let branch_ref = state + .current_branch_name + .as_ref() + .map(|name| format!("refs/heads/{name}")) + .unwrap_or_else(|| "refs/heads/main".to_string()); + let main_worktree = Worktree { + path: work_dir, + ref_name: branch_ref.into(), + sha: head_sha.into(), + }; + let mut all = vec![main_worktree]; + all.extend(state.worktrees.iter().cloned()); + Ok(all) + }) } fn create_worktree( @@ -770,8 +795,8 @@ impl GitRepository for FakeGitRepository { fn diff_stat( &self, - diff_type: git::repository::DiffType, - ) -> BoxFuture<'_, Result>> { + path_prefixes: &[RepoPath], + ) -> BoxFuture<'_, Result> { fn count_lines(s: &str) -> u32 { if s.is_empty() { 0 @@ -780,122 +805,95 @@ impl GitRepository for FakeGitRepository { } } - match diff_type { - git::repository::DiffType::HeadToIndex => self - .with_state_async(false, |state| { - let mut result = HashMap::default(); - let all_paths: HashSet<&RepoPath> = state - .head_contents - .keys() - .chain(state.index_contents.keys()) - .collect(); - for path in all_paths { - let head = state.head_contents.get(path); - let index = state.index_contents.get(path); - match (head, index) { - (Some(old), Some(new)) if old != new => { - result.insert( - path.clone(), - git::status::DiffStat { - added: count_lines(new), - deleted: count_lines(old), - }, - ); - } - (Some(old), None) => { - result.insert( - path.clone(), - git::status::DiffStat { - added: 0, - deleted: count_lines(old), - }, - ); - } - (None, Some(new)) => { - result.insert( - path.clone(), - git::status::DiffStat { - added: count_lines(new), - deleted: 0, - }, - ); - } - _ => {} - } - } - Ok(result) - }) - .boxed(), - git::repository::DiffType::HeadToWorktree => { - let workdir_path = self.dot_git_path.parent().unwrap().to_path_buf(); - let worktree_files: HashMap = self + fn matches_prefixes(path: &RepoPath, prefixes: &[RepoPath]) -> bool { + if prefixes.is_empty() { + return true; + } + prefixes.iter().any(|prefix| { + let prefix_str = prefix.as_unix_str(); + if prefix_str == "." { + return true; + } + path == prefix || path.starts_with(&prefix) + }) + } + + let path_prefixes = path_prefixes.to_vec(); + + let workdir_path = self.dot_git_path.parent().unwrap().to_path_buf(); + let worktree_files: HashMap = self + .fs + .files() + .iter() + .filter_map(|path| { + let repo_path = path.strip_prefix(&workdir_path).ok()?; + if repo_path.starts_with(".git") { + return None; + } + let content = self .fs - .files() - .iter() - .filter_map(|path| { - let repo_path = path.strip_prefix(&workdir_path).ok()?; - if repo_path.starts_with(".git") { - return None; - } - let content = self - .fs - .read_file_sync(path) - .ok() - .and_then(|bytes| String::from_utf8(bytes).ok())?; - let repo_path = RelPath::new(repo_path, PathStyle::local()).ok()?; - Some((RepoPath::from_rel_path(&repo_path), content)) - }) - .collect(); + .read_file_sync(path) + .ok() + .and_then(|bytes| String::from_utf8(bytes).ok())?; + let repo_path = RelPath::new(repo_path, PathStyle::local()).ok()?; + Some((RepoPath::from_rel_path(&repo_path), content)) + }) + .collect(); - self.with_state_async(false, move |state| { - let mut result = HashMap::default(); - let all_paths: HashSet<&RepoPath> = state - .head_contents + self.with_state_async(false, move |state| { + let mut entries = Vec::new(); + let all_paths: HashSet<&RepoPath> = state + .head_contents + .keys() + .chain( + worktree_files .keys() - .chain(worktree_files.keys()) - .collect(); - for path in all_paths { - let head = state.head_contents.get(path); - let worktree = worktree_files.get(path); - match (head, worktree) { - (Some(old), Some(new)) if old != new => { - result.insert( - path.clone(), - git::status::DiffStat { - added: count_lines(new), - deleted: count_lines(old), - }, - ); - } - (Some(old), None) => { - result.insert( - path.clone(), - git::status::DiffStat { - added: 0, - deleted: count_lines(old), - }, - ); - } - (None, Some(new)) => { - result.insert( - path.clone(), - git::status::DiffStat { - added: count_lines(new), - deleted: 0, - }, - ); - } - _ => {} - } + .filter(|p| state.index_contents.contains_key(*p)), + ) + .collect(); + for path in all_paths { + if !matches_prefixes(path, &path_prefixes) { + continue; + } + let head = state.head_contents.get(path); + let worktree = worktree_files.get(path); + match (head, worktree) { + (Some(old), Some(new)) if old != new => { + entries.push(( + path.clone(), + git::status::DiffStat { + added: count_lines(new), + deleted: count_lines(old), + }, + )); } - Ok(result) - }) - .boxed() - } - git::repository::DiffType::MergeBase { .. } => { - future::ready(Ok(HashMap::default())).boxed() + (Some(old), None) => { + entries.push(( + path.clone(), + git::status::DiffStat { + added: 0, + deleted: count_lines(old), + }, + )); + } + (None, Some(new)) => { + entries.push(( + path.clone(), + git::status::DiffStat { + added: count_lines(new), + deleted: 0, + }, + )); + } + _ => {} + } } - } + entries.sort_by(|(a, _), (b, _)| a.cmp(b)); + Ok(git::status::GitDiffStat { + entries: entries.into(), + }) + }) + .boxed() } fn checkpoint(&self) -> BoxFuture<'static, Result> { @@ -1011,146 +1009,13 @@ impl GitRepository for FakeGitRepository { fn commit_data_reader(&self) -> Result { anyhow::bail!("commit_data_reader not supported for FakeGitRepository") } -} -#[cfg(test)] -mod tests { - use super::*; - use crate::{FakeFs, Fs}; - use gpui::TestAppContext; - use serde_json::json; - use std::path::Path; - - #[gpui::test] - async fn test_fake_worktree_lifecycle(cx: &mut TestAppContext) { - let worktree_dir_settings = &["../worktrees", ".git/zed-worktrees", "my-worktrees/"]; - - for worktree_dir_setting in worktree_dir_settings { - let fs = FakeFs::new(cx.executor()); - fs.insert_tree("/project", json!({".git": {}, "file.txt": "content"})) - .await; - let repo = fs - .open_repo(Path::new("/project/.git"), None) - .expect("should open fake repo"); - - // Initially no worktrees - let worktrees = repo.worktrees().await.unwrap(); - assert!(worktrees.is_empty()); - - let expected_dir = git::repository::resolve_worktree_directory( - Path::new("/project"), - worktree_dir_setting, - ); - - // Create a worktree - repo.create_worktree( - "feature-branch".to_string(), - expected_dir.clone(), - Some("abc123".to_string()), - ) - .await - .unwrap(); - - // List worktrees — should have one - let worktrees = repo.worktrees().await.unwrap(); - assert_eq!(worktrees.len(), 1); - assert_eq!( - worktrees[0].path, - expected_dir.join("feature-branch"), - "failed for worktree_directory setting: {worktree_dir_setting:?}" - ); - assert_eq!(worktrees[0].ref_name.as_ref(), "refs/heads/feature-branch"); - assert_eq!(worktrees[0].sha.as_ref(), "abc123"); - - // Directory should exist in FakeFs after create - assert!( - fs.is_dir(&expected_dir.join("feature-branch")).await, - "worktree directory should be created in FakeFs for setting {worktree_dir_setting:?}" - ); - - // Create a second worktree (without explicit commit) - repo.create_worktree("bugfix-branch".to_string(), expected_dir.clone(), None) - .await - .unwrap(); - - let worktrees = repo.worktrees().await.unwrap(); - assert_eq!(worktrees.len(), 2); - assert!( - fs.is_dir(&expected_dir.join("bugfix-branch")).await, - "second worktree directory should be created in FakeFs for setting {worktree_dir_setting:?}" - ); - - // Rename the first worktree - repo.rename_worktree( - expected_dir.join("feature-branch"), - expected_dir.join("renamed-branch"), - ) - .await - .unwrap(); - - let worktrees = repo.worktrees().await.unwrap(); - assert_eq!(worktrees.len(), 2); - assert!( - worktrees - .iter() - .any(|w| w.path == expected_dir.join("renamed-branch")), - "renamed worktree should exist at new path for setting {worktree_dir_setting:?}" - ); - assert!( - worktrees - .iter() - .all(|w| w.path != expected_dir.join("feature-branch")), - "old path should no longer exist for setting {worktree_dir_setting:?}" - ); - - // Directory should be moved in FakeFs after rename - assert!( - !fs.is_dir(&expected_dir.join("feature-branch")).await, - "old worktree directory should not exist after rename for setting {worktree_dir_setting:?}" - ); - assert!( - fs.is_dir(&expected_dir.join("renamed-branch")).await, - "new worktree directory should exist after rename for setting {worktree_dir_setting:?}" - ); - - // Rename a nonexistent worktree should fail - let result = repo - .rename_worktree(PathBuf::from("/nonexistent"), PathBuf::from("/somewhere")) - .await; - assert!(result.is_err()); - - // Remove a worktree - repo.remove_worktree(expected_dir.join("renamed-branch"), false) - .await - .unwrap(); - - let worktrees = repo.worktrees().await.unwrap(); - assert_eq!(worktrees.len(), 1); - assert_eq!(worktrees[0].path, expected_dir.join("bugfix-branch")); - - // Directory should be removed from FakeFs after remove - assert!( - !fs.is_dir(&expected_dir.join("renamed-branch")).await, - "worktree directory should be removed from FakeFs for setting {worktree_dir_setting:?}" - ); - - // Remove a nonexistent worktree should fail - let result = repo - .remove_worktree(PathBuf::from("/nonexistent"), false) - .await; - assert!(result.is_err()); - - // Remove the last worktree - repo.remove_worktree(expected_dir.join("bugfix-branch"), false) - .await - .unwrap(); - - let worktrees = repo.worktrees().await.unwrap(); - assert!(worktrees.is_empty()); - assert!( - !fs.is_dir(&expected_dir.join("bugfix-branch")).await, - "last worktree directory should be removed from FakeFs for setting {worktree_dir_setting:?}" - ); - } + fn set_trusted(&self, trusted: bool) { + self.is_trusted + .store(trusted, std::sync::atomic::Ordering::Release); + } + + fn is_trusted(&self) -> bool { + self.is_trusted.load(std::sync::atomic::Ordering::Acquire) } } diff --git a/crates/fs/src/fs.rs b/crates/fs/src/fs.rs index 2db9e48a2e77bdb3e49fce0b16ea9b67ffaacbc0..0fde444171042eda859edcac7915c456ab91e265 100644 --- a/crates/fs/src/fs.rs +++ b/crates/fs/src/fs.rs @@ -2776,6 +2776,7 @@ impl Fs for FakeFs { repository_dir_path: repository_dir_path.to_owned(), common_dir_path: common_dir_path.to_owned(), checkpoints: Arc::default(), + is_trusted: Arc::default(), }) as _ }, ) diff --git a/crates/fs/tests/integration/fake_git_repo.rs b/crates/fs/tests/integration/fake_git_repo.rs index 36dfcaf168b4f0190c5c49bf4798fac7bc9bd37b..bae7f2fc94dd5161793f85f64cc0a1448a187134 100644 --- a/crates/fs/tests/integration/fake_git_repo.rs +++ b/crates/fs/tests/integration/fake_git_repo.rs @@ -1,9 +1,146 @@ use fs::{FakeFs, Fs}; -use gpui::BackgroundExecutor; +use gpui::{BackgroundExecutor, TestAppContext}; use serde_json::json; -use std::path::Path; +use std::path::{Path, PathBuf}; use util::path; +#[gpui::test] +async fn test_fake_worktree_lifecycle(cx: &mut TestAppContext) { + let worktree_dir_settings = &["../worktrees", ".git/zed-worktrees", "my-worktrees/"]; + + for worktree_dir_setting in worktree_dir_settings { + let fs = FakeFs::new(cx.executor()); + fs.insert_tree("/project", json!({".git": {}, "file.txt": "content"})) + .await; + let repo = fs + .open_repo(Path::new("/project/.git"), None) + .expect("should open fake repo"); + + // Initially only the main worktree exists + let worktrees = repo.worktrees().await.unwrap(); + assert_eq!(worktrees.len(), 1); + assert_eq!(worktrees[0].path, PathBuf::from("/project")); + + let expected_dir = git::repository::resolve_worktree_directory( + Path::new("/project"), + worktree_dir_setting, + ); + + // Create a worktree + repo.create_worktree( + "feature-branch".to_string(), + expected_dir.clone(), + Some("abc123".to_string()), + ) + .await + .unwrap(); + + // List worktrees — should have main + one created + let worktrees = repo.worktrees().await.unwrap(); + assert_eq!(worktrees.len(), 2); + assert_eq!(worktrees[0].path, PathBuf::from("/project")); + assert_eq!( + worktrees[1].path, + expected_dir.join("feature-branch"), + "failed for worktree_directory setting: {worktree_dir_setting:?}" + ); + assert_eq!(worktrees[1].ref_name.as_ref(), "refs/heads/feature-branch"); + assert_eq!(worktrees[1].sha.as_ref(), "abc123"); + + // Directory should exist in FakeFs after create + assert!( + fs.is_dir(&expected_dir.join("feature-branch")).await, + "worktree directory should be created in FakeFs for setting {worktree_dir_setting:?}" + ); + + // Create a second worktree (without explicit commit) + repo.create_worktree("bugfix-branch".to_string(), expected_dir.clone(), None) + .await + .unwrap(); + + let worktrees = repo.worktrees().await.unwrap(); + assert_eq!(worktrees.len(), 3); + assert!( + fs.is_dir(&expected_dir.join("bugfix-branch")).await, + "second worktree directory should be created in FakeFs for setting {worktree_dir_setting:?}" + ); + + // Rename the first worktree + repo.rename_worktree( + expected_dir.join("feature-branch"), + expected_dir.join("renamed-branch"), + ) + .await + .unwrap(); + + let worktrees = repo.worktrees().await.unwrap(); + assert_eq!(worktrees.len(), 3); + assert!( + worktrees + .iter() + .any(|w| w.path == expected_dir.join("renamed-branch")), + "renamed worktree should exist at new path for setting {worktree_dir_setting:?}" + ); + assert!( + worktrees + .iter() + .all(|w| w.path != expected_dir.join("feature-branch")), + "old path should no longer exist for setting {worktree_dir_setting:?}" + ); + + // Directory should be moved in FakeFs after rename + assert!( + !fs.is_dir(&expected_dir.join("feature-branch")).await, + "old worktree directory should not exist after rename for setting {worktree_dir_setting:?}" + ); + assert!( + fs.is_dir(&expected_dir.join("renamed-branch")).await, + "new worktree directory should exist after rename for setting {worktree_dir_setting:?}" + ); + + // Rename a nonexistent worktree should fail + let result = repo + .rename_worktree(PathBuf::from("/nonexistent"), PathBuf::from("/somewhere")) + .await; + assert!(result.is_err()); + + // Remove a worktree + repo.remove_worktree(expected_dir.join("renamed-branch"), false) + .await + .unwrap(); + + let worktrees = repo.worktrees().await.unwrap(); + assert_eq!(worktrees.len(), 2); + assert_eq!(worktrees[0].path, PathBuf::from("/project")); + assert_eq!(worktrees[1].path, expected_dir.join("bugfix-branch")); + + // Directory should be removed from FakeFs after remove + assert!( + !fs.is_dir(&expected_dir.join("renamed-branch")).await, + "worktree directory should be removed from FakeFs for setting {worktree_dir_setting:?}" + ); + + // Remove a nonexistent worktree should fail + let result = repo + .remove_worktree(PathBuf::from("/nonexistent"), false) + .await; + assert!(result.is_err()); + + // Remove the last worktree + repo.remove_worktree(expected_dir.join("bugfix-branch"), false) + .await + .unwrap(); + + let worktrees = repo.worktrees().await.unwrap(); + assert_eq!(worktrees.len(), 1); + assert_eq!(worktrees[0].path, PathBuf::from("/project")); + assert!( + !fs.is_dir(&expected_dir.join("bugfix-branch")).await, + "last worktree directory should be removed from FakeFs for setting {worktree_dir_setting:?}" + ); + } +} + #[gpui::test] async fn test_checkpoints(executor: BackgroundExecutor) { let fs = FakeFs::new(executor); diff --git a/crates/git/clippy.toml b/crates/git/clippy.toml new file mode 100644 index 0000000000000000000000000000000000000000..fb3926840493fd5981c1861e7cea96bd54b9647f --- /dev/null +++ b/crates/git/clippy.toml @@ -0,0 +1,28 @@ +allow-private-module-inception = true +avoid-breaking-exported-api = false +ignore-interior-mutability = [ + # Suppresses clippy::mutable_key_type, which is a false positive as the Eq + # and Hash impls do not use fields with interior mutability. + "agent_ui::context::AgentContextKey" +] +disallowed-methods = [ + { path = "std::process::Command::spawn", reason = "Spawning `std::process::Command` can block the current thread for an unknown duration", replacement = "smol::process::Command::spawn" }, + { path = "std::process::Command::output", reason = "Spawning `std::process::Command` can block the current thread for an unknown duration", replacement = "smol::process::Command::output" }, + { path = "std::process::Command::status", reason = "Spawning `std::process::Command` can block the current thread for an unknown duration", replacement = "smol::process::Command::status" }, + { path = "std::process::Command::stdin", reason = "`smol::process::Command::from()` does not preserve stdio configuration", replacement = "smol::process::Command::stdin" }, + { path = "std::process::Command::stdout", reason = "`smol::process::Command::from()` does not preserve stdio configuration", replacement = "smol::process::Command::stdout" }, + { path = "std::process::Command::stderr", reason = "`smol::process::Command::from()` does not preserve stdio configuration", replacement = "smol::process::Command::stderr" }, + { path = "smol::Timer::after", reason = "smol::Timer introduces non-determinism in tests", replacement = "gpui::BackgroundExecutor::timer" }, + { path = "serde_json::from_reader", reason = "Parsing from a buffer is much slower than first reading the buffer into a Vec/String, see https://github.com/serde-rs/json/issues/160#issuecomment-253446892. Use `serde_json::from_slice` instead." }, + { path = "serde_json_lenient::from_reader", reason = "Parsing from a buffer is much slower than first reading the buffer into a Vec/String, see https://github.com/serde-rs/json/issues/160#issuecomment-253446892, Use `serde_json_lenient::from_slice` instead." }, + { path = "cocoa::foundation::NSString::alloc", reason = "NSString must be autoreleased to avoid memory leaks. Use `ns_string()` helper instead." }, + { path = "smol::process::Command::new", reason = "Git commands must go through `GitBinary::build_command` to ensure security flags like `-c core.fsmonitor=false` are always applied.", replacement = "GitBinary::build_command" }, + { path = "util::command::new_command", reason = "Git commands must go through `GitBinary::build_command` to ensure security flags like `-c core.fsmonitor=false` are always applied.", replacement = "GitBinary::build_command" }, + { path = "util::command::Command::new", reason = "Git commands must go through `GitBinary::build_command` to ensure security flags like `-c core.fsmonitor=false` are always applied.", replacement = "GitBinary::build_command" }, +] +disallowed-types = [ + # { path = "std::collections::HashMap", replacement = "collections::HashMap" }, + # { path = "std::collections::HashSet", replacement = "collections::HashSet" }, + # { path = "indexmap::IndexSet", replacement = "collections::IndexSet" }, + # { path = "indexmap::IndexMap", replacement = "collections::IndexMap" }, +] \ No newline at end of file diff --git a/crates/git/src/blame.rs b/crates/git/src/blame.rs index 9dc184bf2ac253c8bc24f6203f13d6654ac2b64b..c44aea74051bb7c190a091703d6c60807fc4e27e 100644 --- a/crates/git/src/blame.rs +++ b/crates/git/src/blame.rs @@ -1,11 +1,11 @@ use crate::Oid; use crate::commit::get_messages; -use crate::repository::RepoPath; +use crate::repository::{GitBinary, RepoPath}; use anyhow::{Context as _, Result}; use collections::{HashMap, HashSet}; use futures::AsyncWriteExt; use serde::{Deserialize, Serialize}; -use std::{ops::Range, path::Path}; +use std::ops::Range; use text::{LineEnding, Rope}; use time::OffsetDateTime; use time::UtcOffset; @@ -21,15 +21,13 @@ pub struct Blame { } impl Blame { - pub async fn for_path( - git_binary: &Path, - working_directory: &Path, + pub(crate) async fn for_path( + git: &GitBinary, path: &RepoPath, content: &Rope, line_ending: LineEnding, ) -> Result { - let output = - run_git_blame(git_binary, working_directory, path, content, line_ending).await?; + let output = run_git_blame(git, path, content, line_ending).await?; let mut entries = parse_git_blame(&output)?; entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start)); @@ -40,7 +38,7 @@ impl Blame { } let shas = unique_shas.into_iter().collect::>(); - let messages = get_messages(working_directory, &shas) + let messages = get_messages(git, &shas) .await .context("failed to get commit messages")?; @@ -52,8 +50,7 @@ const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD"; const GIT_BLAME_NO_PATH: &str = "fatal: no such path"; async fn run_git_blame( - git_binary: &Path, - working_directory: &Path, + git: &GitBinary, path: &RepoPath, contents: &Rope, line_ending: LineEnding, @@ -61,12 +58,7 @@ async fn run_git_blame( let mut child = { let span = ztracing::debug_span!("spawning git-blame command", path = path.as_unix_str()); let _enter = span.enter(); - util::command::new_command(git_binary) - .current_dir(working_directory) - .arg("blame") - .arg("--incremental") - .arg("--contents") - .arg("-") + git.build_command(["blame", "--incremental", "--contents", "-"]) .arg(path.as_unix_str()) .stdin(Stdio::piped()) .stdout(Stdio::piped()) diff --git a/crates/git/src/commit.rs b/crates/git/src/commit.rs index 3f3526afc4ba8fa146592684a6d3acfc44ce7e73..46e050ce155fc049a670fdfa26101eb729b34352 100644 --- a/crates/git/src/commit.rs +++ b/crates/git/src/commit.rs @@ -1,11 +1,11 @@ use crate::{ BuildCommitPermalinkParams, GitHostingProviderRegistry, GitRemote, Oid, parse_git_remote_url, - status::StatusCode, + repository::GitBinary, status::StatusCode, }; use anyhow::{Context as _, Result}; use collections::HashMap; use gpui::SharedString; -use std::{path::Path, sync::Arc}; +use std::sync::Arc; #[derive(Clone, Debug, Default)] pub struct ParsedCommitMessage { @@ -48,7 +48,7 @@ impl ParsedCommitMessage { } } -pub async fn get_messages(working_directory: &Path, shas: &[Oid]) -> Result> { +pub(crate) async fn get_messages(git: &GitBinary, shas: &[Oid]) -> Result> { if shas.is_empty() { return Ok(HashMap::default()); } @@ -63,12 +63,12 @@ pub async fn get_messages(working_directory: &Path, shas: &[Oid]) -> Result Result>()) } -async fn get_messages_impl(working_directory: &Path, shas: &[Oid]) -> Result> { +async fn get_messages_impl(git: &GitBinary, shas: &[Oid]) -> Result> { const MARKER: &str = ""; - let output = util::command::new_command("git") - .current_dir(working_directory) - .arg("show") + let output = git + .build_command(["show"]) .arg("-s") .arg(format!("--format=%B{}", MARKER)) .args(shas.iter().map(ToString::to_string)) diff --git a/crates/git/src/repository.rs b/crates/git/src/repository.rs index ba77199d75f624c0dd44ad0b2ba4eec812d9a711..45e719fb6d5a586074de523b5974ee11bf225453 100644 --- a/crates/git/src/repository.rs +++ b/crates/git/src/repository.rs @@ -21,6 +21,7 @@ use text::LineEnding; use std::collections::HashSet; use std::ffi::{OsStr, OsString}; +use std::sync::atomic::AtomicBool; use std::process::ExitStatus; use std::str::FromStr; @@ -55,6 +56,26 @@ pub const GRAPH_CHUNK_SIZE: usize = 1000; /// Default value for the `git.worktree_directory` setting. pub const DEFAULT_WORKTREE_DIRECTORY: &str = "../worktrees"; +/// Given the git common directory (from `commondir()`), derive the original +/// repository's working directory. +/// +/// For a standard checkout, `common_dir` is `/.git`, so the parent +/// is the working directory. For a git worktree, `common_dir` is the **main** +/// repo's `.git` directory, so the parent is the original repo's working directory. +/// +/// Falls back to returning `common_dir` itself if it doesn't end with `.git` +/// (e.g. bare repos or unusual layouts). +pub fn original_repo_path_from_common_dir(common_dir: &Path) -> PathBuf { + if common_dir.file_name() == Some(OsStr::new(".git")) { + common_dir + .parent() + .map(|p| p.to_path_buf()) + .unwrap_or_else(|| common_dir.to_path_buf()) + } else { + common_dir.to_path_buf() + } +} + /// Resolves the configured worktree directory to an absolute path. /// /// `worktree_directory_setting` is the raw string from the user setting @@ -283,6 +304,7 @@ impl Branch { pub struct Worktree { pub path: PathBuf, pub ref_name: SharedString, + // todo(git_worktree) This type should be a Oid pub sha: SharedString, } @@ -320,6 +342,8 @@ pub fn parse_worktrees_from_str>(raw_worktrees: T) -> Vec BoxFuture<'_, Result>>; + path_prefixes: &[RepoPath], + ) -> BoxFuture<'_, Result>; /// Creates a checkpoint for the repository. fn checkpoint(&self) -> BoxFuture<'static, Result>; @@ -938,6 +962,9 @@ pub trait GitRepository: Send + Sync { ) -> BoxFuture<'_, Result<()>>; fn commit_data_reader(&self) -> Result; + + fn set_trusted(&self, trusted: bool); + fn is_trusted(&self) -> bool; } pub enum DiffType { @@ -964,6 +991,7 @@ pub struct RealGitRepository { pub any_git_binary_path: PathBuf, any_git_binary_help_output: Arc>>, executor: BackgroundExecutor, + is_trusted: Arc, } impl RealGitRepository { @@ -982,6 +1010,7 @@ impl RealGitRepository { any_git_binary_path, executor, any_git_binary_help_output: Arc::new(Mutex::new(None)), + is_trusted: Arc::new(AtomicBool::new(false)), }) } @@ -993,20 +1022,24 @@ impl RealGitRepository { .map(Path::to_path_buf) } + fn git_binary(&self) -> Result { + Ok(GitBinary::new( + self.any_git_binary_path.clone(), + self.working_directory() + .with_context(|| "Can't run git commands without a working directory")?, + self.executor.clone(), + self.is_trusted(), + )) + } + async fn any_git_binary_help_output(&self) -> SharedString { if let Some(output) = self.any_git_binary_help_output.lock().clone() { return output; } - let git_binary_path = self.any_git_binary_path.clone(); - let executor = self.executor.clone(); - let working_directory = self.working_directory(); + let git_binary = self.git_binary(); let output: SharedString = self .executor - .spawn(async move { - GitBinary::new(git_binary_path, working_directory?, executor) - .run(["help", "-a"]) - .await - }) + .spawn(async move { git_binary?.run(["help", "-a"]).await }) .await .unwrap_or_default() .into(); @@ -1049,6 +1082,7 @@ pub async fn get_git_committer(cx: &AsyncApp) -> GitCommitter { git_binary_path.unwrap_or(PathBuf::from("git")), paths::home_dir().clone(), cx.background_executor().clone(), + true, ); cx.background_spawn(async move { @@ -1080,14 +1114,12 @@ impl GitRepository for RealGitRepository { } fn show(&self, commit: String) -> BoxFuture<'_, Result> { - let git_binary_path = self.any_git_binary_path.clone(); - let working_directory = self.working_directory(); + let git_binary = self.git_binary(); self.executor .spawn(async move { - let working_directory = working_directory?; - let output = new_command(git_binary_path) - .current_dir(&working_directory) - .args([ + let git = git_binary?; + let output = git + .build_command([ "--no-optional-locks", "show", "--no-patch", @@ -1118,15 +1150,14 @@ impl GitRepository for RealGitRepository { } fn load_commit(&self, commit: String, cx: AsyncApp) -> BoxFuture<'_, Result> { - let Some(working_directory) = self.repository.lock().workdir().map(ToOwned::to_owned) - else { + if self.repository.lock().workdir().is_none() { return future::ready(Err(anyhow!("no working directory"))).boxed(); - }; - let git_binary_path = self.any_git_binary_path.clone(); + } + let git_binary = self.git_binary(); cx.background_spawn(async move { - let show_output = util::command::new_command(&git_binary_path) - .current_dir(&working_directory) - .args([ + let git = git_binary?; + let show_output = git + .build_command([ "--no-optional-locks", "show", "--format=", @@ -1147,9 +1178,8 @@ impl GitRepository for RealGitRepository { let changes = parse_git_diff_name_status(&show_stdout); let parent_sha = format!("{}^", commit); - let mut cat_file_process = util::command::new_command(&git_binary_path) - .current_dir(&working_directory) - .args(["--no-optional-locks", "cat-file", "--batch=%(objectsize)"]) + let mut cat_file_process = git + .build_command(["--no-optional-locks", "cat-file", "--batch=%(objectsize)"]) .stdin(Stdio::piped()) .stdout(Stdio::piped()) .stderr(Stdio::piped()) @@ -1256,18 +1286,17 @@ impl GitRepository for RealGitRepository { mode: ResetMode, env: Arc>, ) -> BoxFuture<'_, Result<()>> { + let git_binary = self.git_binary(); async move { - let working_directory = self.working_directory(); - let mode_flag = match mode { ResetMode::Mixed => "--mixed", ResetMode::Soft => "--soft", }; - let output = new_command(&self.any_git_binary_path) + let git = git_binary?; + let output = git + .build_command(["reset", mode_flag, &commit]) .envs(env.iter()) - .current_dir(&working_directory?) - .args(["reset", mode_flag, &commit]) .output() .await?; anyhow::ensure!( @@ -1286,17 +1315,16 @@ impl GitRepository for RealGitRepository { paths: Vec, env: Arc>, ) -> BoxFuture<'_, Result<()>> { - let working_directory = self.working_directory(); - let git_binary_path = self.any_git_binary_path.clone(); + let git_binary = self.git_binary(); async move { if paths.is_empty() { return Ok(()); } - let output = new_command(&git_binary_path) - .current_dir(&working_directory?) + let git = git_binary?; + let output = git + .build_command(["checkout", &commit, "--"]) .envs(env.iter()) - .args(["checkout", &commit, "--"]) .args(paths.iter().map(|path| path.as_unix_str())) .output() .await?; @@ -1391,18 +1419,16 @@ impl GitRepository for RealGitRepository { env: Arc>, is_executable: bool, ) -> BoxFuture<'_, anyhow::Result<()>> { - let working_directory = self.working_directory(); - let git_binary_path = self.any_git_binary_path.clone(); + let git_binary = self.git_binary(); self.executor .spawn(async move { - let working_directory = working_directory?; + let git = git_binary?; let mode = if is_executable { "100755" } else { "100644" }; if let Some(content) = content { - let mut child = new_command(&git_binary_path) - .current_dir(&working_directory) + let mut child = git + .build_command(["hash-object", "-w", "--stdin"]) .envs(env.iter()) - .args(["hash-object", "-w", "--stdin"]) .stdin(Stdio::piped()) .stdout(Stdio::piped()) .spawn()?; @@ -1415,10 +1441,9 @@ impl GitRepository for RealGitRepository { log::debug!("indexing SHA: {sha}, path {path:?}"); - let output = new_command(&git_binary_path) - .current_dir(&working_directory) + let output = git + .build_command(["update-index", "--add", "--cacheinfo", mode, sha]) .envs(env.iter()) - .args(["update-index", "--add", "--cacheinfo", mode, sha]) .arg(path.as_unix_str()) .output() .await?; @@ -1430,10 +1455,9 @@ impl GitRepository for RealGitRepository { ); } else { log::debug!("removing path {path:?} from the index"); - let output = new_command(&git_binary_path) - .current_dir(&working_directory) + let output = git + .build_command(["update-index", "--force-remove"]) .envs(env.iter()) - .args(["update-index", "--force-remove"]) .arg(path.as_unix_str()) .output() .await?; @@ -1462,14 +1486,12 @@ impl GitRepository for RealGitRepository { } fn revparse_batch(&self, revs: Vec) -> BoxFuture<'_, Result>>> { - let working_directory = self.working_directory(); - let git_binary_path = self.any_git_binary_path.clone(); + let git_binary = self.git_binary(); self.executor .spawn(async move { - let working_directory = working_directory?; - let mut process = new_command(&git_binary_path) - .current_dir(&working_directory) - .args([ + let git = git_binary?; + let mut process = git + .build_command([ "--no-optional-locks", "cat-file", "--batch-check=%(objectname)", @@ -1522,19 +1544,14 @@ impl GitRepository for RealGitRepository { } fn status(&self, path_prefixes: &[RepoPath]) -> Task> { - let git_binary_path = self.any_git_binary_path.clone(); - let working_directory = match self.working_directory() { - Ok(working_directory) => working_directory, + let git = match self.git_binary() { + Ok(git) => git, Err(e) => return Task::ready(Err(e)), }; let args = git_status_args(path_prefixes); log::debug!("Checking for git status in {path_prefixes:?}"); self.executor.spawn(async move { - let output = new_command(&git_binary_path) - .current_dir(working_directory) - .args(args) - .output() - .await?; + let output = git.build_command(args).output().await?; if output.status.success() { let stdout = String::from_utf8_lossy(&output.stdout); stdout.parse() @@ -1546,9 +1563,8 @@ impl GitRepository for RealGitRepository { } fn diff_tree(&self, request: DiffTreeType) -> BoxFuture<'_, Result> { - let git_binary_path = self.any_git_binary_path.clone(); - let working_directory = match self.working_directory() { - Ok(working_directory) => working_directory, + let git = match self.git_binary() { + Ok(git) => git, Err(e) => return Task::ready(Err(e)).boxed(), }; @@ -1573,11 +1589,7 @@ impl GitRepository for RealGitRepository { self.executor .spawn(async move { - let output = new_command(&git_binary_path) - .current_dir(working_directory) - .args(args) - .output() - .await?; + let output = git.build_command(args).output().await?; if output.status.success() { let stdout = String::from_utf8_lossy(&output.stdout); stdout.parse() @@ -1590,13 +1602,12 @@ impl GitRepository for RealGitRepository { } fn stash_entries(&self) -> BoxFuture<'_, Result> { - let git_binary_path = self.any_git_binary_path.clone(); - let working_directory = self.working_directory(); + let git_binary = self.git_binary(); self.executor .spawn(async move { - let output = new_command(&git_binary_path) - .current_dir(working_directory?) - .args(&["stash", "list", "--pretty=format:%gd%x00%H%x00%ct%x00%s"]) + let git = git_binary?; + let output = git + .build_command(&["stash", "list", "--pretty=format:%gd%x00%H%x00%ct%x00%s"]) .output() .await?; if output.status.success() { @@ -1611,8 +1622,7 @@ impl GitRepository for RealGitRepository { } fn branches(&self) -> BoxFuture<'_, Result>> { - let working_directory = self.working_directory(); - let git_binary_path = self.any_git_binary_path.clone(); + let git_binary = self.git_binary(); self.executor .spawn(async move { let fields = [ @@ -1634,12 +1644,8 @@ impl GitRepository for RealGitRepository { "--format", &fields, ]; - let working_directory = working_directory?; - let output = new_command(&git_binary_path) - .current_dir(&working_directory) - .args(args) - .output() - .await?; + let git = git_binary?; + let output = git.build_command(args).output().await?; anyhow::ensure!( output.status.success(), @@ -1653,11 +1659,7 @@ impl GitRepository for RealGitRepository { if branches.is_empty() { let args = vec!["symbolic-ref", "--quiet", "HEAD"]; - let output = new_command(&git_binary_path) - .current_dir(&working_directory) - .args(args) - .output() - .await?; + let output = git.build_command(args).output().await?; // git symbolic-ref returns a non-0 exit code if HEAD points // to something other than a branch @@ -1679,13 +1681,12 @@ impl GitRepository for RealGitRepository { } fn worktrees(&self) -> BoxFuture<'_, Result>> { - let git_binary_path = self.any_git_binary_path.clone(); - let working_directory = self.working_directory(); + let git_binary = self.git_binary(); self.executor .spawn(async move { - let output = new_command(&git_binary_path) - .current_dir(working_directory?) - .args(&["--no-optional-locks", "worktree", "list", "--porcelain"]) + let git = git_binary?; + let output = git + .build_command(&["--no-optional-locks", "worktree", "list", "--porcelain"]) .output() .await?; if output.status.success() { @@ -1705,8 +1706,7 @@ impl GitRepository for RealGitRepository { directory: PathBuf, from_commit: Option, ) -> BoxFuture<'_, Result<()>> { - let git_binary_path = self.any_git_binary_path.clone(); - let working_directory = self.working_directory(); + let git_binary = self.git_binary(); let final_path = directory.join(&name); let mut args = vec![ OsString::from("--no-optional-locks"), @@ -1726,11 +1726,8 @@ impl GitRepository for RealGitRepository { self.executor .spawn(async move { std::fs::create_dir_all(final_path.parent().unwrap_or(&final_path))?; - let output = new_command(&git_binary_path) - .current_dir(working_directory?) - .args(args) - .output() - .await?; + let git = git_binary?; + let output = git.build_command(args).output().await?; if output.status.success() { Ok(()) } else { @@ -1742,9 +1739,7 @@ impl GitRepository for RealGitRepository { } fn remove_worktree(&self, path: PathBuf, force: bool) -> BoxFuture<'_, Result<()>> { - let git_binary_path = self.any_git_binary_path.clone(); - let working_directory = self.working_directory(); - let executor = self.executor.clone(); + let git_binary = self.git_binary(); self.executor .spawn(async move { @@ -1758,18 +1753,14 @@ impl GitRepository for RealGitRepository { } args.push("--".into()); args.push(path.as_os_str().into()); - GitBinary::new(git_binary_path, working_directory?, executor) - .run(args) - .await?; + git_binary?.run(args).await?; anyhow::Ok(()) }) .boxed() } fn rename_worktree(&self, old_path: PathBuf, new_path: PathBuf) -> BoxFuture<'_, Result<()>> { - let git_binary_path = self.any_git_binary_path.clone(); - let working_directory = self.working_directory(); - let executor = self.executor.clone(); + let git_binary = self.git_binary(); self.executor .spawn(async move { @@ -1781,9 +1772,7 @@ impl GitRepository for RealGitRepository { old_path.as_os_str().into(), new_path.as_os_str().into(), ]; - GitBinary::new(git_binary_path, working_directory?, executor) - .run(args) - .await?; + git_binary?.run(args).await?; anyhow::Ok(()) }) .boxed() @@ -1791,9 +1780,7 @@ impl GitRepository for RealGitRepository { fn change_branch(&self, name: String) -> BoxFuture<'_, Result<()>> { let repo = self.repository.clone(); - let working_directory = self.working_directory(); - let git_binary_path = self.any_git_binary_path.clone(); - let executor = self.executor.clone(); + let git_binary = self.git_binary(); let branch = self.executor.spawn(async move { let repo = repo.lock(); let branch = if let Ok(branch) = repo.find_branch(&name, BranchType::Local) { @@ -1828,9 +1815,7 @@ impl GitRepository for RealGitRepository { self.executor .spawn(async move { let branch = branch.await?; - GitBinary::new(git_binary_path, working_directory?, executor) - .run(&["checkout", &branch]) - .await?; + git_binary?.run(&["checkout", &branch]).await?; anyhow::Ok(()) }) .boxed() @@ -1841,9 +1826,7 @@ impl GitRepository for RealGitRepository { name: String, base_branch: Option, ) -> BoxFuture<'_, Result<()>> { - let git_binary_path = self.any_git_binary_path.clone(); - let working_directory = self.working_directory(); - let executor = self.executor.clone(); + let git_binary = self.git_binary(); self.executor .spawn(async move { @@ -1854,22 +1837,18 @@ impl GitRepository for RealGitRepository { args.push(&base_branch_str); } - GitBinary::new(git_binary_path, working_directory?, executor) - .run(&args) - .await?; + git_binary?.run(&args).await?; anyhow::Ok(()) }) .boxed() } fn rename_branch(&self, branch: String, new_name: String) -> BoxFuture<'_, Result<()>> { - let git_binary_path = self.any_git_binary_path.clone(); - let working_directory = self.working_directory(); - let executor = self.executor.clone(); + let git_binary = self.git_binary(); self.executor .spawn(async move { - GitBinary::new(git_binary_path, working_directory?, executor) + git_binary? .run(&["branch", "-m", &branch, &new_name]) .await?; anyhow::Ok(()) @@ -1878,15 +1857,11 @@ impl GitRepository for RealGitRepository { } fn delete_branch(&self, name: String) -> BoxFuture<'_, Result<()>> { - let git_binary_path = self.any_git_binary_path.clone(); - let working_directory = self.working_directory(); - let executor = self.executor.clone(); + let git_binary = self.git_binary(); self.executor .spawn(async move { - GitBinary::new(git_binary_path, working_directory?, executor) - .run(&["branch", "-d", &name]) - .await?; + git_binary?.run(&["branch", "-d", &name]).await?; anyhow::Ok(()) }) .boxed() @@ -1898,20 +1873,11 @@ impl GitRepository for RealGitRepository { content: Rope, line_ending: LineEnding, ) -> BoxFuture<'_, Result> { - let working_directory = self.working_directory(); - let git_binary_path = self.any_git_binary_path.clone(); - let executor = self.executor.clone(); + let git = self.git_binary(); - executor + self.executor .spawn(async move { - crate::blame::Blame::for_path( - &git_binary_path, - &working_directory?, - &path, - &content, - line_ending, - ) - .await + crate::blame::Blame::for_path(&git?, &path, &content, line_ending).await }) .boxed() } @@ -1926,11 +1892,10 @@ impl GitRepository for RealGitRepository { skip: usize, limit: Option, ) -> BoxFuture<'_, Result> { - let working_directory = self.working_directory(); - let git_binary_path = self.any_git_binary_path.clone(); + let git_binary = self.git_binary(); self.executor .spawn(async move { - let working_directory = working_directory?; + let git = git_binary?; // Use a unique delimiter with a hardcoded UUID to separate commits // This essentially eliminates any chance of encountering the delimiter in actual commit data let commit_delimiter = @@ -1958,9 +1923,8 @@ impl GitRepository for RealGitRepository { args.push("--"); - let output = new_command(&git_binary_path) - .current_dir(&working_directory) - .args(&args) + let output = git + .build_command(&args) .arg(path.as_unix_str()) .output() .await?; @@ -2005,30 +1969,17 @@ impl GitRepository for RealGitRepository { } fn diff(&self, diff: DiffType) -> BoxFuture<'_, Result> { - let working_directory = self.working_directory(); - let git_binary_path = self.any_git_binary_path.clone(); + let git_binary = self.git_binary(); self.executor .spawn(async move { - let working_directory = working_directory?; + let git = git_binary?; let output = match diff { DiffType::HeadToIndex => { - new_command(&git_binary_path) - .current_dir(&working_directory) - .args(["diff", "--staged"]) - .output() - .await? - } - DiffType::HeadToWorktree => { - new_command(&git_binary_path) - .current_dir(&working_directory) - .args(["diff"]) - .output() - .await? + git.build_command(["diff", "--staged"]).output().await? } + DiffType::HeadToWorktree => git.build_command(["diff"]).output().await?, DiffType::MergeBase { base_ref } => { - new_command(&git_binary_path) - .current_dir(&working_directory) - .args(["diff", "--merge-base", base_ref.as_ref()]) + git.build_command(["diff", "--merge-base", base_ref.as_ref()]) .output() .await? } @@ -2046,51 +1997,30 @@ impl GitRepository for RealGitRepository { fn diff_stat( &self, - diff: DiffType, - ) -> BoxFuture<'_, Result>> { - let working_directory = self.working_directory(); - let git_binary_path = self.any_git_binary_path.clone(); + path_prefixes: &[RepoPath], + ) -> BoxFuture<'_, Result> { + let path_prefixes = path_prefixes.to_vec(); + let git_binary = self.git_binary(); + self.executor .spawn(async move { - let working_directory = working_directory?; - let output = match diff { - DiffType::HeadToIndex => { - new_command(&git_binary_path) - .current_dir(&working_directory) - .args(["diff", "--numstat", "--staged"]) - .output() - .await? - } - DiffType::HeadToWorktree => { - new_command(&git_binary_path) - .current_dir(&working_directory) - .args(["diff", "--numstat"]) - .output() - .await? - } - DiffType::MergeBase { base_ref } => { - new_command(&git_binary_path) - .current_dir(&working_directory) - .args([ - "diff", - "--numstat", - "--merge-base", - base_ref.as_ref(), - "HEAD", - ]) - .output() - .await? - } - }; - - anyhow::ensure!( - output.status.success(), - "Failed to run git diff --numstat:\n{}", - String::from_utf8_lossy(&output.stderr) - ); - Ok(crate::status::parse_numstat(&String::from_utf8_lossy( - &output.stdout, - ))) + let git_binary = git_binary?; + let mut args: Vec = vec![ + "diff".into(), + "--numstat".into(), + "--no-renames".into(), + "HEAD".into(), + ]; + if !path_prefixes.is_empty() { + args.push("--".into()); + args.extend( + path_prefixes + .iter() + .map(|p| p.as_std_path().to_string_lossy().into_owned()), + ); + } + let output = git_binary.run(&args).await?; + Ok(crate::status::parse_numstat(&output)) }) .boxed() } @@ -2100,15 +2030,14 @@ impl GitRepository for RealGitRepository { paths: Vec, env: Arc>, ) -> BoxFuture<'_, Result<()>> { - let working_directory = self.working_directory(); - let git_binary_path = self.any_git_binary_path.clone(); + let git_binary = self.git_binary(); self.executor .spawn(async move { if !paths.is_empty() { - let output = new_command(&git_binary_path) - .current_dir(&working_directory?) + let git = git_binary?; + let output = git + .build_command(["update-index", "--add", "--remove", "--"]) .envs(env.iter()) - .args(["update-index", "--add", "--remove", "--"]) .args(paths.iter().map(|p| p.as_unix_str())) .output() .await?; @@ -2128,16 +2057,15 @@ impl GitRepository for RealGitRepository { paths: Vec, env: Arc>, ) -> BoxFuture<'_, Result<()>> { - let working_directory = self.working_directory(); - let git_binary_path = self.any_git_binary_path.clone(); + let git_binary = self.git_binary(); self.executor .spawn(async move { if !paths.is_empty() { - let output = new_command(&git_binary_path) - .current_dir(&working_directory?) + let git = git_binary?; + let output = git + .build_command(["reset", "--quiet", "--"]) .envs(env.iter()) - .args(["reset", "--quiet", "--"]) .args(paths.iter().map(|p| p.as_std_path())) .output() .await?; @@ -2158,19 +2086,16 @@ impl GitRepository for RealGitRepository { paths: Vec, env: Arc>, ) -> BoxFuture<'_, Result<()>> { - let working_directory = self.working_directory(); - let git_binary_path = self.any_git_binary_path.clone(); + let git_binary = self.git_binary(); self.executor .spawn(async move { - let mut cmd = new_command(&git_binary_path); - cmd.current_dir(&working_directory?) + let git = git_binary?; + let output = git + .build_command(["stash", "push", "--quiet", "--include-untracked"]) .envs(env.iter()) - .args(["stash", "push", "--quiet"]) - .arg("--include-untracked"); - - cmd.args(paths.iter().map(|p| p.as_unix_str())); - - let output = cmd.output().await?; + .args(paths.iter().map(|p| p.as_unix_str())) + .output() + .await?; anyhow::ensure!( output.status.success(), @@ -2187,20 +2112,15 @@ impl GitRepository for RealGitRepository { index: Option, env: Arc>, ) -> BoxFuture<'_, Result<()>> { - let working_directory = self.working_directory(); - let git_binary_path = self.any_git_binary_path.clone(); + let git_binary = self.git_binary(); self.executor .spawn(async move { - let mut cmd = new_command(git_binary_path); + let git = git_binary?; let mut args = vec!["stash".to_string(), "pop".to_string()]; if let Some(index) = index { args.push(format!("stash@{{{}}}", index)); } - cmd.current_dir(&working_directory?) - .envs(env.iter()) - .args(args); - - let output = cmd.output().await?; + let output = git.build_command(&args).envs(env.iter()).output().await?; anyhow::ensure!( output.status.success(), @@ -2217,20 +2137,15 @@ impl GitRepository for RealGitRepository { index: Option, env: Arc>, ) -> BoxFuture<'_, Result<()>> { - let working_directory = self.working_directory(); - let git_binary_path = self.any_git_binary_path.clone(); + let git_binary = self.git_binary(); self.executor .spawn(async move { - let mut cmd = new_command(git_binary_path); + let git = git_binary?; let mut args = vec!["stash".to_string(), "apply".to_string()]; if let Some(index) = index { args.push(format!("stash@{{{}}}", index)); } - cmd.current_dir(&working_directory?) - .envs(env.iter()) - .args(args); - - let output = cmd.output().await?; + let output = git.build_command(&args).envs(env.iter()).output().await?; anyhow::ensure!( output.status.success(), @@ -2247,20 +2162,15 @@ impl GitRepository for RealGitRepository { index: Option, env: Arc>, ) -> BoxFuture<'_, Result<()>> { - let working_directory = self.working_directory(); - let git_binary_path = self.any_git_binary_path.clone(); + let git_binary = self.git_binary(); self.executor .spawn(async move { - let mut cmd = new_command(git_binary_path); + let git = git_binary?; let mut args = vec!["stash".to_string(), "drop".to_string()]; if let Some(index) = index { args.push(format!("stash@{{{}}}", index)); } - cmd.current_dir(&working_directory?) - .envs(env.iter()) - .args(args); - - let output = cmd.output().await?; + let output = git.build_command(&args).envs(env.iter()).output().await?; anyhow::ensure!( output.status.success(), @@ -2280,16 +2190,14 @@ impl GitRepository for RealGitRepository { ask_pass: AskPassDelegate, env: Arc>, ) -> BoxFuture<'_, Result<()>> { - let working_directory = self.working_directory(); - let git_binary_path = self.any_git_binary_path.clone(); + let git_binary = self.git_binary(); let executor = self.executor.clone(); // Note: Do not spawn this command on the background thread, it might pop open the credential helper // which we want to block on. async move { - let mut cmd = new_command(git_binary_path); - cmd.current_dir(&working_directory?) - .envs(env.iter()) - .args(["commit", "--quiet", "-m"]) + let git = git_binary?; + let mut cmd = git.build_command(["commit", "--quiet", "-m"]); + cmd.envs(env.iter()) .arg(&message.to_string()) .arg("--cleanup=strip") .arg("--no-verify") @@ -2328,16 +2236,21 @@ impl GitRepository for RealGitRepository { let working_directory = self.working_directory(); let executor = cx.background_executor().clone(); let git_binary_path = self.system_git_binary_path.clone(); + let is_trusted = self.is_trusted(); // Note: Do not spawn this command on the background thread, it might pop open the credential helper // which we want to block on. async move { let git_binary_path = git_binary_path.context("git not found on $PATH, can't push")?; let working_directory = working_directory?; - let mut command = new_command(git_binary_path); + let git = GitBinary::new( + git_binary_path, + working_directory, + executor.clone(), + is_trusted, + ); + let mut command = git.build_command(["push"]); command .envs(env.iter()) - .current_dir(&working_directory) - .args(["push"]) .args(options.map(|option| match option { PushOptions::SetUpstream => "--set-upstream", PushOptions::Force => "--force-with-lease", @@ -2365,15 +2278,20 @@ impl GitRepository for RealGitRepository { let working_directory = self.working_directory(); let executor = cx.background_executor().clone(); let git_binary_path = self.system_git_binary_path.clone(); + let is_trusted = self.is_trusted(); // Note: Do not spawn this command on the background thread, it might pop open the credential helper // which we want to block on. async move { let git_binary_path = git_binary_path.context("git not found on $PATH, can't pull")?; - let mut command = new_command(git_binary_path); - command - .envs(env.iter()) - .current_dir(&working_directory?) - .arg("pull"); + let working_directory = working_directory?; + let git = GitBinary::new( + git_binary_path, + working_directory, + executor.clone(), + is_trusted, + ); + let mut command = git.build_command(["pull"]); + command.envs(env.iter()); if rebase { command.arg("--rebase"); @@ -2401,15 +2319,21 @@ impl GitRepository for RealGitRepository { let remote_name = format!("{}", fetch_options); let git_binary_path = self.system_git_binary_path.clone(); let executor = cx.background_executor().clone(); + let is_trusted = self.is_trusted(); // Note: Do not spawn this command on the background thread, it might pop open the credential helper // which we want to block on. async move { let git_binary_path = git_binary_path.context("git not found on $PATH, can't fetch")?; - let mut command = new_command(git_binary_path); + let working_directory = working_directory?; + let git = GitBinary::new( + git_binary_path, + working_directory, + executor.clone(), + is_trusted, + ); + let mut command = git.build_command(["fetch", &remote_name]); command .envs(env.iter()) - .current_dir(&working_directory?) - .args(["fetch", &remote_name]) .stdout(Stdio::piped()) .stderr(Stdio::piped()); @@ -2419,14 +2343,12 @@ impl GitRepository for RealGitRepository { } fn get_push_remote(&self, branch: String) -> BoxFuture<'_, Result>> { - let working_directory = self.working_directory(); - let git_binary_path = self.any_git_binary_path.clone(); + let git_binary = self.git_binary(); self.executor .spawn(async move { - let working_directory = working_directory?; - let output = new_command(&git_binary_path) - .current_dir(&working_directory) - .args(["rev-parse", "--abbrev-ref"]) + let git = git_binary?; + let output = git + .build_command(["rev-parse", "--abbrev-ref"]) .arg(format!("{branch}@{{push}}")) .output() .await?; @@ -2446,14 +2368,12 @@ impl GitRepository for RealGitRepository { } fn get_branch_remote(&self, branch: String) -> BoxFuture<'_, Result>> { - let working_directory = self.working_directory(); - let git_binary_path = self.any_git_binary_path.clone(); + let git_binary = self.git_binary(); self.executor .spawn(async move { - let working_directory = working_directory?; - let output = new_command(&git_binary_path) - .current_dir(&working_directory) - .args(["config", "--get"]) + let git = git_binary?; + let output = git + .build_command(["config", "--get"]) .arg(format!("branch.{branch}.remote")) .output() .await?; @@ -2470,16 +2390,11 @@ impl GitRepository for RealGitRepository { } fn get_all_remotes(&self) -> BoxFuture<'_, Result>> { - let working_directory = self.working_directory(); - let git_binary_path = self.any_git_binary_path.clone(); + let git_binary = self.git_binary(); self.executor .spawn(async move { - let working_directory = working_directory?; - let output = new_command(&git_binary_path) - .current_dir(&working_directory) - .args(["remote", "-v"]) - .output() - .await?; + let git = git_binary?; + let output = git.build_command(["remote", "-v"]).output().await?; anyhow::ensure!( output.status.success(), @@ -2528,17 +2443,12 @@ impl GitRepository for RealGitRepository { } fn check_for_pushed_commit(&self) -> BoxFuture<'_, Result>> { - let working_directory = self.working_directory(); - let git_binary_path = self.any_git_binary_path.clone(); + let git_binary = self.git_binary(); self.executor .spawn(async move { - let working_directory = working_directory?; + let git = git_binary?; let git_cmd = async |args: &[&str]| -> Result { - let output = new_command(&git_binary_path) - .current_dir(&working_directory) - .args(args) - .output() - .await?; + let output = git.build_command(args).output().await?; anyhow::ensure!( output.status.success(), String::from_utf8_lossy(&output.stderr).to_string() @@ -2587,14 +2497,10 @@ impl GitRepository for RealGitRepository { } fn checkpoint(&self) -> BoxFuture<'static, Result> { - let working_directory = self.working_directory(); - let git_binary_path = self.any_git_binary_path.clone(); - let executor = self.executor.clone(); + let git_binary = self.git_binary(); self.executor .spawn(async move { - let working_directory = working_directory?; - let mut git = GitBinary::new(git_binary_path, working_directory.clone(), executor) - .envs(checkpoint_author_envs()); + let mut git = git_binary?.envs(checkpoint_author_envs()); git.with_temp_index(async |git| { let head_sha = git.run(&["rev-parse", "HEAD"]).await.ok(); let mut excludes = exclude_files(git).await?; @@ -2620,15 +2526,10 @@ impl GitRepository for RealGitRepository { } fn restore_checkpoint(&self, checkpoint: GitRepositoryCheckpoint) -> BoxFuture<'_, Result<()>> { - let working_directory = self.working_directory(); - let git_binary_path = self.any_git_binary_path.clone(); - - let executor = self.executor.clone(); + let git_binary = self.git_binary(); self.executor .spawn(async move { - let working_directory = working_directory?; - - let git = GitBinary::new(git_binary_path, working_directory, executor); + let git = git_binary?; git.run(&[ "restore", "--source", @@ -2659,14 +2560,10 @@ impl GitRepository for RealGitRepository { left: GitRepositoryCheckpoint, right: GitRepositoryCheckpoint, ) -> BoxFuture<'_, Result> { - let working_directory = self.working_directory(); - let git_binary_path = self.any_git_binary_path.clone(); - - let executor = self.executor.clone(); + let git_binary = self.git_binary(); self.executor .spawn(async move { - let working_directory = working_directory?; - let git = GitBinary::new(git_binary_path, working_directory, executor); + let git = git_binary?; let result = git .run(&[ "diff-tree", @@ -2697,14 +2594,10 @@ impl GitRepository for RealGitRepository { base_checkpoint: GitRepositoryCheckpoint, target_checkpoint: GitRepositoryCheckpoint, ) -> BoxFuture<'_, Result> { - let working_directory = self.working_directory(); - let git_binary_path = self.any_git_binary_path.clone(); - - let executor = self.executor.clone(); + let git_binary = self.git_binary(); self.executor .spawn(async move { - let working_directory = working_directory?; - let git = GitBinary::new(git_binary_path, working_directory, executor); + let git = git_binary?; git.run(&[ "diff", "--find-renames", @@ -2721,14 +2614,10 @@ impl GitRepository for RealGitRepository { &self, include_remote_name: bool, ) -> BoxFuture<'_, Result>> { - let working_directory = self.working_directory(); - let git_binary_path = self.any_git_binary_path.clone(); - - let executor = self.executor.clone(); + let git_binary = self.git_binary(); self.executor .spawn(async move { - let working_directory = working_directory?; - let git = GitBinary::new(git_binary_path, working_directory, executor); + let git = git_binary?; let strip_prefix = if include_remote_name { "refs/remotes/" @@ -2778,22 +2667,23 @@ impl GitRepository for RealGitRepository { hook: RunHook, env: Arc>, ) -> BoxFuture<'_, Result<()>> { - let working_directory = self.working_directory(); + let git_binary = self.git_binary(); let repository = self.repository.clone(); - let git_binary_path = self.any_git_binary_path.clone(); - let executor = self.executor.clone(); let help_output = self.any_git_binary_help_output(); // Note: Do not spawn these commands on the background thread, as this causes some git hooks to hang. async move { - let working_directory = working_directory?; + let git_binary = git_binary?; + + let working_directory = git_binary.working_directory.clone(); if !help_output .await .lines() .any(|line| line.trim().starts_with("hook ")) { let hook_abs_path = repository.lock().path().join("hooks").join(hook.as_str()); - if hook_abs_path.is_file() { + if hook_abs_path.is_file() && git_binary.is_trusted { + #[allow(clippy::disallowed_methods)] let output = new_command(&hook_abs_path) .envs(env.iter()) .current_dir(&working_directory) @@ -2813,10 +2703,12 @@ impl GitRepository for RealGitRepository { return Ok(()); } - let git = GitBinary::new(git_binary_path, working_directory, executor) - .envs(HashMap::clone(&env)); - git.run(&["hook", "run", "--ignore-missing", hook.as_str()]) - .await?; + if git_binary.is_trusted { + let git_binary = git_binary.envs(HashMap::clone(&env)); + git_binary + .run(&["hook", "run", "--ignore-missing", hook.as_str()]) + .await?; + } Ok(()) } .boxed() @@ -2828,13 +2720,10 @@ impl GitRepository for RealGitRepository { log_order: LogOrder, request_tx: Sender>>, ) -> BoxFuture<'_, Result<()>> { - let git_binary_path = self.any_git_binary_path.clone(); - let working_directory = self.working_directory(); - let executor = self.executor.clone(); + let git_binary = self.git_binary(); async move { - let working_directory = working_directory?; - let git = GitBinary::new(git_binary_path, working_directory, executor); + let git = git_binary?; let mut command = git.build_command([ "log", @@ -2888,19 +2777,12 @@ impl GitRepository for RealGitRepository { } fn commit_data_reader(&self) -> Result { - let git_binary_path = self.any_git_binary_path.clone(); - let working_directory = self - .working_directory() - .map_err(|_| anyhow!("no working directory"))?; - let executor = self.executor.clone(); + let git_binary = self.git_binary()?; let (request_tx, request_rx) = smol::channel::bounded::(64); let task = self.executor.spawn(async move { - if let Err(error) = - run_commit_data_reader(git_binary_path, working_directory, executor, request_rx) - .await - { + if let Err(error) = run_commit_data_reader(git_binary, request_rx).await { log::error!("commit data reader failed: {error:?}"); } }); @@ -2910,15 +2792,21 @@ impl GitRepository for RealGitRepository { _task: task, }) } + + fn set_trusted(&self, trusted: bool) { + self.is_trusted + .store(trusted, std::sync::atomic::Ordering::Release); + } + + fn is_trusted(&self) -> bool { + self.is_trusted.load(std::sync::atomic::Ordering::Acquire) + } } async fn run_commit_data_reader( - git_binary_path: PathBuf, - working_directory: PathBuf, - executor: BackgroundExecutor, + git: GitBinary, request_rx: smol::channel::Receiver, ) -> Result<()> { - let git = GitBinary::new(git_binary_path, working_directory, executor); let mut process = git .build_command(["--no-optional-locks", "cat-file", "--batch"]) .stdin(Stdio::piped()) @@ -3041,11 +2929,6 @@ fn git_status_args(path_prefixes: &[RepoPath]) -> Vec { OsString::from("--no-renames"), OsString::from("-z"), ]; - args.extend( - path_prefixes - .iter() - .map(|path_prefix| path_prefix.as_std_path().into()), - ); args.extend(path_prefixes.iter().map(|path_prefix| { if path_prefix.is_empty() { Path::new(".").into() @@ -3094,19 +2977,21 @@ async fn exclude_files(git: &GitBinary) -> Result { Ok(excludes) } -struct GitBinary { +pub(crate) struct GitBinary { git_binary_path: PathBuf, working_directory: PathBuf, executor: BackgroundExecutor, index_file_path: Option, envs: HashMap, + is_trusted: bool, } impl GitBinary { - fn new( + pub(crate) fn new( git_binary_path: PathBuf, working_directory: PathBuf, executor: BackgroundExecutor, + is_trusted: bool, ) -> Self { Self { git_binary_path, @@ -3114,6 +2999,7 @@ impl GitBinary { executor, index_file_path: None, envs: HashMap::default(), + is_trusted, } } @@ -3218,12 +3104,26 @@ impl GitBinary { Ok(String::from_utf8(output.stdout)?) } - fn build_command(&self, args: impl IntoIterator) -> util::command::Command + #[allow(clippy::disallowed_methods)] + pub(crate) fn build_command( + &self, + args: impl IntoIterator, + ) -> util::command::Command where S: AsRef, { let mut command = new_command(&self.git_binary_path); command.current_dir(&self.working_directory); + command.args(["-c", "core.fsmonitor=false"]); + command.arg("--no-pager"); + + if !self.is_trusted { + command.args(["-c", "core.hooksPath=/dev/null"]); + command.args(["-c", "core.sshCommand=ssh"]); + command.args(["-c", "credential.helper="]); + command.args(["-c", "protocol.ext.allow=never"]); + command.args(["-c", "diff.external="]); + } command.args(args); if let Some(index_file_path) = self.index_file_path.as_ref() { command.env("GIT_INDEX_FILE", index_file_path); @@ -3483,6 +3383,102 @@ mod tests { } } + #[gpui::test] + async fn test_build_command_untrusted_includes_both_safety_args(cx: &mut TestAppContext) { + cx.executor().allow_parking(); + let dir = tempfile::tempdir().unwrap(); + let git = GitBinary::new( + PathBuf::from("git"), + dir.path().to_path_buf(), + cx.executor(), + false, + ); + let output = git + .build_command(["version"]) + .output() + .await + .expect("git version should succeed"); + assert!(output.status.success()); + + let git = GitBinary::new( + PathBuf::from("git"), + dir.path().to_path_buf(), + cx.executor(), + false, + ); + let output = git + .build_command(["config", "--get", "core.fsmonitor"]) + .output() + .await + .expect("git config should run"); + let stdout = String::from_utf8_lossy(&output.stdout); + assert_eq!( + stdout.trim(), + "false", + "fsmonitor should be disabled for untrusted repos" + ); + + git2::Repository::init(dir.path()).unwrap(); + let git = GitBinary::new( + PathBuf::from("git"), + dir.path().to_path_buf(), + cx.executor(), + false, + ); + let output = git + .build_command(["config", "--get", "core.hooksPath"]) + .output() + .await + .expect("git config should run"); + let stdout = String::from_utf8_lossy(&output.stdout); + assert_eq!( + stdout.trim(), + "/dev/null", + "hooksPath should be /dev/null for untrusted repos" + ); + } + + #[gpui::test] + async fn test_build_command_trusted_only_disables_fsmonitor(cx: &mut TestAppContext) { + cx.executor().allow_parking(); + let dir = tempfile::tempdir().unwrap(); + git2::Repository::init(dir.path()).unwrap(); + + let git = GitBinary::new( + PathBuf::from("git"), + dir.path().to_path_buf(), + cx.executor(), + true, + ); + let output = git + .build_command(["config", "--get", "core.fsmonitor"]) + .output() + .await + .expect("git config should run"); + let stdout = String::from_utf8_lossy(&output.stdout); + assert_eq!( + stdout.trim(), + "false", + "fsmonitor should be disabled even for trusted repos" + ); + + let git = GitBinary::new( + PathBuf::from("git"), + dir.path().to_path_buf(), + cx.executor(), + true, + ); + let output = git + .build_command(["config", "--get", "core.hooksPath"]) + .output() + .await + .expect("git config should run"); + assert!( + !output.status.success(), + "hooksPath should NOT be overridden for trusted repos" + ); + } + #[gpui::test] async fn test_checkpoint_basic(cx: &mut TestAppContext) { disable_git_global_config(); @@ -4272,6 +4268,34 @@ mod tests { ); } + #[test] + fn test_original_repo_path_from_common_dir() { + // Normal repo: common_dir is /.git + assert_eq!( + original_repo_path_from_common_dir(Path::new("/code/zed5/.git")), + PathBuf::from("/code/zed5") + ); + + // Worktree: common_dir is the main repo's .git + // (same result — that's the point, it always traces back to the original) + assert_eq!( + original_repo_path_from_common_dir(Path::new("/code/zed5/.git")), + PathBuf::from("/code/zed5") + ); + + // Bare repo: no .git suffix, returns as-is + assert_eq!( + original_repo_path_from_common_dir(Path::new("/code/zed5.git")), + PathBuf::from("/code/zed5.git") + ); + + // Root-level .git directory + assert_eq!( + original_repo_path_from_common_dir(Path::new("/.git")), + PathBuf::from("/") + ); + } + #[test] fn test_validate_worktree_directory() { let work_dir = Path::new("/code/my-project"); @@ -4347,7 +4371,7 @@ mod tests { .spawn(async move { let git_binary_path = git_binary_path.clone(); let working_directory = working_directory?; - let git = GitBinary::new(git_binary_path, working_directory, executor); + let git = GitBinary::new(git_binary_path, working_directory, executor, true); git.run(&["gc", "--prune"]).await?; Ok(()) }) diff --git a/crates/git/src/status.rs b/crates/git/src/status.rs index b20919e7ecf4748d0035a003ed5eadebae752dd7..e8b5caec505f7bf65cb4f5cd7d789207ccd8784f 100644 --- a/crates/git/src/status.rs +++ b/crates/git/src/status.rs @@ -586,13 +586,18 @@ pub struct DiffStat { pub deleted: u32, } +#[derive(Clone, Debug)] +pub struct GitDiffStat { + pub entries: Arc<[(RepoPath, DiffStat)]>, +} + /// Parses the output of `git diff --numstat` where output looks like: /// /// ```text /// 24 12 dir/file.txt /// ``` -pub fn parse_numstat(output: &str) -> HashMap { - let mut stats = HashMap::default(); +pub fn parse_numstat(output: &str) -> GitDiffStat { + let mut entries = Vec::new(); for line in output.lines() { let line = line.trim(); if line.is_empty() { @@ -613,10 +618,14 @@ pub fn parse_numstat(output: &str) -> HashMap { let Ok(path) = RepoPath::new(path_str) else { continue; }; - let stat = DiffStat { added, deleted }; - stats.insert(path, stat); + entries.push((path, DiffStat { added, deleted })); + } + entries.sort_by(|(a, _), (b, _)| a.cmp(b)); + entries.dedup_by(|(a, _), (b, _)| a == b); + + GitDiffStat { + entries: entries.into(), } - stats } #[cfg(test)] @@ -629,20 +638,25 @@ mod tests { use super::{DiffStat, parse_numstat}; + fn lookup<'a>(entries: &'a [(RepoPath, DiffStat)], path: &str) -> Option<&'a DiffStat> { + let path = RepoPath::new(path).unwrap(); + entries.iter().find(|(p, _)| p == &path).map(|(_, s)| s) + } + #[test] fn test_parse_numstat_normal() { let input = "10\t5\tsrc/main.rs\n3\t1\tREADME.md\n"; let result = parse_numstat(input); - assert_eq!(result.len(), 2); + assert_eq!(result.entries.len(), 2); assert_eq!( - result.get(&RepoPath::new("src/main.rs").unwrap()), + lookup(&result.entries, "src/main.rs"), Some(&DiffStat { added: 10, deleted: 5 }) ); assert_eq!( - result.get(&RepoPath::new("README.md").unwrap()), + lookup(&result.entries, "README.md"), Some(&DiffStat { added: 3, deleted: 1 @@ -655,10 +669,10 @@ mod tests { // git diff --numstat outputs "-\t-\tpath" for binary files let input = "-\t-\timage.png\n5\t2\tsrc/lib.rs\n"; let result = parse_numstat(input); - assert_eq!(result.len(), 1); - assert!(!result.contains_key(&RepoPath::new("image.png").unwrap())); + assert_eq!(result.entries.len(), 1); + assert!(lookup(&result.entries, "image.png").is_none()); assert_eq!( - result.get(&RepoPath::new("src/lib.rs").unwrap()), + lookup(&result.entries, "src/lib.rs"), Some(&DiffStat { added: 5, deleted: 2 @@ -668,18 +682,18 @@ mod tests { #[test] fn test_parse_numstat_empty_input() { - assert!(parse_numstat("").is_empty()); - assert!(parse_numstat("\n\n").is_empty()); - assert!(parse_numstat(" \n \n").is_empty()); + assert!(parse_numstat("").entries.is_empty()); + assert!(parse_numstat("\n\n").entries.is_empty()); + assert!(parse_numstat(" \n \n").entries.is_empty()); } #[test] fn test_parse_numstat_malformed_lines_skipped() { let input = "not_a_number\t5\tfile.rs\n10\t5\tvalid.rs\n"; let result = parse_numstat(input); - assert_eq!(result.len(), 1); + assert_eq!(result.entries.len(), 1); assert_eq!( - result.get(&RepoPath::new("valid.rs").unwrap()), + lookup(&result.entries, "valid.rs"), Some(&DiffStat { added: 10, deleted: 5 @@ -692,9 +706,9 @@ mod tests { // Lines with fewer than 3 tab-separated fields are skipped let input = "10\t5\n7\t3\tok.rs\n"; let result = parse_numstat(input); - assert_eq!(result.len(), 1); + assert_eq!(result.entries.len(), 1); assert_eq!( - result.get(&RepoPath::new("ok.rs").unwrap()), + lookup(&result.entries, "ok.rs"), Some(&DiffStat { added: 7, deleted: 3 @@ -707,7 +721,7 @@ mod tests { let input = "0\t0\tunchanged_but_present.rs\n"; let result = parse_numstat(input); assert_eq!( - result.get(&RepoPath::new("unchanged_but_present.rs").unwrap()), + lookup(&result.entries, "unchanged_but_present.rs"), Some(&DiffStat { added: 0, deleted: 0 diff --git a/crates/git_ui/src/git_panel.rs b/crates/git_ui/src/git_panel.rs index 1fabc387247e3f0889749463e3aabd89ef0bff42..61d94b68a118525bd9b67217a929ce7462696dc7 100644 --- a/crates/git_ui/src/git_panel.rs +++ b/crates/git_ui/src/git_panel.rs @@ -41,7 +41,7 @@ use gpui::{ WeakEntity, actions, anchored, deferred, point, size, uniform_list, }; use itertools::Itertools; -use language::{Buffer, BufferEvent, File}; +use language::{Buffer, File}; use language_model::{ ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role, }; @@ -51,7 +51,6 @@ use notifications::status_toast::{StatusToast, ToastIcon}; use panel::{PanelHeader, panel_button, panel_filled_button, panel_icon_button}; use project::{ Fs, Project, ProjectPath, - buffer_store::BufferStoreEvent, git_store::{GitStoreEvent, Repository, RepositoryEvent, RepositoryId, pending_op}, project_settings::{GitPathStyle, ProjectSettings}, }; @@ -533,6 +532,7 @@ pub struct GitStatusEntry { pub(crate) repo_path: RepoPath, pub(crate) status: FileStatus, pub(crate) staging: StageStatus, + pub(crate) diff_stat: Option, } impl GitStatusEntry { @@ -653,8 +653,7 @@ pub struct GitPanel { local_committer_task: Option>, bulk_staging: Option, stash_entries: GitStash, - diff_stats: HashMap, - diff_stats_task: Task<()>, + _settings_subscription: Subscription, } @@ -723,18 +722,14 @@ impl GitPanel { if tree_view != was_tree_view { this.view_mode = GitPanelViewMode::from_settings(cx); } + + let mut update_entries = false; if sort_by_path != was_sort_by_path || tree_view != was_tree_view { this.bulk_staging.take(); - this.update_visible_entries(window, cx); + update_entries = true; } - if diff_stats != was_diff_stats { - if diff_stats { - this.fetch_diff_stats(cx); - } else { - this.diff_stats.clear(); - this.diff_stats_task = Task::ready(()); - cx.notify(); - } + if (diff_stats != was_diff_stats) || update_entries { + this.update_visible_entries(window, cx); } was_sort_by_path = sort_by_path; was_tree_view = tree_view; @@ -791,33 +786,6 @@ impl GitPanel { ) .detach(); - let buffer_store = project.read(cx).buffer_store().clone(); - - for buffer in project.read(cx).opened_buffers(cx) { - cx.subscribe(&buffer, |this, _buffer, event, cx| { - if matches!(event, BufferEvent::Saved) { - if GitPanelSettings::get_global(cx).diff_stats { - this.fetch_diff_stats(cx); - } - } - }) - .detach(); - } - - cx.subscribe(&buffer_store, |_this, _store, event, cx| { - if let BufferStoreEvent::BufferAdded(buffer) = event { - cx.subscribe(buffer, |this, _buffer, event, cx| { - if matches!(event, BufferEvent::Saved) { - if GitPanelSettings::get_global(cx).diff_stats { - this.fetch_diff_stats(cx); - } - } - }) - .detach(); - } - }) - .detach(); - let mut this = Self { active_repository, commit_editor, @@ -858,8 +826,6 @@ impl GitPanel { entry_count: 0, bulk_staging: None, stash_entries: Default::default(), - diff_stats: HashMap::default(), - diff_stats_task: Task::ready(()), _settings_subscription, }; @@ -3575,6 +3541,7 @@ impl GitPanel { repo_path: entry.repo_path.clone(), status: entry.status, staging, + diff_stat: entry.diff_stat, }; if staging.has_staged() { @@ -3611,6 +3578,7 @@ impl GitPanel { repo_path: ops.repo_path.clone(), status: status.status, staging: StageStatus::Staged, + diff_stat: status.diff_stat, }); } } @@ -3743,60 +3711,9 @@ impl GitPanel { editor.set_placeholder_text(&placeholder_text, window, cx) }); - if GitPanelSettings::get_global(cx).diff_stats { - self.fetch_diff_stats(cx); - } - cx.notify(); } - fn fetch_diff_stats(&mut self, cx: &mut Context) { - let Some(repo) = self.active_repository.clone() else { - self.diff_stats.clear(); - return; - }; - - let unstaged_rx = repo.update(cx, |repo, cx| repo.diff_stat(DiffType::HeadToWorktree, cx)); - let staged_rx = repo.update(cx, |repo, cx| repo.diff_stat(DiffType::HeadToIndex, cx)); - - self.diff_stats_task = cx.spawn(async move |this, cx| { - let (unstaged_result, staged_result) = - futures::future::join(unstaged_rx, staged_rx).await; - - let mut combined = match unstaged_result { - Ok(Ok(stats)) => stats, - Ok(Err(err)) => { - log::warn!("Failed to fetch unstaged diff stats: {err:?}"); - HashMap::default() - } - Err(_) => HashMap::default(), - }; - - let staged = match staged_result { - Ok(Ok(stats)) => Some(stats), - Ok(Err(err)) => { - log::warn!("Failed to fetch staged diff stats: {err:?}"); - None - } - Err(_) => None, - }; - - if let Some(staged) = staged { - for (path, stat) in staged { - let entry = combined.entry(path).or_default(); - entry.added += stat.added; - entry.deleted += stat.deleted; - } - } - - this.update(cx, |this, cx| { - this.diff_stats = combined; - cx.notify(); - }) - .ok(); - }); - } - fn header_state(&self, header_type: Section) -> ToggleState { let (staged_count, count) = match header_type { Section::New => (self.new_staged_count, self.new_count), @@ -5227,17 +5144,14 @@ impl GitPanel { .active(|s| s.bg(active_bg)) .child(name_row) .when(GitPanelSettings::get_global(cx).diff_stats, |el| { - el.when_some( - self.diff_stats.get(&entry.repo_path).copied(), - move |this, stat| { - let id = format!("diff-stat-{}", id_for_diff_stat); - this.child(ui::DiffStat::new( - id, - stat.added as usize, - stat.deleted as usize, - )) - }, - ) + el.when_some(entry.diff_stat, move |this, stat| { + let id = format!("diff-stat-{}", id_for_diff_stat); + this.child(ui::DiffStat::new( + id, + stat.added as usize, + stat.deleted as usize, + )) + }) }) .child( div() @@ -5629,6 +5543,21 @@ impl GitPanel { } } +#[cfg(any(test, feature = "test-support"))] +impl GitPanel { + pub fn new_test( + workspace: &mut Workspace, + window: &mut Window, + cx: &mut Context, + ) -> Entity { + Self::new(workspace, window, cx) + } + + pub fn active_repository(&self) -> Option<&Entity> { + self.active_repository.as_ref() + } +} + impl Render for GitPanel { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { let project = self.project.read(cx); @@ -6606,11 +6535,19 @@ mod tests { repo_path: repo_path("crates/gpui/gpui.rs"), status: StatusCode::Modified.worktree(), staging: StageStatus::Unstaged, + diff_stat: Some(DiffStat { + added: 1, + deleted: 1, + }), }), GitListEntry::Status(GitStatusEntry { repo_path: repo_path("crates/util/util.rs"), status: StatusCode::Modified.worktree(), staging: StageStatus::Unstaged, + diff_stat: Some(DiffStat { + added: 1, + deleted: 1, + }), },), ], ); @@ -6631,11 +6568,19 @@ mod tests { repo_path: repo_path("crates/gpui/gpui.rs"), status: StatusCode::Modified.worktree(), staging: StageStatus::Unstaged, + diff_stat: Some(DiffStat { + added: 1, + deleted: 1, + }), }), GitListEntry::Status(GitStatusEntry { repo_path: repo_path("crates/util/util.rs"), status: StatusCode::Modified.worktree(), staging: StageStatus::Unstaged, + diff_stat: Some(DiffStat { + added: 1, + deleted: 1, + }), },), ], ); diff --git a/crates/git_ui/src/git_picker.rs b/crates/git_ui/src/git_picker.rs index 82ef9c9516b7c145edbf26d6c5b8927189525cab..6cf82327b43abe6c3784e4ec8ca3d16161edfda7 100644 --- a/crates/git_ui/src/git_picker.rs +++ b/crates/git_ui/src/git_picker.rs @@ -15,7 +15,7 @@ use workspace::{ModalView, Workspace, pane}; use crate::branch_picker::{self, BranchList, DeleteBranch, FilterRemotes}; use crate::stash_picker::{self, DropStashItem, ShowStashItem, StashList}; use crate::worktree_picker::{ - self, WorktreeFromDefault, WorktreeFromDefaultOnWindow, WorktreeList, + self, DeleteWorktree, WorktreeFromDefault, WorktreeFromDefaultOnWindow, WorktreeList, }; actions!( @@ -408,6 +408,19 @@ impl GitPicker { } } + fn handle_worktree_delete( + &mut self, + _: &DeleteWorktree, + window: &mut Window, + cx: &mut Context, + ) { + if let Some(worktree_list) = &self.worktree_list { + worktree_list.update(cx, |list, cx| { + list.handle_delete(&DeleteWorktree, window, cx); + }); + } + } + fn handle_drop_stash( &mut self, _: &DropStashItem, @@ -524,6 +537,7 @@ impl Render for GitPicker { .when(self.tab == GitPickerTab::Worktrees, |el| { el.on_action(cx.listener(Self::handle_worktree_from_default)) .on_action(cx.listener(Self::handle_worktree_from_default_on_window)) + .on_action(cx.listener(Self::handle_worktree_delete)) }) .when(self.tab == GitPickerTab::Stash, |el| { el.on_action(cx.listener(Self::handle_drop_stash)) diff --git a/crates/git_ui/src/worktree_picker.rs b/crates/git_ui/src/worktree_picker.rs index f2826a2b543a73c5341653c42bbb5f1540213b2a..6c35e7c99ffb8f6efa1a2bd7a07c2ded8d158668 100644 --- a/crates/git_ui/src/worktree_picker.rs +++ b/crates/git_ui/src/worktree_picker.rs @@ -22,7 +22,16 @@ use ui::{HighlightedLabel, KeyBinding, ListItem, ListItemSpacing, prelude::*}; use util::ResultExt; use workspace::{ModalView, MultiWorkspace, Workspace, notifications::DetachAndPromptErr}; -actions!(git, [WorktreeFromDefault, WorktreeFromDefaultOnWindow]); +use crate::git_panel::show_error_toast; + +actions!( + git, + [ + WorktreeFromDefault, + WorktreeFromDefaultOnWindow, + DeleteWorktree + ] +); pub fn open( workspace: &mut Workspace, @@ -181,6 +190,19 @@ impl WorktreeList { ); }) } + + pub fn handle_delete( + &mut self, + _: &DeleteWorktree, + window: &mut Window, + cx: &mut Context, + ) { + self.picker.update(cx, |picker, cx| { + picker + .delegate + .delete_at(picker.delegate.selected_index, window, cx) + }) + } } impl ModalView for WorktreeList {} impl EventEmitter for WorktreeList {} @@ -203,6 +225,9 @@ impl Render for WorktreeList { .on_action(cx.listener(|this, _: &WorktreeFromDefaultOnWindow, w, cx| { this.handle_new_worktree(true, w, cx) })) + .on_action(cx.listener(|this, _: &DeleteWorktree, window, cx| { + this.handle_delete(&DeleteWorktree, window, cx) + })) .child(self.picker.clone()) .when(!self.embedded, |el| { el.on_mouse_down_out({ @@ -275,9 +300,9 @@ impl WorktreeListDelegate { .git .worktree_directory .clone(); - let work_dir = repo.work_directory_abs_path.clone(); + let original_repo = repo.original_repo_abs_path.clone(); let directory = - validate_worktree_directory(&work_dir, &worktree_directory_setting)?; + validate_worktree_directory(&original_repo, &worktree_directory_setting)?; let new_worktree_path = directory.join(&branch); let receiver = repo.create_worktree(branch.clone(), directory, commit); anyhow::Ok((receiver, new_worktree_path)) @@ -420,6 +445,57 @@ impl WorktreeListDelegate { .as_ref() .and_then(|repo| repo.read(cx).branch.as_ref().map(|b| b.name())) } + + fn delete_at(&self, idx: usize, window: &mut Window, cx: &mut Context>) { + let Some(entry) = self.matches.get(idx).cloned() else { + return; + }; + if entry.is_new { + return; + } + let Some(repo) = self.repo.clone() else { + return; + }; + let workspace = self.workspace.clone(); + let path = entry.worktree.path; + + cx.spawn_in(window, async move |picker, cx| { + let result = repo + .update(cx, |repo, _| repo.remove_worktree(path.clone(), false)) + .await?; + + if let Err(e) = result { + log::error!("Failed to remove worktree: {}", e); + if let Some(workspace) = workspace.upgrade() { + cx.update(|_window, cx| { + show_error_toast( + workspace, + format!("worktree remove {}", path.display()), + e, + cx, + ) + })?; + } + return Ok(()); + } + + picker.update_in(cx, |picker, _, cx| { + picker.delegate.matches.retain(|e| e.worktree.path != path); + if let Some(all_worktrees) = &mut picker.delegate.all_worktrees { + all_worktrees.retain(|w| w.path != path); + } + if picker.delegate.matches.is_empty() { + picker.delegate.selected_index = 0; + } else if picker.delegate.selected_index >= picker.delegate.matches.len() { + picker.delegate.selected_index = picker.delegate.matches.len() - 1; + } + cx.notify(); + })?; + + anyhow::Ok(()) + }) + .detach(); + } } async fn open_remote_worktree( @@ -778,6 +854,16 @@ impl PickerDelegate for WorktreeListDelegate { } else { Some( footer_container + .child( + Button::new("delete-worktree", "Delete") + .key_binding( + KeyBinding::for_action_in(&DeleteWorktree, &focus_handle, cx) + .map(|kb| kb.size(rems_from_px(12.))), + ) + .on_click(|_, window, cx| { + window.dispatch_action(DeleteWorktree.boxed_clone(), cx) + }), + ) .child( Button::new("open-in-new-window", "Open in New Window") .key_binding( diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs index 3a686f97a8825b30a8f02f4149b110c3d1aacb1e..7659be8ab44da35efd16389c4abd0bf99d8cf3a4 100644 --- a/crates/google_ai/src/google_ai.rs +++ b/crates/google_ai/src/google_ai.rs @@ -510,11 +510,9 @@ pub enum Model { alias = "gemini-2.5-pro-preview-06-05" )] Gemini25Pro, - #[serde(rename = "gemini-3-pro-preview")] - Gemini3Pro, #[serde(rename = "gemini-3-flash-preview")] Gemini3Flash, - #[serde(rename = "gemini-3.1-pro-preview")] + #[serde(rename = "gemini-3.1-pro-preview", alias = "gemini-3-pro-preview")] Gemini31Pro, #[serde(rename = "custom")] Custom { @@ -537,7 +535,6 @@ impl Model { Self::Gemini25FlashLite => "gemini-2.5-flash-lite", Self::Gemini25Flash => "gemini-2.5-flash", Self::Gemini25Pro => "gemini-2.5-pro", - Self::Gemini3Pro => "gemini-3-pro-preview", Self::Gemini3Flash => "gemini-3-flash-preview", Self::Gemini31Pro => "gemini-3.1-pro-preview", Self::Custom { name, .. } => name, @@ -548,7 +545,6 @@ impl Model { Self::Gemini25FlashLite => "gemini-2.5-flash-lite", Self::Gemini25Flash => "gemini-2.5-flash", Self::Gemini25Pro => "gemini-2.5-pro", - Self::Gemini3Pro => "gemini-3-pro-preview", Self::Gemini3Flash => "gemini-3-flash-preview", Self::Gemini31Pro => "gemini-3.1-pro-preview", Self::Custom { name, .. } => name, @@ -560,7 +556,6 @@ impl Model { Self::Gemini25FlashLite => "Gemini 2.5 Flash-Lite", Self::Gemini25Flash => "Gemini 2.5 Flash", Self::Gemini25Pro => "Gemini 2.5 Pro", - Self::Gemini3Pro => "Gemini 3 Pro", Self::Gemini3Flash => "Gemini 3 Flash", Self::Gemini31Pro => "Gemini 3.1 Pro", Self::Custom { @@ -574,7 +569,6 @@ impl Model { Self::Gemini25FlashLite | Self::Gemini25Flash | Self::Gemini25Pro - | Self::Gemini3Pro | Self::Gemini3Flash | Self::Gemini31Pro => 1_048_576, Self::Custom { max_tokens, .. } => *max_tokens, @@ -586,7 +580,6 @@ impl Model { Model::Gemini25FlashLite | Model::Gemini25Flash | Model::Gemini25Pro - | Model::Gemini3Pro | Model::Gemini3Flash | Model::Gemini31Pro => Some(65_536), Model::Custom { .. } => None, @@ -603,10 +596,7 @@ impl Model { pub fn mode(&self) -> GoogleModelMode { match self { - Self::Gemini25FlashLite - | Self::Gemini25Flash - | Self::Gemini25Pro - | Self::Gemini3Pro => { + Self::Gemini25FlashLite | Self::Gemini25Flash | Self::Gemini25Pro => { GoogleModelMode::Thinking { // By default these models are set to "auto", so we preserve that behavior // but indicate they are capable of thinking mode diff --git a/crates/gpui/src/elements/div.rs b/crates/gpui/src/elements/div.rs index 2b4a3c84e8111796bf7ce32a4c6ad83854ded6fd..58f11a7fa1fb876ef4b4ef80fedf1948423a24f5 100644 --- a/crates/gpui/src/elements/div.rs +++ b/crates/gpui/src/elements/div.rs @@ -1886,18 +1886,18 @@ impl Interactivity { // high for the maximum scroll, we round the scroll max to 2 decimal // places here. let padded_content_size = self.content_size + padding_size; - let scroll_max = (padded_content_size - bounds.size) + let scroll_max = Point::from(padded_content_size - bounds.size) .map(round_to_two_decimals) .max(&Default::default()); // Clamp scroll offset in case scroll max is smaller now (e.g., if children // were removed or the bounds became larger). let mut scroll_offset = scroll_offset.borrow_mut(); - scroll_offset.x = scroll_offset.x.clamp(-scroll_max.width, px(0.)); + scroll_offset.x = scroll_offset.x.clamp(-scroll_max.x, px(0.)); if scroll_to_bottom { - scroll_offset.y = -scroll_max.height; + scroll_offset.y = -scroll_max.y; } else { - scroll_offset.y = scroll_offset.y.clamp(-scroll_max.height, px(0.)); + scroll_offset.y = scroll_offset.y.clamp(-scroll_max.y, px(0.)); } if let Some(mut scroll_handle_state) = tracked_scroll_handle { @@ -3285,7 +3285,7 @@ impl ScrollAnchor { struct ScrollHandleState { offset: Rc>>, bounds: Bounds, - max_offset: Size, + max_offset: Point, child_bounds: Vec>, scroll_to_bottom: bool, overflow: Point, @@ -3329,7 +3329,7 @@ impl ScrollHandle { } /// Get the maximum scroll offset. - pub fn max_offset(&self) -> Size { + pub fn max_offset(&self) -> Point { self.0.borrow().max_offset } diff --git a/crates/gpui/src/elements/list.rs b/crates/gpui/src/elements/list.rs index 5403bf10eb9a078dfd113462644636b49d1840e4..92b5389fecf219c0c113f682463498902df4c07d 100644 --- a/crates/gpui/src/elements/list.rs +++ b/crates/gpui/src/elements/list.rs @@ -491,7 +491,7 @@ impl ListState { /// Returns the maximum scroll offset according to the items we have measured. /// This value remains constant while dragging to prevent the scrollbar from moving away unexpectedly. - pub fn max_offset_for_scrollbar(&self) -> Size { + pub fn max_offset_for_scrollbar(&self) -> Point { let state = self.0.borrow(); let bounds = state.last_layout_bounds.unwrap_or_default(); @@ -499,7 +499,7 @@ impl ListState { .scrollbar_drag_start_height .unwrap_or_else(|| state.items.summary().height); - Size::new(Pixels::ZERO, Pixels::ZERO.max(height - bounds.size.height)) + point(Pixels::ZERO, Pixels::ZERO.max(height - bounds.size.height)) } /// Returns the current scroll offset adjusted for the scrollbar diff --git a/crates/gpui/src/elements/svg.rs b/crates/gpui/src/elements/svg.rs index dff389fb93fe7abd2862be70731cc9e6fb613e94..a29b106c0e223b01340ecab27b45fdb94163d207 100644 --- a/crates/gpui/src/elements/svg.rs +++ b/crates/gpui/src/elements/svg.rs @@ -3,8 +3,7 @@ use std::{fs, path::Path, sync::Arc}; use crate::{ App, Asset, Bounds, Element, GlobalElementId, Hitbox, InspectorElementId, InteractiveElement, Interactivity, IntoElement, LayoutId, Pixels, Point, Radians, SharedString, Size, - StyleRefinement, Styled, TransformationMatrix, Window, geometry::Negate as _, point, px, - radians, size, + StyleRefinement, Styled, TransformationMatrix, Window, point, px, radians, size, }; use gpui_util::ResultExt; @@ -254,7 +253,7 @@ impl Transformation { .translate(center.scale(scale_factor) + self.translate.scale(scale_factor)) .rotate(self.rotate) .scale(self.scale) - .translate(center.scale(scale_factor).negate()) + .translate(center.scale(-scale_factor)) } } diff --git a/crates/gpui/src/geometry.rs b/crates/gpui/src/geometry.rs index 73fa9906267412c9f1c840d8403beeef4718119e..76157a06a587ac851d19f19fc5a4ed23c634bab5 100644 --- a/crates/gpui/src/geometry.rs +++ b/crates/gpui/src/geometry.rs @@ -78,6 +78,7 @@ pub trait Along { Deserialize, JsonSchema, Hash, + Neg, )] #[refineable(Debug, PartialEq, Serialize, Deserialize, JsonSchema)] #[repr(C)] @@ -182,12 +183,6 @@ impl Along for Point { } } -impl Negate for Point { - fn negate(self) -> Self { - self.map(Negate::negate) - } -} - impl Point { /// Scales the point by a given factor, which is typically derived from the resolution /// of a target display to ensure proper sizing of UI elements. @@ -393,7 +388,9 @@ impl Display for Point { /// /// This struct is generic over the type `T`, which can be any type that implements `Clone`, `Default`, and `Debug`. /// It is commonly used to specify dimensions for elements in a UI, such as a window or element. -#[derive(Refineable, Default, Clone, Copy, PartialEq, Div, Hash, Serialize, Deserialize)] +#[derive( + Add, Clone, Copy, Default, Deserialize, Div, Hash, Neg, PartialEq, Refineable, Serialize, Sub, +)] #[refineable(Debug, PartialEq, Serialize, Deserialize, JsonSchema)] #[repr(C)] pub struct Size { @@ -598,34 +595,6 @@ where } } -impl Sub for Size -where - T: Sub + Clone + Debug + Default + PartialEq, -{ - type Output = Size; - - fn sub(self, rhs: Self) -> Self::Output { - Size { - width: self.width - rhs.width, - height: self.height - rhs.height, - } - } -} - -impl Add for Size -where - T: Add + Clone + Debug + Default + PartialEq, -{ - type Output = Size; - - fn add(self, rhs: Self) -> Self::Output { - Size { - width: self.width + rhs.width, - height: self.height + rhs.height, - } - } -} - impl Mul for Size where T: Mul + Clone + Debug + Default + PartialEq, @@ -1245,6 +1214,15 @@ where } } +impl From> for Point { + fn from(size: Size) -> Self { + Self { + x: size.width, + y: size.height, + } + } +} + impl Bounds where T: Add + Clone + Debug + Default + PartialEq, @@ -3754,48 +3732,6 @@ impl Half for Rems { } } -/// Provides a trait for types that can negate their values. -pub trait Negate { - /// Returns the negation of the given value - fn negate(self) -> Self; -} - -impl Negate for i32 { - fn negate(self) -> Self { - -self - } -} - -impl Negate for f32 { - fn negate(self) -> Self { - -self - } -} - -impl Negate for DevicePixels { - fn negate(self) -> Self { - Self(-self.0) - } -} - -impl Negate for ScaledPixels { - fn negate(self) -> Self { - Self(-self.0) - } -} - -impl Negate for Pixels { - fn negate(self) -> Self { - Self(-self.0) - } -} - -impl Negate for Rems { - fn negate(self) -> Self { - Self(-self.0) - } -} - /// A trait for checking if a value is zero. /// /// This trait provides a method to determine if a value is considered to be zero. diff --git a/crates/gpui_linux/src/linux/x11/window.rs b/crates/gpui_linux/src/linux/x11/window.rs index f2199ac65e425a8daa04755115264231dd869837..a7cdc67ecd908becd22f799767f482754527fa51 100644 --- a/crates/gpui_linux/src/linux/x11/window.rs +++ b/crates/gpui_linux/src/linux/x11/window.rs @@ -319,12 +319,28 @@ impl rwh::HasDisplayHandle for RawWindow { impl rwh::HasWindowHandle for X11Window { fn window_handle(&self) -> Result, rwh::HandleError> { - unimplemented!() + let Some(non_zero) = NonZeroU32::new(self.0.x_window) else { + return Err(rwh::HandleError::Unavailable); + }; + let handle = rwh::XcbWindowHandle::new(non_zero); + Ok(unsafe { rwh::WindowHandle::borrow_raw(handle.into()) }) } } + impl rwh::HasDisplayHandle for X11Window { fn display_handle(&self) -> Result, rwh::HandleError> { - unimplemented!() + let connection = + as_raw_xcb_connection::AsRawXcbConnection::as_raw_xcb_connection(&*self.0.xcb) + as *mut _; + let Some(non_zero) = NonNull::new(connection) else { + return Err(rwh::HandleError::Unavailable); + }; + let screen_id = { + let state = self.0.state.borrow(); + u32::from(state.display.id()) as i32 + }; + let handle = rwh::XcbDisplayHandle::new(Some(non_zero), screen_id); + Ok(unsafe { rwh::DisplayHandle::borrow_raw(handle.into()) }) } } diff --git a/crates/gpui_wgpu/src/wgpu_renderer.rs b/crates/gpui_wgpu/src/wgpu_renderer.rs index 5beeef6ad1238f25db7c50f739053e138b2e1295..2fd83b7b065e7ce4fe0ba9ec017f39264a33bee3 100644 --- a/crates/gpui_wgpu/src/wgpu_renderer.rs +++ b/crates/gpui_wgpu/src/wgpu_renderer.rs @@ -98,7 +98,6 @@ pub struct WgpuRenderer { queue: Arc, surface: wgpu::Surface<'static>, surface_config: wgpu::SurfaceConfiguration, - surface_configured: bool, pipelines: WgpuPipelines, bind_group_layouts: WgpuBindGroupLayouts, atlas: Arc, @@ -381,7 +380,6 @@ impl WgpuRenderer { queue, surface, surface_config, - surface_configured: true, pipelines, bind_group_layouts, atlas, @@ -875,9 +873,7 @@ impl WgpuRenderer { self.surface_config.width = clamped_width.max(1); self.surface_config.height = clamped_height.max(1); - if self.surface_configured { - self.surface.configure(&self.device, &self.surface_config); - } + self.surface.configure(&self.device, &self.surface_config); // Invalidate intermediate textures - they will be lazily recreated // in draw() after we confirm the surface is healthy. This avoids @@ -928,9 +924,7 @@ impl WgpuRenderer { if new_alpha_mode != self.surface_config.alpha_mode { self.surface_config.alpha_mode = new_alpha_mode; - if self.surface_configured { - self.surface.configure(&self.device, &self.surface_config); - } + self.surface.configure(&self.device, &self.surface_config); self.pipelines = Self::create_pipelines( &self.device, &self.bind_group_layouts, @@ -991,7 +985,7 @@ impl WgpuRenderer { let frame = match self.surface.get_current_texture() { Ok(frame) => frame, Err(wgpu::SurfaceError::Lost | wgpu::SurfaceError::Outdated) => { - self.surface_configured = false; + self.surface.configure(&self.device, &self.surface_config); return; } Err(e) => { diff --git a/crates/http_client/src/async_body.rs b/crates/http_client/src/async_body.rs index 8fb49f218568ea36078d772a7225229f31a916c4..a59a7339db1e4449b875e2c539e98c86b4279365 100644 --- a/crates/http_client/src/async_body.rs +++ b/crates/http_client/src/async_body.rs @@ -7,6 +7,7 @@ use std::{ use bytes::Bytes; use futures::AsyncRead; use http_body::{Body, Frame}; +use serde::Serialize; /// Based on the implementation of AsyncBody in /// . @@ -88,6 +89,19 @@ impl From<&'static str> for AsyncBody { } } +/// Newtype wrapper that serializes a value as JSON into an `AsyncBody`. +pub struct Json(pub T); + +impl From> for AsyncBody { + fn from(json: Json) -> Self { + Self::from_bytes( + serde_json::to_vec(&json.0) + .expect("failed to serialize JSON") + .into(), + ) + } +} + impl> From> for AsyncBody { fn from(body: Option) -> Self { match body { diff --git a/crates/http_client/src/http_client.rs b/crates/http_client/src/http_client.rs index 5cf25a8277872ba3c6d502565e8057623b267d42..bbbe3b1a832332bd6bee693b4c0b916b4f4c182a 100644 --- a/crates/http_client/src/http_client.rs +++ b/crates/http_client/src/http_client.rs @@ -5,7 +5,7 @@ pub mod github; pub mod github_download; pub use anyhow::{Result, anyhow}; -pub use async_body::{AsyncBody, Inner}; +pub use async_body::{AsyncBody, Inner, Json}; use derive_more::Deref; use http::HeaderValue; pub use http::{self, Method, Request, Response, StatusCode, Uri, request::Builder}; diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index eb9bb0827a7be9f4a725246c6d38777e340eee2c..d183615317ecaa481cda45d780c64b2ddf7ec833 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -4,7 +4,7 @@ use crate::{ DebuggerTextObject, LanguageScope, Outline, OutlineConfig, PLAIN_TEXT, RunnableCapture, RunnableTag, TextObject, TreeSitterOptions, diagnostic_set::{DiagnosticEntry, DiagnosticEntryRef, DiagnosticGroup}, - language_settings::{LanguageSettings, language_settings}, + language_settings::{AutoIndentMode, LanguageSettings, language_settings}, outline::OutlineItem, row_chunk::RowChunks, syntax_map::{ @@ -2738,17 +2738,18 @@ impl Buffer { .filter(|((_, (range, _)), _)| { let language = before_edit.language_at(range.start); let language_id = language.map(|l| l.id()); - if let Some((cached_language_id, auto_indent)) = previous_setting + if let Some((cached_language_id, apply_syntax_indent)) = previous_setting && cached_language_id == language_id { - auto_indent + apply_syntax_indent } else { // The auto-indent setting is not present in editorconfigs, hence // we can avoid passing the file here. - let auto_indent = + let auto_indent_mode = language_settings(language.map(|l| l.name()), None, cx).auto_indent; - previous_setting = Some((language_id, auto_indent)); - auto_indent + let apply_syntax_indent = auto_indent_mode == AutoIndentMode::SyntaxAware; + previous_setting = Some((language_id, apply_syntax_indent)); + apply_syntax_indent } }) .map(|((ix, (range, _)), new_text)| { diff --git a/crates/language/src/language.rs b/crates/language/src/language.rs index fe5c5d09aa0765e2c305d88c65e86d6832443b1e..29b569ba1aa68fe83f3456a2eaf9911b4c83677d 100644 --- a/crates/language/src/language.rs +++ b/crates/language/src/language.rs @@ -23,7 +23,7 @@ mod toolchain; pub mod buffer_tests; use crate::language_settings::SoftWrap; -pub use crate::language_settings::{EditPredictionsMode, IndentGuideSettings}; +pub use crate::language_settings::{AutoIndentMode, EditPredictionsMode, IndentGuideSettings}; use anyhow::{Context as _, Result}; use async_trait::async_trait; use collections::{HashMap, HashSet, IndexSet}; @@ -835,6 +835,11 @@ pub struct LanguageConfig { pub name: LanguageName, /// The name of this language for a Markdown code fence block pub code_fence_block_name: Option>, + /// Alternative language names that Jupyter kernels may report for this language. + /// Used when a kernel's `language` field differs from Zed's language name. + /// For example, the Nu extension would set this to `["nushell"]`. + #[serde(default)] + pub kernel_language_names: Vec>, // The name of the grammar in a WASM bundle (experimental). pub grammar: Option>, /// The criteria for matching this language to a given file. @@ -1141,6 +1146,7 @@ impl Default for LanguageConfig { Self { name: LanguageName::new_static(""), code_fence_block_name: None, + kernel_language_names: Default::default(), grammar: None, matcher: LanguageMatcher::default(), brackets: Default::default(), @@ -2075,6 +2081,23 @@ impl Language { .unwrap_or_else(|| self.config.name.as_ref().to_lowercase().into()) } + pub fn matches_kernel_language(&self, kernel_language: &str) -> bool { + let kernel_language_lower = kernel_language.to_lowercase(); + + if self.code_fence_block_name().to_lowercase() == kernel_language_lower { + return true; + } + + if self.config.name.as_ref().to_lowercase() == kernel_language_lower { + return true; + } + + self.config + .kernel_language_names + .iter() + .any(|name| name.to_lowercase() == kernel_language_lower) + } + pub fn context_provider(&self) -> Option> { self.context_provider.clone() } diff --git a/crates/language/src/language_settings.rs b/crates/language/src/language_settings.rs index 9a379697e8bddf9dc71d3d340d5e2a92d8b4405e..f2c55fd1e8a3b8bf5b6c2dd8ea24d1343385fa78 100644 --- a/crates/language/src/language_settings.rs +++ b/crates/language/src/language_settings.rs @@ -12,7 +12,7 @@ use itertools::{Either, Itertools}; use settings::{DocumentFoldingRanges, DocumentSymbols, IntoGpui, SemanticTokens}; pub use settings::{ - CompletionSettingsContent, EditPredictionPromptFormat, EditPredictionProvider, + AutoIndentMode, CompletionSettingsContent, EditPredictionPromptFormat, EditPredictionProvider, EditPredictionsMode, FormatOnSave, Formatter, FormatterList, InlayHintKind, LanguageSettingsContent, LspInsertMode, RewrapBehavior, ShowWhitespaceSetting, SoftWrap, WordsCompletionMode, @@ -144,8 +144,8 @@ pub struct LanguageSettings { /// Whether to use additional LSP queries to format (and amend) the code after /// every "trigger" symbol input, defined by LSP server capabilities. pub use_on_type_format: bool, - /// Whether indentation should be adjusted based on the context whilst typing. - pub auto_indent: bool, + /// Controls automatic indentation behavior when typing. + pub auto_indent: AutoIndentMode, /// Whether indentation of pasted content should be adjusted based on the context. pub auto_indent_on_paste: bool, /// Controls how the editor handles the autoclosed characters. diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 18e099b4d6fc62867bf35fbd1d4573093af44744..b2af80a3c295cab1cf40a330eb8d84f94a137eb7 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use anyhow::{Context as _, Result}; use client::Client; use cloud_api_client::ClientApiError; +use cloud_api_types::OrganizationId; use cloud_api_types::websocket_protocol::MessageToClient; use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME}; use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _}; @@ -26,29 +27,46 @@ impl fmt::Display for PaymentRequiredError { pub struct LlmApiToken(Arc>>); impl LlmApiToken { - pub async fn acquire(&self, client: &Arc) -> Result { + pub async fn acquire( + &self, + client: &Arc, + organization_id: Option, + ) -> Result { let lock = self.0.upgradable_read().await; if let Some(token) = lock.as_ref() { Ok(token.to_string()) } else { - Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await + Self::fetch( + RwLockUpgradableReadGuard::upgrade(lock).await, + client, + organization_id, + ) + .await } } - pub async fn refresh(&self, client: &Arc) -> Result { - Self::fetch(self.0.write().await, client).await + pub async fn refresh( + &self, + client: &Arc, + organization_id: Option, + ) -> Result { + Self::fetch(self.0.write().await, client, organization_id).await } async fn fetch( mut lock: RwLockWriteGuard<'_, Option>, client: &Arc, + organization_id: Option, ) -> Result { let system_id = client .telemetry() .system_id() .map(|system_id| system_id.to_string()); - let result = client.cloud_client().create_llm_token(system_id).await; + let result = client + .cloud_client() + .create_llm_token(system_id, organization_id) + .await; match result { Ok(response) => { *lock = Some(response.token.0.clone()); diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index bcf8401c1c14ae1a74bb7136141d0b35509cdd40..5b493fdf1087911372d8796cc88f4ad14eef8df0 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -658,6 +658,10 @@ impl LanguageModel for BedrockModel { } } + fn supports_streaming_tools(&self) -> bool { + true + } + fn telemetry_id(&self) -> String { format!("bedrock/{}", self.model.id()) } @@ -1200,8 +1204,25 @@ pub fn map_to_language_model_completion_events( .get_mut(&cb_delta.content_block_index) { tool_use.input_json.push_str(tool_output.input()); + if let Ok(input) = serde_json::from_str::( + &partial_json_fixer::fix_json(&tool_use.input_json), + ) { + Some(Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_use.id.clone().into(), + name: tool_use.name.clone().into(), + is_input_complete: false, + raw_input: tool_use.input_json.clone(), + input, + thought_signature: None, + }, + ))) + } else { + None + } + } else { + None } - None } Some(ContentBlockDelta::ReasoningContent(thinking)) => match thinking { ReasoningContentBlockDelta::Text(thoughts) => { diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 19009013bf84ad9751e9ed0de2d3338b279a258e..b84b19b038905ba9f3d9a0637c770acc95687976 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -3,7 +3,7 @@ use anthropic::AnthropicModelMode; use anyhow::{Context as _, Result, anyhow}; use chrono::{DateTime, Utc}; use client::{Client, UserStore, zed_urls}; -use cloud_api_types::Plan; +use cloud_api_types::{OrganizationId, Plan}; use cloud_llm_client::{ CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus, @@ -122,15 +122,25 @@ impl State { recommended_models: Vec::new(), _fetch_models_task: cx.spawn(async move |this, cx| { maybe!(async move { - let (client, llm_api_token) = this - .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?; + let (client, llm_api_token, organization_id) = + this.read_with(cx, |this, cx| { + ( + client.clone(), + this.llm_api_token.clone(), + this.user_store + .read(cx) + .current_organization() + .map(|o| o.id.clone()), + ) + })?; while current_user.borrow().is_none() { current_user.next().await; } let response = - Self::fetch_models(client.clone(), llm_api_token.clone()).await?; + Self::fetch_models(client.clone(), llm_api_token.clone(), organization_id) + .await?; this.update(cx, |this, cx| this.update_models(response, cx))?; anyhow::Ok(()) }) @@ -146,9 +156,17 @@ impl State { move |this, _listener, _event, cx| { let client = this.client.clone(); let llm_api_token = this.llm_api_token.clone(); + let organization_id = this + .user_store + .read(cx) + .current_organization() + .map(|o| o.id.clone()); cx.spawn(async move |this, cx| { - llm_api_token.refresh(&client).await?; - let response = Self::fetch_models(client, llm_api_token).await?; + llm_api_token + .refresh(&client, organization_id.clone()) + .await?; + let response = + Self::fetch_models(client, llm_api_token, organization_id).await?; this.update(cx, |this, cx| { this.update_models(response, cx); }) @@ -209,9 +227,10 @@ impl State { async fn fetch_models( client: Arc, llm_api_token: LlmApiToken, + organization_id: Option, ) -> Result { let http_client = &client.http_client(); - let token = llm_api_token.acquire(&client).await?; + let token = llm_api_token.acquire(&client, organization_id).await?; let request = http_client::Request::builder() .method(Method::GET) @@ -273,11 +292,13 @@ impl CloudLanguageModelProvider { &self, model: Arc, llm_api_token: LlmApiToken, + user_store: Entity, ) -> Arc { Arc::new(CloudLanguageModel { id: LanguageModelId(SharedString::from(model.id.0.clone())), model, llm_api_token, + user_store, client: self.client.clone(), request_limiter: RateLimiter::new(4), }) @@ -306,36 +327,46 @@ impl LanguageModelProvider for CloudLanguageModelProvider { } fn default_model(&self, cx: &App) -> Option> { - let default_model = self.state.read(cx).default_model.clone()?; - let llm_api_token = self.state.read(cx).llm_api_token.clone(); - Some(self.create_language_model(default_model, llm_api_token)) + let state = self.state.read(cx); + let default_model = state.default_model.clone()?; + let llm_api_token = state.llm_api_token.clone(); + let user_store = state.user_store.clone(); + Some(self.create_language_model(default_model, llm_api_token, user_store)) } fn default_fast_model(&self, cx: &App) -> Option> { - let default_fast_model = self.state.read(cx).default_fast_model.clone()?; - let llm_api_token = self.state.read(cx).llm_api_token.clone(); - Some(self.create_language_model(default_fast_model, llm_api_token)) + let state = self.state.read(cx); + let default_fast_model = state.default_fast_model.clone()?; + let llm_api_token = state.llm_api_token.clone(); + let user_store = state.user_store.clone(); + Some(self.create_language_model(default_fast_model, llm_api_token, user_store)) } fn recommended_models(&self, cx: &App) -> Vec> { - let llm_api_token = self.state.read(cx).llm_api_token.clone(); - self.state - .read(cx) + let state = self.state.read(cx); + let llm_api_token = state.llm_api_token.clone(); + let user_store = state.user_store.clone(); + state .recommended_models .iter() .cloned() - .map(|model| self.create_language_model(model, llm_api_token.clone())) + .map(|model| { + self.create_language_model(model, llm_api_token.clone(), user_store.clone()) + }) .collect() } fn provided_models(&self, cx: &App) -> Vec> { - let llm_api_token = self.state.read(cx).llm_api_token.clone(); - self.state - .read(cx) + let state = self.state.read(cx); + let llm_api_token = state.llm_api_token.clone(); + let user_store = state.user_store.clone(); + state .models .iter() .cloned() - .map(|model| self.create_language_model(model, llm_api_token.clone())) + .map(|model| { + self.create_language_model(model, llm_api_token.clone(), user_store.clone()) + }) .collect() } @@ -367,6 +398,7 @@ pub struct CloudLanguageModel { id: LanguageModelId, model: Arc, llm_api_token: LlmApiToken, + user_store: Entity, client: Arc, request_limiter: RateLimiter, } @@ -380,12 +412,15 @@ impl CloudLanguageModel { async fn perform_llm_completion( client: Arc, llm_api_token: LlmApiToken, + organization_id: Option, app_version: Option, body: CompletionBody, ) -> Result { let http_client = &client.http_client(); - let mut token = llm_api_token.acquire(&client).await?; + let mut token = llm_api_token + .acquire(&client, organization_id.clone()) + .await?; let mut refreshed_token = false; loop { @@ -416,7 +451,9 @@ impl CloudLanguageModel { } if !refreshed_token && response.needs_llm_token_refresh() { - token = llm_api_token.refresh(&client).await?; + token = llm_api_token + .refresh(&client, organization_id.clone()) + .await?; refreshed_token = true; continue; } @@ -670,12 +707,17 @@ impl LanguageModel for CloudLanguageModel { cloud_llm_client::LanguageModelProvider::Google => { let client = self.client.clone(); let llm_api_token = self.llm_api_token.clone(); + let organization_id = self + .user_store + .read(cx) + .current_organization() + .map(|o| o.id.clone()); let model_id = self.model.id.to_string(); let generate_content_request = into_google(request, model_id.clone(), GoogleModelMode::Default); async move { let http_client = &client.http_client(); - let token = llm_api_token.acquire(&client).await?; + let token = llm_api_token.acquire(&client, organization_id).await?; let request_body = CountTokensBody { provider: cloud_llm_client::LanguageModelProvider::Google, @@ -736,6 +778,13 @@ impl LanguageModel for CloudLanguageModel { let prompt_id = request.prompt_id.clone(); let intent = request.intent; let app_version = Some(cx.update(|cx| AppVersion::global(cx))); + let user_store = self.user_store.clone(); + let organization_id = cx.update(|cx| { + user_store + .read(cx) + .current_organization() + .map(|o| o.id.clone()) + }); let thinking_allowed = request.thinking_allowed; let enable_thinking = thinking_allowed && self.model.supports_thinking; let provider_name = provider_name(&self.model.provider); @@ -767,6 +816,7 @@ impl LanguageModel for CloudLanguageModel { let client = self.client.clone(); let llm_api_token = self.llm_api_token.clone(); + let organization_id = organization_id.clone(); let future = self.request_limiter.stream(async move { let PerformLlmCompletionResponse { response, @@ -774,6 +824,7 @@ impl LanguageModel for CloudLanguageModel { } = Self::perform_llm_completion( client.clone(), llm_api_token, + organization_id, app_version, CompletionBody { thread_id, @@ -803,6 +854,7 @@ impl LanguageModel for CloudLanguageModel { cloud_llm_client::LanguageModelProvider::OpenAi => { let client = self.client.clone(); let llm_api_token = self.llm_api_token.clone(); + let organization_id = organization_id.clone(); let effort = request .thinking_effort .as_ref() @@ -828,6 +880,7 @@ impl LanguageModel for CloudLanguageModel { } = Self::perform_llm_completion( client.clone(), llm_api_token, + organization_id, app_version, CompletionBody { thread_id, @@ -861,6 +914,7 @@ impl LanguageModel for CloudLanguageModel { None, ); let llm_api_token = self.llm_api_token.clone(); + let organization_id = organization_id.clone(); let future = self.request_limiter.stream(async move { let PerformLlmCompletionResponse { response, @@ -868,6 +922,7 @@ impl LanguageModel for CloudLanguageModel { } = Self::perform_llm_completion( client.clone(), llm_api_token, + organization_id, app_version, CompletionBody { thread_id, @@ -902,6 +957,7 @@ impl LanguageModel for CloudLanguageModel { } = Self::perform_llm_completion( client.clone(), llm_api_token, + organization_id, app_version, CompletionBody { thread_id, diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 4363430f865de63ed5fec0d6b40b085d9413fc2a..7d714cd93a2a93dbb9fd02ec4d2b95149bb43330 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -246,6 +246,10 @@ impl LanguageModel for CopilotChatLanguageModel { self.model.supports_tools() } + fn supports_streaming_tools(&self) -> bool { + true + } + fn supports_images(&self) -> bool { self.model.supports_vision() } @@ -455,6 +459,23 @@ pub fn map_to_language_model_completion_events( entry.thought_signature = Some(thought_signature); } } + + if !entry.id.is_empty() && !entry.name.is_empty() { + if let Ok(input) = serde_json::from_str::( + &partial_json_fixer::fix_json(&entry.arguments), + ) { + events.push(Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: entry.id.clone().into(), + name: entry.name.as_str().into(), + is_input_complete: false, + input, + raw_input: entry.arguments.clone(), + thought_signature: entry.thought_signature.clone(), + }, + ))); + } + } } if let Some(usage) = event.usage { diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index 2a9f7322b1fb5d3d1e6713c5a084b83dc2b01ce2..0bf86ef15c91b16dbc496ff732b087fedd0da0a9 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -246,6 +246,10 @@ impl LanguageModel for DeepSeekLanguageModel { true } + fn supports_streaming_tools(&self) -> bool { + true + } + fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool { true } @@ -469,6 +473,23 @@ impl DeepSeekEventMapper { entry.arguments.push_str(&arguments); } } + + if !entry.id.is_empty() && !entry.name.is_empty() { + if let Ok(input) = serde_json::from_str::( + &partial_json_fixer::fix_json(&entry.arguments), + ) { + events.push(Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: entry.id.clone().into(), + name: entry.name.as_str().into(), + is_input_complete: false, + input, + raw_input: entry.arguments.clone(), + thought_signature: None, + }, + ))); + } + } } } diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index 02d46dcaa7ce7acc76d85c93cad610a7d2489bf0..6af66f4e9a9d257b385c84a6c0c6d989f04c013f 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -280,6 +280,10 @@ impl LanguageModel for MistralLanguageModel { self.model.supports_tools() } + fn supports_streaming_tools(&self) -> bool { + true + } + fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool { self.model.supports_tools() } @@ -629,6 +633,23 @@ impl MistralEventMapper { entry.arguments.push_str(&arguments); } } + + if !entry.id.is_empty() && !entry.name.is_empty() { + if let Ok(input) = serde_json::from_str::( + &partial_json_fixer::fix_json(&entry.arguments), + ) { + events.push(Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: entry.id.clone().into(), + name: entry.name.as_str().into(), + is_input_complete: false, + input, + raw_input: entry.arguments.clone(), + thought_signature: None, + }, + ))); + } + } } } diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 7fb65df0a534c7600f7315fd85d7adda0d66314a..57b3a6b20a9712e7c4d99b3ccfc48719e632da9d 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -328,6 +328,10 @@ impl LanguageModel for OpenAiLanguageModel { } } + fn supports_streaming_tools(&self) -> bool { + true + } + fn supports_thinking(&self) -> bool { self.model.reasoning_effort().is_some() } @@ -824,6 +828,23 @@ impl OpenAiEventMapper { entry.arguments.push_str(&arguments); } } + + if !entry.id.is_empty() && !entry.name.is_empty() { + if let Ok(input) = serde_json::from_str::( + &partial_json_fixer::fix_json(&entry.arguments), + ) { + events.push(Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: entry.id.clone().into(), + name: entry.name.as_str().into(), + is_input_complete: false, + input, + raw_input: entry.arguments.clone(), + thought_signature: None, + }, + ))); + } + } } } } @@ -954,6 +975,20 @@ impl OpenAiResponseEventMapper { ResponsesStreamEvent::FunctionCallArgumentsDelta { item_id, delta, .. } => { if let Some(entry) = self.function_calls_by_item.get_mut(&item_id) { entry.arguments.push_str(&delta); + if let Ok(input) = serde_json::from_str::( + &partial_json_fixer::fix_json(&entry.arguments), + ) { + return vec![Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: LanguageModelToolUseId::from(entry.call_id.clone()), + name: entry.name.clone(), + is_input_complete: false, + input, + raw_input: entry.arguments.clone(), + thought_signature: None, + }, + ))]; + } } Vec::new() } @@ -1670,19 +1705,30 @@ mod tests { ]; let mapped = map_response_events(events); + assert_eq!(mapped.len(), 3); + // First event is the partial tool use (from FunctionCallArgumentsDelta) assert!(matches!( mapped[0], + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + is_input_complete: false, + .. + }) + )); + // Second event is the complete tool use (from FunctionCallArgumentsDone) + assert!(matches!( + mapped[1], LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref id, ref name, ref raw_input, + is_input_complete: true, .. }) if id.to_string() == "call_123" && name.as_ref() == "get_weather" && raw_input == "{\"city\":\"Boston\"}" )); assert!(matches!( - mapped[1], + mapped[2], LanguageModelCompletionEvent::Stop(StopReason::ToolUse) )); } @@ -1878,13 +1924,27 @@ mod tests { ]; let mapped = map_response_events(events); + assert_eq!(mapped.len(), 3); + // First event is the partial tool use (from FunctionCallArgumentsDelta) assert!(matches!( mapped[0], - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. }) - if raw_input == "{\"city\":\"Boston\"}" + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + is_input_complete: false, + .. + }) )); + // Second event is the complete tool use (from the Incomplete response output) assert!(matches!( mapped[1], + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + ref raw_input, + is_input_complete: true, + .. + }) + if raw_input == "{\"city\":\"Boston\"}" + )); + assert!(matches!( + mapped[2], LanguageModelCompletionEvent::Stop(StopReason::MaxTokens) )); } @@ -1976,4 +2036,80 @@ mod tests { LanguageModelCompletionEvent::Stop(StopReason::ToolUse) )); } + + #[test] + fn responses_stream_emits_partial_tool_use_events() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: ResponseOutputItem::FunctionCall(ResponseFunctionToolCall { + id: Some("item_fn".to_string()), + status: Some("in_progress".to_string()), + name: Some("get_weather".to_string()), + call_id: Some("call_abc".to_string()), + arguments: String::new(), + }), + }, + ResponsesStreamEvent::FunctionCallArgumentsDelta { + item_id: "item_fn".into(), + output_index: 0, + delta: "{\"city\":\"Bos".into(), + sequence_number: None, + }, + ResponsesStreamEvent::FunctionCallArgumentsDelta { + item_id: "item_fn".into(), + output_index: 0, + delta: "ton\"}".into(), + sequence_number: None, + }, + ResponsesStreamEvent::FunctionCallArgumentsDone { + item_id: "item_fn".into(), + output_index: 0, + arguments: "{\"city\":\"Boston\"}".into(), + sequence_number: None, + }, + ResponsesStreamEvent::Completed { + response: ResponseSummary::default(), + }, + ]; + + let mapped = map_response_events(events); + // Two partial events + one complete event + Stop + assert!(mapped.len() >= 3); + + // The last complete ToolUse event should have is_input_complete: true + let complete_tool_use = mapped.iter().find(|e| { + matches!( + e, + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + is_input_complete: true, + .. + }) + ) + }); + assert!( + complete_tool_use.is_some(), + "should have a complete tool use event" + ); + + // All ToolUse events before the final one should have is_input_complete: false + let tool_uses: Vec<_> = mapped + .iter() + .filter(|e| matches!(e, LanguageModelCompletionEvent::ToolUse(_))) + .collect(); + assert!( + tool_uses.len() >= 2, + "should have at least one partial and one complete event" + ); + + let last = tool_uses.last().unwrap(); + assert!(matches!( + last, + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + is_input_complete: true, + .. + }) + )); + } } diff --git a/crates/language_models/src/provider/open_ai_compatible.rs b/crates/language_models/src/provider/open_ai_compatible.rs index d47ea26c594ab0abb5c859ed549d43e0ed3f859b..b478bc843c05e01d428561d9c255ef0d2ca97148 100644 --- a/crates/language_models/src/provider/open_ai_compatible.rs +++ b/crates/language_models/src/provider/open_ai_compatible.rs @@ -319,6 +319,10 @@ impl LanguageModel for OpenAiCompatibleLanguageModel { } } + fn supports_streaming_tools(&self) -> bool { + true + } + fn supports_split_token_display(&self) -> bool { true } diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index a044c7c25d7858f69dc8c4ac9fa0c8bda73f6e91..e0e56bc1beadd8309a4c1b3c7626efa99c1c6473 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -1,4 +1,4 @@ -use anyhow::{Result, anyhow}; +use anyhow::Result; use collections::HashMap; use futures::{FutureExt, Stream, StreamExt, future::BoxFuture}; use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task}; @@ -314,6 +314,10 @@ impl LanguageModel for OpenRouterLanguageModel { self.model.supports_tool_calls() } + fn supports_streaming_tools(&self) -> bool { + true + } + fn supports_thinking(&self) -> bool { matches!(self.model.mode, OpenRouterModelMode::Thinking { .. }) } @@ -591,14 +595,21 @@ impl OpenRouterEventMapper { &mut self, event: ResponseStreamEvent, ) -> Vec> { + let mut events = Vec::new(); + + if let Some(usage) = event.usage { + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: usage.prompt_tokens, + output_tokens: usage.completion_tokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }))); + } + let Some(choice) = event.choices.first() else { - return vec![Err(LanguageModelCompletionError::from(anyhow!( - "Response contained no choices" - )))]; + return events; }; - let mut events = Vec::new(); - if let Some(details) = choice.delta.reasoning_details.clone() { // Emit reasoning_details immediately events.push(Ok(LanguageModelCompletionEvent::ReasoningDetails( @@ -643,16 +654,24 @@ impl OpenRouterEventMapper { entry.thought_signature = Some(signature); } } - } - } - if let Some(usage) = event.usage { - events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { - input_tokens: usage.prompt_tokens, - output_tokens: usage.completion_tokens, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - }))); + if !entry.id.is_empty() && !entry.name.is_empty() { + if let Ok(input) = serde_json::from_str::( + &partial_json_fixer::fix_json(&entry.arguments), + ) { + events.push(Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: entry.id.clone().into(), + name: entry.name.as_str().into(), + is_input_complete: false, + input, + raw_input: entry.arguments.clone(), + thought_signature: entry.thought_signature.clone(), + }, + ))); + } + } + } } match choice.finish_reason.as_deref() { @@ -891,7 +910,7 @@ mod tests { ResponseStreamEvent { id: Some("response_123".into()), created: 1234567890, - model: "google/gemini-3-pro-preview".into(), + model: "google/gemini-3.1-pro-preview".into(), choices: vec![ChoiceDelta { index: 0, delta: ResponseMessageDelta { @@ -916,7 +935,7 @@ mod tests { ResponseStreamEvent { id: Some("response_123".into()), created: 1234567890, - model: "google/gemini-3-pro-preview".into(), + model: "google/gemini-3.1-pro-preview".into(), choices: vec![ChoiceDelta { index: 0, delta: ResponseMessageDelta { @@ -942,7 +961,7 @@ mod tests { ResponseStreamEvent { id: Some("response_123".into()), created: 1234567890, - model: "google/gemini-3-pro-preview".into(), + model: "google/gemini-3.1-pro-preview".into(), choices: vec![ChoiceDelta { index: 0, delta: ResponseMessageDelta { @@ -969,7 +988,7 @@ mod tests { ResponseStreamEvent { id: Some("response_123".into()), created: 1234567890, - model: "google/gemini-3-pro-preview".into(), + model: "google/gemini-3.1-pro-preview".into(), choices: vec![ChoiceDelta { index: 0, delta: ResponseMessageDelta { @@ -1055,6 +1074,32 @@ mod tests { ); } + #[gpui::test] + async fn test_usage_only_chunk_with_empty_choices_does_not_error() { + let mut mapper = OpenRouterEventMapper::new(); + + let events = mapper.map_event(ResponseStreamEvent { + id: Some("response_123".into()), + created: 1234567890, + model: "google/gemini-3-flash-preview".into(), + choices: Vec::new(), + usage: Some(open_router::Usage { + prompt_tokens: 12, + completion_tokens: 7, + total_tokens: 19, + }), + }); + + assert_eq!(events.len(), 1); + match events.into_iter().next().unwrap() { + Ok(LanguageModelCompletionEvent::UsageUpdate(usage)) => { + assert_eq!(usage.input_tokens, 12); + assert_eq!(usage.output_tokens, 7); + } + other => panic!("Expected usage update event, got: {other:?}"), + } + } + #[gpui::test] async fn test_agent_prevents_empty_reasoning_details_overwrite() { // This test verifies that the agent layer prevents empty reasoning_details diff --git a/crates/language_models/src/provider/vercel.rs b/crates/language_models/src/provider/vercel.rs index 3b324e46927f5864d83a5e4b74c46f5e39e8ab3a..b71da5b7db05710ee30115ab54379c9ee4e4c750 100644 --- a/crates/language_models/src/provider/vercel.rs +++ b/crates/language_models/src/provider/vercel.rs @@ -248,6 +248,10 @@ impl LanguageModel for VercelLanguageModel { true } + fn supports_streaming_tools(&self) -> bool { + true + } + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { match choice { LanguageModelToolChoice::Auto diff --git a/crates/language_models/src/provider/vercel_ai_gateway.rs b/crates/language_models/src/provider/vercel_ai_gateway.rs index 69c54e624b9e7289abaefbe7ab654d73df385b62..78f900de0c94fd3bbbff3962e92d1a8cb9f3e118 100644 --- a/crates/language_models/src/provider/vercel_ai_gateway.rs +++ b/crates/language_models/src/provider/vercel_ai_gateway.rs @@ -385,6 +385,10 @@ impl LanguageModel for VercelAiGatewayLanguageModel { } } + fn supports_streaming_tools(&self) -> bool { + true + } + fn supports_split_token_display(&self) -> bool { true } diff --git a/crates/language_models/src/provider/x_ai.rs b/crates/language_models/src/provider/x_ai.rs index 06564224dea9621d594e5cf3f4a84093f1620446..f1f8bb658f04a91341951d1602af04f858af7bd3 100644 --- a/crates/language_models/src/provider/x_ai.rs +++ b/crates/language_models/src/provider/x_ai.rs @@ -257,6 +257,10 @@ impl LanguageModel for XAiLanguageModel { self.model.supports_images() } + fn supports_streaming_tools(&self) -> bool { + true + } + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { match choice { LanguageModelToolChoice::Auto @@ -265,8 +269,7 @@ impl LanguageModel for XAiLanguageModel { } } fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { - let model_id = self.model.id().trim().to_lowercase(); - if model_id.eq(x_ai::Model::Grok4.id()) || model_id.eq(x_ai::Model::GrokCodeFast1.id()) { + if self.model.requires_json_schema_subset() { LanguageModelToolSchemaFormat::JsonSchemaSubset } else { LanguageModelToolSchemaFormat::JsonSchema diff --git a/crates/languages/src/cpp/config.toml b/crates/languages/src/cpp/config.toml index 10c36a6ded1e1f3a1204d1e15af47fee78b8e049..e2608a8ce5f17cb648e4f86dc27da60ed8bdd2ae 100644 --- a/crates/languages/src/cpp/config.toml +++ b/crates/languages/src/cpp/config.toml @@ -1,6 +1,6 @@ name = "C++" grammar = "cpp" -path_suffixes = ["cc", "hh", "cpp", "h", "hpp", "cxx", "hxx", "c++", "h++", "ipp", "inl", "ino", "ixx", "cu", "cuh", "C", "H"] +path_suffixes = ["cc", "hh", "cpp", "cppm", "h", "hpp", "cxx", "hxx", "c++", "h++", "ipp", "inl", "ino", "ixx", "cu", "cuh", "C", "H"] line_comments = ["// ", "/// ", "//! "] first_line_pattern = '^//.*-\*-\s*C\+\+\s*-\*-' decrease_indent_patterns = [ diff --git a/crates/languages/src/go.rs b/crates/languages/src/go.rs index 581159503ce8aaf642b62789cb895858f1f963c2..5942a51f2a481b66cc8ba46072bd28c8285cbc07 100644 --- a/crates/languages/src/go.rs +++ b/crates/languages/src/go.rs @@ -11,6 +11,7 @@ use lsp::{LanguageServerBinary, LanguageServerName}; use project::lsp_store::language_server_settings; use regex::Regex; use serde_json::{Value, json}; +use settings::SemanticTokenRules; use smol::fs; use std::{ borrow::Cow, @@ -27,6 +28,16 @@ use std::{ use task::{TaskTemplate, TaskTemplates, TaskVariables, VariableName}; use util::{ResultExt, fs::remove_matching, maybe, merge_json_value_into}; +use crate::LanguageDir; + +pub(crate) fn semantic_token_rules() -> SemanticTokenRules { + let content = LanguageDir::get("go/semantic_token_rules.json") + .expect("missing go/semantic_token_rules.json"); + let json = std::str::from_utf8(&content.data).expect("invalid utf-8 in semantic_token_rules"); + settings::parse_json_with_comments::(json) + .expect("failed to parse go semantic_token_rules.json") +} + fn server_binary_arguments() -> Vec { vec!["-mode=stdio".into()] } diff --git a/crates/languages/src/go/semantic_token_rules.json b/crates/languages/src/go/semantic_token_rules.json new file mode 100644 index 0000000000000000000000000000000000000000..627a5c5f187b47918e6a56069c5ed1bda8583aa6 --- /dev/null +++ b/crates/languages/src/go/semantic_token_rules.json @@ -0,0 +1,7 @@ +[ + { + "token_type": "variable", + "token_modifiers": ["readonly"], + "style": ["constant"] + } +] diff --git a/crates/languages/src/lib.rs b/crates/languages/src/lib.rs index c31911f372261db47f689d29de9c60c0f9cad56e..275b8c58ecde831c8f89ae688dc236583b135c07 100644 --- a/crates/languages/src/lib.rs +++ b/crates/languages/src/lib.rs @@ -141,6 +141,7 @@ pub fn init(languages: Arc, fs: Arc, node: NodeRuntime name: "go", adapters: vec![go_lsp_adapter.clone()], context: Some(go_context_provider.clone()), + semantic_token_rules: Some(go::semantic_token_rules()), ..Default::default() }, LanguageInfo { @@ -179,7 +180,13 @@ pub fn init(languages: Arc, fs: Arc, node: NodeRuntime }, LanguageInfo { name: "python", - adapters: vec![basedpyright_lsp_adapter, ruff_lsp_adapter], + adapters: vec![ + basedpyright_lsp_adapter, + ruff_lsp_adapter, + ty_lsp_adapter, + py_lsp_adapter, + python_lsp_adapter, + ], context: Some(python_context_provider), toolchain: Some(python_toolchain_provider), manifest_name: Some(SharedString::new_static("pyproject.toml").into()), @@ -281,9 +288,6 @@ pub fn init(languages: Arc, fs: Arc, node: NodeRuntime typescript_lsp_adapter, ); - languages.register_available_lsp_adapter(python_lsp_adapter.name(), python_lsp_adapter); - languages.register_available_lsp_adapter(py_lsp_adapter.name(), py_lsp_adapter); - languages.register_available_lsp_adapter(ty_lsp_adapter.name(), ty_lsp_adapter); // Register Tailwind for the existing languages that should have it by default. // // This can be driven by the `language_servers` setting once we have a way for diff --git a/crates/lsp/Cargo.toml b/crates/lsp/Cargo.toml index 9533ddb600b18213de4d6e50599c62aa182b9b8a..2c48575a648a9eba12b16ce8edb2cf959d7cc8b3 100644 --- a/crates/lsp/Cargo.toml +++ b/crates/lsp/Cargo.toml @@ -13,12 +13,13 @@ path = "src/lsp.rs" doctest = false [features] -test-support = ["async-pipe"] +test-support = ["async-pipe", "gpui_util"] [dependencies] anyhow.workspace = true async-pipe = { workspace = true, optional = true } collections.workspace = true +gpui_util = { workspace = true, optional = true } futures.workspace = true gpui.workspace = true log.workspace = true @@ -34,6 +35,7 @@ release_channel.workspace = true [dev-dependencies] async-pipe.workspace = true +gpui_util.workspace = true ctor.workspace = true gpui = { workspace = true, features = ["test-support"] } semver.workspace = true diff --git a/crates/lsp/src/lsp.rs b/crates/lsp/src/lsp.rs index e552c21d701cefa8aa1f4b6e14e826892e3b25b6..2e2318065292ffdc2ac39b577afc7a264d36473d 100644 --- a/crates/lsp/src/lsp.rs +++ b/crates/lsp/src/lsp.rs @@ -1970,10 +1970,14 @@ impl FakeLanguageServer { let responded_tx = responded_tx.clone(); let executor = cx.background_executor().clone(); async move { + let _guard = gpui_util::defer({ + let responded_tx = responded_tx.clone(); + move || { + responded_tx.unbounded_send(()).ok(); + } + }); executor.simulate_random_delay().await; - let result = result.await; - responded_tx.unbounded_send(()).ok(); - result + result.await } }) .detach(); diff --git a/crates/migrator/src/migrations.rs b/crates/migrator/src/migrations.rs index d10116be6032486c92d9f27afcf922178463e151..ec33b6a53b3c598842aa29b6e2c31c08c7b11558 100644 --- a/crates/migrator/src/migrations.rs +++ b/crates/migrator/src/migrations.rs @@ -275,6 +275,12 @@ pub(crate) mod m_2025_12_15 { pub(crate) use settings::SETTINGS_PATTERNS; } +pub(crate) mod m_2025_01_27 { + mod settings; + + pub(crate) use settings::make_auto_indent_an_enum; +} + pub(crate) mod m_2026_02_02 { mod settings; diff --git a/crates/migrator/src/migrations/m_2025_01_27/settings.rs b/crates/migrator/src/migrations/m_2025_01_27/settings.rs new file mode 100644 index 0000000000000000000000000000000000000000..e8df2aa8aabed4daaae3e45e97532c1ce3557dfe --- /dev/null +++ b/crates/migrator/src/migrations/m_2025_01_27/settings.rs @@ -0,0 +1,27 @@ +use anyhow::Result; +use serde_json::Value; + +use crate::migrations::migrate_language_setting; + +pub fn make_auto_indent_an_enum(value: &mut Value) -> Result<()> { + migrate_language_setting(value, migrate_auto_indent) +} + +fn migrate_auto_indent(value: &mut Value, _path: &[&str]) -> Result<()> { + let Some(auto_indent) = value + .as_object_mut() + .and_then(|obj| obj.get_mut("auto_indent")) + else { + return Ok(()); + }; + + *auto_indent = match auto_indent { + Value::Bool(true) => Value::String("syntax_aware".to_string()), + Value::Bool(false) => Value::String("none".to_string()), + Value::String(s) if s == "syntax_aware" || s == "preserve_indent" || s == "none" => { + return Ok(()); + } + _ => anyhow::bail!("Expected auto_indent to be a boolean or valid enum value"), + }; + Ok(()) +} diff --git a/crates/migrator/src/migrator.rs b/crates/migrator/src/migrator.rs index 8b501020a559c74d81c5ad5b37e1adf60a964927..f208faf163aaf425127791f781d4569a737870ff 100644 --- a/crates/migrator/src/migrator.rs +++ b/crates/migrator/src/migrator.rs @@ -232,6 +232,7 @@ pub fn migrate_settings(text: &str) -> Result> { migrations::m_2025_12_15::SETTINGS_PATTERNS, &SETTINGS_QUERY_2025_12_15, ), + MigrationType::Json(migrations::m_2025_01_27::make_auto_indent_an_enum), MigrationType::Json( migrations::m_2026_02_02::move_edit_prediction_provider_to_edit_predictions, ), @@ -2606,6 +2607,91 @@ mod tests { ); } + #[test] + fn test_make_auto_indent_an_enum() { + // Empty settings should not change + assert_migrate_settings_with_migrations( + &[MigrationType::Json( + migrations::m_2025_01_27::make_auto_indent_an_enum, + )], + &r#"{ }"#.unindent(), + None, + ); + + // true should become "syntax_aware" + assert_migrate_settings_with_migrations( + &[MigrationType::Json( + migrations::m_2025_01_27::make_auto_indent_an_enum, + )], + &r#"{ + "auto_indent": true + }"# + .unindent(), + Some( + &r#"{ + "auto_indent": "syntax_aware" + }"# + .unindent(), + ), + ); + + // false should become "none" + assert_migrate_settings_with_migrations( + &[MigrationType::Json( + migrations::m_2025_01_27::make_auto_indent_an_enum, + )], + &r#"{ + "auto_indent": false + }"# + .unindent(), + Some( + &r#"{ + "auto_indent": "none" + }"# + .unindent(), + ), + ); + + // Already valid enum values should not change + assert_migrate_settings_with_migrations( + &[MigrationType::Json( + migrations::m_2025_01_27::make_auto_indent_an_enum, + )], + &r#"{ + "auto_indent": "preserve_indent" + }"# + .unindent(), + None, + ); + + // Should also work inside languages + assert_migrate_settings_with_migrations( + &[MigrationType::Json( + migrations::m_2025_01_27::make_auto_indent_an_enum, + )], + &r#"{ + "auto_indent": true, + "languages": { + "Python": { + "auto_indent": false + } + } + }"# + .unindent(), + Some( + &r#"{ + "auto_indent": "syntax_aware", + "languages": { + "Python": { + "auto_indent": "none" + } + } + }"# + .unindent(), + ), + ); + } + #[test] fn test_move_edit_prediction_provider_to_edit_predictions() { assert_migrate_settings_with_migrations( diff --git a/crates/miniprofiler_ui/src/miniprofiler_ui.rs b/crates/miniprofiler_ui/src/miniprofiler_ui.rs index 1f95dc3d230e7c50b4960560a96c9007fd77aab8..9ae0a33471d31f32852b4b376bbc71ff0911c60b 100644 --- a/crates/miniprofiler_ui/src/miniprofiler_ui.rs +++ b/crates/miniprofiler_ui/src/miniprofiler_ui.rs @@ -464,7 +464,7 @@ impl Render for ProfilerWindow { let scroll_offset = self.scroll_handle.offset(); let max_offset = self.scroll_handle.max_offset(); - self.autoscroll = -scroll_offset.y >= (max_offset.height - px(24.)); + self.autoscroll = -scroll_offset.y >= (max_offset.y - px(24.)); if self.autoscroll { self.scroll_handle.scroll_to_bottom(); } @@ -544,7 +544,7 @@ impl Render for ProfilerWindow { let path = cx.prompt_for_new_path( &active_path, - Some("performance_profile.miniprof"), + Some("performance_profile.miniprof.json"), ); cx.background_spawn(async move { diff --git a/crates/project/src/debugger/session.rs b/crates/project/src/debugger/session.rs index 2430d6c1024c61bb9af984c914df9c308c4cb64f..a6c3f52b17a4a6cf241aa49329f3f14f0b5cefbc 100644 --- a/crates/project/src/debugger/session.rs +++ b/crates/project/src/debugger/session.rs @@ -2645,10 +2645,40 @@ impl Session { self.fetch( command, move |this, variables, cx| { - let Some(variables) = variables.log_err() else { + let Some(mut variables) = variables.log_err() else { return; }; + if this.adapter.0.as_ref() == "Debugpy" { + for variable in variables.iter_mut() { + if variable.type_ == Some("str".into()) { + // reverse Python repr() escaping + let mut unescaped = String::with_capacity(variable.value.len()); + let mut chars = variable.value.chars(); + while let Some(c) = chars.next() { + if c != '\\' { + unescaped.push(c); + } else { + match chars.next() { + Some('\\') => unescaped.push('\\'), + Some('n') => unescaped.push('\n'), + Some('t') => unescaped.push('\t'), + Some('r') => unescaped.push('\r'), + Some('\'') => unescaped.push('\''), + Some('"') => unescaped.push('"'), + Some(c) => { + unescaped.push('\\'); + unescaped.push(c); + } + None => {} + } + } + } + variable.value = unescaped; + } + } + } + this.active_snapshot .variables .insert(variables_reference, variables); diff --git a/crates/project/src/git_store.rs b/crates/project/src/git_store.rs index ae776966a770ccadcffdbf9b140ed10d4871b317..eed16761974876247df2e5936f9db9fbdd8fafcc 100644 --- a/crates/project/src/git_store.rs +++ b/crates/project/src/git_store.rs @@ -6,6 +6,9 @@ pub mod pending_op; use crate::{ ProjectEnvironment, ProjectItem, ProjectPath, buffer_store::{BufferStore, BufferStoreEvent}, + trusted_worktrees::{ + PathTrust, TrustedWorktrees, TrustedWorktreesEvent, TrustedWorktreesStore, + }, worktree_store::{WorktreeStore, WorktreeStoreEvent}, }; use anyhow::{Context as _, Result, anyhow, bail}; @@ -21,7 +24,7 @@ use futures::{ mpsc, oneshot::{self, Canceled}, }, - future::{self, Shared}, + future::{self, BoxFuture, Shared}, stream::FuturesOrdered, }; use git::{ @@ -36,8 +39,8 @@ use git::{ }, stash::{GitStash, StashEntry}, status::{ - DiffTreeType, FileStatus, GitSummary, StatusCode, TrackedStatus, TreeDiff, TreeDiffStatus, - UnmergedStatus, UnmergedStatusCode, + self, DiffStat, DiffTreeType, FileStatus, GitSummary, StatusCode, TrackedStatus, TreeDiff, + TreeDiffStatus, UnmergedStatus, UnmergedStatusCode, }, }; use gpui::{ @@ -192,6 +195,7 @@ pub struct GitStoreCheckpoint { pub struct StatusEntry { pub repo_path: RepoPath, pub status: FileStatus, + pub diff_stat: Option, } impl StatusEntry { @@ -213,6 +217,8 @@ impl StatusEntry { repo_path: self.repo_path.to_proto(), simple_status, status: Some(status_to_proto(self.status)), + diff_stat_added: self.diff_stat.map(|ds| ds.added), + diff_stat_deleted: self.diff_stat.map(|ds| ds.deleted), } } } @@ -223,7 +229,15 @@ impl TryFrom for StatusEntry { fn try_from(value: proto::StatusEntry) -> Result { let repo_path = RepoPath::from_proto(&value.repo_path).context("invalid repo path")?; let status = status_from_proto(value.simple_status, value.status)?; - Ok(Self { repo_path, status }) + let diff_stat = match (value.diff_stat_added, value.diff_stat_deleted) { + (Some(added), Some(deleted)) => Some(DiffStat { added, deleted }), + _ => None, + }; + Ok(Self { + repo_path, + status, + diff_stat, + }) } } @@ -266,6 +280,11 @@ pub struct RepositorySnapshot { pub id: RepositoryId, pub statuses_by_path: SumTree, pub work_directory_abs_path: Arc, + /// The working directory of the original repository. For a normal + /// checkout this equals `work_directory_abs_path`. For a git worktree + /// checkout, this is the original repo's working directory — used to + /// anchor new worktree creation so they don't nest. + pub original_repo_abs_path: Arc, pub path_style: PathStyle, pub branch: Option, pub head_commit: Option, @@ -349,6 +368,7 @@ impl LocalRepositoryState { dot_git_abs_path: Arc, project_environment: WeakEntity, fs: Arc, + is_trusted: bool, cx: &mut AsyncApp, ) -> anyhow::Result { let environment = project_environment @@ -376,6 +396,7 @@ impl LocalRepositoryState { } }) .await?; + backend.set_trusted(is_trusted); Ok(LocalRepositoryState { backend, environment: Arc::new(environment), @@ -490,11 +511,15 @@ impl GitStore { state: GitStoreState, cx: &mut Context, ) -> Self { - let _subscriptions = vec![ + let mut _subscriptions = vec![ cx.subscribe(&worktree_store, Self::on_worktree_store_event), cx.subscribe(&buffer_store, Self::on_buffer_store_event), ]; + if let Some(trusted_worktrees) = TrustedWorktrees::try_get_global(cx) { + _subscriptions.push(cx.subscribe(&trusted_worktrees, Self::on_trusted_worktrees_event)); + } + GitStore { state, buffer_store, @@ -541,7 +566,6 @@ impl GitStore { client.add_entity_request_handler(Self::handle_askpass); client.add_entity_request_handler(Self::handle_check_for_pushed_commits); client.add_entity_request_handler(Self::handle_git_diff); - client.add_entity_request_handler(Self::handle_git_diff_stat); client.add_entity_request_handler(Self::handle_tree_diff); client.add_entity_request_handler(Self::handle_get_blob_content); client.add_entity_request_handler(Self::handle_open_unstaged_diff); @@ -1505,19 +1529,30 @@ impl GitStore { new_work_directory_abs_path: Some(work_directory_abs_path), dot_git_abs_path: Some(dot_git_abs_path), repository_dir_abs_path: Some(_repository_dir_abs_path), - common_dir_abs_path: Some(_common_dir_abs_path), + common_dir_abs_path: Some(common_dir_abs_path), .. } = update { + let original_repo_abs_path: Arc = + git::repository::original_repo_path_from_common_dir(common_dir_abs_path).into(); let id = RepositoryId(next_repository_id.fetch_add(1, atomic::Ordering::Release)); + let is_trusted = TrustedWorktrees::try_get_global(cx) + .map(|trusted_worktrees| { + trusted_worktrees.update(cx, |trusted_worktrees, cx| { + trusted_worktrees.can_trust(&self.worktree_store, worktree_id, cx) + }) + }) + .unwrap_or(false); let git_store = cx.weak_entity(); let repo = cx.new(|cx| { let mut repo = Repository::local( id, work_directory_abs_path.clone(), + original_repo_abs_path.clone(), dot_git_abs_path.clone(), project_environment.downgrade(), fs.clone(), + is_trusted, git_store, cx, ); @@ -1558,6 +1593,39 @@ impl GitStore { } } + fn on_trusted_worktrees_event( + &mut self, + _: Entity, + event: &TrustedWorktreesEvent, + cx: &mut Context, + ) { + if !matches!(self.state, GitStoreState::Local { .. }) { + return; + } + + let (is_trusted, event_paths) = match event { + TrustedWorktreesEvent::Trusted(_, trusted_paths) => (true, trusted_paths), + TrustedWorktreesEvent::Restricted(_, restricted_paths) => (false, restricted_paths), + }; + + for (repo_id, worktree_ids) in &self.worktree_ids { + if worktree_ids + .iter() + .any(|worktree_id| event_paths.contains(&PathTrust::Worktree(*worktree_id))) + { + if let Some(repo) = self.repositories.get(repo_id) { + let repository_state = repo.read(cx).repository_state.clone(); + cx.background_spawn(async move { + if let Ok(RepositoryState::Local(state)) = repository_state.await { + state.backend.set_trusted(is_trusted); + } + }) + .detach(); + } + } + } + } + fn on_buffer_store_event( &mut self, _: Entity, @@ -1840,6 +1908,11 @@ impl GitStore { let id = RepositoryId::from_proto(update.id); let client = this.upstream_client().context("no upstream client")?; + let original_repo_abs_path: Option> = update + .original_repo_abs_path + .as_deref() + .map(|p| Path::new(p).into()); + let mut repo_subscription = None; let repo = this.repositories.entry(id).or_insert_with(|| { let git_store = cx.weak_entity(); @@ -1847,6 +1920,7 @@ impl GitStore { Repository::remote( id, Path::new(&update.abs_path).into(), + original_repo_abs_path.clone(), path_style, ProjectId(update.project_id), client, @@ -2697,45 +2771,6 @@ impl GitStore { Ok(proto::GitDiffResponse { diff }) } - async fn handle_git_diff_stat( - this: Entity, - envelope: TypedEnvelope, - mut cx: AsyncApp, - ) -> Result { - let repository_id = RepositoryId::from_proto(envelope.payload.repository_id); - let repository_handle = Self::repository_for_request(&this, repository_id, &mut cx)?; - let diff_type = match envelope.payload.diff_type() { - proto::git_diff_stat::DiffType::HeadToIndex => DiffType::HeadToIndex, - proto::git_diff_stat::DiffType::HeadToWorktree => DiffType::HeadToWorktree, - proto::git_diff_stat::DiffType::MergeBase => { - let base_ref = envelope - .payload - .merge_base_ref - .ok_or_else(|| anyhow!("merge_base_ref is required for MergeBase diff type"))?; - DiffType::MergeBase { - base_ref: base_ref.into(), - } - } - }; - - let stats = repository_handle - .update(&mut cx, |repository_handle, cx| { - repository_handle.diff_stat(diff_type, cx) - }) - .await??; - - let entries = stats - .into_iter() - .map(|(path, stat)| proto::GitDiffStatEntry { - path: path.to_proto(), - added: stat.added, - deleted: stat.deleted, - }) - .collect(); - - Ok(proto::GitDiffStatResponse { entries }) - } - async fn handle_tree_diff( this: Entity, request: TypedEnvelope, @@ -3481,10 +3516,17 @@ impl RepositoryId { } impl RepositorySnapshot { - fn empty(id: RepositoryId, work_directory_abs_path: Arc, path_style: PathStyle) -> Self { + fn empty( + id: RepositoryId, + work_directory_abs_path: Arc, + original_repo_abs_path: Option>, + path_style: PathStyle, + ) -> Self { Self { id, statuses_by_path: Default::default(), + original_repo_abs_path: original_repo_abs_path + .unwrap_or_else(|| work_directory_abs_path.clone()), work_directory_abs_path, branch: None, head_commit: None, @@ -3528,6 +3570,9 @@ impl RepositorySnapshot { .collect(), remote_upstream_url: self.remote_upstream_url.clone(), remote_origin_url: self.remote_origin_url.clone(), + original_repo_abs_path: Some( + self.original_repo_abs_path.to_string_lossy().into_owned(), + ), } } @@ -3549,7 +3594,9 @@ impl RepositorySnapshot { current_new_entry = new_statuses.next(); } Ordering::Equal => { - if new_entry.status != old_entry.status { + if new_entry.status != old_entry.status + || new_entry.diff_stat != old_entry.diff_stat + { updated_statuses.push(new_entry.to_proto()); } current_old_entry = old_statuses.next(); @@ -3599,6 +3646,9 @@ impl RepositorySnapshot { .collect(), remote_upstream_url: self.remote_upstream_url.clone(), remote_origin_url: self.remote_origin_url.clone(), + original_repo_abs_path: Some( + self.original_repo_abs_path.to_string_lossy().into_owned(), + ), } } @@ -3616,6 +3666,12 @@ impl RepositorySnapshot { .cloned() } + pub fn diff_stat_for_path(&self, path: &RepoPath) -> Option { + self.statuses_by_path + .get(&PathKey(path.as_ref().clone()), ()) + .and_then(|entry| entry.diff_stat) + } + pub fn abs_path_to_repo_path(&self, abs_path: &Path) -> Option { Self::abs_path_to_repo_path_inner(&self.work_directory_abs_path, abs_path, self.path_style) } @@ -3736,6 +3792,13 @@ impl MergeDetails { } impl Repository { + pub fn is_trusted(&self) -> bool { + match self.repository_state.peek() { + Some(Ok(RepositoryState::Local(state))) => state.backend.is_trusted(), + _ => false, + } + } + pub fn snapshot(&self) -> RepositorySnapshot { self.snapshot.clone() } @@ -3757,14 +3820,20 @@ impl Repository { fn local( id: RepositoryId, work_directory_abs_path: Arc, + original_repo_abs_path: Arc, dot_git_abs_path: Arc, project_environment: WeakEntity, fs: Arc, + is_trusted: bool, git_store: WeakEntity, cx: &mut Context, ) -> Self { - let snapshot = - RepositorySnapshot::empty(id, work_directory_abs_path.clone(), PathStyle::local()); + let snapshot = RepositorySnapshot::empty( + id, + work_directory_abs_path.clone(), + Some(original_repo_abs_path), + PathStyle::local(), + ); let state = cx .spawn(async move |_, cx| { LocalRepositoryState::new( @@ -3772,6 +3841,7 @@ impl Repository { dot_git_abs_path, project_environment, fs, + is_trusted, cx, ) .await @@ -3818,13 +3888,19 @@ impl Repository { fn remote( id: RepositoryId, work_directory_abs_path: Arc, + original_repo_abs_path: Option>, path_style: PathStyle, project_id: ProjectId, client: AnyProtoClient, git_store: WeakEntity, cx: &mut Context, ) -> Self { - let snapshot = RepositorySnapshot::empty(id, work_directory_abs_path, path_style); + let snapshot = RepositorySnapshot::empty( + id, + work_directory_abs_path, + original_repo_abs_path, + path_style, + ); let repository_state = RemoteRepositoryState { project_id, client }; let job_sender = Self::spawn_remote_git_worker(repository_state.clone(), cx); let repository_state = Task::ready(Ok(RepositoryState::Remote(repository_state))).shared(); @@ -4096,6 +4172,10 @@ impl Repository { self.snapshot.status() } + pub fn diff_stat_for_path(&self, path: &RepoPath) -> Option { + self.snapshot.diff_stat_for_path(path) + } + pub fn cached_stash(&self) -> GitStash { self.snapshot.stash_entries.clone() } @@ -5650,6 +5730,24 @@ impl Repository { ) } + pub fn remove_worktree(&mut self, path: PathBuf, force: bool) -> oneshot::Receiver> { + self.send_job( + Some(format!("git worktree remove: {}", path.display()).into()), + move |repo, _cx| async move { + match repo { + RepositoryState::Local(LocalRepositoryState { backend, .. }) => { + backend.remove_worktree(path, force).await + } + RepositoryState::Remote(_) => { + anyhow::bail!( + "Removing worktrees on remote repositories is not yet supported" + ) + } + } + }, + ) + } + pub fn default_branch( &mut self, include_remote_name: bool, @@ -5769,63 +5867,6 @@ impl Repository { }) } - /// Fetches per-line diff statistics (additions/deletions) via `git diff --numstat`. - pub fn diff_stat( - &mut self, - diff_type: DiffType, - _cx: &App, - ) -> oneshot::Receiver< - Result>, - > { - let id = self.id; - self.send_job(None, move |repo, _cx| async move { - match repo { - RepositoryState::Local(LocalRepositoryState { backend, .. }) => { - backend.diff_stat(diff_type).await - } - RepositoryState::Remote(RemoteRepositoryState { project_id, client }) => { - let (proto_diff_type, merge_base_ref) = match &diff_type { - DiffType::HeadToIndex => { - (proto::git_diff_stat::DiffType::HeadToIndex.into(), None) - } - DiffType::HeadToWorktree => { - (proto::git_diff_stat::DiffType::HeadToWorktree.into(), None) - } - DiffType::MergeBase { base_ref } => ( - proto::git_diff_stat::DiffType::MergeBase.into(), - Some(base_ref.to_string()), - ), - }; - let response = client - .request(proto::GitDiffStat { - project_id: project_id.0, - repository_id: id.to_proto(), - diff_type: proto_diff_type, - merge_base_ref, - }) - .await?; - - let stats = response - .entries - .into_iter() - .filter_map(|entry| { - let path = RepoPath::from_proto(&entry.path).log_err()?; - Some(( - path, - git::status::DiffStat { - added: entry.added, - deleted: entry.deleted, - }, - )) - }) - .collect(); - - Ok(stats) - } - } - }) - } - pub fn create_branch( &mut self, branch_name: String, @@ -5988,6 +6029,10 @@ impl Repository { update: proto::UpdateRepository, cx: &mut Context, ) -> Result<()> { + if let Some(main_path) = &update.original_repo_abs_path { + self.snapshot.original_repo_abs_path = Path::new(main_path.as_str()).into(); + } + let new_branch = update.branch_summary.as_ref().map(proto_to_branch); let new_head_commit = update .head_commit_details @@ -6046,6 +6091,7 @@ impl Repository { cx.emit(RepositoryEvent::StatusesChanged); } self.snapshot.statuses_by_path.edit(edits, ()); + if update.is_last_update { self.snapshot.scan_id = update.scan_id; } @@ -6360,22 +6406,43 @@ impl Repository { return Ok(()); } + let has_head = prev_snapshot.head_commit.is_some(); + let stash_entries = backend.stash_entries().await?; let changed_path_statuses = cx .background_spawn(async move { let mut changed_paths = changed_paths.into_iter().flatten().collect::>(); - let statuses = backend - .status(&changed_paths.iter().cloned().collect::>()) - .await?; + let changed_paths_vec = changed_paths.iter().cloned().collect::>(); + + let status_task = backend.status(&changed_paths_vec); + let diff_stat_future = if has_head { + backend.diff_stat(&changed_paths_vec) + } else { + future::ready(Ok(status::GitDiffStat { + entries: Arc::default(), + })) + .boxed() + }; + + let (statuses, diff_stats) = + futures::future::try_join(status_task, diff_stat_future).await?; + + let diff_stats: HashMap = + HashMap::from_iter(diff_stats.entries.into_iter().cloned()); + let mut changed_path_statuses = Vec::new(); let prev_statuses = prev_snapshot.statuses_by_path.clone(); let mut cursor = prev_statuses.cursor::(()); for (repo_path, status) in &*statuses.entries { + let current_diff_stat = diff_stats.get(repo_path).copied(); + changed_paths.remove(repo_path); if cursor.seek_forward(&PathTarget::Path(repo_path), Bias::Left) - && cursor.item().is_some_and(|entry| entry.status == *status) + && cursor.item().is_some_and(|entry| { + entry.status == *status && entry.diff_stat == current_diff_stat + }) { continue; } @@ -6383,6 +6450,7 @@ impl Repository { changed_path_statuses.push(Edit::Insert(StatusEntry { repo_path: repo_path.clone(), status: *status, + diff_stat: current_diff_stat, })); } let mut cursor = prev_statuses.cursor::(()); @@ -6740,11 +6808,31 @@ async fn compute_snapshot( let mut events = Vec::new(); let branches = backend.branches().await?; let branch = branches.into_iter().find(|branch| branch.is_head); - let statuses = backend - .status(&[RepoPath::from_rel_path( + + // Useful when branch is None in detached head state + let head_commit = match backend.head_sha().await { + Some(head_sha) => backend.show(head_sha).await.log_err(), + None => None, + }; + + let diff_stat_future: BoxFuture<'_, Result> = if head_commit.is_some() { + backend.diff_stat(&[]) + } else { + future::ready(Ok(status::GitDiffStat { + entries: Arc::default(), + })) + .boxed() + }; + let (statuses, diff_stats) = futures::future::try_join( + backend.status(&[RepoPath::from_rel_path( &RelPath::new(".".as_ref(), PathStyle::local()).unwrap(), - )]) - .await?; + )]), + diff_stat_future, + ) + .await?; + + let diff_stat_map: HashMap<&RepoPath, DiffStat> = + diff_stats.entries.iter().map(|(p, s)| (p, *s)).collect(); let stash_entries = backend.stash_entries().await?; let mut conflicted_paths = Vec::new(); let statuses_by_path = SumTree::from_iter( @@ -6755,6 +6843,7 @@ async fn compute_snapshot( StatusEntry { repo_path: repo_path.clone(), status: *status, + diff_stat: diff_stat_map.get(repo_path).copied(), } }), (), @@ -6767,12 +6856,6 @@ async fn compute_snapshot( events.push(RepositoryEvent::StatusesChanged) } - // Useful when branch is None in detached head state - let head_commit = match backend.head_sha().await { - Some(head_sha) => backend.show(head_sha).await.log_err(), - None => None, - }; - if branch != prev_snapshot.branch || head_commit != prev_snapshot.head_commit { events.push(RepositoryEvent::BranchChanged); } @@ -6784,6 +6867,7 @@ async fn compute_snapshot( id, statuses_by_path, work_directory_abs_path, + original_repo_abs_path: prev_snapshot.original_repo_abs_path, path_style: prev_snapshot.path_style, scan_id: prev_snapshot.scan_id + 1, branch, diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 9e37802213dfb8df5cf63af5648044ae8ec65ecb..756f095511a9688678df013458710e69d720c52e 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -1942,6 +1942,11 @@ impl Project { } } + #[cfg(feature = "test-support")] + pub fn client_subscriptions(&self) -> &Vec { + &self.client_subscriptions + } + #[cfg(feature = "test-support")] pub async fn example( root_paths: impl IntoIterator, @@ -2741,6 +2746,7 @@ impl Project { } = &mut self.client_state { *sharing_has_stopped = true; + self.client_subscriptions.clear(); self.collaborators.clear(); self.worktree_store.update(cx, |store, cx| { store.disconnected_from_host(cx); diff --git a/crates/project/tests/integration/git_store.rs b/crates/project/tests/integration/git_store.rs index 802e0c072bf60466c32146d12cadd7c1e35c61ad..82e92bc4f1cfb606fb09d5efd5d341ed2951c067 100644 --- a/crates/project/tests/integration/git_store.rs +++ b/crates/project/tests/integration/git_store.rs @@ -1174,3 +1174,327 @@ mod git_traversal { pretty_assertions::assert_eq!(found_statuses, expected_statuses); } } + +mod git_worktrees { + use std::path::PathBuf; + + use fs::FakeFs; + use gpui::TestAppContext; + use serde_json::json; + use settings::SettingsStore; + use util::path; + + fn init_test(cx: &mut gpui::TestAppContext) { + zlog::init_test(); + + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + }); + } + + #[gpui::test] + async fn test_git_worktrees_list_and_create(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree( + path!("/root"), + json!({ + ".git": {}, + "file.txt": "content", + }), + ) + .await; + + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + cx.executor().run_until_parked(); + + let repository = project.read_with(cx, |project, cx| { + project.repositories(cx).values().next().unwrap().clone() + }); + + let worktrees = cx + .update(|cx| repository.update(cx, |repository, _| repository.worktrees())) + .await + .unwrap() + .unwrap(); + assert_eq!(worktrees.len(), 1); + assert_eq!(worktrees[0].path, PathBuf::from(path!("/root"))); + + let worktree_directory = PathBuf::from(path!("/root")); + cx.update(|cx| { + repository.update(cx, |repository, _| { + repository.create_worktree( + "feature-branch".to_string(), + worktree_directory.clone(), + Some("abc123".to_string()), + ) + }) + }) + .await + .unwrap() + .unwrap(); + + cx.executor().run_until_parked(); + + let worktrees = cx + .update(|cx| repository.update(cx, |repository, _| repository.worktrees())) + .await + .unwrap() + .unwrap(); + assert_eq!(worktrees.len(), 2); + assert_eq!(worktrees[0].path, PathBuf::from(path!("/root"))); + assert_eq!(worktrees[1].path, worktree_directory.join("feature-branch")); + assert_eq!(worktrees[1].ref_name.as_ref(), "refs/heads/feature-branch"); + assert_eq!(worktrees[1].sha.as_ref(), "abc123"); + + cx.update(|cx| { + repository.update(cx, |repository, _| { + repository.create_worktree( + "bugfix-branch".to_string(), + worktree_directory.clone(), + None, + ) + }) + }) + .await + .unwrap() + .unwrap(); + + cx.executor().run_until_parked(); + + // List worktrees — should now have main + two created + let worktrees = cx + .update(|cx| repository.update(cx, |repository, _| repository.worktrees())) + .await + .unwrap() + .unwrap(); + assert_eq!(worktrees.len(), 3); + + let feature_worktree = worktrees + .iter() + .find(|worktree| worktree.ref_name.as_ref() == "refs/heads/feature-branch") + .expect("should find feature-branch worktree"); + assert_eq!( + feature_worktree.path, + worktree_directory.join("feature-branch") + ); + + let bugfix_worktree = worktrees + .iter() + .find(|worktree| worktree.ref_name.as_ref() == "refs/heads/bugfix-branch") + .expect("should find bugfix-branch worktree"); + assert_eq!( + bugfix_worktree.path, + worktree_directory.join("bugfix-branch") + ); + assert_eq!(bugfix_worktree.sha.as_ref(), "fake-sha"); + } + + use crate::Project; +} + +mod trust_tests { + use collections::HashSet; + use fs::FakeFs; + use gpui::TestAppContext; + use project::trusted_worktrees::*; + + use serde_json::json; + use settings::SettingsStore; + use util::path; + + use crate::Project; + + fn init_test(cx: &mut TestAppContext) { + zlog::init_test(); + + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + }); + } + + #[gpui::test] + async fn test_repository_defaults_to_untrusted_without_trust_system(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree( + path!("/project"), + json!({ + ".git": {}, + "a.txt": "hello", + }), + ) + .await; + + // Create project without trust system — repos should default to untrusted. + let project = Project::test(fs, [path!("/project").as_ref()], cx).await; + cx.executor().run_until_parked(); + + let repository = project.read_with(cx, |project, cx| { + project.repositories(cx).values().next().unwrap().clone() + }); + + repository.read_with(cx, |repo, _| { + assert!( + !repo.is_trusted(), + "repository should default to untrusted when no trust system is initialized" + ); + }); + } + + #[gpui::test] + async fn test_multiple_repos_trust_with_single_worktree(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree( + path!("/project"), + json!({ + ".git": {}, + "a.txt": "hello", + "sub": { + ".git": {}, + "b.txt": "world", + }, + }), + ) + .await; + + cx.update(|cx| { + init(DbTrustedPaths::default(), cx); + }); + + let project = + Project::test_with_worktree_trust(fs.clone(), [path!("/project").as_ref()], cx).await; + cx.executor().run_until_parked(); + + let worktree_store = project.read_with(cx, |project, _| project.worktree_store()); + let worktree_id = worktree_store.read_with(cx, |store, cx| { + store.worktrees().next().unwrap().read(cx).id() + }); + + let repos = project.read_with(cx, |project, cx| { + project + .repositories(cx) + .values() + .cloned() + .collect::>() + }); + assert_eq!(repos.len(), 2, "should have two repositories"); + for repo in &repos { + repo.read_with(cx, |repo, _| { + assert!( + !repo.is_trusted(), + "all repos should be untrusted initially" + ); + }); + } + + let trusted_worktrees = cx + .update(|cx| TrustedWorktrees::try_get_global(cx).expect("trust global should be set")); + trusted_worktrees.update(cx, |store, cx| { + store.trust( + &worktree_store, + HashSet::from_iter([PathTrust::Worktree(worktree_id)]), + cx, + ); + }); + cx.executor().run_until_parked(); + + for repo in &repos { + repo.read_with(cx, |repo, _| { + assert!( + repo.is_trusted(), + "all repos should be trusted after worktree is trusted" + ); + }); + } + } + + #[gpui::test] + async fn test_repository_trust_restrict_trust_cycle(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree( + path!("/project"), + json!({ + ".git": {}, + "a.txt": "hello", + }), + ) + .await; + + cx.update(|cx| { + project::trusted_worktrees::init(DbTrustedPaths::default(), cx); + }); + + let project = + Project::test_with_worktree_trust(fs.clone(), [path!("/project").as_ref()], cx).await; + cx.executor().run_until_parked(); + + let worktree_store = project.read_with(cx, |project, _| project.worktree_store()); + let worktree_id = worktree_store.read_with(cx, |store, cx| { + store.worktrees().next().unwrap().read(cx).id() + }); + + let repository = project.read_with(cx, |project, cx| { + project.repositories(cx).values().next().unwrap().clone() + }); + + repository.read_with(cx, |repo, _| { + assert!(!repo.is_trusted(), "repository should start untrusted"); + }); + + let trusted_worktrees = cx + .update(|cx| TrustedWorktrees::try_get_global(cx).expect("trust global should be set")); + + trusted_worktrees.update(cx, |store, cx| { + store.trust( + &worktree_store, + HashSet::from_iter([PathTrust::Worktree(worktree_id)]), + cx, + ); + }); + cx.executor().run_until_parked(); + + repository.read_with(cx, |repo, _| { + assert!( + repo.is_trusted(), + "repository should be trusted after worktree is trusted" + ); + }); + + trusted_worktrees.update(cx, |store, cx| { + store.restrict( + worktree_store.downgrade(), + HashSet::from_iter([PathTrust::Worktree(worktree_id)]), + cx, + ); + }); + cx.executor().run_until_parked(); + + repository.read_with(cx, |repo, _| { + assert!( + !repo.is_trusted(), + "repository should be untrusted after worktree is restricted" + ); + }); + + trusted_worktrees.update(cx, |store, cx| { + store.trust( + &worktree_store, + HashSet::from_iter([PathTrust::Worktree(worktree_id)]), + cx, + ); + }); + cx.executor().run_until_parked(); + + repository.read_with(cx, |repo, _| { + assert!( + repo.is_trusted(), + "repository should be trusted again after second trust" + ); + }); + } +} diff --git a/crates/project/tests/integration/project_tests.rs b/crates/project/tests/integration/project_tests.rs index 6092836c19ef280aa2d13abcb32932f3b47703b6..d86b969e61ed173ee314cde6f584f2dbab6859f9 100644 --- a/crates/project/tests/integration/project_tests.rs +++ b/crates/project/tests/integration/project_tests.rs @@ -31,7 +31,7 @@ use futures::{StreamExt, future}; use git::{ GitHostingProviderRegistry, repository::{RepoPath, repo_path}, - status::{FileStatus, StatusCode, TrackedStatus}, + status::{DiffStat, FileStatus, StatusCode, TrackedStatus}, }; use git2::RepositoryInitOptions; use gpui::{ @@ -5359,6 +5359,52 @@ async fn test_rescan_and_remote_updates(cx: &mut gpui::TestAppContext) { }); } +#[cfg(target_os = "linux")] +#[gpui::test(retries = 5)] +async fn test_recreated_directory_receives_child_events(cx: &mut gpui::TestAppContext) { + init_test(cx); + cx.executor().allow_parking(); + + let dir = TempTree::new(json!({})); + let project = Project::test(Arc::new(RealFs::new(None, cx.executor())), [dir.path()], cx).await; + let tree = project.update(cx, |project, cx| project.worktrees(cx).next().unwrap()); + + tree.flush_fs_events(cx).await; + + let repro_dir = dir.path().join("repro"); + std::fs::create_dir(&repro_dir).unwrap(); + tree.flush_fs_events(cx).await; + + cx.update(|cx| { + assert!(tree.read(cx).entry_for_path(rel_path("repro")).is_some()); + }); + + std::fs::remove_dir_all(&repro_dir).unwrap(); + tree.flush_fs_events(cx).await; + + cx.update(|cx| { + assert!(tree.read(cx).entry_for_path(rel_path("repro")).is_none()); + }); + + std::fs::create_dir(&repro_dir).unwrap(); + tree.flush_fs_events(cx).await; + + cx.update(|cx| { + assert!(tree.read(cx).entry_for_path(rel_path("repro")).is_some()); + }); + + std::fs::write(repro_dir.join("repro-marker"), "").unwrap(); + tree.flush_fs_events(cx).await; + + cx.update(|cx| { + assert!( + tree.read(cx) + .entry_for_path(rel_path("repro/repro-marker")) + .is_some() + ); + }); +} + #[gpui::test(iterations = 10)] async fn test_buffer_identity_across_renames(cx: &mut gpui::TestAppContext) { init_test(cx); @@ -9207,14 +9253,23 @@ async fn test_git_repository_status(cx: &mut gpui::TestAppContext) { StatusEntry { repo_path: repo_path("a.txt"), status: StatusCode::Modified.worktree(), + diff_stat: Some(DiffStat { + added: 1, + deleted: 1, + }), }, StatusEntry { repo_path: repo_path("b.txt"), status: FileStatus::Untracked, + diff_stat: None, }, StatusEntry { repo_path: repo_path("d.txt"), status: StatusCode::Deleted.worktree(), + diff_stat: Some(DiffStat { + added: 0, + deleted: 1, + }), }, ] ); @@ -9236,18 +9291,31 @@ async fn test_git_repository_status(cx: &mut gpui::TestAppContext) { StatusEntry { repo_path: repo_path("a.txt"), status: StatusCode::Modified.worktree(), + diff_stat: Some(DiffStat { + added: 1, + deleted: 1, + }), }, StatusEntry { repo_path: repo_path("b.txt"), status: FileStatus::Untracked, + diff_stat: None, }, StatusEntry { repo_path: repo_path("c.txt"), status: StatusCode::Modified.worktree(), + diff_stat: Some(DiffStat { + added: 1, + deleted: 1, + }), }, StatusEntry { repo_path: repo_path("d.txt"), status: StatusCode::Deleted.worktree(), + diff_stat: Some(DiffStat { + added: 0, + deleted: 1, + }), }, ] ); @@ -9281,6 +9349,10 @@ async fn test_git_repository_status(cx: &mut gpui::TestAppContext) { [StatusEntry { repo_path: repo_path("a.txt"), status: StatusCode::Deleted.worktree(), + diff_stat: Some(DiffStat { + added: 0, + deleted: 1, + }), }] ); }); @@ -9345,6 +9417,7 @@ async fn test_git_status_postprocessing(cx: &mut gpui::TestAppContext) { worktree_status: StatusCode::Added } .into(), + diff_stat: None, }] ) }); @@ -9547,6 +9620,10 @@ async fn test_repository_pending_ops_staging( worktree_status: StatusCode::Unmodified } .into(), + diff_stat: Some(DiffStat { + added: 1, + deleted: 0, + }), }] ); }); @@ -9653,6 +9730,10 @@ async fn test_repository_pending_ops_long_running_staging( worktree_status: StatusCode::Unmodified } .into(), + diff_stat: Some(DiffStat { + added: 1, + deleted: 0, + }), }] ); }); @@ -9777,10 +9858,12 @@ async fn test_repository_pending_ops_stage_all( StatusEntry { repo_path: repo_path("a.txt"), status: FileStatus::Untracked, + diff_stat: None, }, StatusEntry { repo_path: repo_path("b.txt"), status: FileStatus::Untracked, + diff_stat: None, }, ] ); diff --git a/crates/project_panel/src/project_panel.rs b/crates/project_panel/src/project_panel.rs index 7f746a6ccd7efec2b73354992c593433b0b6f281..082086d6a0a946e610be4c96e50d626b7000bda4 100644 --- a/crates/project_panel/src/project_panel.rs +++ b/crates/project_panel/src/project_panel.rs @@ -46,6 +46,7 @@ use settings::{ update_settings_file, }; use smallvec::SmallVec; +use std::ops::Neg; use std::{any::TypeId, time::Instant}; use std::{ cell::OnceCell, @@ -6457,11 +6458,14 @@ impl Render for ProjectPanel { el.on_action(cx.listener(Self::trash)) }) }) - .when(project.is_local(), |el| { - el.on_action(cx.listener(Self::reveal_in_finder)) - .on_action(cx.listener(Self::open_system)) - .on_action(cx.listener(Self::open_in_terminal)) - }) + .when( + project.is_local() || project.is_via_wsl_with_host_interop(cx), + |el| { + el.on_action(cx.listener(Self::reveal_in_finder)) + .on_action(cx.listener(Self::open_system)) + .on_action(cx.listener(Self::open_in_terminal)) + }, + ) .when(project.is_via_remote_server(), |el| { el.on_action(cx.listener(Self::open_in_terminal)) .on_action(cx.listener(Self::download_from_remote)) @@ -6688,6 +6692,24 @@ impl Render for ProjectPanel { .id("project-panel-blank-area") .block_mouse_except_scroll() .flex_grow() + .on_scroll_wheel({ + let scroll_handle = self.scroll_handle.clone(); + let entity_id = cx.entity().entity_id(); + move |event, window, cx| { + let state = scroll_handle.0.borrow(); + let base_handle = &state.base_handle; + let current_offset = base_handle.offset(); + let max_offset = base_handle.max_offset(); + let delta = event.delta.pixel_delta(window.line_height()); + let new_offset = (current_offset + delta) + .clamp(&max_offset.neg(), &Point::default()); + + if new_offset != current_offset { + base_handle.set_offset(new_offset); + cx.notify(entity_id); + } + } + }) .when( self.drag_target_entry.as_ref().is_some_and( |entry| match entry { diff --git a/crates/proto/proto/git.proto b/crates/proto/proto/git.proto index 86f3d4c328af06e1a3f4f7cc406ac84272577cd0..736abcdaa49f62d72582750a8a28ea785baee282 100644 --- a/crates/proto/proto/git.proto +++ b/crates/proto/proto/git.proto @@ -125,6 +125,7 @@ message UpdateRepository { repeated StashEntry stash_entries = 13; optional string remote_upstream_url = 14; optional string remote_origin_url = 15; + optional string original_repo_abs_path = 16; } message RemoveRepository { @@ -228,29 +229,6 @@ message GitDiffResponse { string diff = 1; } -message GitDiffStat { - uint64 project_id = 1; - uint64 repository_id = 2; - DiffType diff_type = 3; - optional string merge_base_ref = 4; - - enum DiffType { - HEAD_TO_WORKTREE = 0; - HEAD_TO_INDEX = 1; - MERGE_BASE = 2; - } -} - -message GitDiffStatResponse { - repeated GitDiffStatEntry entries = 1; -} - -message GitDiffStatEntry { - string path = 1; - uint32 added = 2; - uint32 deleted = 3; -} - message GitInit { uint64 project_id = 1; string abs_path = 2; @@ -359,6 +337,8 @@ message StatusEntry { // Can be removed once collab's min version is >=0.171.0. GitStatus simple_status = 2; GitFileStatus status = 3; + optional uint32 diff_stat_added = 4; + optional uint32 diff_stat_deleted = 5; } message StashEntry { diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index d6139f5342d153221d13917e26565a4c0eb5a707..c129b6eff26404b66b38439c29f0b83289b37172 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -474,9 +474,7 @@ message Envelope { SpawnKernel spawn_kernel = 426; SpawnKernelResponse spawn_kernel_response = 427; - KillKernel kill_kernel = 428; - GitDiffStat git_diff_stat = 429; - GitDiffStatResponse git_diff_stat_response = 430; // current max + KillKernel kill_kernel = 428; // current max } reserved 87 to 88; @@ -501,6 +499,7 @@ message Envelope { reserved 280 to 281; reserved 332 to 333; reserved 394 to 396; + reserved 429 to 430; } message Hello { diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index 3d30551557000c305a82b328828b566c9d78f75e..dd0a77beb29345021563b21bafd261d02b87e1ab 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -322,8 +322,6 @@ messages!( (CheckForPushedCommitsResponse, Background), (GitDiff, Background), (GitDiffResponse, Background), - (GitDiffStat, Background), - (GitDiffStatResponse, Background), (GitInit, Background), (GetDebugAdapterBinary, Background), (DebugAdapterBinary, Background), @@ -541,7 +539,6 @@ request_messages!( (GitRenameBranch, Ack), (CheckForPushedCommits, CheckForPushedCommitsResponse), (GitDiff, GitDiffResponse), - (GitDiffStat, GitDiffStatResponse), (GitInit, Ack), (ToggleBreakpoint, Ack), (GetDebugAdapterBinary, DebugAdapterBinary), @@ -730,7 +727,6 @@ entity_messages!( GitRemoveRemote, CheckForPushedCommits, GitDiff, - GitDiffStat, GitInit, BreakpointsForFile, ToggleBreakpoint, diff --git a/crates/recent_projects/src/remote_servers.rs b/crates/recent_projects/src/remote_servers.rs index 6c0ce4b18854320fda8e72f291800049b07cec1a..a94f7b1d57eaef8657fb0d448480f84c97ce7e70 100644 --- a/crates/recent_projects/src/remote_servers.rs +++ b/crates/recent_projects/src/remote_servers.rs @@ -1161,12 +1161,11 @@ impl RemoteServerProjects { workspace.toggle_modal(window, cx, |window, cx| { RemoteConnectionModal::new(&connection_options, Vec::new(), window, cx) }); - let prompt = workspace - .active_modal::(cx) - .unwrap() - .read(cx) - .prompt - .clone(); + // can be None if another copy of this modal opened in the meantime + let Some(modal) = workspace.active_modal::(cx) else { + return; + }; + let prompt = modal.read(cx).prompt.clone(); let connect = connect( ConnectionIdentifier::setup(), diff --git a/crates/remote_server/src/remote_editing_tests.rs b/crates/remote_server/src/remote_editing_tests.rs index 778f7292d2a032df6995169852deeecee6fa76a7..7f9953c8a4e746d9586b663330badb38149cfb64 100644 --- a/crates/remote_server/src/remote_editing_tests.rs +++ b/crates/remote_server/src/remote_editing_tests.rs @@ -2,15 +2,11 @@ /// The tests in this file assume that server_cx is running on Windows too. /// We neead to find a way to test Windows-Non-Windows interactions. use crate::headless_project::HeadlessProject; -use agent::{ - AgentTool, ReadFileTool, ReadFileToolInput, Templates, Thread, ToolCallEventStream, ToolInput, -}; +use agent::{AgentTool, ReadFileTool, ReadFileToolInput, ToolCallEventStream, ToolInput}; use client::{Client, UserStore}; use clock::FakeSystemClock; use collections::{HashMap, HashSet}; -use git::repository::DiffType; -use language_model::{LanguageModelToolResultContent, fake_provider::FakeLanguageModel}; -use prompt_store::ProjectContext; +use language_model::LanguageModelToolResultContent; use extension::ExtensionHostProxy; use fs::{FakeFs, Fs}; @@ -1920,129 +1916,6 @@ async fn test_remote_git_branches(cx: &mut TestAppContext, server_cx: &mut TestA assert_eq!(server_branch.name(), "totally-new-branch"); } -#[gpui::test] -async fn test_remote_git_diff_stat(cx: &mut TestAppContext, server_cx: &mut TestAppContext) { - let fs = FakeFs::new(server_cx.executor()); - fs.insert_tree( - path!("/code"), - json!({ - "project1": { - ".git": {}, - "src": { - "lib.rs": "line1\nline2\nline3\n", - "new_file.rs": "added1\nadded2\n", - }, - "README.md": "# project 1", - }, - }), - ) - .await; - - let dot_git = Path::new(path!("/code/project1/.git")); - - // HEAD: lib.rs (2 lines), deleted.rs (1 line) - fs.set_head_for_repo( - dot_git, - &[ - ("src/lib.rs", "line1\nold_line2\n".into()), - ("src/deleted.rs", "was_here\n".into()), - ], - "deadbeef", - ); - // Index: lib.rs modified (4 lines), staged_only.rs new (2 lines) - fs.set_index_for_repo( - dot_git, - &[ - ("src/lib.rs", "line1\nold_line2\nline3\nline4\n".into()), - ("src/staged_only.rs", "x\ny\n".into()), - ], - ); - - let (project, _headless) = init_test(&fs, cx, server_cx).await; - let (_worktree, _) = project - .update(cx, |project, cx| { - project.find_or_create_worktree(path!("/code/project1"), true, cx) - }) - .await - .unwrap(); - cx.run_until_parked(); - - let repo_path = |s: &str| git::repository::RepoPath::new(s).unwrap(); - - let repository = project.update(cx, |project, cx| project.active_repository(cx).unwrap()); - - // --- HeadToWorktree --- - let stats = cx - .update(|cx| repository.update(cx, |repo, cx| repo.diff_stat(DiffType::HeadToWorktree, cx))) - .await - .unwrap() - .unwrap(); - - // src/lib.rs: worktree 3 lines vs HEAD 2 lines - let stat = stats.get(&repo_path("src/lib.rs")).expect("src/lib.rs"); - assert_eq!((stat.added, stat.deleted), (3, 2)); - - // src/new_file.rs: only in worktree (2 lines) - let stat = stats - .get(&repo_path("src/new_file.rs")) - .expect("src/new_file.rs"); - assert_eq!((stat.added, stat.deleted), (2, 0)); - - // src/deleted.rs: only in HEAD (1 line) - let stat = stats - .get(&repo_path("src/deleted.rs")) - .expect("src/deleted.rs"); - assert_eq!((stat.added, stat.deleted), (0, 1)); - - // README.md: only in worktree (1 line) - let stat = stats.get(&repo_path("README.md")).expect("README.md"); - assert_eq!((stat.added, stat.deleted), (1, 0)); - - // --- HeadToIndex --- - let stats = cx - .update(|cx| repository.update(cx, |repo, cx| repo.diff_stat(DiffType::HeadToIndex, cx))) - .await - .unwrap() - .unwrap(); - - // src/lib.rs: index 4 lines vs HEAD 2 lines - let stat = stats.get(&repo_path("src/lib.rs")).expect("src/lib.rs"); - assert_eq!((stat.added, stat.deleted), (4, 2)); - - // src/staged_only.rs: only in index (2 lines) - let stat = stats - .get(&repo_path("src/staged_only.rs")) - .expect("src/staged_only.rs"); - assert_eq!((stat.added, stat.deleted), (2, 0)); - - // src/deleted.rs: in HEAD but not in index - let stat = stats - .get(&repo_path("src/deleted.rs")) - .expect("src/deleted.rs"); - assert_eq!((stat.added, stat.deleted), (0, 1)); - - // --- MergeBase (not implemented in FakeGitRepository) --- - let stats = cx - .update(|cx| { - repository.update(cx, |repo, cx| { - repo.diff_stat( - DiffType::MergeBase { - base_ref: "main".into(), - }, - cx, - ) - }) - }) - .await - .unwrap() - .unwrap(); - - assert!( - stats.is_empty(), - "MergeBase diff_stat should return empty from FakeGitRepository" - ); -} - #[gpui::test] async fn test_remote_agent_fs_tool_calls(cx: &mut TestAppContext, server_cx: &mut TestAppContext) { let fs = FakeFs::new(server_cx.executor()); @@ -2065,27 +1938,12 @@ async fn test_remote_agent_fs_tool_calls(cx: &mut TestAppContext, server_cx: &mu let action_log = cx.new(|_| action_log::ActionLog::new(project.clone())); - // Create a minimal thread for the ReadFileTool - let context_server_registry = - cx.new(|cx| agent::ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let input = ReadFileToolInput { path: "project/b.txt".into(), start_line: None, end_line: None, }; - let read_tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log)); + let read_tool = Arc::new(ReadFileTool::new(project, action_log, true)); let (event_stream, _) = ToolCallEventStream::test(); let exists_result = cx.update(|cx| { diff --git a/crates/repl/src/notebook/cell.rs b/crates/repl/src/notebook/cell.rs index d66261698b722cfcd0f547e09d84cf83a0d2b1a6..200424742aff113d637fe9aca30999c0f95e79a5 100644 --- a/crates/repl/src/notebook/cell.rs +++ b/crates/repl/src/notebook/cell.rs @@ -1,13 +1,11 @@ -#![allow(unused, dead_code)] use std::sync::Arc; use std::time::{Duration, Instant}; use editor::{Editor, EditorMode, MultiBuffer, SizingBehavior}; use futures::future::Shared; use gpui::{ - App, Entity, EventEmitter, Focusable, Hsla, InteractiveElement, KeyContext, - RetainAllImageCache, StatefulInteractiveElement, Task, TextStyleRefinement, image_cache, - prelude::*, + App, Entity, EventEmitter, Focusable, Hsla, InteractiveElement, RetainAllImageCache, + StatefulInteractiveElement, Task, TextStyleRefinement, prelude::*, }; use language::{Buffer, Language, LanguageRegistry}; use markdown::{Markdown, MarkdownElement, MarkdownStyle}; @@ -236,7 +234,7 @@ pub trait RenderableCell: Render { fn source(&self) -> &String; fn selected(&self) -> bool; fn set_selected(&mut self, selected: bool) -> &mut Self; - fn selected_bg_color(&self, window: &mut Window, cx: &mut Context) -> Hsla { + fn selected_bg_color(&self, _window: &mut Window, cx: &mut Context) -> Hsla { if self.selected() { let mut color = cx.theme().colors().element_hover; color.fade_out(0.5); @@ -253,7 +251,7 @@ pub trait RenderableCell: Render { fn cell_position_spacer( &self, is_first: bool, - window: &mut Window, + _window: &mut Window, cx: &mut Context, ) -> Option { let cell_position = self.cell_position(); @@ -328,7 +326,6 @@ pub struct MarkdownCell { editing: bool, selected: bool, cell_position: Option, - languages: Arc, _editor_subscription: gpui::Subscription, } @@ -386,7 +383,6 @@ impl MarkdownCell { let markdown = cx.new(|cx| Markdown::new(source.clone().into(), None, None, cx)); - let cell_id = id.clone(); let editor_subscription = cx.subscribe(&editor, move |this, _editor, event, cx| match event { editor::EditorEvent::Blurred => { @@ -410,7 +406,6 @@ impl MarkdownCell { editing: start_editing, selected: false, cell_position: None, - languages, _editor_subscription: editor_subscription, } } @@ -461,8 +456,6 @@ impl MarkdownCell { .unwrap_or_default(); self.source = source.clone(); - let languages = self.languages.clone(); - self.markdown.update(cx, |markdown, cx| { markdown.reset(source.into(), cx); }); @@ -606,7 +599,7 @@ pub struct CodeCell { outputs: Vec, selected: bool, cell_position: Option, - language_task: Task<()>, + _language_task: Task<()>, execution_start_time: Option, execution_duration: Option, is_executing: bool, @@ -670,10 +663,10 @@ impl CodeCell { outputs: Vec::new(), selected: false, cell_position: None, - language_task, execution_start_time: None, execution_duration: None, is_executing: false, + _language_task: language_task, } } @@ -748,10 +741,10 @@ impl CodeCell { outputs, selected: false, cell_position: None, - language_task, execution_start_time: None, execution_duration: None, is_executing: false, + _language_task: language_task, } } @@ -879,15 +872,7 @@ impl CodeCell { cx.notify(); } - fn output_control(&self) -> Option { - if self.has_outputs() { - Some(CellControlType::ClearCell) - } else { - None - } - } - - pub fn gutter_output(&self, window: &mut Window, cx: &mut Context) -> impl IntoElement { + pub fn gutter_output(&self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { let is_selected = self.selected(); div() @@ -948,7 +933,7 @@ impl RenderableCell for CodeCell { &self.source } - fn control(&self, window: &mut Window, cx: &mut Context) -> Option { + fn control(&self, _window: &mut Window, cx: &mut Context) -> Option { let control_type = if self.has_outputs() { CellControlType::RerunCell } else { @@ -1038,8 +1023,7 @@ impl RenderableCell for CodeCell { } impl RunnableCell for CodeCell { - fn run(&mut self, window: &mut Window, cx: &mut Context) { - println!("Running code cell: {}", self.id); + fn run(&mut self, _window: &mut Window, cx: &mut Context) { cx.emit(CellEvent::Run(self.id.clone())); } @@ -1062,11 +1046,8 @@ impl Render for CodeCell { } else { None }; - let output_max_width = plain::max_width_for_columns( - ReplSettings::get_global(cx).output_max_width_columns, - window, - cx, - ); + let output_max_width = + plain::max_width_for_columns(ReplSettings::get_global(cx).max_columns, window, cx); // get the language from the editor's buffer let language_name = self .editor @@ -1198,41 +1179,23 @@ impl Render for CodeCell { }, ) // output at bottom - .child(div().w_full().children(self.outputs.iter().map( - |output| { - let content = match output { - Output::Plain { content, .. } => { - Some(content.clone().into_any_element()) - } - Output::Markdown { content, .. } => { - Some(content.clone().into_any_element()) - } - Output::Stream { content, .. } => { - Some(content.clone().into_any_element()) - } - Output::Image { content, .. } => { - Some(content.clone().into_any_element()) - } - Output::Message(message) => Some( - div() - .child(message.clone()) - .into_any_element(), - ), - Output::Table { content, .. } => { - Some(content.clone().into_any_element()) - } - Output::Json { content, .. } => { - Some(content.clone().into_any_element()) - } - Output::ErrorOutput(error_view) => { - error_view.render(window, cx) - } - Output::ClearOutputWaitMarker => None, - }; - - div().children(content) - }, - ))), + .child( + div() + .id(( + ElementId::from(self.id.to_string()), + "output-scroll", + )) + .w_full() + .when_some(output_max_width, |div, max_width| { + div.max_w(max_width).overflow_x_scroll() + }) + .when_some(output_max_height, |div, max_height| { + div.max_h(max_height).overflow_y_scroll() + }) + .children(self.outputs.iter().map(|output| { + div().children(output.content(window, cx)) + })), + ), ), ), ) diff --git a/crates/repl/src/notebook/notebook_ui.rs b/crates/repl/src/notebook/notebook_ui.rs index 5b8c0746cdf1289ac3c612139fab1819b5596c07..87f18708a1988c70d66dc4cef5355d4cbcb11dba 100644 --- a/crates/repl/src/notebook/notebook_ui.rs +++ b/crates/repl/src/notebook/notebook_ui.rs @@ -1514,6 +1514,9 @@ impl project::ProjectItem for NotebookItem { nbformat::upgrade_legacy_notebook(legacy_notebook)? } + nbformat::Notebook::V3(v3_notebook) => { + nbformat::upgrade_v3_notebook(v3_notebook)? + } } }; @@ -1791,6 +1794,9 @@ impl Item for NotebookEditor { Ok(nbformat::Notebook::Legacy(legacy_notebook)) => { nbformat::upgrade_legacy_notebook(legacy_notebook)? } + Ok(nbformat::Notebook::V3(v3_notebook)) => { + nbformat::upgrade_v3_notebook(v3_notebook)? + } Err(e) => { anyhow::bail!("Failed to parse notebook: {:?}", e); } diff --git a/crates/repl/src/outputs.rs b/crates/repl/src/outputs.rs index 8be8c57cceee84435a6d99ba5c611d24c563bec3..f6d2bc4d3173ce64700b7b5ac45301df0fe0ab53 100644 --- a/crates/repl/src/outputs.rs +++ b/crates/repl/src/outputs.rs @@ -253,18 +253,8 @@ impl Output { ) } - pub fn render( - &self, - workspace: WeakEntity, - window: &mut Window, - cx: &mut Context, - ) -> impl IntoElement + use<> { - let max_width = plain::max_width_for_columns( - ReplSettings::get_global(cx).output_max_width_columns, - window, - cx, - ); - let content = match self { + pub fn content(&self, window: &mut Window, cx: &mut App) -> Option { + match self { Self::Plain { content, .. } => Some(content.clone().into_any_element()), Self::Markdown { content, .. } => Some(content.clone().into_any_element()), Self::Stream { content, .. } => Some(content.clone().into_any_element()), @@ -274,21 +264,36 @@ impl Output { Self::Json { content, .. } => Some(content.clone().into_any_element()), Self::ErrorOutput(error_view) => error_view.render(window, cx), Self::ClearOutputWaitMarker => None, - }; + } + } - let needs_horizontal_scroll = matches!(self, Self::Table { .. } | Self::Image { .. }); + pub fn render( + &self, + workspace: WeakEntity, + window: &mut Window, + cx: &mut Context, + ) -> impl IntoElement + use<> { + let max_width = + plain::max_width_for_columns(ReplSettings::get_global(cx).max_columns, window, cx); + let content = self.content(window, cx); + + let needs_horizontal_scroll = matches!(self, Self::Table { .. }); h_flex() .id("output-content") .w_full() - .when_some(max_width, |this, max_w| this.max_w(max_w)) - .overflow_x_scroll() + .when_else( + needs_horizontal_scroll, + |this| this.overflow_x_scroll(), + |this| this.overflow_x_hidden(), + ) .items_start() .child( div() .when(!needs_horizontal_scroll, |el| { el.flex_1().w_full().overflow_x_hidden() }) + .when_some(max_width, |el, max_width| el.max_w(max_width)) .children(content), ) .children(match self { diff --git a/crates/repl/src/outputs/image.rs b/crates/repl/src/outputs/image.rs index 9d1ffa3d2065281cd69e67b2faf960c9aa690bcb..e5444be3d779c9541fcadd55b9255d3e25da0cba 100644 --- a/crates/repl/src/outputs/image.rs +++ b/crates/repl/src/outputs/image.rs @@ -3,10 +3,10 @@ use base64::{ Engine as _, alphabet, engine::{DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig}, }; -use gpui::{App, ClipboardItem, Image, ImageFormat, RenderImage, Window, img}; +use gpui::{App, ClipboardItem, Image, ImageFormat, Pixels, RenderImage, Window, img}; use settings::Settings as _; use std::sync::Arc; -use ui::{IntoElement, Styled, div, prelude::*}; +use ui::{IntoElement, Styled, prelude::*}; use crate::outputs::{OutputContent, plain}; use crate::repl_settings::ReplSettings; @@ -113,7 +113,7 @@ impl Render for ImageView { let settings = ReplSettings::get_global(cx); let line_height = window.line_height(); - let max_width = plain::max_width_for_columns(settings.output_max_width_columns, window, cx); + let max_width = plain::max_width_for_columns(settings.max_columns, window, cx); let max_height = if settings.output_max_height_lines > 0 { Some(line_height * settings.output_max_height_lines as f32) @@ -125,7 +125,7 @@ impl Render for ImageView { let image = self.image.clone(); - div().h(height).w(width).child(img(image)) + img(image).w(width).h(height) } } diff --git a/crates/repl/src/outputs/plain.rs b/crates/repl/src/outputs/plain.rs index 0db2f811fb9ca3b82114db23826e37fe699bd3a0..71e2624f8ad7b0172a86793d5d81b38339b04f36 100644 --- a/crates/repl/src/outputs/plain.rs +++ b/crates/repl/src/outputs/plain.rs @@ -22,7 +22,7 @@ use alacritty_terminal::{ term::Config, vte::ansi::Processor, }; -use gpui::{Bounds, ClipboardItem, Entity, FontStyle, TextStyle, WhiteSpace, canvas, size}; +use gpui::{Bounds, ClipboardItem, Entity, FontStyle, Pixels, TextStyle, WhiteSpace, canvas, size}; use language::Buffer; use settings::Settings as _; use terminal::terminal_settings::TerminalSettings; diff --git a/crates/repl/src/repl_editor.rs b/crates/repl/src/repl_editor.rs index 6e061c3e2e37aa94074f17f94791ad147f56f344..56b79e20ffca74ab3f9f9c7948a7caeffc4ad4ce 100644 --- a/crates/repl/src/repl_editor.rs +++ b/crates/repl/src/repl_editor.rs @@ -636,12 +636,9 @@ fn language_supported(language: &Arc, cx: &mut App) -> bool { let store = ReplStore::global(cx); let store_read = store.read(cx); - // Since we're just checking for general language support, we only need to look at - // the pure Jupyter kernels - these are all the globally available ones - store_read.pure_jupyter_kernel_specifications().any(|spec| { - // Convert to lowercase for case-insensitive comparison since kernels might report "python" while our language is "Python" - spec.language().as_ref().to_lowercase() == language.name().as_ref().to_lowercase() - }) + store_read + .pure_jupyter_kernel_specifications() + .any(|spec| language.matches_kernel_language(spec.language().as_ref())) } fn get_language(editor: WeakEntity, cx: &mut App) -> Option> { diff --git a/crates/repl/src/repl_settings.rs b/crates/repl/src/repl_settings.rs index 302164a5b360157edceff1b1f2e18f6c6fd7a50b..5fd7623bb71e6446b8cacd6029108e481efc8680 100644 --- a/crates/repl/src/repl_settings.rs +++ b/crates/repl/src/repl_settings.rs @@ -27,11 +27,6 @@ pub struct ReplSettings { /// /// Default: 0 pub output_max_height_lines: usize, - /// Maximum number of columns of output to display before scaling images. - /// Set to 0 to disable output width limits. - /// - /// Default: 0 - pub output_max_width_columns: usize, } impl Settings for ReplSettings { @@ -44,7 +39,6 @@ impl Settings for ReplSettings { inline_output: repl.inline_output.unwrap_or(true), inline_output_max_length: repl.inline_output_max_length.unwrap_or(50), output_max_height_lines: repl.output_max_height_lines.unwrap_or(0), - output_max_width_columns: repl.output_max_width_columns.unwrap_or(0), } } } diff --git a/crates/repl/src/repl_store.rs b/crates/repl/src/repl_store.rs index 1c6ce99c2177260c1b9aaf1733326ddbda85a64f..8da94eaa7fe40e28a1d6336a648d7eae5c6767ae 100644 --- a/crates/repl/src/repl_store.rs +++ b/crates/repl/src/repl_store.rs @@ -289,7 +289,6 @@ impl ReplStore { } let language_at_cursor = language_at_cursor?; - let language_name = language_at_cursor.code_fence_block_name().to_lowercase(); // Prefer the recommended (active toolchain) kernel if it has ipykernel if let Some(active_path) = self.active_python_toolchain_path(worktree_id) { @@ -297,7 +296,7 @@ impl ReplStore { .kernel_specifications_for_worktree(worktree_id) .find(|spec| { spec.has_ipykernel() - && spec.language().as_ref().to_lowercase() == language_name + && language_at_cursor.matches_kernel_language(spec.language().as_ref()) && spec.path().as_ref() == active_path.as_ref() }) .cloned(); @@ -312,7 +311,7 @@ impl ReplStore { .find(|spec| { matches!(spec, KernelSpecification::PythonEnv(_)) && spec.has_ipykernel() - && spec.language().as_ref().to_lowercase() == language_name + && language_at_cursor.matches_kernel_language(spec.language().as_ref()) }) .cloned(); if python_env.is_some() { @@ -350,10 +349,10 @@ impl ReplStore { return Some(found_by_name); } - let language_name = language_at_cursor.code_fence_block_name().to_lowercase(); self.kernel_specifications_for_worktree(worktree_id) .find(|spec| { - spec.has_ipykernel() && spec.language().as_ref().to_lowercase() == language_name + spec.has_ipykernel() + && language_at_cursor.matches_kernel_language(spec.language().as_ref()) }) .cloned() } diff --git a/crates/settings/src/settings_store.rs b/crates/settings/src/settings_store.rs index 8551fc2edd53df66965b18abbe91f7083dd08461..26425faf113a9dc0f52ad04809dc71c2f89eeb69 100644 --- a/crates/settings/src/settings_store.rs +++ b/crates/settings/src/settings_store.rs @@ -1706,7 +1706,7 @@ mod tests { r#"{ "languages": { "JSON": { - "auto_indent": true + "auto_indent": "syntax_aware" } } }"# @@ -1716,12 +1716,12 @@ mod tests { .languages_mut() .get_mut("JSON") .unwrap() - .auto_indent = Some(false); + .auto_indent = Some(crate::AutoIndentMode::None); settings.languages_mut().insert( "Rust".into(), LanguageSettingsContent { - auto_indent: Some(true), + auto_indent: Some(crate::AutoIndentMode::SyntaxAware), ..Default::default() }, ); @@ -1729,10 +1729,10 @@ mod tests { r#"{ "languages": { "Rust": { - "auto_indent": true + "auto_indent": "syntax_aware" }, "JSON": { - "auto_indent": false + "auto_indent": "none" } } }"# diff --git a/crates/settings_content/src/language.rs b/crates/settings_content/src/language.rs index d429f53824fd0f4f0a5810bce01b05badcfb9a51..fba636ee28be121a15da4b3d50046c53c0bdd5b3 100644 --- a/crates/settings_content/src/language.rs +++ b/crates/settings_content/src/language.rs @@ -90,7 +90,7 @@ pub enum EditPredictionProvider { Experimental(&'static str), } -pub const EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME: &str = "zeta2"; +const EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME: &str = "zeta2"; impl<'de> Deserialize<'de> for EditPredictionProvider { fn deserialize(deserializer: D) -> Result @@ -157,10 +157,7 @@ impl EditPredictionProvider { EditPredictionProvider::Codestral => Some("Codestral"), EditPredictionProvider::Sweep => Some("Sweep"), EditPredictionProvider::Mercury => Some("Mercury"), - EditPredictionProvider::Experimental( - EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, - ) => Some("Zeta2"), - EditPredictionProvider::None | EditPredictionProvider::Experimental(_) => None, + EditPredictionProvider::Experimental(_) | EditPredictionProvider::None => None, EditPredictionProvider::Ollama => Some("Ollama"), EditPredictionProvider::OpenAiCompatibleApi => Some("OpenAI-Compatible API"), } @@ -372,6 +369,32 @@ pub enum EditPredictionsMode { Eager, } +/// Controls the soft-wrapping behavior in the editor. +#[derive( + Copy, + Clone, + Debug, + Serialize, + Deserialize, + PartialEq, + Eq, + JsonSchema, + MergeFrom, + strum::VariantArray, + strum::VariantNames, +)] +#[serde(rename_all = "snake_case")] +pub enum AutoIndentMode { + /// Adjusts indentation based on syntax context when typing. + /// Uses tree-sitter to analyze code structure and indent accordingly. + SyntaxAware, + /// Preserve the indentation of the current line when creating new lines, + /// but don't adjust based on syntax context. + PreserveIndent, + /// No automatic indentation. New lines start at column 0. + None, +} + /// Controls the soft-wrapping behavior in the editor. #[derive( Copy, @@ -574,10 +597,14 @@ pub struct LanguageSettingsContent { /// /// Default: true pub linked_edits: Option, - /// Whether indentation should be adjusted based on the context whilst typing. + /// Controls automatic indentation behavior when typing. /// - /// Default: true - pub auto_indent: Option, + /// - "syntax_aware": Adjusts indentation based on syntax context (default) + /// - "preserve_indent": Preserves current line's indentation on new lines + /// - "none": No automatic indentation + /// + /// Default: syntax_aware + pub auto_indent: Option, /// Whether indentation of pasted content should be adjusted based on the context. /// /// Default: true diff --git a/crates/settings_content/src/settings_content.rs b/crates/settings_content/src/settings_content.rs index f94c6a0b98d7fa23686dc1c89012e3b1fe476c70..5a4e87c384d802f3de4c96c07f65cf163c3a6d1a 100644 --- a/crates/settings_content/src/settings_content.rs +++ b/crates/settings_content/src/settings_content.rs @@ -1148,11 +1148,6 @@ pub struct ReplSettingsContent { /// /// Default: 0 pub output_max_height_lines: Option, - /// Maximum number of columns of output to display before scaling images. - /// Set to 0 to disable output width limits. - /// - /// Default: 0 - pub output_max_width_columns: Option, } /// Settings for configuring the which-key popup behaviour. diff --git a/crates/settings_ui/src/page_data.rs b/crates/settings_ui/src/page_data.rs index afc84a9f9b91e32f3a110e19dc78db5634369458..dbac4d7ba350fcff07016a2ccfa483f3d84472c7 100644 --- a/crates/settings_ui/src/page_data.rs +++ b/crates/settings_ui/src/page_data.rs @@ -7405,7 +7405,7 @@ fn language_settings_data() -> Box<[SettingsPageItem]> { }), SettingsPageItem::SettingItem(SettingItem { title: "Auto Indent", - description: "Whether indentation should be adjusted based on the context whilst typing.", + description: "Controls automatic indentation behavior when typing.", field: Box::new(SettingField { json_path: Some("languages.$(language).auto_indent"), pick: |settings_content| { diff --git a/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs b/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs index 338fe4de14f1f7e9060fafe865253f09f0bdc481..32c4bee84bd1f72263ed28bcd44d7e6349c4b24c 100644 --- a/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs +++ b/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs @@ -2,6 +2,7 @@ use codestral::{CODESTRAL_API_URL, codestral_api_key_state, codestral_api_url}; use edit_prediction::{ ApiKeyState, mercury::{MERCURY_CREDENTIALS_URL, mercury_api_token}, + open_ai_compatible::{open_ai_compatible_api_token, open_ai_compatible_api_url}, sweep_ai::{SWEEP_CREDENTIALS_URL, sweep_api_token}, }; use edit_prediction_ui::{get_available_providers, set_completion_provider}; @@ -33,7 +34,9 @@ pub(crate) fn render_edit_prediction_setup_page( render_api_key_provider( IconName::Inception, "Mercury", - "https://platform.inceptionlabs.ai/dashboard/api-keys".into(), + ApiKeyDocs::Link { + dashboard_url: "https://platform.inceptionlabs.ai/dashboard/api-keys".into(), + }, mercury_api_token(cx), |_cx| MERCURY_CREDENTIALS_URL, None, @@ -46,7 +49,9 @@ pub(crate) fn render_edit_prediction_setup_page( render_api_key_provider( IconName::SweepAi, "Sweep", - "https://app.sweep.dev/".into(), + ApiKeyDocs::Link { + dashboard_url: "https://app.sweep.dev/".into(), + }, sweep_api_token(cx), |_cx| SWEEP_CREDENTIALS_URL, Some( @@ -68,7 +73,9 @@ pub(crate) fn render_edit_prediction_setup_page( render_api_key_provider( IconName::AiMistral, "Codestral", - "https://console.mistral.ai/codestral".into(), + ApiKeyDocs::Link { + dashboard_url: "https://console.mistral.ai/codestral".into(), + }, codestral_api_key_state(cx), |cx| codestral_api_url(cx), Some( @@ -87,7 +94,31 @@ pub(crate) fn render_edit_prediction_setup_page( .into_any_element(), ), Some(render_ollama_provider(settings_window, window, cx).into_any_element()), - Some(render_open_ai_compatible_provider(settings_window, window, cx).into_any_element()), + Some( + render_api_key_provider( + IconName::AiOpenAiCompat, + "OpenAI Compatible API", + ApiKeyDocs::Custom { + message: "Set an API key here. It will be sent as Authorization: Bearer {key}." + .into(), + }, + open_ai_compatible_api_token(cx), + |cx| open_ai_compatible_api_url(cx), + Some( + settings_window + .render_sub_page_items_section( + open_ai_compatible_settings().iter().enumerate(), + true, + window, + cx, + ) + .into_any_element(), + ), + window, + cx, + ) + .into_any_element(), + ), ]; div() @@ -162,10 +193,15 @@ fn render_provider_dropdown(window: &mut Window, cx: &mut App) -> AnyElement { .into_any_element() } +enum ApiKeyDocs { + Link { dashboard_url: SharedString }, + Custom { message: SharedString }, +} + fn render_api_key_provider( icon: IconName, title: &'static str, - link: SharedString, + docs: ApiKeyDocs, api_key_state: Entity, current_url: fn(&mut App) -> SharedString, additional_fields: Option, @@ -209,25 +245,32 @@ fn render_api_key_provider( .icon(icon) .no_padding(true); let button_link_label = format!("{} dashboard", title); - let description = h_flex() - .min_w_0() - .gap_0p5() - .child( - Label::new("Visit the") + let description = match docs { + ApiKeyDocs::Custom { message } => h_flex().min_w_0().gap_0p5().child( + Label::new(message) .size(LabelSize::Small) .color(Color::Muted), - ) - .child( - ButtonLink::new(button_link_label, link) - .no_icon(true) - .label_size(LabelSize::Small) - .label_color(Color::Muted), - ) - .child( - Label::new("to generate an API key.") - .size(LabelSize::Small) - .color(Color::Muted), - ); + ), + ApiKeyDocs::Link { dashboard_url } => h_flex() + .min_w_0() + .gap_0p5() + .child( + Label::new("Visit the") + .size(LabelSize::Small) + .color(Color::Muted), + ) + .child( + ButtonLink::new(button_link_label, dashboard_url) + .no_icon(true) + .label_size(LabelSize::Small) + .label_color(Color::Muted), + ) + .child( + Label::new("to generate an API key.") + .size(LabelSize::Small) + .color(Color::Muted), + ), + }; let configured_card_label = if is_from_env_var { "API Key Set in Environment Variable" } else { @@ -484,34 +527,6 @@ fn ollama_settings() -> Box<[SettingsPageItem]> { ]) } -fn render_open_ai_compatible_provider( - settings_window: &SettingsWindow, - window: &mut Window, - cx: &mut Context, -) -> impl IntoElement { - let open_ai_compatible_settings = open_ai_compatible_settings(); - let additional_fields = settings_window - .render_sub_page_items_section( - open_ai_compatible_settings.iter().enumerate(), - true, - window, - cx, - ) - .into_any_element(); - - v_flex() - .id("open-ai-compatible") - .min_w_0() - .pt_8() - .gap_1p5() - .child( - SettingsSectionHeader::new("OpenAI Compatible API") - .icon(IconName::AiOpenAiCompat) - .no_padding(true), - ) - .child(div().px_neg_8().child(additional_fields)) -} - fn open_ai_compatible_settings() -> Box<[SettingsPageItem]> { Box::new([ SettingsPageItem::SettingItem(SettingItem { diff --git a/crates/settings_ui/src/settings_ui.rs b/crates/settings_ui/src/settings_ui.rs index def4c7630869cae69c539e1d83660e8df9a18318..9d7fe83736be8d1d9ed79d85708c5ed0574b7e3a 100644 --- a/crates/settings_ui/src/settings_ui.rs +++ b/crates/settings_ui/src/settings_ui.rs @@ -474,6 +474,7 @@ fn init_renderers(cx: &mut App) { .add_basic_renderer::(render_dropdown) .add_basic_renderer::(render_dropdown) .add_basic_renderer::(render_dropdown) + .add_basic_renderer::(render_dropdown) .add_basic_renderer::(render_dropdown) .add_basic_renderer::(render_dropdown) .add_basic_renderer::(render_dropdown) diff --git a/crates/terminal/src/terminal_hyperlinks.rs b/crates/terminal/src/terminal_hyperlinks.rs index d239f680f9e2ecbd3d320e731d3cc74303a552ed..0ca6cb2edd916019a4a7822830faa1fdfaa238f3 100644 --- a/crates/terminal/src/terminal_hyperlinks.rs +++ b/crates/terminal/src/terminal_hyperlinks.rs @@ -905,6 +905,18 @@ mod tests { ); } + #[test] + // + fn issue_50531() { + // Paths preceded by "N:" prefix (e.g. grep output line numbers) + // should still be clickable + test_path!("0: ‹«foo/👉bar.txt»›"); + test_path!("0: ‹«👉foo/bar.txt»›"); + test_path!("42: ‹«👉foo/bar.txt»›"); + test_path!("1: ‹«/👉test/cool.rs»›"); + test_path!("1: ‹«/👉test/cool.rs»:«4»:«2»›"); + } + #[test] // fn issue_46795() { diff --git a/crates/terminal_view/src/terminal_scrollbar.rs b/crates/terminal_view/src/terminal_scrollbar.rs index 82ca0b4097dad1be899879b0241aed50d8e60bfa..16dc580e877310b79501ca469b0351935dbb46f7 100644 --- a/crates/terminal_view/src/terminal_scrollbar.rs +++ b/crates/terminal_view/src/terminal_scrollbar.rs @@ -3,7 +3,7 @@ use std::{ rc::Rc, }; -use gpui::{Bounds, Point, Size, size}; +use gpui::{Bounds, Point, point, size}; use terminal::Terminal; use ui::{Pixels, ScrollableHandle, px}; @@ -46,9 +46,9 @@ impl TerminalScrollHandle { } impl ScrollableHandle for TerminalScrollHandle { - fn max_offset(&self) -> Size { + fn max_offset(&self) -> Point { let state = self.state.borrow(); - size( + point( Pixels::ZERO, state.total_lines.saturating_sub(state.viewport_lines) as f32 * state.line_height, ) diff --git a/crates/theme/src/icon_theme.rs b/crates/theme/src/icon_theme.rs index 8415462595cb93a19365a929660b4e8e3f78f8d8..121ff9d7d4fbd841315b89e631606c7e67bc5cde 100644 --- a/crates/theme/src/icon_theme.rs +++ b/crates/theme/src/icon_theme.rs @@ -66,7 +66,7 @@ pub struct IconDefinition { } const FILE_STEMS_BY_ICON_KEY: &[(&str, &[&str])] = &[ - ("docker", &["Dockerfile"]), + ("docker", &["Containerfile", "Dockerfile"]), ("ruby", &["Podfile"]), ("heroku", &["Procfile"]), ]; @@ -89,7 +89,7 @@ const FILE_SUFFIXES_BY_ICON_KEY: &[(&str, &[&str])] = &[ ( "cpp", &[ - "c++", "h++", "cc", "cpp", "cxx", "hh", "hpp", "hxx", "inl", "ixx", + "c++", "h++", "cc", "cpp", "cppm", "cxx", "hh", "hpp", "hxx", "inl", "ixx", ], ), ("crystal", &["cr", "ecr"]), @@ -99,6 +99,15 @@ const FILE_SUFFIXES_BY_ICON_KEY: &[(&str, &[&str])] = &[ ("cue", &["cue"]), ("dart", &["dart"]), ("diff", &["diff"]), + ( + "docker", + &[ + "docker-compose.yml", + "docker-compose.yaml", + "compose.yml", + "compose.yaml", + ], + ), ( "document", &[ @@ -138,12 +147,27 @@ const FILE_SUFFIXES_BY_ICON_KEY: &[(&str, &[&str])] = &[ ("font", &["otf", "ttf", "woff", "woff2"]), ("fsharp", &["fs"]), ("fsproj", &["fsproj"]), - ("gitlab", &["gitlab-ci.yml"]), + ("gitlab", &["gitlab-ci.yml", "gitlab-ci.yaml"]), ("gleam", &["gleam"]), ("go", &["go", "mod", "work"]), ("graphql", &["gql", "graphql", "graphqls"]), ("haskell", &["hs"]), ("hcl", &["hcl"]), + ( + "helm", + &[ + "helmfile.yaml", + "helmfile.yml", + "Chart.yaml", + "Chart.yml", + "Chart.lock", + "values.yaml", + "values.yml", + "requirements.yaml", + "requirements.yml", + "tpl", + ], + ), ("html", &["htm", "html"]), ( "image", @@ -198,7 +222,7 @@ const FILE_SUFFIXES_BY_ICON_KEY: &[(&str, &[&str])] = &[ ("rust", &["rs"]), ("sass", &["sass", "scss"]), ("scala", &["scala", "sc"]), - ("settings", &["conf", "ini", "yaml", "yml"]), + ("settings", &["conf", "ini"]), ("solidity", &["sol"]), ( "storage", @@ -279,6 +303,7 @@ const FILE_SUFFIXES_BY_ICON_KEY: &[(&str, &[&str])] = &[ ("vue", &["vue"]), ("vyper", &["vy", "vyi"]), ("wgsl", &["wgsl"]), + ("yaml", &["yaml", "yml"]), ("zig", &["zig"]), ]; @@ -310,12 +335,13 @@ const FILE_ICONS: &[(&str, &str)] = &[ ("font", "icons/file_icons/font.svg"), ("fsharp", "icons/file_icons/fsharp.svg"), ("fsproj", "icons/file_icons/file.svg"), - ("gitlab", "icons/file_icons/settings.svg"), + ("gitlab", "icons/file_icons/gitlab.svg"), ("gleam", "icons/file_icons/gleam.svg"), ("go", "icons/file_icons/go.svg"), ("graphql", "icons/file_icons/graphql.svg"), ("haskell", "icons/file_icons/haskell.svg"), ("hcl", "icons/file_icons/hcl.svg"), + ("helm", "icons/file_icons/helm.svg"), ("heroku", "icons/file_icons/heroku.svg"), ("html", "icons/file_icons/html.svg"), ("image", "icons/file_icons/image.svg"), @@ -371,6 +397,7 @@ const FILE_ICONS: &[(&str, &str)] = &[ ("vue", "icons/file_icons/vue.svg"), ("vyper", "icons/file_icons/vyper.svg"), ("wgsl", "icons/file_icons/wgsl.svg"), + ("yaml", "icons/file_icons/yaml.svg"), ("zig", "icons/file_icons/zig.svg"), ]; diff --git a/crates/ui/src/components/callout.rs b/crates/ui/src/components/callout.rs index 24762ec1765a58259b061194ea31ed7e8721c2a0..23c820cd545adff2985a4116a6efb00c1e731693 100644 --- a/crates/ui/src/components/callout.rs +++ b/crates/ui/src/components/callout.rs @@ -295,7 +295,7 @@ impl Component for Callout { "Error details:", "• Quota exceeded for metric", "• Limit: 0", - "• Model: gemini-3-pro", + "• Model: gemini-3.1-pro", "Please retry in 26.33s.", "Additional details:", "- Request ID: abc123def456", diff --git a/crates/ui/src/components/data_table.rs b/crates/ui/src/components/data_table.rs index 8a40c246ca44ea9dbb25e61bb611882343ba7f94..76ed64850c92e274bd8aeca483dd197cfbccbf52 100644 --- a/crates/ui/src/components/data_table.rs +++ b/crates/ui/src/components/data_table.rs @@ -36,6 +36,13 @@ pub mod table_row { pub struct TableRow(Vec); impl TableRow { + pub fn from_element(element: T, length: usize) -> Self + where + T: Clone, + { + Self::from_vec(vec![element; length], length) + } + /// Constructs a `TableRow` from a `Vec`, panicking if the length does not match `expected_length`. /// /// Use this when you want to ensure at construction time that the row has the correct number of columns. @@ -70,7 +77,8 @@ pub mod table_row { /// /// # Panics /// Panics if `col` is out of bounds (i.e., `col >= self.cols()`). - pub fn expect_get(&self, col: usize) -> &T { + pub fn expect_get(&self, col: impl Into) -> &T { + let col = col.into(); self.0.get(col).unwrap_or_else(|| { panic!( "Expected table row of `{}` to have {col:?}", @@ -79,8 +87,8 @@ pub mod table_row { }) } - pub fn get(&self, col: usize) -> Option<&T> { - self.0.get(col) + pub fn get(&self, col: impl Into) -> Option<&T> { + self.0.get(col.into()) } pub fn as_slice(&self) -> &[T] { @@ -735,6 +743,7 @@ pub struct Table { empty_table_callback: Option AnyElement>>, /// The number of columns in the table. Used to assert column numbers in `TableRow` collections cols: usize, + disable_base_cell_style: bool, } impl Table { @@ -753,9 +762,19 @@ impl Table { use_ui_font: true, empty_table_callback: None, col_widths: None, + disable_base_cell_style: false, } } + /// Disables based styling of row cell (paddings, text ellipsis, nowrap, etc), keeping width settings + /// + /// Doesn't affect base style of header cell. + /// Doesn't remove overflow-hidden + pub fn disable_base_style(mut self) -> Self { + self.disable_base_cell_style = true; + self + } + /// Enables uniform list rendering. /// The provided function will be passed directly to the `uniform_list` element. /// Therefore, if this method is called, any calls to [`Table::row`] before or after @@ -973,10 +992,18 @@ pub fn render_table_row( .into_iter() .zip(column_widths.into_vec()) .map(|(cell, width)| { - base_cell_style_text(width, table_context.use_ui_font, cx) - .px_1() - .py_0p5() - .child(cell) + if table_context.disable_base_cell_style { + div() + .when_some(width, |this, width| this.w(width)) + .when(width.is_none(), |this| this.flex_1()) + .overflow_hidden() + .child(cell) + } else { + base_cell_style_text(width, table_context.use_ui_font, cx) + .px_1() + .py_0p5() + .child(cell) + } }), ); @@ -1071,6 +1098,7 @@ pub struct TableRenderContext { pub column_widths: Option>, pub map_row: Option), &mut Window, &mut App) -> AnyElement>>, pub use_ui_font: bool, + pub disable_base_cell_style: bool, } impl TableRenderContext { @@ -1083,6 +1111,7 @@ impl TableRenderContext { column_widths: table.col_widths.as_ref().map(|widths| widths.lengths(cx)), map_row: table.map_row.clone(), use_ui_font: table.use_ui_font, + disable_base_cell_style: table.disable_base_cell_style, } } } diff --git a/crates/ui/src/components/scrollbar.rs b/crates/ui/src/components/scrollbar.rs index 8e8e89be9c0580a7820685b5690a996dfd2dade0..21d6aa46d0f90a0d48e267e935b00d9f263a30c5 100644 --- a/crates/ui/src/components/scrollbar.rs +++ b/crates/ui/src/components/scrollbar.rs @@ -9,8 +9,8 @@ use gpui::{ Along, App, AppContext as _, Axis as ScrollbarAxis, BorderStyle, Bounds, ContentMask, Context, Corner, Corners, CursorStyle, DispatchPhase, Div, Edges, Element, ElementId, Entity, EntityId, GlobalElementId, Hitbox, HitboxBehavior, Hsla, InteractiveElement, IntoElement, IsZero, - LayoutId, ListState, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Negate, - ParentElement, Pixels, Point, Position, Render, ScrollHandle, ScrollWheelEvent, Size, Stateful, + LayoutId, ListState, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, ParentElement, + Pixels, Point, Position, Render, ScrollHandle, ScrollWheelEvent, Size, Stateful, StatefulInteractiveElement, Style, Styled, Task, UniformListDecoration, UniformListScrollHandle, Window, ease_in_out, prelude::FluentBuilder as _, px, quad, relative, size, @@ -258,7 +258,7 @@ impl UniformListDecoration for ScrollbarStateWrapper { _cx: &mut App, ) -> gpui::AnyElement { ScrollbarElement { - origin: scroll_offset.negate(), + origin: -scroll_offset, state: self.0.clone(), } .into_any() @@ -911,7 +911,7 @@ impl ThumbState { } impl ScrollableHandle for UniformListScrollHandle { - fn max_offset(&self) -> Size { + fn max_offset(&self) -> Point { self.0.borrow().base_handle.max_offset() } @@ -929,7 +929,7 @@ impl ScrollableHandle for UniformListScrollHandle { } impl ScrollableHandle for ListState { - fn max_offset(&self) -> Size { + fn max_offset(&self) -> Point { self.max_offset_for_scrollbar() } @@ -955,7 +955,7 @@ impl ScrollableHandle for ListState { } impl ScrollableHandle for ScrollHandle { - fn max_offset(&self) -> Size { + fn max_offset(&self) -> Point { self.max_offset() } @@ -973,7 +973,7 @@ impl ScrollableHandle for ScrollHandle { } pub trait ScrollableHandle: 'static + Any + Sized + Clone { - fn max_offset(&self) -> Size; + fn max_offset(&self) -> Point; fn set_offset(&self, point: Point); fn offset(&self) -> Point; fn viewport(&self) -> Bounds; @@ -984,7 +984,7 @@ pub trait ScrollableHandle: 'static + Any + Sized + Clone { self.max_offset().along(axis) > Pixels::ZERO } fn content_size(&self) -> Size { - self.viewport().size + self.max_offset() + self.viewport().size + self.max_offset().into() } } @@ -1006,7 +1006,7 @@ impl ScrollbarLayout { fn compute_click_offset( &self, event_position: Point, - max_offset: Size, + max_offset: Point, event_type: ScrollbarMouseEvent, ) -> Pixels { let Self { diff --git a/crates/util/src/shell.rs b/crates/util/src/shell.rs index 27ab18b58ce14cc59d57e563103fc9135f93d060..87872856d916ae39809debaeb6c151705367246b 100644 --- a/crates/util/src/shell.rs +++ b/crates/util/src/shell.rs @@ -1012,4 +1012,40 @@ mod tests { "uname".to_string() ); } + + #[test] + fn test_try_quote_single_quote_paths() { + let path_with_quote = r"C:\Temp\O'Brien\repo"; + let shlex_shells = [ + ShellKind::Posix, + ShellKind::Fish, + ShellKind::Csh, + ShellKind::Tcsh, + ShellKind::Rc, + ShellKind::Xonsh, + ShellKind::Elvish, + ShellKind::Nushell, + ]; + + for shell_kind in shlex_shells { + let quoted = shell_kind.try_quote(path_with_quote).unwrap().into_owned(); + assert_ne!(quoted, path_with_quote); + assert_eq!( + shlex::split("ed), + Some(vec![path_with_quote.to_string()]) + ); + + if shell_kind == ShellKind::Nushell { + let prefixed = shell_kind.prepend_command_prefix("ed); + assert!(prefixed.starts_with('^')); + } + } + + for shell_kind in [ShellKind::PowerShell, ShellKind::Pwsh] { + let quoted = shell_kind.try_quote(path_with_quote).unwrap().into_owned(); + assert!(quoted.starts_with('\'')); + assert!(quoted.ends_with('\'')); + assert!(quoted.contains("O''Brien")); + } + } } diff --git a/crates/util/src/shell_env.rs b/crates/util/src/shell_env.rs index 4fc9fd2d69b608c1215495d84c340f11e5be8179..ba9e77cb81086e810af8d17c7f17f2b77f5392d9 100644 --- a/crates/util/src/shell_env.rs +++ b/crates/util/src/shell_env.rs @@ -141,6 +141,14 @@ async fn capture_windows( std::env::current_exe().context("Failed to determine current zed executable path.")?; let shell_kind = ShellKind::new(shell_path, true); + let directory_string = directory.display().to_string(); + let zed_path_string = zed_path.display().to_string(); + let quote_for_shell = |value: &str| { + shell_kind + .try_quote(value) + .map(|quoted| quoted.into_owned()) + .unwrap_or_else(|| value.to_owned()) + }; let mut cmd = crate::command::new_command(shell_path); cmd.args(args); let cmd = match shell_kind { @@ -149,52 +157,54 @@ async fn capture_windows( | ShellKind::Rc | ShellKind::Fish | ShellKind::Xonsh - | ShellKind::Posix => cmd.args([ - "-l", - "-i", - "-c", - &format!( - "cd '{}'; '{}' --printenv", - directory.display(), - zed_path.display() - ), - ]), - ShellKind::PowerShell | ShellKind::Pwsh => cmd.args([ - "-NonInteractive", - "-NoProfile", - "-Command", - &format!( - "Set-Location '{}'; & '{}' --printenv", - directory.display(), - zed_path.display() - ), - ]), - ShellKind::Elvish => cmd.args([ - "-c", - &format!( - "cd '{}'; '{}' --printenv", - directory.display(), - zed_path.display() - ), - ]), - ShellKind::Nushell => cmd.args([ - "-c", - &format!( - "cd '{}'; {}'{}' --printenv", - directory.display(), - shell_kind - .command_prefix() - .map(|prefix| prefix.to_string()) - .unwrap_or_default(), - zed_path.display() - ), - ]), + | ShellKind::Posix => { + let quoted_directory = quote_for_shell(&directory_string); + let quoted_zed_path = quote_for_shell(&zed_path_string); + cmd.args([ + "-l", + "-i", + "-c", + &format!("cd {}; {} --printenv", quoted_directory, quoted_zed_path), + ]) + } + ShellKind::PowerShell | ShellKind::Pwsh => { + let quoted_directory = ShellKind::quote_pwsh(&directory_string); + let quoted_zed_path = ShellKind::quote_pwsh(&zed_path_string); + cmd.args([ + "-NonInteractive", + "-NoProfile", + "-Command", + &format!( + "Set-Location {}; & {} --printenv", + quoted_directory, quoted_zed_path + ), + ]) + } + ShellKind::Elvish => { + let quoted_directory = quote_for_shell(&directory_string); + let quoted_zed_path = quote_for_shell(&zed_path_string); + cmd.args([ + "-c", + &format!("cd {}; {} --printenv", quoted_directory, quoted_zed_path), + ]) + } + ShellKind::Nushell => { + let quoted_directory = quote_for_shell(&directory_string); + let quoted_zed_path = quote_for_shell(&zed_path_string); + let zed_command = shell_kind + .prepend_command_prefix("ed_zed_path) + .into_owned(); + cmd.args([ + "-c", + &format!("cd {}; {} --printenv", quoted_directory, zed_command), + ]) + } ShellKind::Cmd => cmd.args([ "/c", "cd", - &directory.display().to_string(), + &directory_string, "&&", - &zed_path.display().to_string(), + &zed_path_string, "--printenv", ]), } diff --git a/crates/web_search_providers/Cargo.toml b/crates/web_search_providers/Cargo.toml index ecdca5883ff541459e94170986df3b7f16036c5a..ff264edcb150063237c633de746b2f6b9f6f250c 100644 --- a/crates/web_search_providers/Cargo.toml +++ b/crates/web_search_providers/Cargo.toml @@ -14,6 +14,7 @@ path = "src/web_search_providers.rs" [dependencies] anyhow.workspace = true client.workspace = true +cloud_api_types.workspace = true cloud_llm_client.workspace = true futures.workspace = true gpui.workspace = true diff --git a/crates/web_search_providers/src/cloud.rs b/crates/web_search_providers/src/cloud.rs index 2f3ccdbb52a884471250ad458e8b7922437cb9ae..c8bc89953f2b2d3ec62bac07e80f2737522824f7 100644 --- a/crates/web_search_providers/src/cloud.rs +++ b/crates/web_search_providers/src/cloud.rs @@ -1,7 +1,8 @@ use std::sync::Arc; use anyhow::{Context as _, Result}; -use client::Client; +use client::{Client, UserStore}; +use cloud_api_types::OrganizationId; use cloud_llm_client::{WebSearchBody, WebSearchResponse}; use futures::AsyncReadExt as _; use gpui::{App, AppContext, Context, Entity, Subscription, Task}; @@ -14,8 +15,8 @@ pub struct CloudWebSearchProvider { } impl CloudWebSearchProvider { - pub fn new(client: Arc, cx: &mut App) -> Self { - let state = cx.new(|cx| State::new(client, cx)); + pub fn new(client: Arc, user_store: Entity, cx: &mut App) -> Self { + let state = cx.new(|cx| State::new(client, user_store, cx)); Self { state } } @@ -23,24 +24,31 @@ impl CloudWebSearchProvider { pub struct State { client: Arc, + user_store: Entity, llm_api_token: LlmApiToken, _llm_token_subscription: Subscription, } impl State { - pub fn new(client: Arc, cx: &mut Context) -> Self { + pub fn new(client: Arc, user_store: Entity, cx: &mut Context) -> Self { let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); Self { client, + user_store, llm_api_token: LlmApiToken::default(), _llm_token_subscription: cx.subscribe( &refresh_llm_token_listener, |this, _, _event, cx| { let client = this.client.clone(); let llm_api_token = this.llm_api_token.clone(); + let organization_id = this + .user_store + .read(cx) + .current_organization() + .map(|o| o.id.clone()); cx.spawn(async move |_this, _cx| { - llm_api_token.refresh(&client).await?; + llm_api_token.refresh(&client, organization_id).await?; anyhow::Ok(()) }) .detach_and_log_err(cx); @@ -61,21 +69,31 @@ impl WebSearchProvider for CloudWebSearchProvider { let state = self.state.read(cx); let client = state.client.clone(); let llm_api_token = state.llm_api_token.clone(); + let organization_id = state + .user_store + .read(cx) + .current_organization() + .map(|o| o.id.clone()); let body = WebSearchBody { query }; - cx.background_spawn(async move { perform_web_search(client, llm_api_token, body).await }) + cx.background_spawn(async move { + perform_web_search(client, llm_api_token, organization_id, body).await + }) } } async fn perform_web_search( client: Arc, llm_api_token: LlmApiToken, + organization_id: Option, body: WebSearchBody, ) -> Result { const MAX_RETRIES: usize = 3; let http_client = &client.http_client(); let mut retries_remaining = MAX_RETRIES; - let mut token = llm_api_token.acquire(&client).await?; + let mut token = llm_api_token + .acquire(&client, organization_id.clone()) + .await?; loop { if retries_remaining == 0 { @@ -100,7 +118,9 @@ async fn perform_web_search( response.body_mut().read_to_string(&mut body).await?; return Ok(serde_json::from_str(&body)?); } else if response.needs_llm_token_refresh() { - token = llm_api_token.refresh(&client).await?; + token = llm_api_token + .refresh(&client, organization_id.clone()) + .await?; retries_remaining -= 1; } else { // For now we will only retry if the LLM token is expired, diff --git a/crates/web_search_providers/src/web_search_providers.rs b/crates/web_search_providers/src/web_search_providers.rs index 8ab0aee47a414c4cc669ab05e727a827d17c2844..509632429fb167cd489cd4253ceae0ce479b10a8 100644 --- a/crates/web_search_providers/src/web_search_providers.rs +++ b/crates/web_search_providers/src/web_search_providers.rs @@ -1,26 +1,28 @@ mod cloud; -use client::Client; +use client::{Client, UserStore}; use gpui::{App, Context, Entity}; use language_model::LanguageModelRegistry; use std::sync::Arc; use web_search::{WebSearchProviderId, WebSearchRegistry}; -pub fn init(client: Arc, cx: &mut App) { +pub fn init(client: Arc, user_store: Entity, cx: &mut App) { let registry = WebSearchRegistry::global(cx); registry.update(cx, |registry, cx| { - register_web_search_providers(registry, client, cx); + register_web_search_providers(registry, client, user_store, cx); }); } fn register_web_search_providers( registry: &mut WebSearchRegistry, client: Arc, + user_store: Entity, cx: &mut Context, ) { register_zed_web_search_provider( registry, client.clone(), + user_store.clone(), &LanguageModelRegistry::global(cx), cx, ); @@ -29,7 +31,13 @@ fn register_web_search_providers( &LanguageModelRegistry::global(cx), move |this, registry, event, cx| { if let language_model::Event::DefaultModelChanged = event { - register_zed_web_search_provider(this, client.clone(), ®istry, cx) + register_zed_web_search_provider( + this, + client.clone(), + user_store.clone(), + ®istry, + cx, + ) } }, ) @@ -39,6 +47,7 @@ fn register_web_search_providers( fn register_zed_web_search_provider( registry: &mut WebSearchRegistry, client: Arc, + user_store: Entity, language_model_registry: &Entity, cx: &mut Context, ) { @@ -47,7 +56,10 @@ fn register_zed_web_search_provider( .default_model() .is_some_and(|default| default.is_provided_by_zed()); if using_zed_provider { - registry.register_provider(cloud::CloudWebSearchProvider::new(client, cx), cx) + registry.register_provider( + cloud::CloudWebSearchProvider::new(client, user_store, cx), + cx, + ) } else { registry.unregister_provider(WebSearchProviderId( cloud::ZED_WEB_SEARCH_PROVIDER_ID.into(), diff --git a/crates/workspace/Cargo.toml b/crates/workspace/Cargo.toml index dcd0bf640fdf279fb1874ba77307ccbd3c431393..84fd10c8c03e4f7411fc8c813b70255f5e00031d 100644 --- a/crates/workspace/Cargo.toml +++ b/crates/workspace/Cargo.toml @@ -14,7 +14,6 @@ doctest = false [features] test-support = [ - "call/test-support", "client/test-support", "http_client/test-support", "db/test-support", @@ -72,7 +71,6 @@ zed_actions.workspace = true windows.workspace = true [dev-dependencies] -call = { workspace = true, features = ["test-support"] } client = { workspace = true, features = ["test-support"] } dap = { workspace = true, features = ["test-support"] } db = { workspace = true, features = ["test-support"] } diff --git a/crates/workspace/src/pane.rs b/crates/workspace/src/pane.rs index a39be125a5784b8c9d995bb750b9d7ff57a67191..81283427e83afb820b113250545d90f787030e25 100644 --- a/crates/workspace/src/pane.rs +++ b/crates/workspace/src/pane.rs @@ -3450,7 +3450,7 @@ impl Pane { cx, ) .children(pinned_tabs.len().ne(&0).then(|| { - let max_scroll = self.tab_bar_scroll_handle.max_offset().width; + let max_scroll = self.tab_bar_scroll_handle.max_offset().x; // We need to check both because offset returns delta values even when the scroll handle is not scrollable let is_scrolled = self.tab_bar_scroll_handle.offset().x < px(0.); // Avoid flickering when max_offset is very small (< 2px). @@ -7974,7 +7974,7 @@ mod tests { let scroll_handle = pane.update_in(cx, |pane, _window, _cx| pane.tab_bar_scroll_handle.clone()); assert!( - scroll_handle.max_offset().width > px(0.), + scroll_handle.max_offset().x > px(0.), "Test requires tab overflow to verify scrolling. Increase tab count or reduce window width." ); diff --git a/crates/workspace/src/persistence/model.rs b/crates/workspace/src/persistence/model.rs index cdb646ec3b8248bdd0b5784424ed7b8df8ac0ee8..0971ebd0ddc9265ccf9ea10da7745ba59914db30 100644 --- a/crates/workspace/src/persistence/model.rs +++ b/crates/workspace/src/persistence/model.rs @@ -93,9 +93,9 @@ pub(crate) struct SerializedWorkspace { #[derive(Debug, PartialEq, Clone, Default, Serialize, Deserialize)] pub struct DockStructure { - pub(crate) left: DockData, - pub(crate) right: DockData, - pub(crate) bottom: DockData, + pub left: DockData, + pub right: DockData, + pub bottom: DockData, } impl RemoteConnectionKind { @@ -143,9 +143,9 @@ impl Bind for DockStructure { #[derive(Debug, PartialEq, Clone, Default, Serialize, Deserialize)] pub struct DockData { - pub(crate) visible: bool, - pub(crate) active_panel: Option, - pub(crate) zoom: bool, + pub visible: bool, + pub active_panel: Option, + pub zoom: bool, } impl Column for DockData { diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index b62f6b5eb60eafb7177f7883b825a208e7c81d62..3839b4446e7399536a12e7951c004cce81d5c4e6 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -79,7 +79,10 @@ pub use pane_group::{ use persistence::{DB, SerializedWindowBounds, model::SerializedWorkspace}; pub use persistence::{ DB as WORKSPACE_DB, WorkspaceDb, delete_unloaded_items, - model::{ItemId, SerializedMultiWorkspace, SerializedWorkspaceLocation, SessionWorkspace}, + model::{ + DockStructure, ItemId, SerializedMultiWorkspace, SerializedWorkspaceLocation, + SessionWorkspace, + }, read_serialized_multi_workspaces, }; use postage::stream::Stream; @@ -149,7 +152,7 @@ use crate::{item::ItemBufferKind, notifications::NotificationId}; use crate::{ persistence::{ SerializedAxis, - model::{DockData, DockStructure, SerializedItem, SerializedPane, SerializedPaneGroup}, + model::{DockData, SerializedItem, SerializedPane, SerializedPaneGroup}, }, security_modal::SecurityModal, }; @@ -628,7 +631,7 @@ fn prompt_and_open_paths(app_state: Arc, options: PathPromptOptions, c }) .ok(); } else { - let task = Workspace::new_local(Vec::new(), app_state.clone(), None, None, None, cx); + let task = Workspace::new_local(Vec::new(), app_state.clone(), None, None, None, true, cx); cx.spawn(async move |cx| { let (window, _) = task.await?; window.update(cx, |multi_workspace, window, cx| { @@ -1290,6 +1293,7 @@ pub struct Workspace { scheduled_tasks: Vec>, last_open_dock_positions: Vec, removing: bool, + _panels_task: Option>>, } impl EventEmitter for Workspace {} @@ -1660,6 +1664,7 @@ impl Workspace { left_dock, bottom_dock, right_dock, + _panels_task: None, project: project.clone(), follower_states: Default::default(), last_leaders_by_pane: Default::default(), @@ -1703,6 +1708,7 @@ impl Workspace { requesting_window: Option>, env: Option>, init: Option) + Send>>, + activate: bool, cx: &mut App, ) -> Task< anyhow::Result<( @@ -1830,7 +1836,11 @@ impl Workspace { workspace }); - multi_workspace.activate(workspace.clone(), cx); + if activate { + multi_workspace.activate(workspace.clone(), cx); + } else { + multi_workspace.add_workspace(workspace.clone(), cx); + } workspace })?; (window, workspace) @@ -1984,6 +1994,76 @@ impl Workspace { [&self.left_dock, &self.bottom_dock, &self.right_dock] } + pub fn capture_dock_state(&self, _window: &Window, cx: &App) -> DockStructure { + let left_dock = self.left_dock.read(cx); + let left_visible = left_dock.is_open(); + let left_active_panel = left_dock + .active_panel() + .map(|panel| panel.persistent_name().to_string()); + // `zoomed_position` is kept in sync with individual panel zoom state + // by the dock code in `Dock::new` and `Dock::add_panel`. + let left_dock_zoom = self.zoomed_position == Some(DockPosition::Left); + + let right_dock = self.right_dock.read(cx); + let right_visible = right_dock.is_open(); + let right_active_panel = right_dock + .active_panel() + .map(|panel| panel.persistent_name().to_string()); + let right_dock_zoom = self.zoomed_position == Some(DockPosition::Right); + + let bottom_dock = self.bottom_dock.read(cx); + let bottom_visible = bottom_dock.is_open(); + let bottom_active_panel = bottom_dock + .active_panel() + .map(|panel| panel.persistent_name().to_string()); + let bottom_dock_zoom = self.zoomed_position == Some(DockPosition::Bottom); + + DockStructure { + left: DockData { + visible: left_visible, + active_panel: left_active_panel, + zoom: left_dock_zoom, + }, + right: DockData { + visible: right_visible, + active_panel: right_active_panel, + zoom: right_dock_zoom, + }, + bottom: DockData { + visible: bottom_visible, + active_panel: bottom_active_panel, + zoom: bottom_dock_zoom, + }, + } + } + + pub fn set_dock_structure( + &self, + docks: DockStructure, + window: &mut Window, + cx: &mut Context, + ) { + for (dock, data) in [ + (&self.left_dock, docks.left), + (&self.bottom_dock, docks.bottom), + (&self.right_dock, docks.right), + ] { + dock.update(cx, |dock, cx| { + dock.serialized_dock = Some(data); + dock.restore_state(window, cx); + }); + } + } + + pub fn open_item_abs_paths(&self, cx: &App) -> Vec { + self.items(cx) + .filter_map(|item| { + let project_path = item.project_path(cx)?; + self.project.read(cx).absolute_path(&project_path, cx) + }) + .collect() + } + pub fn dock_at_position(&self, position: DockPosition) -> &Entity { match position { DockPosition::Left => &self.left_dock, @@ -2043,6 +2123,14 @@ impl Workspace { &self.app_state } + pub fn set_panels_task(&mut self, task: Task>) { + self._panels_task = Some(task); + } + + pub fn take_panels_task(&mut self) -> Option>> { + self._panels_task.take() + } + pub fn user_store(&self) -> &Entity { &self.app_state.user_store } @@ -2548,7 +2636,15 @@ impl Workspace { Task::ready(Ok(callback(self, window, cx))) } else { let env = self.project.read(cx).cli_environment(cx); - let task = Self::new_local(Vec::new(), self.app_state.clone(), None, env, None, cx); + let task = Self::new_local( + Vec::new(), + self.app_state.clone(), + None, + env, + None, + true, + cx, + ); cx.spawn_in(window, async move |_vh, cx| { let (multi_workspace_window, _) = task.await?; multi_workspace_window.update(cx, |multi_workspace, window, cx| { @@ -2578,7 +2674,15 @@ impl Workspace { Task::ready(Ok(callback(self, window, cx))) } else { let env = self.project.read(cx).cli_environment(cx); - let task = Self::new_local(Vec::new(), self.app_state.clone(), None, env, None, cx); + let task = Self::new_local( + Vec::new(), + self.app_state.clone(), + None, + env, + None, + true, + cx, + ); cx.spawn_in(window, async move |_vh, cx| { let (multi_workspace_window, _) = task.await?; multi_workspace_window.update(cx, |multi_workspace, window, cx| { @@ -6012,53 +6116,7 @@ impl Workspace { window: &mut Window, cx: &mut App, ) -> DockStructure { - let left_dock = this.left_dock.read(cx); - let left_visible = left_dock.is_open(); - let left_active_panel = left_dock - .active_panel() - .map(|panel| panel.persistent_name().to_string()); - let left_dock_zoom = left_dock - .active_panel() - .map(|panel| panel.is_zoomed(window, cx)) - .unwrap_or(false); - - let right_dock = this.right_dock.read(cx); - let right_visible = right_dock.is_open(); - let right_active_panel = right_dock - .active_panel() - .map(|panel| panel.persistent_name().to_string()); - let right_dock_zoom = right_dock - .active_panel() - .map(|panel| panel.is_zoomed(window, cx)) - .unwrap_or(false); - - let bottom_dock = this.bottom_dock.read(cx); - let bottom_visible = bottom_dock.is_open(); - let bottom_active_panel = bottom_dock - .active_panel() - .map(|panel| panel.persistent_name().to_string()); - let bottom_dock_zoom = bottom_dock - .active_panel() - .map(|panel| panel.is_zoomed(window, cx)) - .unwrap_or(false); - - DockStructure { - left: DockData { - visible: left_visible, - active_panel: left_active_panel, - zoom: left_dock_zoom, - }, - right: DockData { - visible: right_visible, - active_panel: right_active_panel, - zoom: right_dock_zoom, - }, - bottom: DockData { - visible: bottom_visible, - active_panel: bottom_active_panel, - zoom: bottom_dock_zoom, - }, - } + this.capture_dock_state(window, cx) } match self.workspace_location(cx) { @@ -8087,6 +8145,7 @@ pub async fn restore_multiworkspace( None, None, None, + true, cx, ) }) @@ -8116,6 +8175,7 @@ pub async fn restore_multiworkspace( Some(window_handle), None, None, + true, cx, ) }) @@ -8385,6 +8445,7 @@ pub fn join_channel( requesting_window, None, None, + true, cx, ) }) @@ -8457,7 +8518,7 @@ pub async fn get_any_active_multi_workspace( // find an existing workspace to focus and show call controls let active_window = activate_any_workspace_window(&mut cx); if active_window.is_none() { - cx.update(|cx| Workspace::new_local(vec![], app_state.clone(), None, None, None, cx)) + cx.update(|cx| Workspace::new_local(vec![], app_state.clone(), None, None, None, true, cx)) .await?; } activate_any_workspace_window(&mut cx).context("could not open zed") @@ -8845,6 +8906,7 @@ pub fn open_paths( open_options.replace_window, open_options.env, None, + true, cx, ) }) @@ -8908,6 +8970,7 @@ pub fn open_new( open_options.replace_window, open_options.env, Some(Box::new(init)), + true, cx, ); cx.spawn(async move |cx| { diff --git a/crates/worktree/src/worktree.rs b/crates/worktree/src/worktree.rs index 69b0be24e7ffb09d3fe759ec0bd3d54b54db21d3..9e62beb3c375fb8d580be02382091cafe04d31e2 100644 --- a/crates/worktree/src/worktree.rs +++ b/crates/worktree/src/worktree.rs @@ -2945,7 +2945,7 @@ impl BackgroundScannerState { self.snapshot.check_invariants(false); } - fn remove_path(&mut self, path: &RelPath) { + fn remove_path(&mut self, path: &RelPath, watcher: &dyn Watcher) { log::trace!("background scanner removing path {path:?}"); let mut new_entries; let removed_entries; @@ -2961,7 +2961,12 @@ impl BackgroundScannerState { self.snapshot.entries_by_path = new_entries; let mut removed_ids = Vec::with_capacity(removed_entries.summary().count); + let mut removed_dir_abs_paths = Vec::new(); for entry in removed_entries.cursor::<()>(()) { + if entry.is_dir() { + removed_dir_abs_paths.push(self.snapshot.absolutize(&entry.path)); + } + match self.removed_entries.entry(entry.inode) { hash_map::Entry::Occupied(mut e) => { let prev_removed_entry = e.get_mut(); @@ -2997,6 +3002,10 @@ impl BackgroundScannerState { .git_repositories .retain(|id, _| removed_ids.binary_search(id).is_err()); + for removed_dir_abs_path in removed_dir_abs_paths { + watcher.remove(&removed_dir_abs_path).log_err(); + } + #[cfg(feature = "test-support")] self.snapshot.check_invariants(false); } @@ -4461,7 +4470,10 @@ impl BackgroundScanner { if self.settings.is_path_excluded(&child_path) { log::debug!("skipping excluded child entry {child_path:?}"); - self.state.lock().await.remove_path(&child_path); + self.state + .lock() + .await + .remove_path(&child_path, self.watcher.as_ref()); continue; } @@ -4651,7 +4663,7 @@ impl BackgroundScanner { // detected regardless of the order of the paths. for (path, metadata) in relative_paths.iter().zip(metadata.iter()) { if matches!(metadata, Ok(None)) || doing_recursive_update { - state.remove_path(path); + state.remove_path(path, self.watcher.as_ref()); } } diff --git a/crates/x_ai/src/x_ai.rs b/crates/x_ai/src/x_ai.rs index 072a893a6a8f4fc7fbc8a6f4f5ed43316915b974..1abb2b53771fa1e29e2979560e9f394744b26158 100644 --- a/crates/x_ai/src/x_ai.rs +++ b/crates/x_ai/src/x_ai.rs @@ -165,6 +165,18 @@ impl Model { } } + pub fn requires_json_schema_subset(&self) -> bool { + match self { + Self::Grok4 + | Self::Grok4FastReasoning + | Self::Grok4FastNonReasoning + | Self::Grok41FastNonReasoning + | Self::Grok41FastReasoning + | Self::GrokCodeFast1 => true, + _ => false, + } + } + pub fn supports_prompt_cache_key(&self) -> bool { false } diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index cf8df08c010bfe643b93b5628cf520ee2ec1dd8b..6ea308db5a32cf82e48439c477c8bb81f02ab777 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -2,7 +2,7 @@ description = "The fast, collaborative code editor." edition.workspace = true name = "zed" -version = "0.227.0" +version = "0.228.0" publish.workspace = true license = "GPL-3.0-or-later" authors = ["Zed Team "] @@ -17,7 +17,6 @@ test-support = [ "gpui/test-support", "gpui_platform/screen-capture", "dep:image", - "dep:semver", "workspace/test-support", "project/test-support", "editor/test-support", @@ -32,7 +31,6 @@ visual-tests = [ "gpui_platform/screen-capture", "gpui_platform/test-support", "dep:image", - "dep:semver", "dep:tempfile", "dep:action_log", "dep:agent_servers", @@ -76,7 +74,6 @@ assets.workspace = true audio.workspace = true auto_update.workspace = true auto_update_ui.workspace = true -bincode.workspace = true breadcrumbs.workspace = true call.workspace = true chrono.workspace = true @@ -94,6 +91,7 @@ copilot.workspace = true copilot_chat.workspace = true copilot_ui.workspace = true crashes.workspace = true +csv_preview.workspace = true dap_adapters.workspace = true db.workspace = true debug_adapter_extension.workspace = true @@ -121,7 +119,7 @@ system_specs.workspace = true gpui.workspace = true gpui_platform = {workspace = true, features=["screen-capture", "font-kit", "wayland", "x11"]} image = { workspace = true, optional = true } -semver = { workspace = true, optional = true } +semver.workspace = true tempfile = { workspace = true, optional = true } clock = { workspace = true, optional = true } acp_thread.workspace = true diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index e93bd92d041a18e927e1560379bcdb2886605874..0d50339f6c9d42ffa653e5c7565ae6e22441bdca 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -335,7 +335,13 @@ fn main() { crashes::init( InitCrashHandler { session_id, - zed_version: app_version.to_string(), + // strip the build and channel information from the version string, we send them separately + zed_version: semver::Version::new( + app_version.major, + app_version.minor, + app_version.patch, + ) + .to_string(), binary: "zed".to_string(), release_channel: release_channel::RELEASE_CHANNEL_NAME.clone(), commit_sha: app_commit_sha @@ -573,6 +579,19 @@ fn main() { session.id().to_owned(), cx, ); + cx.subscribe(&user_store, { + let telemetry = telemetry.clone(); + move |_, evt: &client::user::Event, _| match evt { + client::user::Event::PrivateUserInfoUpdated => { + crashes::set_user_info(crashes::UserInfo { + metrics_id: telemetry.metrics_id().map(|s| s.to_string()), + is_staff: telemetry.is_staff(), + }); + } + _ => {} + } + }) + .detach(); // We should rename these in the future to `first app open`, `first app open for release channel`, and `app open` if let (Some(system_id), Some(installation_id)) = (&system_id, &installation_id) { @@ -645,7 +664,7 @@ fn main() { zed::remote_debug::init(cx); edit_prediction_ui::init(cx); web_search::init(cx); - web_search_providers::init(app_state.client.clone(), cx); + web_search_providers::init(app_state.client.clone(), app_state.user_store.clone(), cx); snippet_provider::init(cx); edit_prediction_registry::init(app_state.client.clone(), app_state.user_store.clone(), cx); let prompt_builder = PromptBuilder::load(app_state.fs.clone(), stdout_is_a_pty(), cx); @@ -715,6 +734,7 @@ fn main() { git_graph::init(cx); feedback::init(cx); markdown_preview::init(cx); + csv_preview::init(cx); svg_preview::init(cx); onboarding::init(cx); settings_ui::init(cx); diff --git a/crates/zed/src/reliability.rs b/crates/zed/src/reliability.rs index b291b9c8493db75e20282c8c9bc5a3750fb5e705..2f284027929b19e5b0d5ac084267cf5548cda667 100644 --- a/crates/zed/src/reliability.rs +++ b/crates/zed/src/reliability.rs @@ -144,7 +144,7 @@ fn cleanup_old_hang_traces() { entry .path() .extension() - .is_some_and(|ext| ext == "miniprof") + .is_some_and(|ext| ext == "json" || ext == "miniprof") }) .collect(); @@ -175,7 +175,7 @@ fn save_hang_trace( .collect::>(); let trace_path = paths::hang_traces_dir().join(&format!( - "hang-{}.miniprof", + "hang-{}.miniprof.json", hang_time.format("%Y-%m-%d_%H-%M-%S") )); @@ -193,7 +193,7 @@ fn save_hang_trace( entry .path() .extension() - .is_some_and(|ext| ext == "miniprof") + .is_some_and(|ext| ext == "json" || ext == "miniprof") }) .collect(); @@ -288,16 +288,23 @@ async fn upload_minidump( form = form.text("minidump_error", minidump_error); } - if let Some(id) = client.telemetry().metrics_id() { - form = form.text("sentry[user][id]", id.to_string()); + if let Some(is_staff) = &metadata + .user_info + .as_ref() + .and_then(|user_info| user_info.is_staff) + { form = form.text( "sentry[user][is_staff]", - if client.telemetry().is_staff().unwrap_or_default() { - "true" - } else { - "false" - }, + if *is_staff { "true" } else { "false" }, ); + } + + if let Some(metrics_id) = metadata + .user_info + .as_ref() + .and_then(|user_info| user_info.metrics_id.as_ref()) + { + form = form.text("sentry[user][id]", metrics_id.clone()); } else if let Some(id) = client.telemetry().installation_id() { form = form.text("sentry[user][id]", format!("installation-{}", id)) } diff --git a/crates/zed/src/visual_test_runner.rs b/crates/zed/src/visual_test_runner.rs index 0ae98d510aa34b05f7fa1766176f21ea353394d9..8f005fa68b6accb5cf5686157bbb065e33bb1b0c 100644 --- a/crates/zed/src/visual_test_runner.rs +++ b/crates/zed/src/visual_test_runner.rs @@ -71,7 +71,7 @@ use { time::Duration, }, util::ResultExt as _, - workspace::{AppState, MultiWorkspace, Workspace, WorkspaceId}, + workspace::{AppState, MultiWorkspace, Panel as _, Workspace, WorkspaceId}, zed_actions::OpenSettingsAt, }; @@ -548,6 +548,27 @@ fn run_visual_tests(project_path: PathBuf, update_baseline: bool) -> Result<()> } } + // Run Test 11: Thread target selector visual tests + #[cfg(feature = "visual-tests")] + { + println!("\n--- Test 11: start_thread_in_selector (6 variants) ---"); + match run_start_thread_in_selector_visual_tests(app_state.clone(), &mut cx, update_baseline) + { + Ok(TestResult::Passed) => { + println!("✓ start_thread_in_selector: PASSED"); + passed += 1; + } + Ok(TestResult::BaselineUpdated(_)) => { + println!("✓ start_thread_in_selector: Baselines updated"); + updated += 1; + } + Err(e) => { + eprintln!("✗ start_thread_in_selector: FAILED - {}", e); + failed += 1; + } + } + } + // Run Test 9: Tool Permissions Settings UI visual test println!("\n--- Test 9: tool_permissions_settings ---"); match run_tool_permissions_visual_tests(app_state.clone(), &mut cx, update_baseline) { @@ -2011,32 +2032,9 @@ fn run_agent_thread_view_test( // Create the necessary entities for the ReadFileTool let action_log = cx.update(|cx| cx.new(|_| action_log::ActionLog::new(project.clone()))); - let context_server_registry = cx.update(|cx| { - cx.new(|cx| agent::ContextServerRegistry::new(project.read(cx).context_server_store(), cx)) - }); - let fake_model = Arc::new(language_model::fake_provider::FakeLanguageModel::default()); - let project_context = cx.update(|cx| cx.new(|_| prompt_store::ProjectContext::default())); - - // Create the agent Thread - let thread = cx.update(|cx| { - cx.new(|cx| { - agent::Thread::new( - project.clone(), - project_context, - context_server_registry, - agent::Templates::new(), - Some(fake_model), - cx, - ) - }) - }); // Create the ReadFileTool - let tool = Arc::new(agent::ReadFileTool::new( - thread.downgrade(), - project.clone(), - action_log, - )); + let tool = Arc::new(agent::ReadFileTool::new(project.clone(), action_log, true)); // Create a test event stream to capture tool output let (event_stream, mut event_receiver) = agent::ToolCallEventStream::test(); @@ -3066,3 +3064,629 @@ fn run_error_wrapping_visual_tests( Ok(test_result) } + +#[cfg(all(target_os = "macos", feature = "visual-tests"))] +/// Runs a git command in the given directory and returns an error with +/// stderr/stdout context if the command fails (non-zero exit status). +fn run_git_command(args: &[&str], dir: &std::path::Path) -> Result<()> { + let output = std::process::Command::new("git") + .args(args) + .current_dir(dir) + .output() + .with_context(|| format!("failed to spawn `git {}`", args.join(" ")))?; + + if !output.status.success() { + let stdout = String::from_utf8_lossy(&output.stdout); + let stderr = String::from_utf8_lossy(&output.stderr); + anyhow::bail!( + "`git {}` failed (exit {})\nstdout: {}\nstderr: {}", + args.join(" "), + output.status, + stdout.trim(), + stderr.trim(), + ); + } + Ok(()) +} + +#[cfg(all(target_os = "macos", feature = "visual-tests"))] +fn run_start_thread_in_selector_visual_tests( + app_state: Arc, + cx: &mut VisualTestAppContext, + update_baseline: bool, +) -> Result { + use agent_ui::{AgentPanel, StartThreadIn, WorktreeCreationStatus}; + + // Enable feature flags so the thread target selector renders + cx.update(|cx| { + cx.update_flags( + true, + vec!["agent-v2".to_string(), "agent-git-worktrees".to_string()], + ); + }); + + // Create a temp directory with a real git repo so "New Worktree" is enabled + let temp_dir = tempfile::tempdir()?; + let temp_path = temp_dir.keep(); + let canonical_temp = temp_path.canonicalize()?; + let project_path = canonical_temp.join("project"); + std::fs::create_dir_all(&project_path)?; + + // Initialize git repo + run_git_command(&["init"], &project_path)?; + run_git_command(&["config", "user.email", "test@test.com"], &project_path)?; + run_git_command(&["config", "user.name", "Test User"], &project_path)?; + + // Create source files + let src_dir = project_path.join("src"); + std::fs::create_dir_all(&src_dir)?; + std::fs::write( + src_dir.join("main.rs"), + r#"fn main() { + println!("Hello, world!"); + + let x = 42; + let y = x * 2; + + if y > 50 { + println!("y is greater than 50"); + } else { + println!("y is not greater than 50"); + } + + for i in 0..10 { + println!("i = {}", i); + } +} + +fn helper_function(a: i32, b: i32) -> i32 { + a + b +} +"#, + )?; + + std::fs::write( + project_path.join("Cargo.toml"), + r#"[package] +name = "test_project" +version = "0.1.0" +edition = "2021" +"#, + )?; + + // Commit so git status is clean + run_git_command(&["add", "."], &project_path)?; + run_git_command(&["commit", "-m", "Initial commit"], &project_path)?; + + let project = cx.update(|cx| { + project::Project::local( + app_state.client.clone(), + app_state.node_runtime.clone(), + app_state.user_store.clone(), + app_state.languages.clone(), + app_state.fs.clone(), + None, + project::LocalProjectFlags { + init_worktree_trust: false, + ..Default::default() + }, + cx, + ) + }); + + // Use a wide window so we see project panel + editor + agent panel + let window_size = size(px(1280.0), px(800.0)); + let bounds = Bounds { + origin: point(px(0.0), px(0.0)), + size: window_size, + }; + + let workspace_window: WindowHandle = cx + .update(|cx| { + cx.open_window( + WindowOptions { + window_bounds: Some(WindowBounds::Windowed(bounds)), + focus: false, + show: false, + ..Default::default() + }, + |window, cx| { + let workspace = cx.new(|cx| { + Workspace::new(None, project.clone(), app_state.clone(), window, cx) + }); + cx.new(|cx| MultiWorkspace::new(workspace, window, cx)) + }, + ) + }) + .context("Failed to open thread target selector test window")?; + + cx.run_until_parked(); + + // Create and register the workspace sidebar + let sidebar = workspace_window + .update(cx, |_multi_workspace, window, cx| { + let multi_workspace_handle = cx.entity(); + cx.new(|cx| sidebar::Sidebar::new(multi_workspace_handle, window, cx)) + }) + .context("Failed to create sidebar")?; + + workspace_window + .update(cx, |multi_workspace, window, cx| { + multi_workspace.register_sidebar(sidebar.clone(), window, cx); + }) + .context("Failed to register sidebar")?; + + // Open the sidebar + workspace_window + .update(cx, |multi_workspace, window, cx| { + multi_workspace.toggle_sidebar(window, cx); + }) + .context("Failed to toggle sidebar")?; + + cx.run_until_parked(); + + // Add the git project as a worktree + let add_worktree_task = workspace_window + .update(cx, |multi_workspace, _window, cx| { + let workspace = &multi_workspace.workspaces()[0]; + let project = workspace.read(cx).project().clone(); + project.update(cx, |project, cx| { + project.find_or_create_worktree(&project_path, true, cx) + }) + }) + .context("Failed to start adding worktree")?; + + cx.background_executor.allow_parking(); + cx.foreground_executor + .block_test(add_worktree_task) + .context("Failed to add worktree")?; + cx.background_executor.forbid_parking(); + + cx.run_until_parked(); + + // Wait for worktree scan and git status + for _ in 0..5 { + cx.advance_clock(Duration::from_millis(100)); + cx.run_until_parked(); + } + + // Open the project panel + let (weak_workspace, async_window_cx) = workspace_window + .update(cx, |multi_workspace, window, cx| { + let workspace = &multi_workspace.workspaces()[0]; + (workspace.read(cx).weak_handle(), window.to_async(cx)) + }) + .context("Failed to get workspace handle")?; + + cx.background_executor.allow_parking(); + let project_panel = cx + .foreground_executor + .block_test(ProjectPanel::load(weak_workspace, async_window_cx)) + .context("Failed to load project panel")?; + cx.background_executor.forbid_parking(); + + workspace_window + .update(cx, |multi_workspace, window, cx| { + let workspace = &multi_workspace.workspaces()[0]; + workspace.update(cx, |workspace, cx| { + workspace.add_panel(project_panel, window, cx); + workspace.open_panel::(window, cx); + }); + }) + .context("Failed to add project panel")?; + + cx.run_until_parked(); + + // Open main.rs in the editor + let open_file_task = workspace_window + .update(cx, |multi_workspace, window, cx| { + let workspace = &multi_workspace.workspaces()[0]; + workspace.update(cx, |workspace, cx| { + let worktree = workspace.project().read(cx).worktrees(cx).next(); + if let Some(worktree) = worktree { + let worktree_id = worktree.read(cx).id(); + let rel_path: std::sync::Arc = + util::rel_path::rel_path("src/main.rs").into(); + let project_path: project::ProjectPath = (worktree_id, rel_path).into(); + Some(workspace.open_path(project_path, None, true, window, cx)) + } else { + None + } + }) + }) + .log_err() + .flatten(); + + if let Some(task) = open_file_task { + cx.background_executor.allow_parking(); + cx.foreground_executor.block_test(task).log_err(); + cx.background_executor.forbid_parking(); + } + + cx.run_until_parked(); + + // Load the AgentPanel + let (weak_workspace, async_window_cx) = workspace_window + .update(cx, |multi_workspace, window, cx| { + let workspace = &multi_workspace.workspaces()[0]; + (workspace.read(cx).weak_handle(), window.to_async(cx)) + }) + .context("Failed to get workspace handle for agent panel")?; + + let prompt_builder = + cx.update(|cx| prompt_store::PromptBuilder::load(app_state.fs.clone(), false, cx)); + + // Register an observer so that workspaces created by the worktree creation + // flow get AgentPanel and ProjectPanel loaded automatically. Without this, + // `workspace.panel::(cx)` returns None in the new workspace and + // the creation flow's `focus_panel::` call is a no-op. + let _workspace_observer = cx.update({ + let prompt_builder = prompt_builder.clone(); + |cx| { + cx.observe_new(move |workspace: &mut Workspace, window, cx| { + let Some(window) = window else { return }; + let prompt_builder = prompt_builder.clone(); + let panels_task = cx.spawn_in(window, async move |workspace_handle, cx| { + let project_panel = ProjectPanel::load(workspace_handle.clone(), cx.clone()); + let agent_panel = + AgentPanel::load(workspace_handle.clone(), prompt_builder, cx.clone()); + if let Ok(panel) = project_panel.await { + workspace_handle + .update_in(cx, |workspace, window, cx| { + workspace.add_panel(panel, window, cx); + }) + .log_err(); + } + if let Ok(panel) = agent_panel.await { + workspace_handle + .update_in(cx, |workspace, window, cx| { + workspace.add_panel(panel, window, cx); + }) + .log_err(); + } + anyhow::Ok(()) + }); + workspace.set_panels_task(panels_task); + }) + } + }); + + cx.background_executor.allow_parking(); + let panel = cx + .foreground_executor + .block_test(AgentPanel::load( + weak_workspace, + prompt_builder, + async_window_cx, + )) + .context("Failed to load AgentPanel")?; + cx.background_executor.forbid_parking(); + + workspace_window + .update(cx, |multi_workspace, window, cx| { + let workspace = &multi_workspace.workspaces()[0]; + workspace.update(cx, |workspace, cx| { + workspace.add_panel(panel.clone(), window, cx); + workspace.open_panel::(window, cx); + }); + }) + .context("Failed to add and open AgentPanel")?; + + cx.run_until_parked(); + + // Inject the stub server and open a thread so the toolbar is visible + let connection = StubAgentConnection::new(); + let stub_agent: Rc = Rc::new(StubAgentServer::new(connection)); + + cx.update_window(workspace_window.into(), |_, window, cx| { + panel.update(cx, |panel, cx| { + panel.open_external_thread_with_server(stub_agent.clone(), window, cx); + }); + })?; + + cx.run_until_parked(); + + // ---- Screenshot 1: Default "Local Project" selector (dropdown closed) ---- + cx.update_window(workspace_window.into(), |_, window, _cx| { + window.refresh(); + })?; + cx.run_until_parked(); + + let result_default = run_visual_test( + "start_thread_in_selector_default", + workspace_window.into(), + cx, + update_baseline, + ); + + // ---- Screenshot 2: Dropdown open showing menu entries ---- + cx.update_window(workspace_window.into(), |_, window, cx| { + panel.update(cx, |panel, cx| { + panel.open_start_thread_in_menu_for_tests(window, cx); + }); + })?; + cx.run_until_parked(); + + cx.update_window(workspace_window.into(), |_, window, _cx| { + window.refresh(); + })?; + cx.run_until_parked(); + + let result_open_dropdown = run_visual_test( + "start_thread_in_selector_open", + workspace_window.into(), + cx, + update_baseline, + ); + + // ---- Screenshot 3: "New Worktree" selected (dropdown closed, label changed) ---- + // First dismiss the dropdown, then change the target so the toolbar label is visible + cx.update_window(workspace_window.into(), |_, _window, cx| { + panel.update(cx, |panel, cx| { + panel.close_start_thread_in_menu_for_tests(cx); + }); + })?; + cx.run_until_parked(); + + cx.update_window(workspace_window.into(), |_, _window, cx| { + panel.update(cx, |panel, cx| { + panel.set_start_thread_in_for_tests(StartThreadIn::NewWorktree, cx); + }); + })?; + cx.run_until_parked(); + + cx.update_window(workspace_window.into(), |_, window, _cx| { + window.refresh(); + })?; + cx.run_until_parked(); + + let result_new_worktree = run_visual_test( + "start_thread_in_selector_new_worktree", + workspace_window.into(), + cx, + update_baseline, + ); + + // ---- Screenshot 4: "Creating worktree…" status banner ---- + cx.update_window(workspace_window.into(), |_, _window, cx| { + panel.update(cx, |panel, cx| { + panel + .set_worktree_creation_status_for_tests(Some(WorktreeCreationStatus::Creating), cx); + }); + })?; + cx.run_until_parked(); + + cx.update_window(workspace_window.into(), |_, window, _cx| { + window.refresh(); + })?; + cx.run_until_parked(); + + let result_creating = run_visual_test( + "worktree_creation_status_creating", + workspace_window.into(), + cx, + update_baseline, + ); + + // ---- Screenshot 5: Error status banner ---- + cx.update_window(workspace_window.into(), |_, _window, cx| { + panel.update(cx, |panel, cx| { + panel.set_worktree_creation_status_for_tests( + Some(WorktreeCreationStatus::Error( + "Failed to create worktree: branch already exists".into(), + )), + cx, + ); + }); + })?; + cx.run_until_parked(); + + cx.update_window(workspace_window.into(), |_, window, _cx| { + window.refresh(); + })?; + cx.run_until_parked(); + + let result_error = run_visual_test( + "worktree_creation_status_error", + workspace_window.into(), + cx, + update_baseline, + ); + + // ---- Screenshot 6: Worktree creation succeeded ---- + // Clear the error status and re-select New Worktree to ensure a clean state. + cx.update_window(workspace_window.into(), |_, _window, cx| { + panel.update(cx, |panel, cx| { + panel.set_worktree_creation_status_for_tests(None, cx); + }); + })?; + cx.run_until_parked(); + + cx.update_window(workspace_window.into(), |_, window, cx| { + window.dispatch_action(Box::new(StartThreadIn::NewWorktree), cx); + })?; + cx.run_until_parked(); + + // Insert a message into the active thread's message editor and submit. + let thread_view = cx + .read(|cx| panel.read(cx).as_active_thread_view(cx)) + .ok_or_else(|| anyhow::anyhow!("No active thread view"))?; + + cx.update_window(workspace_window.into(), |_, window, cx| { + let message_editor = thread_view.read(cx).message_editor.clone(); + message_editor.update(cx, |message_editor, cx| { + message_editor.set_message( + vec![acp::ContentBlock::Text(acp::TextContent::new( + "Add a CLI flag to set the log level".to_string(), + ))], + window, + cx, + ); + message_editor.send(cx); + }); + })?; + cx.run_until_parked(); + + // Wait for the full worktree creation flow to complete. The creation status + // is cleared to `None` at the very end of the async task, after panels are + // loaded, the agent panel is focused, and the new workspace is activated. + cx.background_executor.allow_parking(); + let mut creation_complete = false; + for _ in 0..120 { + cx.run_until_parked(); + let status_cleared = cx.read(|cx| { + panel + .read(cx) + .worktree_creation_status_for_tests() + .is_none() + }); + let workspace_count = workspace_window.update(cx, |multi_workspace, _window, _cx| { + multi_workspace.workspaces().len() + })?; + if workspace_count == 2 && status_cleared { + creation_complete = true; + break; + } + cx.advance_clock(Duration::from_millis(100)); + } + cx.background_executor.forbid_parking(); + + if !creation_complete { + return Err(anyhow::anyhow!("Worktree creation did not complete")); + } + + // The creation flow called `external_thread` on the new workspace's agent + // panel, which tried to launch a real agent binary and failed. Replace the + // error state by injecting the stub server, and shrink the panel so the + // editor content is visible. + workspace_window.update(cx, |multi_workspace, window, cx| { + let new_workspace = &multi_workspace.workspaces()[1]; + new_workspace.update(cx, |workspace, cx| { + if let Some(new_panel) = workspace.panel::(cx) { + new_panel.update(cx, |panel, cx| { + panel.set_size(Some(px(480.0)), window, cx); + panel.open_external_thread_with_server(stub_agent.clone(), window, cx); + }); + } + }); + })?; + cx.run_until_parked(); + + // Type and send a message so the thread target dropdown disappears. + let new_panel = workspace_window.update(cx, |multi_workspace, _window, cx| { + let new_workspace = &multi_workspace.workspaces()[1]; + new_workspace.read(cx).panel::(cx) + })?; + if let Some(new_panel) = new_panel { + let new_thread_view = cx.read(|cx| new_panel.read(cx).as_active_thread_view(cx)); + if let Some(new_thread_view) = new_thread_view { + cx.update_window(workspace_window.into(), |_, window, cx| { + let message_editor = new_thread_view.read(cx).message_editor.clone(); + message_editor.update(cx, |editor, cx| { + editor.set_message( + vec![acp::ContentBlock::Text(acp::TextContent::new( + "Add a CLI flag to set the log level".to_string(), + ))], + window, + cx, + ); + editor.send(cx); + }); + })?; + cx.run_until_parked(); + } + } + + cx.update_window(workspace_window.into(), |_, window, _cx| { + window.refresh(); + })?; + cx.run_until_parked(); + + let result_succeeded = run_visual_test( + "worktree_creation_succeeded", + workspace_window.into(), + cx, + update_baseline, + ); + + // Clean up — drop the workspace observer first so no new panels are + // registered on workspaces created during teardown. + drop(_workspace_observer); + + workspace_window + .update(cx, |multi_workspace, _window, cx| { + let workspace = &multi_workspace.workspaces()[0]; + let project = workspace.read(cx).project().clone(); + project.update(cx, |project, cx| { + let worktree_ids: Vec<_> = + project.worktrees(cx).map(|wt| wt.read(cx).id()).collect(); + for id in worktree_ids { + project.remove_worktree(id, cx); + } + }); + }) + .log_err(); + + cx.run_until_parked(); + + cx.update_window(workspace_window.into(), |_, window, _cx| { + window.remove_window(); + }) + .log_err(); + + cx.run_until_parked(); + + for _ in 0..15 { + cx.advance_clock(Duration::from_millis(100)); + cx.run_until_parked(); + } + + // Delete the preserved temp directory so visual-test runs don't + // accumulate filesystem artifacts. + if let Err(err) = std::fs::remove_dir_all(&temp_path) { + log::warn!( + "failed to clean up visual-test temp dir {}: {err}", + temp_path.display() + ); + } + + // Reset feature flags + cx.update(|cx| { + cx.update_flags(false, vec![]); + }); + + let results = [ + ("default", result_default), + ("open_dropdown", result_open_dropdown), + ("new_worktree", result_new_worktree), + ("creating", result_creating), + ("error", result_error), + ("succeeded", result_succeeded), + ]; + + let mut has_baseline_update = None; + let mut failures = Vec::new(); + + for (name, result) in &results { + match result { + Ok(TestResult::Passed) => {} + Ok(TestResult::BaselineUpdated(p)) => { + has_baseline_update = Some(p.clone()); + } + Err(e) => { + failures.push(format!("{}: {}", name, e)); + } + } + } + + if !failures.is_empty() { + Err(anyhow::anyhow!( + "start_thread_in_selector failures: {}", + failures.join("; ") + )) + } else if let Some(p) = has_baseline_update { + Ok(TestResult::BaselineUpdated(p)) + } else { + Ok(TestResult::Passed) + } +} diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 55f185aae13e49c6b90610a50ad197ee47ee8a98..aeb740c5ec05f5382e3b93527bb2191cb44f9d51 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -422,16 +422,7 @@ pub fn initialize_workspace( if let Some(specs) = window.gpu_specs() { log::info!("Using GPU: {:?}", specs); show_software_emulation_warning_if_needed(specs.clone(), window, cx); - if let Some((crash_server, message)) = crashes::CRASH_HANDLER - .get() - .zip(bincode::serialize(&specs).ok()) - && let Err(err) = crash_server.send_message(3, message) - { - log::warn!( - "Failed to store active gpu info for crash reporting: {}", - err - ); - } + crashes::set_gpu_info(specs); } let edit_prediction_menu_handle = PopoverMenuHandle::default(); @@ -496,7 +487,8 @@ pub fn initialize_workspace( status_bar.add_right_item(image_info, window, cx); }); - initialize_panels(prompt_builder.clone(), window, cx); + let panels_task = initialize_panels(prompt_builder.clone(), window, cx); + workspace.set_panels_task(panels_task); register_actions(app_state.clone(), workspace, window, cx); workspace.focus_handle(cx).focus(window, cx); @@ -620,7 +612,7 @@ fn initialize_panels( prompt_builder: Arc, window: &mut Window, cx: &mut Context, -) { +) -> Task> { cx.spawn_in(window, async move |workspace_handle, cx| { let project_panel = ProjectPanel::load(workspace_handle.clone(), cx.clone()); let outline_panel = OutlinePanel::load(workspace_handle.clone(), cx.clone()); @@ -662,7 +654,6 @@ fn initialize_panels( anyhow::Ok(()) }) - .detach(); } fn setup_or_teardown_ai_panel( @@ -1103,7 +1094,7 @@ fn register_actions( ); }, ) - .detach(); + .detach_and_log_err(cx); } } }) @@ -4809,6 +4800,7 @@ mod tests { "console", "context_server", "copilot", + "csv", "debug_panel", "debugger", "dev", @@ -5020,7 +5012,7 @@ mod tests { language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx); web_search::init(cx); git_graph::init(cx); - web_search_providers::init(app_state.client.clone(), cx); + web_search_providers::init(app_state.client.clone(), app_state.user_store.clone(), cx); let prompt_builder = PromptBuilder::load(app_state.fs.clone(), false, cx); project::AgentRegistryStore::init_global( cx, @@ -5807,7 +5799,15 @@ mod tests { // Window B: workspace for dir3 let (window_a, _) = cx .update(|cx| { - Workspace::new_local(vec![dir1.into()], app_state.clone(), None, None, None, cx) + Workspace::new_local( + vec![dir1.into()], + app_state.clone(), + None, + None, + None, + true, + cx, + ) }) .await .expect("failed to open first workspace"); @@ -5823,7 +5823,15 @@ mod tests { let (window_b, _) = cx .update(|cx| { - Workspace::new_local(vec![dir3.into()], app_state.clone(), None, None, None, cx) + Workspace::new_local( + vec![dir3.into()], + app_state.clone(), + None, + None, + None, + true, + cx, + ) }) .await .expect("failed to open third workspace"); diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index 67b0d26c88cf0bd254a776834de09fb89d6ea195..9f05c5795e6f16cab231df8a5586106ed25b03ee 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -2,15 +2,12 @@ use client::{Client, UserStore}; use codestral::{CodestralEditPredictionDelegate, load_codestral_api_key}; use collections::HashMap; use copilot::CopilotEditPredictionDelegate; -use edit_prediction::{EditPredictionModel, ZedEditPredictionDelegate, Zeta2FeatureFlag}; +use edit_prediction::{EditPredictionModel, ZedEditPredictionDelegate}; use editor::Editor; -use feature_flags::FeatureFlagAppExt; use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity}; use language::language_settings::{EditPredictionProvider, all_language_settings}; -use settings::{ - EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, EditPredictionPromptFormat, SettingsStore, -}; +use settings::{EditPredictionPromptFormat, SettingsStore}; use std::{cell::RefCell, rc::Rc, sync::Arc}; use ui::Window; @@ -81,9 +78,6 @@ pub fn init(client: Arc, user_store: Entity, cx: &mut App) { .detach(); cx.observe_global::({ - let editors = editors.clone(); - let client = client.clone(); - let user_store = user_store.clone(); let mut previous_config = edit_prediction_provider_config_for_settings(cx); move |cx| { let new_provider_config = edit_prediction_provider_config_for_settings(cx); @@ -107,24 +101,6 @@ pub fn init(client: Arc, user_store: Entity, cx: &mut App) { } }) .detach(); - - cx.observe_flag::({ - let mut previous_config = edit_prediction_provider_config_for_settings(cx); - move |_is_enabled, cx| { - let new_provider_config = edit_prediction_provider_config_for_settings(cx); - if new_provider_config != previous_config { - previous_config = new_provider_config; - assign_edit_prediction_providers( - &editors, - new_provider_config, - &client, - user_store.clone(), - cx, - ); - } - } - }) - .detach(); } fn edit_prediction_provider_config_for_settings(cx: &App) -> Option { @@ -154,7 +130,10 @@ fn edit_prediction_provider_config_for_settings(cx: &App) -> Option Option Some(EditPredictionProviderConfig::Zed( EditPredictionModel::Mercury, )), - EditPredictionProvider::Experimental(name) => { - if name == EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME - && cx.has_flag::() - { - Some(EditPredictionProviderConfig::Zed(EditPredictionModel::Zeta)) - } else { - None - } - } + EditPredictionProvider::Experimental(_) => None, } } diff --git a/crates/zed/src/zed/quick_action_bar/preview.rs b/crates/zed/src/zed/quick_action_bar/preview.rs index 5d43e79542357977b06fbbd884472f94ad3595c8..01e2d164d7d7a8a81e64ab77ad646111e4baacd7 100644 --- a/crates/zed/src/zed/quick_action_bar/preview.rs +++ b/crates/zed/src/zed/quick_action_bar/preview.rs @@ -1,3 +1,8 @@ +use csv_preview::{ + CsvPreviewView, OpenPreview as CsvOpenPreview, OpenPreviewToTheSide as CsvOpenPreviewToTheSide, + TabularDataPreviewFeatureFlag, +}; +use feature_flags::FeatureFlagAppExt as _; use gpui::{AnyElement, Modifiers, WeakEntity}; use markdown_preview::{ OpenPreview as MarkdownOpenPreview, OpenPreviewToTheSide as MarkdownOpenPreviewToTheSide, @@ -16,6 +21,7 @@ use super::QuickActionBar; enum PreviewType { Markdown, Svg, + Csv, } impl QuickActionBar { @@ -35,6 +41,10 @@ impl QuickActionBar { } else if SvgPreviewView::resolve_active_item_as_svg_buffer(workspace, cx).is_some() { preview_type = Some(PreviewType::Svg); + } else if cx.has_flag::() + && CsvPreviewView::resolve_active_item_as_csv_editor(workspace, cx).is_some() + { + preview_type = Some(PreviewType::Csv); } }); } @@ -57,6 +67,13 @@ impl QuickActionBar { Box::new(SvgOpenPreviewToTheSide) as Box, &svg_preview::OpenPreview as &dyn gpui::Action, ), + PreviewType::Csv => ( + "toggle-csv-preview", + "Preview CSV", + Box::new(CsvOpenPreview) as Box, + Box::new(CsvOpenPreviewToTheSide) as Box, + &csv_preview::OpenPreview as &dyn gpui::Action, + ), }; let alt_click = gpui::Keystroke { diff --git a/crates/zeta_prompt/src/zeta_prompt.rs b/crates/zeta_prompt/src/zeta_prompt.rs index 0cd37a455397334933dbfa2464c2dbcb72bba456..d1cb24a8c83710e06d04e0c006a1963882982f59 100644 --- a/crates/zeta_prompt/src/zeta_prompt.rs +++ b/crates/zeta_prompt/src/zeta_prompt.rs @@ -61,6 +61,8 @@ pub struct ZetaPromptInput { pub in_open_source_repo: bool, #[serde(default)] pub can_collect_data: bool, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub repo_url: Option, } #[derive( @@ -86,6 +88,7 @@ pub enum ZetaFormat { V0131GitMergeMarkersPrefix, V0211Prefill, V0211SeedCoder, + v0226Hashline, } impl std::fmt::Display for ZetaFormat { @@ -122,25 +125,6 @@ impl ZetaFormat { .collect::>() .concat() } - - pub fn special_tokens(&self) -> &'static [&'static str] { - match self { - ZetaFormat::V0112MiddleAtEnd - | ZetaFormat::V0113Ordered - | ZetaFormat::V0114180EditableRegion => &[ - "<|fim_prefix|>", - "<|fim_suffix|>", - "<|fim_middle|>", - "<|file_sep|>", - CURSOR_MARKER, - ], - ZetaFormat::V0120GitMergeMarkers => v0120_git_merge_markers::special_tokens(), - ZetaFormat::V0131GitMergeMarkersPrefix | ZetaFormat::V0211Prefill => { - v0131_git_merge_markers_prefix::special_tokens() - } - ZetaFormat::V0211SeedCoder => seed_coder::special_tokens(), - } - } } #[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)] @@ -212,33 +196,29 @@ pub struct RelatedExcerpt { } pub fn prompt_input_contains_special_tokens(input: &ZetaPromptInput, format: ZetaFormat) -> bool { - format - .special_tokens() + special_tokens_for_format(format) .iter() .any(|token| input.cursor_excerpt.contains(token)) } pub fn format_zeta_prompt(input: &ZetaPromptInput, format: ZetaFormat) -> String { - format_zeta_prompt_with_budget(input, format, MAX_PROMPT_TOKENS) + format_prompt_with_budget_for_format(input, format, MAX_PROMPT_TOKENS) } -/// Post-processes model output for the given zeta format by stripping format-specific suffixes. -pub fn clean_zeta2_model_output(output: &str, format: ZetaFormat) -> &str { +pub fn special_tokens_for_format(format: ZetaFormat) -> &'static [&'static str] { match format { - ZetaFormat::V0120GitMergeMarkers => output - .strip_suffix(v0120_git_merge_markers::END_MARKER) - .unwrap_or(output), - ZetaFormat::V0131GitMergeMarkersPrefix => output - .strip_suffix(v0131_git_merge_markers_prefix::END_MARKER) - .unwrap_or(output), - ZetaFormat::V0211SeedCoder => output - .strip_suffix(seed_coder::END_MARKER) - .unwrap_or(output), - _ => output, + ZetaFormat::V0112MiddleAtEnd => v0112_middle_at_end::special_tokens(), + ZetaFormat::V0113Ordered => v0113_ordered::special_tokens(), + ZetaFormat::V0114180EditableRegion => v0114180_editable_region::special_tokens(), + ZetaFormat::V0120GitMergeMarkers => v0120_git_merge_markers::special_tokens(), + ZetaFormat::V0131GitMergeMarkersPrefix => v0131_git_merge_markers_prefix::special_tokens(), + ZetaFormat::V0211Prefill => v0211_prefill::special_tokens(), + ZetaFormat::V0211SeedCoder => seed_coder::special_tokens(), + ZetaFormat::v0226Hashline => hashline::special_tokens(), } } -pub fn excerpt_range_for_format( +pub fn excerpt_ranges_for_format( format: ZetaFormat, ranges: &ExcerptRanges, ) -> (Range, Range) { @@ -247,129 +227,257 @@ pub fn excerpt_range_for_format( ranges.editable_150.clone(), ranges.editable_150_context_350.clone(), ), - ZetaFormat::V0114180EditableRegion - | ZetaFormat::V0120GitMergeMarkers + ZetaFormat::V0114180EditableRegion => ( + ranges.editable_180.clone(), + ranges.editable_180_context_350.clone(), + ), + ZetaFormat::V0120GitMergeMarkers | ZetaFormat::V0131GitMergeMarkersPrefix | ZetaFormat::V0211Prefill - | ZetaFormat::V0211SeedCoder => ( + | ZetaFormat::V0211SeedCoder + | ZetaFormat::v0226Hashline => ( ranges.editable_350.clone(), ranges.editable_350_context_150.clone(), ), } } -pub fn resolve_cursor_region( - input: &ZetaPromptInput, - format: ZetaFormat, -) -> (&str, Range, usize) { - let (editable_range, context_range) = excerpt_range_for_format(format, &input.excerpt_ranges); - let context_start = context_range.start; - let context_text = &input.cursor_excerpt[context_range]; - let adjusted_editable = - (editable_range.start - context_start)..(editable_range.end - context_start); - let adjusted_cursor = input.cursor_offset_in_excerpt - context_start; - - (context_text, adjusted_editable, adjusted_cursor) -} - -fn format_zeta_prompt_with_budget( - input: &ZetaPromptInput, +pub fn write_cursor_excerpt_section_for_format( format: ZetaFormat, - max_tokens: usize, -) -> String { - let (context, editable_range, cursor_offset) = resolve_cursor_region(input, format); - let path = &*input.cursor_path; - - let mut cursor_section = String::new(); + prompt: &mut String, + path: &Path, + context: &str, + editable_range: &Range, + cursor_offset: usize, +) { match format { - ZetaFormat::V0112MiddleAtEnd => { - v0112_middle_at_end::write_cursor_excerpt_section( - &mut cursor_section, - path, - context, - &editable_range, - cursor_offset, - ); - } + ZetaFormat::V0112MiddleAtEnd => v0112_middle_at_end::write_cursor_excerpt_section( + prompt, + path, + context, + editable_range, + cursor_offset, + ), ZetaFormat::V0113Ordered | ZetaFormat::V0114180EditableRegion => { v0113_ordered::write_cursor_excerpt_section( - &mut cursor_section, + prompt, path, context, - &editable_range, + editable_range, cursor_offset, ) } ZetaFormat::V0120GitMergeMarkers => v0120_git_merge_markers::write_cursor_excerpt_section( - &mut cursor_section, + prompt, path, context, - &editable_range, + editable_range, cursor_offset, ), ZetaFormat::V0131GitMergeMarkersPrefix | ZetaFormat::V0211Prefill => { v0131_git_merge_markers_prefix::write_cursor_excerpt_section( - &mut cursor_section, + prompt, path, context, - &editable_range, + editable_range, cursor_offset, ) } - ZetaFormat::V0211SeedCoder => { - return seed_coder::format_prompt_with_budget( + ZetaFormat::V0211SeedCoder => seed_coder::write_cursor_excerpt_section( + prompt, + path, + context, + editable_range, + cursor_offset, + ), + ZetaFormat::v0226Hashline => hashline::write_cursor_excerpt_section( + prompt, + path, + context, + editable_range, + cursor_offset, + ), + } +} + +pub fn format_prompt_with_budget_for_format( + input: &ZetaPromptInput, + format: ZetaFormat, + max_tokens: usize, +) -> String { + let (context, editable_range, cursor_offset) = resolve_cursor_region(input, format); + let path = &*input.cursor_path; + + match format { + ZetaFormat::V0211SeedCoder => seed_coder::format_prompt_with_budget( + path, + context, + &editable_range, + cursor_offset, + &input.events, + &input.related_files, + max_tokens, + ), + _ => { + let mut cursor_section = String::new(); + write_cursor_excerpt_section_for_format( + format, + &mut cursor_section, path, context, &editable_range, cursor_offset, + ); + + let cursor_tokens = estimate_tokens(cursor_section.len()); + let budget_after_cursor = max_tokens.saturating_sub(cursor_tokens); + + let edit_history_section = format_edit_history_within_budget( &input.events, + "<|file_sep|>", + "edit history", + budget_after_cursor, + ); + let edit_history_tokens = estimate_tokens(edit_history_section.len()); + let budget_after_edit_history = budget_after_cursor.saturating_sub(edit_history_tokens); + + let related_files_section = format_related_files_within_budget( &input.related_files, - max_tokens, + "<|file_sep|>", + "", + budget_after_edit_history, ); + + let mut prompt = String::new(); + prompt.push_str(&related_files_section); + prompt.push_str(&edit_history_section); + prompt.push_str(&cursor_section); + prompt } } - - let cursor_tokens = estimate_tokens(cursor_section.len()); - let budget_after_cursor = max_tokens.saturating_sub(cursor_tokens); - - let edit_history_section = format_edit_history_within_budget( - &input.events, - "<|file_sep|>", - "edit history", - budget_after_cursor, - ); - let edit_history_tokens = estimate_tokens(edit_history_section.len()); - let budget_after_edit_history = budget_after_cursor.saturating_sub(edit_history_tokens); - - let related_files_section = format_related_files_within_budget( - &input.related_files, - "<|file_sep|>", - "", - budget_after_edit_history, - ); - - let mut prompt = String::new(); - prompt.push_str(&related_files_section); - prompt.push_str(&edit_history_section); - prompt.push_str(&cursor_section); - prompt } -pub fn get_prefill(input: &ZetaPromptInput, format: ZetaFormat) -> String { +pub fn get_prefill_for_format( + format: ZetaFormat, + context: &str, + editable_range: &Range, +) -> String { match format { + ZetaFormat::V0211Prefill => v0211_prefill::get_prefill(context, editable_range), ZetaFormat::V0112MiddleAtEnd | ZetaFormat::V0113Ordered | ZetaFormat::V0114180EditableRegion | ZetaFormat::V0120GitMergeMarkers | ZetaFormat::V0131GitMergeMarkersPrefix - | ZetaFormat::V0211SeedCoder => String::new(), - ZetaFormat::V0211Prefill => { - let (context, editable_range, _) = resolve_cursor_region(input, format); - v0211_prefill::get_prefill(context, &editable_range) + | ZetaFormat::V0211SeedCoder + | ZetaFormat::v0226Hashline => String::new(), + } +} + +pub fn output_end_marker_for_format(format: ZetaFormat) -> Option<&'static str> { + match format { + ZetaFormat::V0120GitMergeMarkers => Some(v0120_git_merge_markers::END_MARKER), + ZetaFormat::V0131GitMergeMarkersPrefix => Some(v0131_git_merge_markers_prefix::END_MARKER), + ZetaFormat::V0211Prefill => Some(v0131_git_merge_markers_prefix::END_MARKER), + ZetaFormat::V0211SeedCoder => Some(seed_coder::END_MARKER), + ZetaFormat::V0112MiddleAtEnd + | ZetaFormat::V0113Ordered + | ZetaFormat::V0114180EditableRegion + | ZetaFormat::v0226Hashline => None, + } +} + +pub fn current_region_markers_for_format(format: ZetaFormat) -> (&'static str, &'static str) { + match format { + ZetaFormat::V0112MiddleAtEnd => ("<|fim_middle|>current\n", "<|fim_middle|>updated"), + ZetaFormat::V0113Ordered + | ZetaFormat::V0114180EditableRegion + | ZetaFormat::v0226Hashline => ("<|fim_middle|>current\n", "<|fim_suffix|>"), + ZetaFormat::V0120GitMergeMarkers + | ZetaFormat::V0131GitMergeMarkersPrefix + | ZetaFormat::V0211Prefill => ( + v0120_git_merge_markers::START_MARKER, + v0120_git_merge_markers::SEPARATOR, + ), + ZetaFormat::V0211SeedCoder => (seed_coder::START_MARKER, seed_coder::SEPARATOR), + } +} + +pub fn clean_extracted_region_for_format(format: ZetaFormat, region: &str) -> String { + match format { + ZetaFormat::v0226Hashline => hashline::strip_hashline_prefixes(region), + _ => region.to_string(), + } +} + +pub fn encode_patch_as_output_for_format( + format: ZetaFormat, + old_editable_region: &str, + patch: &str, + cursor_offset: Option, +) -> Result> { + match format { + ZetaFormat::v0226Hashline => { + hashline::patch_to_edit_commands(old_editable_region, patch, cursor_offset).map(Some) + } + _ => Ok(None), + } +} + +pub fn output_with_context_for_format( + format: ZetaFormat, + old_editable_region: &str, + output: &str, +) -> Result> { + match format { + ZetaFormat::v0226Hashline => { + if hashline::output_has_edit_commands(output) { + Ok(Some(hashline::apply_edit_commands( + old_editable_region, + output, + ))) + } else { + Ok(None) + } } + _ => Ok(None), + } +} + +/// Post-processes model output for the given zeta format by stripping format-specific suffixes. +pub fn clean_zeta2_model_output(output: &str, format: ZetaFormat) -> &str { + match output_end_marker_for_format(format) { + Some(marker) => output.strip_suffix(marker).unwrap_or(output), + None => output, } } +pub fn excerpt_range_for_format( + format: ZetaFormat, + ranges: &ExcerptRanges, +) -> (Range, Range) { + excerpt_ranges_for_format(format, ranges) +} + +pub fn resolve_cursor_region( + input: &ZetaPromptInput, + format: ZetaFormat, +) -> (&str, Range, usize) { + let (editable_range, context_range) = excerpt_range_for_format(format, &input.excerpt_ranges); + let context_start = context_range.start; + let context_text = &input.cursor_excerpt[context_range]; + let adjusted_editable = + (editable_range.start - context_start)..(editable_range.end - context_start); + let adjusted_cursor = input.cursor_offset_in_excerpt - context_start; + + (context_text, adjusted_editable, adjusted_cursor) +} + +pub fn get_prefill(input: &ZetaPromptInput, format: ZetaFormat) -> String { + let (context, editable_range, _) = resolve_cursor_region(input, format); + get_prefill_for_format(format, context, &editable_range) +} + fn format_edit_history_within_budget( events: &[Arc], file_marker: &str, @@ -533,6 +641,16 @@ pub fn write_related_files( mod v0112_middle_at_end { use super::*; + pub fn special_tokens() -> &'static [&'static str] { + &[ + "<|fim_prefix|>", + "<|fim_suffix|>", + "<|fim_middle|>", + "<|file_sep|>", + CURSOR_MARKER, + ] + } + pub fn write_cursor_excerpt_section( prompt: &mut String, path: &Path, @@ -567,6 +685,16 @@ mod v0112_middle_at_end { mod v0113_ordered { use super::*; + pub fn special_tokens() -> &'static [&'static str] { + &[ + "<|fim_prefix|>", + "<|fim_suffix|>", + "<|fim_middle|>", + "<|file_sep|>", + CURSOR_MARKER, + ] + } + pub fn write_cursor_excerpt_section( prompt: &mut String, path: &Path, @@ -601,6 +729,14 @@ mod v0113_ordered { } } +mod v0114180_editable_region { + use super::*; + + pub fn special_tokens() -> &'static [&'static str] { + v0113_ordered::special_tokens() + } +} + pub mod v0120_git_merge_markers { //! A prompt that uses git-style merge conflict markers to represent the editable region. //! @@ -752,6 +888,10 @@ pub mod v0131_git_merge_markers_prefix { pub mod v0211_prefill { use super::*; + pub fn special_tokens() -> &'static [&'static str] { + v0131_git_merge_markers_prefix::special_tokens() + } + pub fn get_prefill(context: &str, editable_range: &Range) -> String { let editable_region = &context[editable_range.start..editable_range.end]; @@ -783,6 +923,1413 @@ pub mod v0211_prefill { } } +pub mod hashline { + + use std::fmt::Display; + + pub const END_MARKER: &str = "<|fim_middle|>updated"; + pub const START_MARKER: &str = "<|fim_middle|>current"; + + use super::*; + + const SET_COMMAND_MARKER: &str = "<|set|>"; + const INSERT_COMMAND_MARKER: &str = "<|insert|>"; + + pub fn special_tokens() -> &'static [&'static str] { + return &[ + SET_COMMAND_MARKER, + "<|set_range|>", + INSERT_COMMAND_MARKER, + CURSOR_MARKER, + "<|file_sep|>", + "<|fim_prefix|>", + "<|fim_suffix|>", + "<|fim_middle|>", + ]; + } + + /// A parsed line reference like `3:c3` (line index 3 with hash 0xc3). + #[derive(Debug, Clone, PartialEq, Eq)] + struct LineRef { + index: usize, + hash: u8, + } + + impl Display for LineRef { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}:{:02x}", self.index, self.hash) + } + } + + pub fn hash_line(line: &[u8]) -> u8 { + let mut h: u8 = 0; + for &byte in line { + h = h.wrapping_add(byte); + } + return h; + } + + /// Write the hashline-encoded editable region into `out`. Each line of + /// `editable_text` is prefixed with `{line_index}:{hash}|` and the cursor + /// marker is inserted at `cursor_offset_in_editable` (byte offset relative + /// to the start of `editable_text`). + pub fn write_hashline_editable_region( + out: &mut String, + editable_text: &str, + cursor_offset_in_editable: usize, + ) { + let mut offset = 0; + for (i, line) in editable_text.lines().enumerate() { + let (head, cursor, tail) = if cursor_offset_in_editable > offset + && cursor_offset_in_editable < offset + line.len() + { + ( + &line[..cursor_offset_in_editable - offset], + CURSOR_MARKER, + &line[cursor_offset_in_editable - offset..], + ) + } else { + (line, "", "") + }; + write!( + out, + "\n{}|{head}{cursor}{tail}", + LineRef { + index: i, + hash: hash_line(line.as_bytes()) + } + ) + .unwrap(); + offset += line.len() + 1; + } + } + + pub fn write_cursor_excerpt_section( + prompt: &mut String, + path: &Path, + context: &str, + editable_range: &Range, + cursor_offset: usize, + ) { + let path_str = path.to_string_lossy(); + write!(prompt, "<|file_sep|>{}\n", path_str).ok(); + + prompt.push_str("<|fim_prefix|>\n"); + prompt.push_str(&context[..editable_range.start]); + prompt.push_str(START_MARKER); + + let cursor_offset_in_editable = cursor_offset.saturating_sub(editable_range.start); + let editable_region = &context[editable_range.clone()]; + write_hashline_editable_region(prompt, editable_region, cursor_offset_in_editable); + + if !prompt.ends_with('\n') { + prompt.push('\n'); + } + + prompt.push_str("<|fim_suffix|>\n"); + prompt.push_str(&context[editable_range.end..]); + if !prompt.ends_with('\n') { + prompt.push('\n'); + } + + prompt.push_str(END_MARKER); + } + + /// A single edit command parsed from the model output. + #[derive(Debug)] + enum EditCommand<'a> { + /// Replace a range of lines (inclusive on both ends). Single-line set is + /// represented by `start == end`. + Set { + start: LineRef, + end: LineRef, + content: &'a str, + }, + /// Insert new lines after the given line, or before the first line if + /// `after` is `None`. + Insert { + after: Option, + content: &'a str, + }, + } + + /// Parse a line reference like `3:c3` into a `LineRef`. + fn parse_line_ref(s: &str) -> Option { + let (idx_str, hash_str) = s.split_once(':')?; + let index = idx_str.parse::().ok()?; + let hash = u8::from_str_radix(hash_str, 16).ok()?; + Some(LineRef { index, hash }) + } + + /// Parse the model output into a list of `EditCommand`s. + fn parse_edit_commands(model_output: &str) -> Vec> { + let mut commands = Vec::new(); + let mut offset = 0usize; + + while offset < model_output.len() { + let next_nl = model_output[offset..] + .find('\n') + .map(|i| offset + i) + .unwrap_or(model_output.len()); + let line = &model_output[offset..next_nl]; + let line_end = if next_nl < model_output.len() { + next_nl + 1 + } else { + next_nl + }; + + let trimmed = line.trim(); + let (is_set, specifier) = if let Some(spec) = trimmed.strip_prefix(SET_COMMAND_MARKER) { + (true, spec) + } else if let Some(spec) = trimmed.strip_prefix(INSERT_COMMAND_MARKER) { + (false, spec) + } else { + offset = line_end; + continue; + }; + + let mut content_end = line_end; + let mut scan = line_end; + + while scan < model_output.len() { + let body_nl = model_output[scan..] + .find('\n') + .map(|i| scan + i) + .unwrap_or(model_output.len()); + let body_line = &model_output[scan..body_nl]; + if body_line.trim().starts_with(SET_COMMAND_MARKER) + || body_line.trim().starts_with(INSERT_COMMAND_MARKER) + { + break; + } + scan = if body_nl < model_output.len() { + body_nl + 1 + } else { + body_nl + }; + content_end = scan; + } + + let content = &model_output[line_end..content_end]; + + if is_set { + if let Some((start_str, end_str)) = specifier.split_once('-') { + if let (Some(start), Some(end)) = + (parse_line_ref(start_str), parse_line_ref(end_str)) + { + commands.push(EditCommand::Set { + start, + end, + content, + }); + } + } else if let Some(target) = parse_line_ref(specifier) { + commands.push(EditCommand::Set { + start: target.clone(), + end: target, + content, + }); + } + } else { + let after = parse_line_ref(specifier); + commands.push(EditCommand::Insert { after, content }); + } + + offset = scan; + } + + commands + } + + /// Returns `true` if the model output contains `<|set|>` or `<|insert|>` commands + /// (as opposed to being a plain full-replacement output). + /// Strip the `{line_num}:{hash}|` prefixes from each line of a hashline-encoded + /// editable region, returning the plain text content. + pub fn strip_hashline_prefixes(region: &str) -> String { + let mut decoded: String = region + .lines() + .map(|line| line.find('|').map_or(line, |pos| &line[pos + 1..])) + .collect::>() + .join("\n"); + if region.ends_with('\n') { + decoded.push('\n'); + } + decoded + } + + pub fn output_has_edit_commands(model_output: &str) -> bool { + model_output.contains(SET_COMMAND_MARKER) || model_output.contains(INSERT_COMMAND_MARKER) + } + + /// Apply `<|set|>` and `<|insert|>` edit commands from the model output to the + /// original editable region text. + /// + /// `editable_region` is the original text of the editable region (without hash + /// prefixes). `model_output` is the raw model response containing edit commands. + /// + /// Returns the full replacement text for the editable region. + pub fn apply_edit_commands(editable_region: &str, model_output: &str) -> String { + let original_lines: Vec<&str> = editable_region.lines().collect(); + let old_hashes: Vec = original_lines + .iter() + .map(|line| hash_line(line.as_bytes())) + .collect(); + + let commands = parse_edit_commands(model_output); + + // For set operations: indexed by start line → Some((end line index, content)) + // For insert operations: indexed by line index → vec of content to insert after + // Insert-before-first is tracked separately. + let mut set_ops: Vec> = vec![None; original_lines.len()]; + let mut insert_before_first: Vec<&str> = Vec::new(); + let mut insert_after: Vec> = vec![Vec::new(); original_lines.len()]; + + for command in &commands { + match command { + EditCommand::Set { + start, + end, + content, + } => { + if start.index < old_hashes.len() + && end.index < old_hashes.len() + && start.index <= end.index + && old_hashes[start.index] == start.hash + && old_hashes[end.index] == end.hash + { + set_ops[start.index] = Some((end.index, *content)); + } + } + EditCommand::Insert { after, content } => match after { + None => insert_before_first.push(*content), + Some(line_ref) => { + if line_ref.index < old_hashes.len() + && old_hashes[line_ref.index] == line_ref.hash + { + insert_after[line_ref.index].push(*content); + } + } + }, + } + } + + let mut result = String::new(); + + // Emit any insertions before the first line + for content in &insert_before_first { + result.push_str(content); + if !content.ends_with('\n') { + result.push('\n'); + } + } + + let mut i = 0; + while i < original_lines.len() { + if let Some((end_index, replacement)) = set_ops[i].as_ref() { + // Replace lines i..=end_index with the replacement content + result.push_str(replacement); + if !replacement.is_empty() && !replacement.ends_with('\n') { + result.push('\n'); + } + // Emit any insertions after the end of this set range + if *end_index < insert_after.len() { + for content in &insert_after[*end_index] { + result.push_str(content); + if !content.ends_with('\n') { + result.push('\n'); + } + } + } + i = end_index + 1; + } else { + // Keep the original line + result.push_str(original_lines[i]); + result.push('\n'); + // Emit any insertions after this line + for content in &insert_after[i] { + result.push_str(content); + if !content.ends_with('\n') { + result.push('\n'); + } + } + i += 1; + } + } + + // Preserve trailing newline behavior: if the original ended with a + // newline the result already has one; if it didn't, trim the extra one + // we added. + if !editable_region.ends_with('\n') && result.ends_with('\n') { + result.pop(); + } + + result + } + + /// Convert a unified diff patch into hashline edit commands. + /// + /// Parses the unified diff `patch` directly to determine which lines of + /// `old_text` are deleted/replaced and what new lines are added, then emits + /// `<|set|>` and `<|insert|>` edit commands referencing old lines by their + /// `{index}:{hash}` identifiers. + /// + /// `cursor_offset` is an optional byte offset into the first hunk's new + /// text (context + additions) where the cursor marker should be placed. + pub fn patch_to_edit_commands( + old_text: &str, + patch: &str, + cursor_offset: Option, + ) -> Result { + let old_lines: Vec<&str> = old_text.lines().collect(); + let old_hashes: Vec = old_lines + .iter() + .map(|line| hash_line(line.as_bytes())) + .collect(); + + let mut result = String::new(); + let mut first_hunk = true; + + struct Hunk<'a> { + line_range: Range, + new_text_lines: Vec<&'a str>, + cursor_line_offset_in_new_text: Option<(usize, usize)>, + } + + // Parse the patch line by line. We only care about hunk headers, + // context, deletions, and additions. + let mut old_line_index: usize = 0; + let mut current_hunk: Option = None; + // Byte offset tracking within the hunk's new text for cursor placement. + let mut new_text_byte_offset: usize = 0; + // The line index of the last old line seen before/in the current hunk + // (used for insert-after reference). + let mut last_old_line_before_hunk: Option = None; + + fn flush_hunk( + hunk: Hunk, + last_old_line: Option, + result: &mut String, + old_hashes: &[u8], + ) { + if hunk.line_range.is_empty() { + // Pure insertion — reference the old line to insert after when in bounds. + if let Some(after) = last_old_line + && let Some(&hash) = old_hashes.get(after) + { + write!( + result, + "{INSERT_COMMAND_MARKER}{}\n", + LineRef { index: after, hash } + ) + .unwrap(); + } else { + result.push_str(INSERT_COMMAND_MARKER); + result.push('\n'); + } + } else { + let start = hunk.line_range.start; + let end_exclusive = hunk.line_range.end; + let deleted_line_count = end_exclusive.saturating_sub(start); + + if deleted_line_count == 1 { + if let Some(&hash) = old_hashes.get(start) { + write!( + result, + "{SET_COMMAND_MARKER}{}\n", + LineRef { index: start, hash } + ) + .unwrap(); + } else { + result.push_str(SET_COMMAND_MARKER); + result.push('\n'); + } + } else { + let end_inclusive = end_exclusive - 1; + match ( + old_hashes.get(start).copied(), + old_hashes.get(end_inclusive).copied(), + ) { + (Some(start_hash), Some(end_hash)) => { + write!( + result, + "{SET_COMMAND_MARKER}{}-{}\n", + LineRef { + index: start, + hash: start_hash + }, + LineRef { + index: end_inclusive, + hash: end_hash + } + ) + .unwrap(); + } + _ => { + result.push_str(SET_COMMAND_MARKER); + result.push('\n'); + } + } + } + } + for (line_offset, line) in hunk.new_text_lines.iter().enumerate() { + if let Some((cursor_line_offset, char_offset)) = hunk.cursor_line_offset_in_new_text + && line_offset == cursor_line_offset + { + result.push_str(&line[..char_offset]); + result.push_str(CURSOR_MARKER); + result.push_str(&line[char_offset..]); + continue; + } + + result.push_str(line); + } + } + + for raw_line in patch.split_inclusive('\n') { + if raw_line.starts_with("@@") { + // Flush any pending change hunk from a previous patch hunk. + if let Some(hunk) = current_hunk.take() { + flush_hunk(hunk, last_old_line_before_hunk, &mut result, &old_hashes); + } + + // Parse hunk header: @@ -old_start[,old_count] +new_start[,new_count] @@ + // We intentionally do not trust old_start as a direct local index into `old_text`, + // because some patches are produced against a larger file region and carry + // non-local line numbers. We keep indexing local by advancing from parsed patch lines. + if first_hunk { + new_text_byte_offset = 0; + first_hunk = false; + } + continue; + } + + if raw_line.starts_with("---") || raw_line.starts_with("+++") { + continue; + } + if raw_line.starts_with("\\ No newline") { + continue; + } + + if raw_line.starts_with('-') { + // Extend or start a change hunk with this deleted old line. + match &mut current_hunk { + Some(Hunk { + line_range: range, .. + }) => range.end = old_line_index + 1, + None => { + current_hunk = Some(Hunk { + line_range: old_line_index..old_line_index + 1, + new_text_lines: Vec::new(), + cursor_line_offset_in_new_text: None, + }); + } + } + old_line_index += 1; + } else if let Some(added_content) = raw_line.strip_prefix('+') { + // Place cursor marker if cursor_offset falls within this line. + let mut cursor_line_offset = None; + if let Some(cursor_off) = cursor_offset + && (first_hunk + || cursor_off >= new_text_byte_offset + && cursor_off <= new_text_byte_offset + added_content.len()) + { + let line_offset = added_content.floor_char_boundary( + cursor_off + .saturating_sub(new_text_byte_offset) + .min(added_content.len()), + ); + cursor_line_offset = Some(line_offset); + } + + new_text_byte_offset += added_content.len(); + + let hunk = current_hunk.get_or_insert(Hunk { + line_range: old_line_index..old_line_index, + new_text_lines: vec![], + cursor_line_offset_in_new_text: None, + }); + hunk.new_text_lines.push(added_content); + hunk.cursor_line_offset_in_new_text = cursor_line_offset + .map(|offset_in_line| (hunk.new_text_lines.len() - 1, offset_in_line)); + } else { + // Context line (starts with ' ' or is empty). + if let Some(hunk) = current_hunk.take() { + flush_hunk(hunk, last_old_line_before_hunk, &mut result, &old_hashes); + } + last_old_line_before_hunk = Some(old_line_index); + old_line_index += 1; + let content = raw_line.strip_prefix(' ').unwrap_or(raw_line); + new_text_byte_offset += content.len(); + } + } + + // Flush final group. + if let Some(hunk) = current_hunk.take() { + flush_hunk(hunk, last_old_line_before_hunk, &mut result, &old_hashes); + } + + // Trim a single trailing newline. + if result.ends_with('\n') { + result.pop(); + } + + Ok(result) + } + + #[cfg(test)] + mod tests { + use super::*; + use indoc::indoc; + + #[test] + fn test_format_cursor_region() { + struct Case { + name: &'static str, + context: &'static str, + editable_range: Range, + cursor_offset: usize, + expected: &'static str, + } + + let cases = [ + Case { + name: "basic_cursor_placement", + context: "hello world\n", + editable_range: 0..12, + cursor_offset: 5, + expected: indoc! {" + <|file_sep|>test.rs + <|fim_prefix|> + <|fim_middle|>current + 0:5c|hello<|user_cursor|> world + <|fim_suffix|> + <|fim_middle|>updated"}, + }, + Case { + name: "multiline_cursor_on_second_line", + context: "aaa\nbbb\nccc\n", + editable_range: 0..12, + cursor_offset: 5, // byte 5 → 1 byte into "bbb" + expected: indoc! {" + <|file_sep|>test.rs + <|fim_prefix|> + <|fim_middle|>current + 0:23|aaa + 1:26|b<|user_cursor|>bb + 2:29|ccc + <|fim_suffix|> + <|fim_middle|>updated"}, + }, + Case { + name: "no_trailing_newline_in_context", + context: "line1\nline2", + editable_range: 0..11, + cursor_offset: 3, + expected: indoc! {" + <|file_sep|>test.rs + <|fim_prefix|> + <|fim_middle|>current + 0:d9|lin<|user_cursor|>e1 + 1:da|line2 + <|fim_suffix|> + <|fim_middle|>updated"}, + }, + Case { + name: "leading_newline_in_editable_region", + context: "\nabc\n", + editable_range: 0..5, + cursor_offset: 2, // byte 2 = 'a' in "abc" (after leading \n) + expected: indoc! {" + <|file_sep|>test.rs + <|fim_prefix|> + <|fim_middle|>current + 0:00| + 1:26|a<|user_cursor|>bc + <|fim_suffix|> + <|fim_middle|>updated"}, + }, + Case { + name: "with_suffix", + context: "abc\ndef", + editable_range: 0..4, // editable region = "abc\n", suffix = "def" + cursor_offset: 2, + expected: indoc! {" + <|file_sep|>test.rs + <|fim_prefix|> + <|fim_middle|>current + 0:26|ab<|user_cursor|>c + <|fim_suffix|> + def + <|fim_middle|>updated"}, + }, + Case { + name: "unicode_two_byte_chars", + context: "héllo\n", + editable_range: 0..7, + cursor_offset: 3, // byte 3 = after "hé" (h=1 byte, é=2 bytes), before "llo" + expected: indoc! {" + <|file_sep|>test.rs + <|fim_prefix|> + <|fim_middle|>current + 0:1b|hé<|user_cursor|>llo + <|fim_suffix|> + <|fim_middle|>updated"}, + }, + Case { + name: "unicode_three_byte_chars", + context: "日本語\n", + editable_range: 0..10, + cursor_offset: 6, // byte 6 = after "日本" (3+3 bytes), before "語" + expected: indoc! {" + <|file_sep|>test.rs + <|fim_prefix|> + <|fim_middle|>current + 0:80|日本<|user_cursor|>語 + <|fim_suffix|> + <|fim_middle|>updated"}, + }, + Case { + name: "unicode_four_byte_chars", + context: "a🌍b\n", + editable_range: 0..7, + cursor_offset: 5, // byte 5 = after "a🌍" (1+4 bytes), before "b" + expected: indoc! {" + <|file_sep|>test.rs + <|fim_prefix|> + <|fim_middle|>current + 0:6b|a🌍<|user_cursor|>b + <|fim_suffix|> + <|fim_middle|>updated"}, + }, + Case { + name: "cursor_at_start_of_region_not_placed", + context: "abc\n", + editable_range: 0..4, + cursor_offset: 0, // cursor_offset(0) > offset(0) is false → cursor not placed + expected: indoc! {" + <|file_sep|>test.rs + <|fim_prefix|> + <|fim_middle|>current + 0:26|abc + <|fim_suffix|> + <|fim_middle|>updated"}, + }, + Case { + name: "cursor_at_end_of_line_not_placed", + context: "abc\ndef\n", + editable_range: 0..8, + cursor_offset: 3, // byte 3 = the \n after "abc" → falls between lines, not placed + expected: indoc! {" + <|file_sep|>test.rs + <|fim_prefix|> + <|fim_middle|>current + 0:26|abc + 1:2f|def + <|fim_suffix|> + <|fim_middle|>updated"}, + }, + Case { + name: "cursor_offset_relative_to_context_not_editable_region", + // cursor_offset is relative to `context`, so when editable_range.start > 0, + // write_cursor_excerpt_section must subtract it before comparing against + // per-line offsets within the editable region. + context: "pre\naaa\nbbb\nsuf\n", + editable_range: 4..12, // editable region = "aaa\nbbb\n" + cursor_offset: 9, // byte 9 in context = second 'b' in "bbb" + expected: indoc! {" + <|file_sep|>test.rs + <|fim_prefix|> + pre + <|fim_middle|>current + 0:23|aaa + 1:26|b<|user_cursor|>bb + <|fim_suffix|> + suf + <|fim_middle|>updated"}, + }, + ]; + + for case in &cases { + let mut prompt = String::new(); + hashline::write_cursor_excerpt_section( + &mut prompt, + Path::new("test.rs"), + case.context, + &case.editable_range, + case.cursor_offset, + ); + assert_eq!(prompt, case.expected, "failed case: {}", case.name); + } + } + + #[test] + fn test_apply_edit_commands() { + struct Case { + name: &'static str, + original: &'static str, + model_output: &'static str, + expected: &'static str, + } + + let cases = vec![ + Case { + name: "set_single_line", + original: indoc! {" + let mut total = 0; + for product in products { + total += ; + } + total + "}, + model_output: indoc! {" + <|set|>2:87 + total += product.price; + "}, + expected: indoc! {" + let mut total = 0; + for product in products { + total += product.price; + } + total + "}, + }, + Case { + name: "set_range", + original: indoc! {" + fn foo() { + let x = 1; + let y = 2; + let z = 3; + } + "}, + model_output: indoc! {" + <|set|>1:46-3:4a + let sum = 6; + "}, + expected: indoc! {" + fn foo() { + let sum = 6; + } + "}, + }, + Case { + name: "insert_after_line", + original: indoc! {" + fn main() { + let x = 1; + } + "}, + model_output: indoc! {" + <|insert|>1:46 + let y = 2; + "}, + expected: indoc! {" + fn main() { + let x = 1; + let y = 2; + } + "}, + }, + Case { + name: "insert_before_first", + original: indoc! {" + let x = 1; + let y = 2; + "}, + model_output: indoc! {" + <|insert|> + use std::io; + "}, + expected: indoc! {" + use std::io; + let x = 1; + let y = 2; + "}, + }, + Case { + name: "set_with_cursor_marker", + original: indoc! {" + fn main() { + println!(); + } + "}, + model_output: indoc! {" + <|set|>1:34 + eprintln!(\"<|user_cursor|>\"); + "}, + expected: indoc! {" + fn main() { + eprintln!(\"<|user_cursor|>\"); + } + "}, + }, + Case { + name: "multiple_set_commands", + original: indoc! {" + aaa + bbb + ccc + ddd + "}, + model_output: indoc! {" + <|set|>0:23 + AAA + <|set|>2:29 + CCC + "}, + expected: indoc! {" + AAA + bbb + CCC + ddd + "}, + }, + Case { + name: "set_range_multiline_replacement", + original: indoc! {" + fn handle_submit() { + } + + fn handle_keystroke() { + "}, + model_output: indoc! {" + <|set|>0:3f-1:7d + fn handle_submit(modal_state: &mut ModalState) { + <|user_cursor|> + } + "}, + expected: indoc! {" + fn handle_submit(modal_state: &mut ModalState) { + <|user_cursor|> + } + + fn handle_keystroke() { + "}, + }, + Case { + name: "no_edit_commands_returns_original", + original: indoc! {" + hello + world + "}, + model_output: "some random text with no commands", + expected: indoc! {" + hello + world + "}, + }, + Case { + name: "wrong_hash_set_ignored", + original: indoc! {" + aaa + bbb + "}, + model_output: indoc! {" + <|set|>0:ff + ZZZ + "}, + expected: indoc! {" + aaa + bbb + "}, + }, + Case { + name: "insert_and_set_combined", + original: indoc! {" + alpha + beta + gamma + "}, + model_output: indoc! {" + <|set|>0:06 + ALPHA + <|insert|>1:9c + beta_extra + "}, + expected: indoc! {" + ALPHA + beta + beta_extra + gamma + "}, + }, + Case { + name: "no_trailing_newline_preserved", + original: "hello\nworld", + model_output: indoc! {" + <|set|>0:14 + HELLO + "}, + expected: "HELLO\nworld", + }, + Case { + name: "set_range_hash_mismatch_in_end_bound", + original: indoc! {" + one + two + three + "}, + model_output: indoc! {" + <|set|>0:42-2:ff + ONE_TWO_THREE + "}, + expected: indoc! {" + one + two + three + "}, + }, + Case { + name: "set_range_start_greater_than_end_ignored", + original: indoc! {" + a + b + c + "}, + model_output: indoc! {" + <|set|>2:63-1:62 + X + "}, + expected: indoc! {" + a + b + c + "}, + }, + Case { + name: "insert_out_of_bounds_ignored", + original: indoc! {" + x + y + "}, + model_output: indoc! {" + <|insert|>99:aa + z + "}, + expected: indoc! {" + x + y + "}, + }, + Case { + name: "set_out_of_bounds_ignored", + original: indoc! {" + x + y + "}, + model_output: indoc! {" + <|set|>99:aa + z + "}, + expected: indoc! {" + x + y + "}, + }, + Case { + name: "malformed_set_command_ignored", + original: indoc! {" + alpha + beta + "}, + model_output: indoc! {" + <|set|>not-a-line-ref + UPDATED + "}, + expected: indoc! {" + alpha + beta + "}, + }, + Case { + name: "malformed_insert_hash_treated_as_before_first", + original: indoc! {" + alpha + beta + "}, + model_output: indoc! {" + <|insert|>1:nothex + preamble + "}, + expected: indoc! {" + preamble + alpha + beta + "}, + }, + Case { + name: "set_then_insert_same_target_orders_insert_after_replacement", + original: indoc! {" + cat + dog + "}, + model_output: indoc! {" + <|set|>0:38 + CAT + <|insert|>0:38 + TAIL + "}, + expected: indoc! {" + CAT + TAIL + dog + "}, + }, + Case { + name: "overlapping_set_ranges_last_wins", + original: indoc! {" + a + b + c + d + "}, + model_output: indoc! {" + <|set|>0:61-2:63 + FIRST + <|set|>1:62-3:64 + SECOND + "}, + expected: indoc! {" + FIRST + d + "}, + }, + Case { + name: "insert_before_first_and_after_line", + original: indoc! {" + a + b + "}, + model_output: indoc! {" + <|insert|> + HEAD + <|insert|>0:61 + MID + "}, + expected: indoc! {" + HEAD + a + MID + b + "}, + }, + ]; + + for case in &cases { + let result = hashline::apply_edit_commands(case.original, &case.model_output); + assert_eq!(result, case.expected, "failed case: {}", case.name); + } + } + + #[test] + fn test_output_has_edit_commands() { + assert!(hashline::output_has_edit_commands(&format!( + "{}0:ab\nnew", + SET_COMMAND_MARKER + ))); + assert!(hashline::output_has_edit_commands(&format!( + "{}0:ab\nnew", + INSERT_COMMAND_MARKER + ))); + assert!(hashline::output_has_edit_commands(&format!( + "some text\n{}1:cd\nstuff", + SET_COMMAND_MARKER + ))); + assert!(!hashline::output_has_edit_commands("just plain text")); + assert!(!hashline::output_has_edit_commands("NO_EDITS")); + } + + // ---- hashline::patch_to_edit_commands round-trip tests ---- + + #[test] + fn test_patch_to_edit_commands() { + struct Case { + name: &'static str, + old: &'static str, + patch: &'static str, + expected_new: &'static str, + } + + let cases = [ + Case { + name: "single_line_replacement", + old: indoc! {" + let mut total = 0; + for product in products { + total += ; + } + total + "}, + patch: indoc! {" + @@ -1,5 +1,5 @@ + let mut total = 0; + for product in products { + - total += ; + + total += product.price; + } + total + "}, + expected_new: indoc! {" + let mut total = 0; + for product in products { + total += product.price; + } + total + "}, + }, + Case { + name: "multiline_replacement", + old: indoc! {" + fn foo() { + let x = 1; + let y = 2; + let z = 3; + } + "}, + patch: indoc! {" + @@ -1,5 +1,3 @@ + fn foo() { + - let x = 1; + - let y = 2; + - let z = 3; + + let sum = 1 + 2 + 3; + } + "}, + expected_new: indoc! {" + fn foo() { + let sum = 1 + 2 + 3; + } + "}, + }, + Case { + name: "insertion", + old: indoc! {" + fn main() { + let x = 1; + } + "}, + patch: indoc! {" + @@ -1,3 +1,4 @@ + fn main() { + let x = 1; + + let y = 2; + } + "}, + expected_new: indoc! {" + fn main() { + let x = 1; + let y = 2; + } + "}, + }, + Case { + name: "insertion_before_first", + old: indoc! {" + let x = 1; + let y = 2; + "}, + patch: indoc! {" + @@ -1,2 +1,3 @@ + +use std::io; + let x = 1; + let y = 2; + "}, + expected_new: indoc! {" + use std::io; + let x = 1; + let y = 2; + "}, + }, + Case { + name: "deletion", + old: indoc! {" + aaa + bbb + ccc + ddd + "}, + patch: indoc! {" + @@ -1,4 +1,2 @@ + aaa + -bbb + -ccc + ddd + "}, + expected_new: indoc! {" + aaa + ddd + "}, + }, + Case { + name: "multiple_changes", + old: indoc! {" + alpha + beta + gamma + delta + epsilon + "}, + patch: indoc! {" + @@ -1,5 +1,5 @@ + -alpha + +ALPHA + beta + gamma + -delta + +DELTA + epsilon + "}, + expected_new: indoc! {" + ALPHA + beta + gamma + DELTA + epsilon + "}, + }, + Case { + name: "replace_with_insertion", + old: indoc! {r#" + fn handle() { + modal_state.close(); + modal_state.dismiss(); + "#}, + patch: indoc! {r#" + @@ -1,3 +1,4 @@ + fn handle() { + modal_state.close(); + + eprintln!(""); + modal_state.dismiss(); + "#}, + expected_new: indoc! {r#" + fn handle() { + modal_state.close(); + eprintln!(""); + modal_state.dismiss(); + "#}, + }, + Case { + name: "complete_replacement", + old: indoc! {" + aaa + bbb + ccc + "}, + patch: indoc! {" + @@ -1,3 +1,3 @@ + -aaa + -bbb + -ccc + +xxx + +yyy + +zzz + "}, + expected_new: indoc! {" + xxx + yyy + zzz + "}, + }, + Case { + name: "add_function_body", + old: indoc! {" + fn foo() { + modal_state.dismiss(); + } + + fn + + fn handle_keystroke() { + "}, + patch: indoc! {" + @@ -1,6 +1,8 @@ + fn foo() { + modal_state.dismiss(); + } + + -fn + +fn handle_submit() { + + todo() + +} + + fn handle_keystroke() { + "}, + expected_new: indoc! {" + fn foo() { + modal_state.dismiss(); + } + + fn handle_submit() { + todo() + } + + fn handle_keystroke() { + "}, + }, + Case { + name: "with_cursor_offset", + old: indoc! {r#" + fn main() { + println!(); + } + "#}, + patch: indoc! {r#" + @@ -1,3 +1,3 @@ + fn main() { + - println!(); + + eprintln!(""); + } + "#}, + expected_new: indoc! {r#" + fn main() { + eprintln!("<|user_cursor|>"); + } + "#}, + }, + Case { + name: "non_local_hunk_header_pure_insertion_repro", + old: indoc! {" + aaa + bbb + "}, + patch: indoc! {" + @@ -20,2 +20,3 @@ + aaa + +xxx + bbb + "}, + expected_new: indoc! {" + aaa + xxx + bbb + "}, + }, + ]; + + for case in &cases { + // The cursor_offset for patch_to_edit_commands is relative to + // the first hunk's new text (context + additions). We compute + // it by finding where the marker sits in the expected output + // (which mirrors the new text of the hunk). + let cursor_offset = case.expected_new.find(CURSOR_MARKER); + + let commands = + hashline::patch_to_edit_commands(case.old, case.patch, cursor_offset) + .unwrap_or_else(|e| panic!("failed case {}: {e}", case.name)); + + assert!( + hashline::output_has_edit_commands(&commands), + "case {}: expected edit commands, got: {commands:?}", + case.name, + ); + + let applied = hashline::apply_edit_commands(case.old, &commands); + assert_eq!(applied, case.expected_new, "case {}", case.name); + } + } + } +} + pub mod seed_coder { //! Seed-Coder prompt format using SPM (Suffix-Prefix-Middle) FIM mode. //! @@ -847,6 +2394,17 @@ pub mod seed_coder { ] } + pub fn write_cursor_excerpt_section( + prompt: &mut String, + path: &Path, + context: &str, + editable_range: &Range, + cursor_offset: usize, + ) { + let section = build_cursor_prefix_section(path, context, editable_range, cursor_offset); + prompt.push_str(§ion); + } + pub fn format_prompt_with_budget( path: &Path, context: &str, @@ -1159,6 +2717,7 @@ mod tests { experiment: None, in_open_source_repo: false, can_collect_data: false, + repo_url: None, } } @@ -1186,7 +2745,7 @@ mod tests { } fn format_with_budget(input: &ZetaPromptInput, max_tokens: usize) -> String { - format_zeta_prompt_with_budget(input, ZetaFormat::V0114180EditableRegion, max_tokens) + format_prompt_with_budget_for_format(input, ZetaFormat::V0114180EditableRegion, max_tokens) } #[test] @@ -1551,11 +3110,11 @@ mod tests { } fn format_seed_coder(input: &ZetaPromptInput) -> String { - format_zeta_prompt_with_budget(input, ZetaFormat::V0211SeedCoder, 10000) + format_prompt_with_budget_for_format(input, ZetaFormat::V0211SeedCoder, 10000) } fn format_seed_coder_with_budget(input: &ZetaPromptInput, max_tokens: usize) -> String { - format_zeta_prompt_with_budget(input, ZetaFormat::V0211SeedCoder, max_tokens) + format_prompt_with_budget_for_format(input, ZetaFormat::V0211SeedCoder, max_tokens) } #[test] @@ -1756,6 +3315,7 @@ mod tests { experiment: None, in_open_source_repo: false, can_collect_data: false, + repo_url: None, }; let prompt = zeta1::format_zeta1_from_input(&input, 15..41, 0..excerpt.len()); @@ -1818,6 +3378,7 @@ mod tests { experiment: None, in_open_source_repo: false, can_collect_data: false, + repo_url: None, }; let prompt = zeta1::format_zeta1_from_input(&input, 0..28, 0..28); @@ -1875,6 +3436,7 @@ mod tests { experiment: None, in_open_source_repo: false, can_collect_data: false, + repo_url: None, }; let prompt = zeta1::format_zeta1_from_input(&input, editable_range, context_range); diff --git a/docs/.prettierignore b/docs/.prettierignore index a52439689a83a1c2e834918c39441186b47120e5..c742ed4b6859f32219cecbac9f722db8a6929710 100644 --- a/docs/.prettierignore +++ b/docs/.prettierignore @@ -1,2 +1,5 @@ # Handlebars partials are not supported by Prettier. *.hbs + +# Automatically generated +theme/c15t@*.js diff --git a/docs/README.md b/docs/README.md index e1649f4bc99e1668352a46ee2071dcfe1775f4a7..a0f9bbd5c628f41d291880239ca555ea7ec0e3ea 100644 --- a/docs/README.md +++ b/docs/README.md @@ -64,6 +64,22 @@ This will render a human-readable version of the action name, e.g., "zed: open s Templates are functions that modify the source of the docs pages (usually with a regex match and replace). You can see how the actions and keybindings are templated in `crates/docs_preprocessor/src/main.rs` for reference on how to create new templates. +## Consent Banner + +We pre-bundle the `c15t` package because the docs pipeline does not include a JS bundler. If you need to update `c15t` and rebuild the bundle, use: + +``` +mkdir c15t-bundle && cd c15t-bundle +npm init -y +npm install c15t@ esbuild +echo "import { getOrCreateConsentRuntime } from 'c15t'; window.c15t = { getOrCreateConsentRuntime };" > entry.js +npx esbuild entry.js --bundle --format=iife --minify --outfile=c15t@.js +cp c15t@.js ../theme/c15t@.js +cd .. && rm -rf c15t-bundle +``` + +Replace `` with the new version of `c15t` you are installing. Then update `book.toml` to reference the new bundle filename. + ### References - Template Trait: `crates/docs_preprocessor/src/templates.rs` diff --git a/docs/book.toml b/docs/book.toml index 86fa447f581fba88ff7df53bb51e08440585a9dc..3269003a1d37ede19ec18b62809a928a08764d2f 100644 --- a/docs/book.toml +++ b/docs/book.toml @@ -23,8 +23,8 @@ default-description = "Learn how to use and customize Zed, the fast, collaborati default-title = "Zed Code Editor Documentation" no-section-label = true preferred-dark-theme = "dark" -additional-css = ["theme/page-toc.css", "theme/plugins.css", "theme/highlight.css"] -additional-js = ["theme/page-toc.js", "theme/plugins.js"] +additional-css = ["theme/page-toc.css", "theme/plugins.css", "theme/highlight.css", "theme/consent-banner.css"] +additional-js = ["theme/page-toc.js", "theme/plugins.js", "theme/c15t@2.0.0-rc.3.js", "theme/analytics.js"] [output.zed-html.print] enable = false diff --git a/docs/src/ai/edit-prediction.md b/docs/src/ai/edit-prediction.md index 973dc9546a8b81ad58fc996102ff25aed2d241a9..92fde3eddd3be0a2dbfb1b6d37065b58cf2ad411 100644 --- a/docs/src/ai/edit-prediction.md +++ b/docs/src/ai/edit-prediction.md @@ -406,8 +406,6 @@ After adding your API key, Codestral will appear in the provider dropdown in the ### Self-Hosted OpenAI-compatible servers -> **Preview:** This feature is available in Zed Preview. It will be included in the next Stable release. - You can use any self-hosted server that implements the OpenAI completion API format. This works with vLLM, llama.cpp server, LocalAI, and other compatible servers. #### Configuration diff --git a/docs/src/ai/llm-providers.md b/docs/src/ai/llm-providers.md index 3a32bd96e73d9df427897798681f203c4ceb2273..24501ab2d356b8dc4098808ed8e9193cf6e171c6 100644 --- a/docs/src/ai/llm-providers.md +++ b/docs/src/ai/llm-providers.md @@ -88,7 +88,7 @@ With that done, choose one of the three authentication methods: While it's possible to configure through the Agent Panel settings UI by entering your AWS access key and secret directly, we recommend using named profiles instead for better security practices. To do this: -1. Create an IAM User that you can assume in the [IAM Console](https://us-east-1.console.aws.amazon.com/iam/home?region=us-east-1#/users). +1. Create an IAM User in the [IAM Console](https://us-east-1.console.aws.amazon.com/iam/home?region=us-east-1#/users). 2. Create security credentials for that User, save them and keep them secure. 3. Open the Agent Configuration with (`agent: open settings`) and go to the Amazon Bedrock section 4. Copy the credentials from Step 2 into the respective **Access Key ID**, **Secret Access Key**, and **Region** fields. @@ -152,8 +152,6 @@ For the most up-to-date supported regions and models, refer to the [Supported Mo #### Extended Context Window {#bedrock-extended-context} -> **Preview:** This feature is available in Zed Preview. It will be included in the next Stable release. - Anthropic models on Bedrock support a 1M token extended context window through the `anthropic_beta` API parameter. To enable this feature, set `"allow_extended_context": true` in your Bedrock configuration: ```json [settings] @@ -173,8 +171,6 @@ Zed enables extended context for supported models (Claude Sonnet 4.5 and Claude #### Image Support {#bedrock-image-support} -> **Preview:** This feature is available in Zed Preview. It will be included in the next Stable release. - Bedrock models that support vision (Claude 3 and later, Amazon Nova Pro and Lite, Meta Llama 3.2 Vision models, Mistral Pixtral) can receive images in conversations and tool results. ### Anthropic {#anthropic} @@ -630,8 +626,6 @@ The OpenRouter API key will be saved in your keychain. Zed will also use the `OPENROUTER_API_KEY` environment variable if it's defined. -> **Changed in Preview (v0.225).** See [release notes](/releases#0.225). - When using OpenRouter as your assistant provider, you must explicitly select a model in your settings. OpenRouter no longer provides a default model selection. Configure your preferred OpenRouter model in `settings.json`: diff --git a/docs/src/ai/models.md b/docs/src/ai/models.md index a86b873ef8aff112ceddbe7da000e4350023ec42..bbf41cf66cc4d93b38123c12fadd7a60c119dfef 100644 --- a/docs/src/ai/models.md +++ b/docs/src/ai/models.md @@ -43,10 +43,6 @@ Zed's plans offer hosted versions of major LLMs with higher rate limits than dir | | OpenAI | Cached Input | $0.005 | $0.0055 | | Gemini 3.1 Pro | Google | Input | $2.00 | $2.20 | | | Google | Output | $12.00 | $13.20 | -| Gemini 3.1 Pro | Google | Input | $2.00 | $2.20 | -| | Google | Output | $12.00 | $13.20 | -| Gemini 3 Pro | Google | Input | $2.00 | $2.20 | -| | Google | Output | $12.00 | $13.20 | | Gemini 3 Flash | Google | Input | $0.30 | $0.33 | | | Google | Output | $2.50 | $2.75 | | Grok 4 | X.ai | Input | $3.00 | $3.30 | @@ -70,7 +66,8 @@ As of February 19, 2026, Zed Pro serves newer model versions in place of the ret - Claude Sonnet 4 → Claude Sonnet 4.5 or Claude Sonnet 4.6 - Claude Sonnet 3.7 (retired Feb 19) → Claude Sonnet 4.5 or Claude Sonnet 4.6 - GPT-5.1 and GPT-5 → GPT-5.2 or GPT-5.2 Codex -- Gemini 2.5 Pro → Gemini 3 Pro or Gemini 3.1 Pro +- Gemini 2.5 Pro → Gemini 3.1 Pro +- Gemini 3 Pro → Gemini 3.1 Pro - Gemini 2.5 Flash → Gemini 3 Flash ## Usage {#usage} @@ -95,7 +92,6 @@ A context window is the maximum span of text and code an LLM can consider at onc | GPT-5 mini | OpenAI | 400k | | GPT-5 nano | OpenAI | 400k | | Gemini 3.1 Pro | Google | 200k | -| Gemini 3 Pro | Google | 200k | | Gemini 3 Flash | Google | 200k | > Context window limits for hosted Sonnet 4.5/4.6 and Gemini 3.1 Pro/3 Pro/Flash may increase in future releases. diff --git a/docs/src/ai/tools.md b/docs/src/ai/tools.md index 66f0af571d70fb8db7add2bd89139bf788369de6..faafc76b164f7f786c91c212bf51960f24a6bb0a 100644 --- a/docs/src/ai/tools.md +++ b/docs/src/ai/tools.md @@ -91,6 +91,6 @@ Executes shell commands and returns the combined output, creating a new shell pr ## Other Tools -### `subagent` +### `spawn_agent` -Spawns a subagent with its own context window to perform a delegated task. Useful for running parallel investigations, completing self-contained tasks, or performing research where only the outcome matters. Each subagent has access to the same tools as the parent agent. +Spawns a subagent with its own context window to perform a delegated task. Each subagent has access to the same tools as the parent agent. diff --git a/docs/src/collaboration/overview.md b/docs/src/collaboration/overview.md index 97efdae088d1692ad5840e23c13bc50d4ecb75c7..1022ec683bf5eefab55b9aff939c568098fdda30 100644 --- a/docs/src/collaboration/overview.md +++ b/docs/src/collaboration/overview.md @@ -24,8 +24,6 @@ See the [Data and Privacy FAQs](https://zed.dev/faq#data-and-privacy) for more d ### Selecting Audio Devices -> **Preview:** This feature is available in Zed Preview. It will be included in the next Stable release. - You can select specific input and output audio devices instead of using system defaults. To configure audio devices: 1. Open {#kb zed::OpenSettings} diff --git a/docs/src/configuring-languages.md b/docs/src/configuring-languages.md index 4e9bbce822f2f0d87ac2a8c9617698acd5983243..485d843fd480177376cf4e5e990fc495e2bb60a7 100644 --- a/docs/src/configuring-languages.md +++ b/docs/src/configuring-languages.md @@ -122,11 +122,40 @@ You can specify your preference using the `language_servers` setting: In this example: -- `intelephense` is set as the primary language server -- `phpactor` is disabled (note the `!` prefix) -- `...` expands to the rest of the language servers that are registered for PHP +- `intelephense` is set as the primary language server. +- `phpactor` and `phptools` are disabled (note the `!` prefix). +- `"..."` expands to the rest of the language servers registered for PHP that are not already listed. -This configuration allows you to tailor the language server setup to your specific needs, ensuring that you get the most suitable functionality for your development workflow. +The `"..."` entry acts as a wildcard that includes any registered language server you haven't explicitly mentioned. Servers you list by name keep their position, and `"..."` fills in the remaining ones at that point in the list. Servers prefixed with `!` are excluded entirely. This means that if a new language server extension is installed or a new server is registered for a language, `"..."` will automatically include it. If you want full control over which servers are enabled, omit `"..."` — only the servers you list by name will be used. + +#### Examples + +Suppose you're working with Ruby. The default configuration is: + +```json [settings] +{ + "language_servers": [ + "solargraph", + "!ruby-lsp", + "!rubocop", + "!sorbet", + "!steep", + "!kanayago", + "..." + ] +} +``` + +When you override `language_servers` in your settings, your list **replaces** the default entirely. This means default-disabled servers like `kanayago` will be re-enabled by `"..."` unless you explicitly disable them again. + +| Configuration | Result | +| ------------------------------------------------- | ------------------------------------------------------------------ | +| `["..."]` | `solargraph`, `ruby-lsp`, `rubocop`, `sorbet`, `steep`, `kanayago` | +| `["ruby-lsp", "..."]` | `ruby-lsp`, `solargraph`, `rubocop`, `sorbet`, `steep`, `kanayago` | +| `["ruby-lsp", "!solargraph", "!kanayago", "..."]` | `ruby-lsp`, `rubocop`, `sorbet`, `steep` | +| `["ruby-lsp", "solargraph"]` | `ruby-lsp`, `solargraph` | + +> Note: In the first example, `"..."` includes `kanayago` even though it is disabled by default. The override replaced the default list, so the `"!kanayago"` entry is no longer present. To keep it disabled, you must include `"!kanayago"` in your configuration. ### Toolchains @@ -136,8 +165,6 @@ Not all languages in Zed support toolchain discovery and selection, but for thos ### Configuring Language Servers -> **Changed in Preview (v0.225).** See [release notes](/releases#0.225). - When configuring language servers in your `settings.json`, autocomplete suggestions include all available LSP adapters recognized by Zed, not only those currently active for loaded languages. This helps you discover and configure language servers before opening files that use them. Many language servers accept custom configuration options. You can set these in the `lsp` section of your `settings.json`: diff --git a/docs/src/debugger.md b/docs/src/debugger.md index c659c1410b38166cf11da0af728e18f8c9282054..bf05de0f6ccccff4e95fd622bab7130d655a1167 100644 --- a/docs/src/debugger.md +++ b/docs/src/debugger.md @@ -165,8 +165,6 @@ The debug adapter will then stop whenever an exception of a given kind occurs. W ## Working with Split Panes -> **Changed in Preview (v0.225).** See [release notes](/releases#0.225). - When debugging with multiple split panes open, Zed shows the active debug line in one pane and preserves your layout in others. If you have the same file open in multiple panes, the debugger picks a pane where the file is already the active tab—it won't switch tabs in panes where the file is inactive. Once the debugger picks a pane, it continues using that pane for subsequent breakpoints during the session. If you drag the tab with the active debug line to a different split, the debugger tracks the move and uses the new pane. diff --git a/docs/src/getting-started.md b/docs/src/getting-started.md index af6a41c26a6f70f073b2d7e45267871962bb1697..a87e1bea0f4c3eacaa330b34874283a0b61b5eb9 100644 --- a/docs/src/getting-started.md +++ b/docs/src/getting-started.md @@ -13,8 +13,6 @@ This guide covers the essential commands, environment setup, and navigation basi ### Welcome Page -> **Changed in Preview (v0.225).** See [release notes](/releases#0.225). - When you open Zed without a folder, you see the welcome page in the main editor area. The welcome page offers quick actions to open a folder, clone a repository, or view documentation. Once you open a folder or file, the welcome page disappears. If you split the editor into multiple panes, the welcome page appears only in the center pane when empty—other panes show a standard empty state. To reopen the welcome page, close all items in the center pane or use the command palette to search for "Welcome". diff --git a/docs/src/languages/python.md b/docs/src/languages/python.md index d66f52c71cb9295fe9ca94e5890de48cd1275e57..fdeabec5069ed20a9b168ab19129dde0cc6280ba 100644 --- a/docs/src/languages/python.md +++ b/docs/src/languages/python.md @@ -89,8 +89,8 @@ Configure language servers in Settings ({#kb zed::OpenSettings}) under Languages "languages": { "Python": { "language_servers": [ - // Disable basedpyright and enable ty, and otherwise - // use the default configuration. + // Disable basedpyright and enable ty, and include all + // other registered language servers (ruff, pylsp, pyright). "ty", "!basedpyright", "..." diff --git a/docs/src/outline-panel.md b/docs/src/outline-panel.md index 1bacc3cacf4f556c9c3a06e59d6f3fac9b8c74b0..7b31725bf2cec844881e0c5b0b41aac864e28fc9 100644 --- a/docs/src/outline-panel.md +++ b/docs/src/outline-panel.md @@ -7,8 +7,6 @@ description: Navigate code structure with Zed's outline panel. View symbols, jum In addition to the modal outline (`cmd-shift-o`), Zed offers an outline panel. The outline panel can be deployed via `cmd-shift-b` (`outline panel: toggle focus` via the command palette), or by clicking the `Outline Panel` button in the status bar. -> **Changed in Preview (v0.225).** See [release notes](/releases#0.225). - When viewing a "singleton" buffer (i.e., a single file on a tab), the outline panel works similarly to that of the outline modal-it displays the outline of the current buffer's symbols. Each symbol entry shows its type prefix (such as "struct", "fn", "mod", "impl") along with the symbol name, helping you quickly identify what kind of symbol you're looking at. Clicking on an entry allows you to jump to the associated section in the file. The outline view will also automatically scroll to the section associated with the current cursor position within the file. ![Using the outline panel in a singleton buffer](https://zed.dev/img/outline-panel/singleton.png) diff --git a/docs/src/performance.md b/docs/src/performance.md index 09abecdeffe4e268413a73b189ef301511b1a20e..e974d63f8816b68d30a1c06d7cbbc083f8564327 100644 --- a/docs/src/performance.md +++ b/docs/src/performance.md @@ -78,7 +78,7 @@ Download the importer - `cd import && mkdir build && cd build` - Run cmake to generate build files: `cmake -G Ninja -DCMAKE_BUILD_TYPE=Release ..` - Build the importer: `ninja` -- Run the importer on the trace file: `./tracy-import-miniprofiler /path/to/trace.miniprof /path/to/output.tracy` +- Run the importer on the trace file: `./tracy-import-miniprofiler /path/to/trace.miniprof.json /path/to/output.tracy` - Open the trace in tracy: - If you're on windows download the v0.12.2 version from the releases on the upstream repo - If you're on other platforms open it on the website: https://tracy.nereid.pl/ (the version might mismatch so your luck might vary, we need to host our own ideally..) @@ -87,7 +87,7 @@ Download the importer - Run the action: `zed open performance profiler` - Hit the save button. This opens a save dialog or if that fails to open the trace gets saved in your working directory. -- Convert the profile so it can be imported in tracy using the importer: `./tracy-import-miniprofiler output.tracy` +- Convert the profile so it can be imported in tracy using the importer: `./tracy-import-miniprofiler output.tracy` - Go to hit the 'power button' in the top left and then open saved trace. - Now zoom in to see the tasks and how long they took diff --git a/docs/src/snippets.md b/docs/src/snippets.md index 72cbec7b20ff694304a58a70cd9b142a60fc58a2..9f6b6c880be9edcace23f0e3fd0a02263549776a 100644 --- a/docs/src/snippets.md +++ b/docs/src/snippets.md @@ -42,24 +42,4 @@ To create JSX snippets you have to use `javascript.json` snippets file, instead ## Known Limitations - Only the first prefix is used when a list of prefixes is passed in. -- Currently only the `json` snippet file format is supported, even though the `simple-completion-language-server` supports both `json` and `toml` file formats. - -## See also - -The `feature_paths` option in `simple-completion-language-server` is disabled by default. - -If you want to enable it you can add the following to your `settings.json`: - -```json [settings] -{ - "lsp": { - "snippet-completion-server": { - "settings": { - "feature_paths": true - } - } - } -} -``` - -For more configuration information, see the [`simple-completion-language-server` instructions](https://github.com/zed-industries/simple-completion-language-server/tree/main). +- Currently only the `json` snippet file format is supported. diff --git a/docs/src/tasks.md b/docs/src/tasks.md index 0fa659eb2cc58fe63536e721475b0093e0650618..482ca7b4d5779a4861756332ce2c0f25eaad4ad4 100644 --- a/docs/src/tasks.md +++ b/docs/src/tasks.md @@ -225,8 +225,6 @@ This could be useful for launching a terminal application that you want to use i ## VS Code Task Format -> **Preview:** This feature is available in Zed Preview. It will be included in the next Stable release. - When importing VS Code tasks from `.vscode/tasks.json`, you can omit the `label` field. Zed automatically generates labels based on the task type: - **npm tasks**: `npm: {{/if}} + +
@@ -343,6 +345,13 @@ href="https://zed.dev/blog" >Blog + +
@@ -444,23 +453,82 @@ {{/if}} {{/if}} - - + +
diff --git a/nix/livekit-libwebrtc/0001-shared-libraries.patch b/nix/livekit-libwebrtc/0001-shared-libraries.patch index e0b8709a4d1607f2ab416d725079d71f0fe40105..2a7fcf0cbdd519d51d9df446d5b9db00b22d521e 100644 --- a/nix/livekit-libwebrtc/0001-shared-libraries.patch +++ b/nix/livekit-libwebrtc/0001-shared-libraries.patch @@ -1,6 +1,6 @@ ---- a/BUILD.gn 2026-01-10 19:22:47.201811909 -0500 -+++ b/BUILD.gn 2026-01-10 19:24:36.440918317 -0500 -@@ -143,8 +143,8 @@ +--- a/BUILD.gn ++++ b/BUILD.gn +@@ -143,8 +143,12 @@ # target_defaults and direct_dependent_settings. config("common_inherited_config") { defines = [ "PROTOBUF_ENABLE_DEBUG_LOGGING_MAY_LEAK_PII=0" ] @@ -8,6 +8,10 @@ - ldflags = [] + cflags = [ "-fvisibility=default" ] + ldflags = [ "-lavutil", "-lavformat", "-lavcodec" ] ++ ++ if (is_linux) { ++ ldflags += [ "-Wl,--version-script=" + rebase_path("//libwebrtc.version", root_build_dir) ] ++ } if (rtc_objc_prefix != "") { defines += [ "RTC_OBJC_TYPE_PREFIX=${rtc_objc_prefix}" ] diff --git a/nix/livekit-libwebrtc/libwebrtc.version b/nix/livekit-libwebrtc/libwebrtc.version new file mode 100644 index 0000000000000000000000000000000000000000..abf9d5b9df61640d4775e2e1aeea6f113954a944 --- /dev/null +++ b/nix/livekit-libwebrtc/libwebrtc.version @@ -0,0 +1,22 @@ +/* Linker version script for libwebrtc.so (Linux only). + * + * When libwebrtc.so is built with rtc_use_pipewire=true and + * -fvisibility=default, PipeWire lazy-load trampoline stubs (pw_*, spa_*) + * are exported as weak symbols. If the PipeWire ALSA plugin + * (libasound_module_pcm_pipewire.so) is later dlopen'd by libasound, + * the dynamic linker may resolve the plugin's pw_* references through + * libwebrtc.so's broken trampolines instead of the real libpipewire.so, + * causing a SIGSEGV (NULL function pointer dereference). + * + * This script hides only those third-party symbol namespaces while + * keeping every WebRTC / BoringSSL / internal symbol exported (which + * the Rust webrtc-sys bindings require). + */ +{ + global: + *; + + local: + pw_*; + spa_*; +}; diff --git a/nix/livekit-libwebrtc/package.nix b/nix/livekit-libwebrtc/package.nix index 80ed3e18c58e9f3d1a4c5695b9fa7772a9bf51de..dd7b5808ac65ab07d1293683905b694910ee503a 100644 --- a/nix/livekit-libwebrtc/package.nix +++ b/nix/livekit-libwebrtc/package.nix @@ -114,7 +114,9 @@ stdenv.mkDerivation { stripLen = 1; extraPrefix = "third_party/"; }) - # Required for dynamically linking to ffmpeg libraries and exposing symbols + # Required for dynamically linking to ffmpeg libraries, exposing symbols, + # and hiding PipeWire symbols via version script (Linux only) to prevent + # SIGSEGV when ALSA's PipeWire plugin is loaded. ./0001-shared-libraries.patch # Borrow a patch from chromium to prevent a build failure due to missing libclang libraries ./chromium-129-rust.patch @@ -161,6 +163,7 @@ stdenv.mkDerivation { + lib.optionalString stdenv.hostPlatform.isLinux '' mkdir -p buildtools/linux64 ln -sf ${lib.getExe gn} buildtools/linux64/gn + cp ${./libwebrtc.version} libwebrtc.version substituteInPlace build/toolchain/linux/BUILD.gn \ --replace 'toolprefix = "aarch64-linux-gnu-"' 'toolprefix = ""' '' diff --git a/tooling/xtask/Cargo.toml b/tooling/xtask/Cargo.toml index 13179b2eb69ba9a63ba6be5784907b78bba1b9f2..21090d1304ea0eab9ad70808b91f76789f2fd923 100644 --- a/tooling/xtask/Cargo.toml +++ b/tooling/xtask/Cargo.toml @@ -9,6 +9,7 @@ license = "GPL-3.0-or-later" workspace = true [dependencies] +annotate-snippets = "0.12.1" anyhow.workspace = true backtrace.workspace = true cargo_metadata.workspace = true @@ -17,7 +18,11 @@ clap = { workspace = true, features = ["derive"] } toml.workspace = true indoc.workspace = true indexmap.workspace = true +itertools.workspace = true +regex.workspace = true serde.workspace = true serde_json.workspace = true +serde_yaml = "0.9.34" +strum.workspace = true toml_edit.workspace = true gh-workflow.workspace = true diff --git a/tooling/xtask/src/main.rs b/tooling/xtask/src/main.rs index 8246b98772184276ecabc685a9b4d2e7c5346edf..05afe3c766829137a7c2ba6e73d57638624d5e6a 100644 --- a/tooling/xtask/src/main.rs +++ b/tooling/xtask/src/main.rs @@ -23,6 +23,7 @@ enum CliCommand { /// Builds GPUI web examples and serves them. WebExamples(tasks::web_examples::WebExamplesArgs), Workflows(tasks::workflows::GenerateWorkflowArgs), + CheckWorkflows(tasks::workflow_checks::WorkflowValidationArgs), } fn main() -> Result<()> { @@ -37,5 +38,6 @@ fn main() -> Result<()> { CliCommand::PublishGpui(args) => tasks::publish_gpui::run_publish_gpui(args), CliCommand::WebExamples(args) => tasks::web_examples::run_web_examples(args), CliCommand::Workflows(args) => tasks::workflows::run_workflows(args), + CliCommand::CheckWorkflows(args) => tasks::workflow_checks::validate(args), } } diff --git a/tooling/xtask/src/tasks.rs b/tooling/xtask/src/tasks.rs index 4701b56d8dd201ad5b5f28764976b0c5397f3a3e..80f504fa0345de0d5bc71c5b44c71846f04c50bc 100644 --- a/tooling/xtask/src/tasks.rs +++ b/tooling/xtask/src/tasks.rs @@ -3,4 +3,5 @@ pub mod licenses; pub mod package_conformity; pub mod publish_gpui; pub mod web_examples; +pub mod workflow_checks; pub mod workflows; diff --git a/tooling/xtask/src/tasks/workflow_checks.rs b/tooling/xtask/src/tasks/workflow_checks.rs new file mode 100644 index 0000000000000000000000000000000000000000..d6be0299327ad2dd4b4a126a61a8b2ae6ddb9fd3 --- /dev/null +++ b/tooling/xtask/src/tasks/workflow_checks.rs @@ -0,0 +1,118 @@ +mod check_run_patterns; + +use std::{fs, path::PathBuf}; + +use annotate_snippets::Renderer; +use anyhow::{Result, anyhow}; +use clap::Parser; +use itertools::{Either, Itertools}; +use serde_yaml::Value; +use strum::IntoEnumIterator; + +use crate::tasks::{ + workflow_checks::check_run_patterns::{ + RunValidationError, WorkflowFile, WorkflowValidationError, + }, + workflows::WorkflowType, +}; + +pub use check_run_patterns::validate_run_command; + +#[derive(Default, Parser)] +pub struct WorkflowValidationArgs {} + +pub fn validate(_: WorkflowValidationArgs) -> Result<()> { + let (parsing_errors, file_errors): (Vec<_>, Vec<_>) = get_all_workflow_files() + .map(check_workflow) + .flat_map(Result::err) + .partition_map(|error| match error { + WorkflowError::ParseError(error) => Either::Left(error), + WorkflowError::ValidationError(error) => Either::Right(error), + }); + + if !parsing_errors.is_empty() { + Err(anyhow!( + "Failed to read or parse some workflow files: {}", + parsing_errors.into_iter().join("\n") + )) + } else if !file_errors.is_empty() { + let errors: Vec<_> = file_errors + .iter() + .map(|error| error.annotation_group()) + .collect(); + + let renderer = + Renderer::styled().decor_style(annotate_snippets::renderer::DecorStyle::Ascii); + println!("{}", renderer.render(errors.as_slice())); + + Err(anyhow!("Workflow checks failed!")) + } else { + Ok(()) + } +} + +enum WorkflowError { + ParseError(anyhow::Error), + ValidationError(Box), +} + +fn get_all_workflow_files() -> impl Iterator { + WorkflowType::iter() + .map(|workflow_type| workflow_type.folder_path()) + .flat_map(|folder_path| { + fs::read_dir(folder_path).into_iter().flat_map(|entries| { + entries + .flat_map(Result::ok) + .map(|entry| entry.path()) + .filter(|path| { + path.extension() + .is_some_and(|ext| ext == "yaml" || ext == "yml") + }) + }) + }) +} + +fn check_workflow(workflow_file_path: PathBuf) -> Result<(), WorkflowError> { + fn collect_errors( + iter: impl Iterator>>, + ) -> Result<(), Vec> { + Some(iter.flat_map(Result::err).flatten().collect::>()) + .filter(|errors| !errors.is_empty()) + .map_or(Ok(()), Err) + } + + fn check_recursive(key: &Value, value: &Value) -> Result<(), Vec> { + match value { + Value::Mapping(mapping) => collect_errors( + mapping + .into_iter() + .map(|(key, value)| check_recursive(key, value)), + ), + Value::Sequence(sequence) => collect_errors( + sequence + .into_iter() + .map(|value| check_recursive(key, value)), + ), + Value::String(string) => check_string(key, string).map_err(|error| vec![error]), + Value::Null | Value::Bool(_) | Value::Number(_) | Value::Tagged(_) => Ok(()), + } + } + + let file_content = + WorkflowFile::load(&workflow_file_path).map_err(WorkflowError::ParseError)?; + + check_recursive(&Value::Null, &file_content.parsed_content).map_err(|errors| { + WorkflowError::ValidationError(Box::new(WorkflowValidationError::new( + errors, + file_content, + workflow_file_path, + ))) + }) +} + +fn check_string(key: &Value, value: &str) -> Result<(), RunValidationError> { + match key { + Value::String(key) if key == "run" => validate_run_command(value), + _ => Ok(()), + } +} diff --git a/tooling/xtask/src/tasks/workflow_checks/check_run_patterns.rs b/tooling/xtask/src/tasks/workflow_checks/check_run_patterns.rs new file mode 100644 index 0000000000000000000000000000000000000000..50c435d033336dd82d2f110f5c880dff0d677e52 --- /dev/null +++ b/tooling/xtask/src/tasks/workflow_checks/check_run_patterns.rs @@ -0,0 +1,124 @@ +use annotate_snippets::{AnnotationKind, Group, Level, Snippet}; +use anyhow::{Result, anyhow}; +use regex::Regex; +use serde_yaml::Value; +use std::{ + collections::HashMap, + fs, + ops::Range, + path::{Path, PathBuf}, + sync::LazyLock, +}; + +static GITHUB_INPUT_PATTERN: LazyLock = LazyLock::new(|| { + Regex::new(r#"\$\{\{[[:blank:]]*([[:alnum:]]|[[:punct:]])+?[[:blank:]]*\}\}"#) + .expect("Should compile") +}); + +pub struct WorkflowFile { + raw_content: String, + pub parsed_content: Value, +} + +impl WorkflowFile { + pub fn load(workflow_file_path: &Path) -> Result { + fs::read_to_string(workflow_file_path) + .map_err(|_| { + anyhow!( + "Could not read workflow file at {}", + workflow_file_path.display() + ) + }) + .and_then(|file_content| { + serde_yaml::from_str(&file_content) + .map(|parsed_content| Self { + raw_content: file_content, + parsed_content, + }) + .map_err(|e| anyhow!("Failed to parse workflow file: {e:?}")) + }) + } +} + +pub struct WorkflowValidationError { + file_path: PathBuf, + contents: WorkflowFile, + errors: Vec, +} + +impl WorkflowValidationError { + pub fn new( + errors: Vec, + contents: WorkflowFile, + file_path: PathBuf, + ) -> Self { + Self { + file_path, + contents, + errors, + } + } + + pub fn annotation_group<'a>(&'a self) -> Group<'a> { + let raw_content = &self.contents.raw_content; + let mut identical_lines = HashMap::new(); + + let ranges = self + .errors + .iter() + .flat_map(|error| error.found_injection_patterns.iter()) + .map(|(line, pattern_range)| { + let initial_offset = identical_lines + .get(&(line.as_str(), pattern_range.start)) + .copied() + .unwrap_or_default(); + + let line_start = raw_content[initial_offset..] + .find(line.as_str()) + .map(|offset| offset + initial_offset) + .unwrap_or_default(); + + let pattern_start = line_start + pattern_range.start; + let pattern_end = pattern_start + pattern_range.len(); + + identical_lines.insert((line.as_str(), pattern_range.start), pattern_end); + + pattern_start..pattern_end + }); + + Level::ERROR + .primary_title("Found GitHub input injection in run command") + .element( + Snippet::source(&self.contents.raw_content) + .path(self.file_path.display().to_string()) + .annotations(ranges.map(|range| { + AnnotationKind::Primary + .span(range) + .label("This should be passed via an environment variable") + })), + ) + } +} + +pub struct RunValidationError { + found_injection_patterns: Vec<(String, Range)>, +} + +pub fn validate_run_command(command: &str) -> Result<(), RunValidationError> { + let patterns: Vec<_> = command + .lines() + .flat_map(move |line| { + GITHUB_INPUT_PATTERN + .find_iter(line) + .map(|m| (line.to_owned(), m.range())) + }) + .collect(); + + if patterns.is_empty() { + Ok(()) + } else { + Err(RunValidationError { + found_injection_patterns: patterns, + }) + } +} diff --git a/tooling/xtask/src/tasks/workflows.rs b/tooling/xtask/src/tasks/workflows.rs index 5663ebec247c4025f7cfbae8e9467733e2c7be2d..9151b9c671ef42e3dc54661f80438a4e31aff1e9 100644 --- a/tooling/xtask/src/tasks/workflows.rs +++ b/tooling/xtask/src/tasks/workflows.rs @@ -4,6 +4,8 @@ use gh_workflow::Workflow; use std::fs; use std::path::{Path, PathBuf}; +use crate::tasks::workflow_checks::{self}; + mod after_release; mod autofix_pr; mod bump_patch_version; @@ -87,8 +89,8 @@ impl WorkflowFile { } } -#[derive(PartialEq, Eq)] -enum WorkflowType { +#[derive(PartialEq, Eq, strum::EnumIter)] +pub enum WorkflowType { /// Workflows living in the Zed repository Zed, /// Workflows living in the `zed-extensions/workflows` repository that are @@ -113,7 +115,7 @@ impl WorkflowType { ) } - fn folder_path(&self) -> PathBuf { + pub fn folder_path(&self) -> PathBuf { match self { WorkflowType::Zed => PathBuf::from(".github/workflows"), WorkflowType::ExtensionCi => PathBuf::from("extensions/workflows"), @@ -155,5 +157,5 @@ pub fn run_workflows(_: GenerateWorkflowArgs) -> Result<()> { workflow_file.generate_file()?; } - Ok(()) + workflow_checks::validate(Default::default()) } diff --git a/tooling/xtask/src/tasks/workflows/after_release.rs b/tooling/xtask/src/tasks/workflows/after_release.rs index 3936e3ffb7754d167c6c39f02e17f758bed0c1ae..07ff1fba0d4799c463128362ad4ba996ccf8cea0 100644 --- a/tooling/xtask/src/tasks/workflows/after_release.rs +++ b/tooling/xtask/src/tasks/workflows/after_release.rs @@ -123,7 +123,7 @@ fn publish_winget() -> NamedJob { "X-GitHub-Api-Version" = "2022-11-28" } $body = @{ branch = "master" } | ConvertTo-Json - $uri = "https://api.github.com/repos/${{ github.repository_owner }}/winget-pkgs/merge-upstream" + $uri = "https://api.github.com/repos/$env:GITHUB_REPOSITORY_OWNER/winget-pkgs/merge-upstream" try { Invoke-RestMethod -Uri $uri -Method Post -Headers $headers -Body $body -ContentType "application/json" Write-Host "Successfully synced winget-pkgs fork" diff --git a/tooling/xtask/src/tasks/workflows/autofix_pr.rs b/tooling/xtask/src/tasks/workflows/autofix_pr.rs index c2c89b7cd05394c225c015a6cc83f48bd35b24a4..2779dc2b01fa873bc050be4d873b9a5d502606bd 100644 --- a/tooling/xtask/src/tasks/workflows/autofix_pr.rs +++ b/tooling/xtask/src/tasks/workflows/autofix_pr.rs @@ -55,7 +55,8 @@ fn download_patch_artifact() -> Step { fn run_autofix(pr_number: &WorkflowInput, run_clippy: &WorkflowInput) -> NamedJob { fn checkout_pr(pr_number: &WorkflowInput) -> Step { - named::bash(&format!("gh pr checkout {pr_number}")) + named::bash(r#"gh pr checkout "$PR_NUMBER""#) + .add_env(("PR_NUMBER", pr_number.to_string())) .add_env(("GITHUB_TOKEN", vars::GITHUB_TOKEN)) } @@ -133,7 +134,9 @@ fn run_autofix(pr_number: &WorkflowInput, run_clippy: &WorkflowInput) -> NamedJo fn commit_changes(pr_number: &WorkflowInput, autofix_job: &NamedJob) -> NamedJob { fn checkout_pr(pr_number: &WorkflowInput, token: &StepOutput) -> Step { - named::bash(&format!("gh pr checkout {pr_number}")).add_env(("GITHUB_TOKEN", token)) + named::bash(r#"gh pr checkout "$PR_NUMBER""#) + .add_env(("PR_NUMBER", pr_number.to_string())) + .add_env(("GITHUB_TOKEN", token)) } fn apply_patch() -> Step { diff --git a/tooling/xtask/src/tasks/workflows/cherry_pick.rs b/tooling/xtask/src/tasks/workflows/cherry_pick.rs index eaa786837f84ebf4d4f7e1a579db0c7b4dcc5040..5680bf6b23b85c17e68e531cecadfb31f091520d 100644 --- a/tooling/xtask/src/tasks/workflows/cherry_pick.rs +++ b/tooling/xtask/src/tasks/workflows/cherry_pick.rs @@ -35,7 +35,10 @@ fn run_cherry_pick( channel: &WorkflowInput, token: &StepOutput, ) -> Step { - named::bash(&format!("./script/cherry-pick {branch} {commit} {channel}")) + named::bash(r#"./script/cherry-pick "$BRANCH" "$COMMIT" "$CHANNEL""#) + .add_env(("BRANCH", branch.to_string())) + .add_env(("COMMIT", commit.to_string())) + .add_env(("CHANNEL", channel.to_string())) .add_env(("GIT_COMMITTER_NAME", "Zed Zippy")) .add_env(("GIT_COMMITTER_EMAIL", "hi@zed.dev")) .add_env(("GITHUB_TOKEN", token)) diff --git a/tooling/xtask/src/tasks/workflows/compare_perf.rs b/tooling/xtask/src/tasks/workflows/compare_perf.rs index 1d111acc4f8a4dc47edea6f45c0b93c845b7cda2..74a1fbdc389e2b0dacdf579d9ee96a0366eb5c01 100644 --- a/tooling/xtask/src/tasks/workflows/compare_perf.rs +++ b/tooling/xtask/src/tasks/workflows/compare_perf.rs @@ -29,14 +29,16 @@ pub fn run_perf( crate_name: &WorkflowInput, ) -> NamedJob { fn cargo_perf_test(ref_name: &WorkflowInput, crate_name: &WorkflowInput) -> Step { - named::bash(&format!( - " - if [ -n \"{crate_name}\" ]; then - cargo perf-test -p {crate_name} -- --json={ref_name}; + named::bash( + r#" + if [ -n "$CRATE_NAME" ]; then + cargo perf-test -p "$CRATE_NAME" -- --json="$REF_NAME"; else - cargo perf-test -p vim -- --json={ref_name}; - fi" - )) + cargo perf-test -p vim -- --json="$REF_NAME"; + fi"#, + ) + .add_env(("REF_NAME", ref_name.to_string())) + .add_env(("CRATE_NAME", crate_name.to_string())) } fn install_hyperfine() -> Step { @@ -44,9 +46,9 @@ pub fn run_perf( } fn compare_runs(head: &WorkflowInput, base: &WorkflowInput) -> Step { - named::bash(&format!( - "cargo perf-compare --save=results.md {base} {head}" - )) + named::bash(r#"cargo perf-compare --save=results.md "$BASE" "$HEAD""#) + .add_env(("BASE", base.to_string())) + .add_env(("HEAD", head.to_string())) } named::job( diff --git a/tooling/xtask/src/tasks/workflows/deploy_collab.rs b/tooling/xtask/src/tasks/workflows/deploy_collab.rs index 58212118c7ba4fa6d44d5f29fac671ca6eb5e662..300680f95b880e9adb14dffd2572d80cb08fd63c 100644 --- a/tooling/xtask/src/tasks/workflows/deploy_collab.rs +++ b/tooling/xtask/src/tasks/workflows/deploy_collab.rs @@ -1,5 +1,5 @@ use gh_workflow::{Container, Event, Port, Push, Run, Step, Use, Workflow}; -use indoc::{formatdoc, indoc}; +use indoc::indoc; use crate::tasks::workflows::runners::{self, Platform}; use crate::tasks::workflows::steps::{ @@ -115,9 +115,10 @@ fn deploy(deps: &[&NamedJob]) -> NamedJob { } fn sign_into_kubernetes() -> Step { - named::bash(formatdoc! {r#" - doctl kubernetes cluster kubeconfig save --expiry-seconds 600 {cluster_name} - "#, cluster_name = vars::CLUSTER_NAME}) + named::bash( + r#"doctl kubernetes cluster kubeconfig save --expiry-seconds 600 "$CLUSTER_NAME""#, + ) + .add_env(("CLUSTER_NAME", vars::CLUSTER_NAME)) } fn start_rollout() -> Step { @@ -139,7 +140,7 @@ fn deploy(deps: &[&NamedJob]) -> NamedJob { echo "Deploying collab:$GITHUB_SHA to $ZED_KUBE_NAMESPACE" source script/lib/deploy-helpers.sh - export_vars_for_environment $ZED_KUBE_NAMESPACE + export_vars_for_environment "$ZED_KUBE_NAMESPACE" ZED_DO_CERTIFICATE_ID="$(doctl compute certificate list --format ID --no-header)" export ZED_DO_CERTIFICATE_ID @@ -149,14 +150,14 @@ fn deploy(deps: &[&NamedJob]) -> NamedJob { export ZED_LOAD_BALANCER_SIZE_UNIT=$ZED_COLLAB_LOAD_BALANCER_SIZE_UNIT export DATABASE_MAX_CONNECTIONS=850 envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f - - kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch + kubectl -n "$ZED_KUBE_NAMESPACE" rollout status "deployment/$ZED_SERVICE_NAME" --watch echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}" export ZED_SERVICE_NAME=api export ZED_LOAD_BALANCER_SIZE_UNIT=$ZED_API_LOAD_BALANCER_SIZE_UNIT export DATABASE_MAX_CONNECTIONS=60 envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f - - kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch + kubectl -n "$ZED_KUBE_NAMESPACE" rollout status "deployment/$ZED_SERVICE_NAME" --watch echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}" "#}) } diff --git a/tooling/xtask/src/tasks/workflows/extension_bump.rs b/tooling/xtask/src/tasks/workflows/extension_bump.rs index 746b842f18dfcc8805be9285facefdfa52085b84..8c31de202ee7ac81b5f5e95fb26ec89452fd077c 100644 --- a/tooling/xtask/src/tasks/workflows/extension_bump.rs +++ b/tooling/xtask/src/tasks/workflows/extension_bump.rs @@ -150,7 +150,7 @@ pub(crate) fn compare_versions() -> (Step, StepOutput, StepOutput) { r#" CURRENT_VERSION="$({VERSION_CHECK})" - if [[ "${{{{ github.event_name }}}}" == "pull_request" ]]; then + if [[ "$GITHUB_EVENT_NAME" == "pull_request" ]]; then PR_FORK_POINT="$(git merge-base origin/main HEAD)" git checkout "$PR_FORK_POINT" elif BRANCH_PARENT_SHA="$(git merge-base origin/main origin/zed-zippy-autobump)"; then @@ -258,8 +258,6 @@ fn install_bump_2_version() -> Step { fn bump_version(current_version: &JobOutput, bump_type: &WorkflowInput) -> (Step, StepOutput) { let step = named::bash(formatdoc! {r#" - OLD_VERSION="{current_version}" - BUMP_FILES=("extension.toml") if [[ -f "Cargo.toml" ]]; then BUMP_FILES+=("Cargo.toml") @@ -269,7 +267,7 @@ fn bump_version(current_version: &JobOutput, bump_type: &WorkflowInput) -> (Step --search "version = \"{{current_version}}"\" \ --replace "version = \"{{new_version}}"\" \ --current-version "$OLD_VERSION" \ - --no-configured-files {bump_type} "${{BUMP_FILES[@]}}" + --no-configured-files "$BUMP_TYPE" "${{BUMP_FILES[@]}}" if [[ -f "Cargo.toml" ]]; then cargo update --workspace @@ -280,7 +278,9 @@ fn bump_version(current_version: &JobOutput, bump_type: &WorkflowInput) -> (Step echo "new_version=${{NEW_VERSION}}" >> "$GITHUB_OUTPUT" "# }) - .id("bump-version"); + .id("bump-version") + .add_env(("OLD_VERSION", current_version.to_string())) + .add_env(("BUMP_TYPE", bump_type.to_string())); let new_version = StepOutput::new(&step, "new_version"); (step, new_version) diff --git a/tooling/xtask/src/tasks/workflows/extension_tests.rs b/tooling/xtask/src/tasks/workflows/extension_tests.rs index de4f1dd94267a55dcc3e1555c1a5673ff813ad26..09f0cadf1c8731f8eed4ef1197a7edd05e0d1558 100644 --- a/tooling/xtask/src/tasks/workflows/extension_tests.rs +++ b/tooling/xtask/src/tasks/workflows/extension_tests.rs @@ -1,5 +1,5 @@ use gh_workflow::*; -use indoc::{formatdoc, indoc}; +use indoc::indoc; use crate::tasks::workflows::{ extension_bump::compare_versions, @@ -142,12 +142,14 @@ pub fn check() -> Step { } fn verify_version_did_not_change(version_changed: StepOutput) -> Step { - named::bash(formatdoc! {r#" - if [[ {version_changed} == "true" && "${{{{ github.event_name }}}}" == "pull_request" && "${{{{ github.event.pull_request.user.login }}}}" != "zed-zippy[bot]" ]] ; then + named::bash(indoc! {r#" + if [[ "$VERSION_CHANGED" == "true" && "$GITHUB_EVENT_NAME" == "pull_request" && "$PR_USER_LOGIN" != "zed-zippy[bot]" ]] ; then echo "Version change detected in your change!" echo "Version changes happen in separate PRs and will be performed by the zed-zippy bot" exit 42 fi "# }) + .add_env(("VERSION_CHANGED", version_changed.to_string())) + .add_env(("PR_USER_LOGIN", "${{ github.event.pull_request.user.login }}")) } diff --git a/tooling/xtask/src/tasks/workflows/extension_workflow_rollout.rs b/tooling/xtask/src/tasks/workflows/extension_workflow_rollout.rs index 2ba6069c273e8a3e9a27885595d2ad5380748cdd..6f03ad1521850fb24c5bad7265ebf913228c5077 100644 --- a/tooling/xtask/src/tasks/workflows/extension_workflow_rollout.rs +++ b/tooling/xtask/src/tasks/workflows/extension_workflow_rollout.rs @@ -105,10 +105,8 @@ fn rollout_workflows_to_extension(fetch_repos_job: &NamedJob) -> NamedJob { } fn get_removed_files(prev_commit: &StepOutput) -> (Step, StepOutput) { - let step = named::bash(formatdoc! {r#" - PREV_COMMIT="{prev_commit}" - - if [ "${{{{ matrix.repo }}}}" = "workflows" ]; then + let step = named::bash(indoc::indoc! {r#" + if [ "$MATRIX_REPO" = "workflows" ]; then WORKFLOW_DIR="extensions/workflows" else WORKFLOW_DIR="extensions/workflows/shared" @@ -119,8 +117,8 @@ fn rollout_workflows_to_extension(fetch_repos_job: &NamedJob) -> NamedJob { # Get deleted files (status D) and renamed files (status R - old name needs removal) # Using -M to detect renames, then extracting files that are gone from their original location REMOVED_FILES=$(git diff --name-status -M "$PREV_COMMIT" HEAD -- "$WORKFLOW_DIR" | \ - awk '/^D/ {{ print $2 }} /^R/ {{ print $2 }}' | \ - xargs -I{{}} basename {{}} 2>/dev/null | \ + awk '/^D/ { print $2 } /^R/ { print $2 }' | \ + xargs -I{} basename {} 2>/dev/null | \ tr '\n' ' ' || echo "") REMOVED_FILES=$(echo "$REMOVED_FILES" | xargs) @@ -129,7 +127,9 @@ fn rollout_workflows_to_extension(fetch_repos_job: &NamedJob) -> NamedJob { echo "removed_files=$REMOVED_FILES" >> "$GITHUB_OUTPUT" "#}) .id("calc-changes") - .working_directory("zed"); + .working_directory("zed") + .add_env(("PREV_COMMIT", prev_commit.to_string())) + .add_env(("MATRIX_REPO", "${{ matrix.repo }}")); let removed_files = StepOutput::new(&step, "removed_files"); @@ -137,9 +137,7 @@ fn rollout_workflows_to_extension(fetch_repos_job: &NamedJob) -> NamedJob { } fn sync_workflow_files(removed_files: &StepOutput) -> Step { - named::bash(formatdoc! {r#" - REMOVED_FILES="{removed_files}" - + named::bash(indoc::indoc! {r#" mkdir -p extension/.github/workflows cd extension/.github/workflows @@ -153,12 +151,14 @@ fn rollout_workflows_to_extension(fetch_repos_job: &NamedJob) -> NamedJob { cd - > /dev/null - if [ "${{{{ matrix.repo }}}}" = "workflows" ]; then + if [ "$MATRIX_REPO" = "workflows" ]; then cp zed/extensions/workflows/*.yml extension/.github/workflows/ else cp zed/extensions/workflows/shared/*.yml extension/.github/workflows/ fi "#}) + .add_env(("REMOVED_FILES", removed_files.to_string())) + .add_env(("MATRIX_REPO", "${{ matrix.repo }}")) } fn get_short_sha() -> (Step, StepOutput) { @@ -205,13 +205,16 @@ fn rollout_workflows_to_extension(fetch_repos_job: &NamedJob) -> NamedJob { fn enable_auto_merge(token: &StepOutput) -> Step { named::bash(indoc::indoc! {r#" - PR_NUMBER="${{ steps.create-pr.outputs.pull-request-number }}" if [ -n "$PR_NUMBER" ]; then cd extension gh pr merge "$PR_NUMBER" --auto --squash fi "#}) .add_env(("GH_TOKEN", token.to_string())) + .add_env(( + "PR_NUMBER", + "${{ steps.create-pr.outputs.pull-request-number }}", + )) } let (authenticate, token) = generate_token( diff --git a/tooling/xtask/src/tasks/workflows/publish_extension_cli.rs b/tooling/xtask/src/tasks/workflows/publish_extension_cli.rs index 549b0fdfcfbb8f44b24ac849e2fe3c13bf5acdb0..2269201a2de383bc5ae7147d9e1d08105c540d15 100644 --- a/tooling/xtask/src/tasks/workflows/publish_extension_cli.rs +++ b/tooling/xtask/src/tasks/workflows/publish_extension_cli.rs @@ -28,7 +28,7 @@ fn publish_job() -> NamedJob { } fn upload_binary() -> Step { - named::bash("script/upload-extension-cli ${{ github.sha }}") + named::bash(r#"script/upload-extension-cli "$GITHUB_SHA""#) .add_env(( "DIGITALOCEAN_SPACES_ACCESS_KEY", vars::DIGITALOCEAN_SPACES_ACCESS_KEY, @@ -60,7 +60,7 @@ fn update_sha_in_zed(publish_job: &NamedJob) -> NamedJob { fn replace_sha() -> Step { named::bash(indoc! {r#" - sed -i "s/ZED_EXTENSION_CLI_SHA: &str = \"[a-f0-9]*\"/ZED_EXTENSION_CLI_SHA: \&str = \"${{ github.sha }}\"/" \ + sed -i "s/ZED_EXTENSION_CLI_SHA: &str = \"[a-f0-9]*\"/ZED_EXTENSION_CLI_SHA: \&str = \"$GITHUB_SHA\"/" \ tooling/xtask/src/tasks/workflows/extension_tests.rs "#}) } @@ -139,7 +139,7 @@ fn update_sha_in_extensions(publish_job: &NamedJob) -> NamedJob { fn replace_sha() -> Step { named::bash(indoc! {r#" - sed -i "s/ZED_EXTENSION_CLI_SHA: [a-f0-9]*/ZED_EXTENSION_CLI_SHA: ${{ github.sha }}/" \ + sed -i "s/ZED_EXTENSION_CLI_SHA: [a-f0-9]*/ZED_EXTENSION_CLI_SHA: $GITHUB_SHA/" \ .github/workflows/ci.yml "#}) } @@ -191,7 +191,7 @@ fn create_pull_request_extensions( fn get_short_sha() -> (Step, StepOutput) { let step = named::bash(indoc::indoc! {r#" - echo "sha_short=$(echo "${{ github.sha }}" | cut -c1-7)" >> "$GITHUB_OUTPUT" + echo "sha_short=$(echo "$GITHUB_SHA" | cut -c1-7)" >> "$GITHUB_OUTPUT" "#}) .id("short-sha"); diff --git a/tooling/xtask/src/tasks/workflows/release.rs b/tooling/xtask/src/tasks/workflows/release.rs index 8241fc58f0821b950e32ee9b1a42473975ec008d..2963bbec24301b85b345461a6ea532a9ac3421c5 100644 --- a/tooling/xtask/src/tasks/workflows/release.rs +++ b/tooling/xtask/src/tasks/workflows/release.rs @@ -272,18 +272,55 @@ pub(crate) fn push_release_update_notification( test_jobs: &[&NamedJob], bundle_jobs: &ReleaseBundleJobs, ) -> NamedJob { - let all_job_names = test_jobs - .into_iter() + fn env_name(name: &str) -> String { + format!("RESULT_{}", name.to_uppercase()) + } + + let all_job_names: Vec<&str> = test_jobs + .iter() .map(|j| j.name.as_ref()) - .chain(bundle_jobs.jobs().into_iter().map(|j| j.name.as_ref())); + .chain(bundle_jobs.jobs().into_iter().map(|j| j.name.as_ref())) + .collect(); + + let env_entries = [ + ( + "DRAFT_RESULT".into(), + format!("${{{{ needs.{}.result }}}}", create_draft_release_job.name), + ), + ( + "UPLOAD_RESULT".into(), + format!("${{{{ needs.{}.result }}}}", upload_assets_job.name), + ), + ( + "VALIDATE_RESULT".into(), + format!("${{{{ needs.{}.result }}}}", validate_assets_job.name), + ), + ( + "AUTO_RELEASE_RESULT".into(), + format!("${{{{ needs.{}.result }}}}", auto_release_preview.name), + ), + ("RUN_URL".into(), CURRENT_ACTION_RUN_URL.to_string()), + ] + .into_iter() + .chain( + all_job_names + .iter() + .map(|name| (env_name(name), format!("${{{{ needs.{name}.result }}}}"))), + ); + + let failure_checks = all_job_names + .iter() + .map(|name| { + format!( + "if [ \"${env_name}\" == \"failure\" ];then FAILED_JOBS=\"$FAILED_JOBS {name}\"; fi", + env_name = env_name(name) + ) + }) + .collect::>() + .join("\n "); let notification_script = formatdoc! {r#" - DRAFT_RESULT="${{{{ needs.{draft_job}.result }}}}" - UPLOAD_RESULT="${{{{ needs.{upload_job}.result }}}}" - VALIDATE_RESULT="${{{{ needs.{validate_job}.result }}}}" - AUTO_RELEASE_RESULT="${{{{ needs.{auto_release_job}.result }}}}" TAG="$GITHUB_REF_NAME" - RUN_URL="{run_url}" if [ "$DRAFT_RESULT" == "failure" ]; then echo "❌ Draft release creation failed for $TAG: $RUN_URL" @@ -319,19 +356,6 @@ pub(crate) fn push_release_update_notification( fi fi "#, - draft_job = create_draft_release_job.name, - upload_job = upload_assets_job.name, - validate_job = validate_assets_job.name, - auto_release_job = auto_release_preview.name, - run_url = CURRENT_ACTION_RUN_URL, - failure_checks = all_job_names - .into_iter() - .map(|name: &str| format!( - "if [ \"${{{{ needs.{name}.result }}}}\" == \"failure\" ];\ - then FAILED_JOBS=\"$FAILED_JOBS {name}\"; fi" - )) - .collect::>() - .join("\n "), }; let mut all_deps: Vec<&NamedJob> = vec![ @@ -347,7 +371,10 @@ pub(crate) fn push_release_update_notification( .runs_on(runners::LINUX_SMALL) .cond(Expression::new("always()")); - for step in notify_slack(MessageType::Evaluated(notification_script)) { + for step in notify_slack(MessageType::Evaluated { + script: notification_script, + env: env_entries.collect(), + }) { job = job.add_step(step); } named::job(job) @@ -368,14 +395,17 @@ pub(crate) fn notify_on_failure(deps: &[&NamedJob]) -> NamedJob { pub(crate) enum MessageType { Static(String), - Evaluated(String), + Evaluated { + script: String, + env: Vec<(String, String)>, + }, } fn notify_slack(message: MessageType) -> Vec> { match message { MessageType::Static(message) => vec![send_slack_message(message)], - MessageType::Evaluated(expression) => { - let (generate_step, generated_message) = generate_slack_message(expression); + MessageType::Evaluated { script, env } => { + let (generate_step, generated_message) = generate_slack_message(script, env); vec![ generate_step, @@ -385,26 +415,32 @@ fn notify_slack(message: MessageType) -> Vec> { } } -fn generate_slack_message(expression: String) -> (Step, StepOutput) { +fn generate_slack_message( + expression: String, + env: Vec<(String, String)>, +) -> (Step, StepOutput) { let script = formatdoc! {r#" MESSAGE=$({expression}) echo "message=$MESSAGE" >> "$GITHUB_OUTPUT" "# }; - let generate_step = named::bash(&script) + let mut generate_step = named::bash(&script) .id("generate-webhook-message") .add_env(("GH_TOKEN", Context::github().token())); + for (name, value) in env { + generate_step = generate_step.add_env((name, value)); + } + let output = StepOutput::new(&generate_step, "message"); (generate_step, output) } fn send_slack_message(message: String) -> Step { - let script = formatdoc! {r#" - curl -X POST -H 'Content-type: application/json'\ - --data '{{"text":"{message}"}}' "$SLACK_WEBHOOK" - "# - }; - named::bash(&script).add_env(("SLACK_WEBHOOK", vars::SLACK_WEBHOOK_WORKFLOW_FAILURES)) + named::bash( + r#"curl -X POST -H 'Content-type: application/json' --data "$(jq -n --arg text "$SLACK_MESSAGE" '{"text": $text}')" "$SLACK_WEBHOOK""# + ) + .add_env(("SLACK_WEBHOOK", vars::SLACK_WEBHOOK_WORKFLOW_FAILURES)) + .add_env(("SLACK_MESSAGE", message)) } diff --git a/tooling/xtask/src/tasks/workflows/run_agent_evals.rs b/tooling/xtask/src/tasks/workflows/run_agent_evals.rs index e83d3a07f079c1f40360f413f3007813dbe552ce..521f419d9b317c42a1106ebe8500ccf0a3f494ec 100644 --- a/tooling/xtask/src/tasks/workflows/run_agent_evals.rs +++ b/tooling/xtask/src/tasks/workflows/run_agent_evals.rs @@ -123,7 +123,7 @@ fn cron_unit_evals() -> NamedJob { const UNIT_EVAL_MODELS: &[&str] = &[ "anthropic/claude-sonnet-4-5-latest", "anthropic/claude-opus-4-5-latest", - "google/gemini-3-pro", + "google/gemini-3.1-pro", "openai/gpt-5", ]; diff --git a/tooling/xtask/src/tasks/workflows/run_tests.rs b/tooling/xtask/src/tasks/workflows/run_tests.rs index d617dda5af0ad51d0e86cfeeb69a035a53c07663..38ba1bd32945f9ba8ee1e08ebc994a1132fb07f2 100644 --- a/tooling/xtask/src/tasks/workflows/run_tests.rs +++ b/tooling/xtask/src/tasks/workflows/run_tests.rs @@ -6,7 +6,10 @@ use indexmap::IndexMap; use indoc::formatdoc; use crate::tasks::workflows::{ - steps::{CommonJobConditions, repository_owner_guard_expression, use_clang}, + steps::{ + CommonJobConditions, cache_rust_dependencies_namespace, repository_owner_guard_expression, + use_clang, + }, vars::{self, PathCondition}, }; @@ -116,7 +119,7 @@ fn orchestrate_impl(rules: &[&PathCondition], include_package_filter: bool) -> N git fetch origin "$GITHUB_BASE_REF" --depth=350 COMPARE_REV="$(git merge-base "origin/${GITHUB_BASE_REF}" HEAD)" fi - CHANGED_FILES="$(git diff --name-only "$COMPARE_REV" ${{ github.sha }})" + CHANGED_FILES="$(git diff --name-only "$COMPARE_REV" "$GITHUB_SHA")" check_pattern() { local output_name="$1" @@ -240,15 +243,20 @@ pub fn tests_pass(jobs: &[NamedJob]) -> NamedJob { "#}); + let env_entries: Vec<_> = jobs + .iter() + .map(|job| { + let env_name = format!("RESULT_{}", job.name.to_uppercase()); + let env_value = format!("${{{{ needs.{}.result }}}}", job.name); + (env_name, env_value) + }) + .collect(); + script.push_str( &jobs .iter() - .map(|job| { - format!( - "check_result \"{}\" \"${{{{ needs.{}.result }}}}\"", - job.name, job.name - ) - }) + .zip(env_entries.iter()) + .map(|(job, (env_name, _))| format!("check_result \"{}\" \"${}\"", job.name, env_name)) .collect::>() .join("\n"), ); @@ -263,7 +271,13 @@ pub fn tests_pass(jobs: &[NamedJob]) -> NamedJob { .collect::>(), ) .cond(repository_owner_guard_expression(true)) - .add_step(named::bash(&script)); + .add_step( + env_entries + .into_iter() + .fold(named::bash(&script), |step, env_item| { + step.add_env(env_item) + }), + ); named::job(job) } @@ -646,9 +660,10 @@ pub(crate) fn check_scripts() -> NamedJob { } fn run_actionlint() -> Step { - named::bash(indoc::indoc! {r#" - ${{ steps.get_actionlint.outputs.executable }} -color - "#}) + named::bash(r#""$ACTIONLINT_BIN" -color"#).add_env(( + "ACTIONLINT_BIN", + "${{ steps.get_actionlint.outputs.executable }}", + )) } fn run_shellcheck() -> Step { @@ -673,6 +688,7 @@ pub(crate) fn check_scripts() -> NamedJob { .add_step(run_shellcheck()) .add_step(download_actionlint().id("get_actionlint")) .add_step(run_actionlint()) + .add_step(cache_rust_dependencies_namespace()) .add_step(check_xtask_workflows()), ) } diff --git a/tooling/xtask/src/tasks/workflows/steps.rs b/tooling/xtask/src/tasks/workflows/steps.rs index 9e54452424dba36d64a209c71b281e3b72eaafc8..4d17be81322277d0093de5d547bf4f0849e38dc3 100644 --- a/tooling/xtask/src/tasks/workflows/steps.rs +++ b/tooling/xtask/src/tasks/workflows/steps.rs @@ -503,9 +503,8 @@ pub mod named { } pub fn git_checkout(ref_name: &dyn std::fmt::Display) -> Step { - named::bash(&format!( - "git fetch origin {ref_name} && git checkout {ref_name}" - )) + named::bash(r#"git fetch origin "$REF_NAME" && git checkout "$REF_NAME""#) + .add_env(("REF_NAME", ref_name.to_string())) } pub fn authenticate_as_zippy() -> (Step, StepOutput) { diff --git a/typos.toml b/typos.toml index 6f76cc75d25add39d841c07bbde82f93514adac5..863fea3822d62a51f737c3d7fa87a4c198710cfa 100644 --- a/typos.toml +++ b/typos.toml @@ -4,6 +4,9 @@ ignore-hidden = false extend-exclude = [ ".git/", + # Typewriter model names used for agent branch names aren't typos. + "crates/agent_ui/src/branch_names.rs", + # Contributor names aren't typos. ".mailmap", @@ -42,6 +45,8 @@ extend-exclude = [ "crates/gpui_windows/src/window.rs", # Some typos in the base mdBook CSS. "docs/theme/css/", + # Automatically generated JS. + "docs/theme/c15t@*.js", # Spellcheck triggers on `|Fixe[sd]|` regex part. "script/danger/dangerfile.ts", # Eval examples for prompts and criteria